源码分析RestTemplate及自定义请求和拦截器

star2017 1年前 ⋅ 403 阅读

Saas 项目分布式微服务架构,服务调用使用的是 RestTemplate,并且对 RestTemplateRequest 请求进行了自定义,做个记录。

自定义 Request 有很多作用。例如自定义请求实现安全认证,自定义请求拦截器实现负载均衡或请求代理等,可以非常灵活的做些定制化。

RestTemplate 相关文章:Spring Boot 2实践系列(二十一):RestTemplate 远程调用 REST 服务Spring Cloud系列(四):客户端负载均衡 Ribbon

源码分析

RestTemplate 不是一个独立的类,直接父类是 InterceptingHttpAccessor,顶级父类是 HttpAccessor

  • InterceptingHttpAccessor 是 RestTemplate 和其它 HTT 访问网关助手的基类,作用是将 拦截器 相关属性添加到父类 HttpAccessor 的公共属性中。
  • HttpAccessorInterceptingHttpAccessor 的直接父类,HttpAccessor 同样是 RestTemplate 和其它 HTT 访问网关助手的基类,作用是定义用于操作的公共属性(例如 ClientHttpRequestFactory)。

ClientHttpRequest

RestTemplate操作

RestTemplate 所有执行请求的操作调用的是 execute(),最终调用的都是 doExecute()方法,在该方法里创建 Http Request。

@Override
@Nullable
public <T> T getForObject(String url, Class<T> responseType, Object... uriVariables) throws RestClientException {
    RequestCallback requestCallback = acceptHeaderRequestCallback(responseType);
    HttpMessageConverterExtractor<T> responseExtractor =
        new HttpMessageConverterExtractor<>(responseType, getMessageConverters(), logger);
    // 调用 execute() 方法
    return execute(url, HttpMethod.GET, requestCallback, responseExtractor, uriVariables);
}

@Override
@Nullable
public <T> T execute(String url, HttpMethod method, @Nullable RequestCallback requestCallback, @Nullable ResponseExtractor<T> responseExtractor, Object... uriVariables) throws RestClientException {
    URI expanded = getUriTemplateHandler().expand(url, uriVariables);
    // 调用 doExecut 方法
    return doExecute(expanded, method, requestCallback, responseExtractor);
}

@Nullable
protected <T> T doExecute(URI url, @Nullable HttpMethod method, @Nullable RequestCallback requestCallback,
                          @Nullable ResponseExtractor<T> responseExtractor) throws RestClientException {

    Assert.notNull(url, "URI is required");
    Assert.notNull(method, "HttpMethod is required");
    ClientHttpResponse response = null;
    try {
        //创建请求
        ClientHttpRequest request = createRequest(url, method);
        if (requestCallback != null) {
            //请求回调
            requestCallback.doWithRequest(request);
        }
        // 执行请求
        response = request.execute();
        // 处理响应
        handleResponse(url, method, response);
        // 反回响应类型的数据
        return (responseExtractor != null ? responseExtractor.extractData(response) : null);
    }
    catch (IOException ex) {
        String resource = url.toString();
        String query = url.getRawQuery();
        resource = (query != null ? resource.substring(0, resource.indexOf('?')) : resource);
        throw new ResourceAccessException("I/O error on " + method.name() +
                                          " request for \"" + resource + "\": " + ex.getMessage(), ex);
    }
    finally {
        if (response != null) {
            response.close();
        }
    }
}

ClientHttpRequest

在 RestTemplate 的 doExecute() 方法中创建了 request(ClientHttpRequest request = createRequest(url, method))。

创建 request 调用的是顶级抽象父类的 createRequest() 方法

public abstract class HttpAccessor {

protected final Log logger = HttpLogging.forLogName(getClass());
//默认的 request factory:ClientHttpRequestFactory
private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();

/**
 * 设置 request factory 工厂
 * 这个方法在自定义请求时会用到
 */
public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
    Assert.notNull(requestFactory, "ClientHttpRequestFactory must not be null");
    this.requestFactory = requestFactory;
}

/**
 * 返回 request 工厂
 */
public ClientHttpRequestFactory getRequestFactory() {
    return this.requestFactory;
}

/**
 * 通过 ClientHttpRequestFactory 创建 ClientHttpRequest
 */
protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException {
       // 获取请求工厂创建 Request
       // 注意:这里调的是子类 InterceptingHttpAccessor 重写的 getRequestFactory() 方法
       // 实际最终调的仍是顶级父类 HttpAccessor 的方法
    ClientHttpRequest request = getRequestFactory().createRequest(url, method);
    if (logger.isDebugEnabled()) {
        logger.debug("HTTP " + method.name() + " " + url);
    }
    return request;
}
}

InterceptingHttpAccessor

InterceptingHttpAccessor 是 RestTemplate 的直接父类,提供拦截器相关的设置。

public abstract class InterceptingHttpAccessor extends HttpAccessor {

    private final List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();

    @Nullable
    private volatile ClientHttpRequestFactory interceptingRequestFactory;

    /**
     * 设置拦截器
     */
    public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
        if (this.interceptors != interceptors) {
            this.interceptors.clear();
            this.interceptors.addAll(interceptors);
            AnnotationAwareOrderComparator.sort(this.interceptors);
        }
    }

    /**
     * 获取拦截器
     */
    public List<ClientHttpRequestInterceptor> getInterceptors() {
        return this.interceptors;
    }

    /**
     * 设置请求工厂
     */
    @Override
    public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
        super.setRequestFactory(requestFactory);
        this.interceptingRequestFactory = null;
    }

    /**
     * 重写抽象父类(HttpAccessor)的方法
     */
    @Override
    public ClientHttpRequestFactory getRequestFactory() {
        List<ClientHttpRequestInterceptor> interceptors = getInterceptors();
        if (!CollectionUtils.isEmpty(interceptors)) {
            //如果存在拦截器
            ClientHttpRequestFactory factory = this.interceptingRequestFactory;
            if (factory == null) {
                factory = new InterceptingClientHttpRequestFactory(super.getRequestFactory(), interceptors);
                this.interceptingRequestFactory = factory;
            }
            return factory;
        }
        else {
            //没有拦截器, 调父类(HttpAccessor)的方法
            return super.getRequestFactory();
        }
    }
}

InterceptingClientHttpRequestFactory

public class InterceptingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper {

    private final List<ClientHttpRequestInterceptor> interceptors;

    /**
     * 构造方法赋值
     */
    public InterceptingClientHttpRequestFactory(ClientHttpRequestFactory requestFactory,
            @Nullable List<ClientHttpRequestInterceptor> interceptors) {

        super(requestFactory);
        this.interceptors = (interceptors != null ? interceptors : Collections.emptyList());
    }

    /**
     * 创建请求
     */
    @Override
    protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) {
        return new InterceptingClientHttpRequest(requestFactory, this.interceptors, uri, httpMethod);
    }
}

SimpleClientHttpRequestFactory

SimpleClientHttpRequestFactory 是 ClientHttpRequestFactory 默认的简单实现。

public class SimpleClientHttpRequestFactory implements ClientHttpRequestFactory, AsyncClientHttpRequestFactory {

    private static final int DEFAULT_CHUNK_SIZE = 4096;


    @Nullable
    private Proxy proxy;

    private boolean bufferRequestBody = true;

    private int chunkSize = DEFAULT_CHUNK_SIZE;

    private int connectTimeout = -1;

    private int readTimeout = -1;

    private boolean outputStreaming = true;

    @Nullable
    private AsyncListenableTaskExecutor taskExecutor;


    /**
     * 设置 Request Factory 需要的代理
     */
    public void setProxy(Proxy proxy) {
        this.proxy = proxy;
    }

    /**
     * 请求工厂是否应在内部缓存 ClientHttpRequest#getBody() request body
     * 默认为 true,当使用 POST or PUT 发送大量数据时建议改为 false,以免耗尽内存
     * 
     */
    public void setBufferRequestBody(boolean bufferRequestBody) {
        this.bufferRequestBody = bufferRequestBody;
    }

    /**
     * 当本地没有缓存请求正文时,设置每个块要写入的字节数
     */
    public void setChunkSize(int chunkSize) {
        this.chunkSize = chunkSize;
    }

    /**
     * 设置链接超时(耗秒)
     */
    public void setConnectTimeout(int connectTimeout) {
        this.connectTimeout = connectTimeout;
    }

    /**
     * 设置读取超时(耗秒)
     */
    public void setReadTimeout(int readTimeout) {
        this.readTimeout = readTimeout;
    }

    /**
     * 是否设置为输出模式,默认 true。
     */
    public void setOutputStreaming(boolean outputStreaming) {
        this.outputStreaming = outputStreaming;
    }

    /**
     * 为请求工厂设置任务执行器, 在创建异步请求时需要
     */
    public void setTaskExecutor(AsyncListenableTaskExecutor taskExecutor) {
        this.taskExecutor = taskExecutor;
    }

    /**
     * 创建请求
     * 底层调的是 java.net 的 HttpURLConnection,继承自 URLConnection
     */
    @Override
    public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
        HttpURLConnection connection = openConnection(uri.toURL(), this.proxy);
        prepareConnection(connection, httpMethod.name());

        if (this.bufferRequestBody) {
            return new SimpleBufferingClientHttpRequest(connection, this.outputStreaming);
        }
        else {
            return new SimpleStreamingClientHttpRequest(connection, this.chunkSize, this.outputStreaming);
        }
    }

    /**
     * 创建异步请求
     */
    @Override
    public AsyncClientHttpRequest createAsyncRequest(URI uri, HttpMethod httpMethod) throws IOException {
        Assert.state(this.taskExecutor != null, "Asynchronous execution requires TaskExecutor to be set");

        HttpURLConnection connection = openConnection(uri.toURL(), this.proxy);
        prepareConnection(connection, httpMethod.name());

        if (this.bufferRequestBody) {
            return new SimpleBufferingAsyncClientHttpRequest(
                    connection, this.outputStreaming, this.taskExecutor);
        }
        else {
            return new SimpleStreamingAsyncClientHttpRequest(
                    connection, this.chunkSize, this.outputStreaming, this.taskExecutor);
        }
    }

    /**
     * 打开并返回URL的连接。
     * 默认实现使用了 setProxy(java.net.Proxy) proxy, 如果 proxy 存在的话
     */
    protected HttpURLConnection openConnection(URL url, @Nullable Proxy proxy) throws IOException {
        URLConnection urlConnection = (proxy != null ? url.openConnection(proxy) : url.openConnection());
        if (!HttpURLConnection.class.isInstance(urlConnection)) {
            throw new IllegalStateException("HttpURLConnection required for [" + url + "] but got: " + urlConnection);
        }
        return (HttpURLConnection) urlConnection;
    }

    /**
     * 连接准备:设置一些参数,如 connectTimeout、readTimeout, Http 请求方法
     */
    protected void prepareConnection(HttpURLConnection connection, String httpMethod) throws IOException {
        if (this.connectTimeout >= 0) {
            //连接超时
            connection.setConnectTimeout(this.connectTimeout);
        }
        if (this.readTimeout >= 0) {
            //读超时
            connection.setReadTimeout(this.readTimeout);
        }
        //开启从连接读取数据
        connection.setDoInput(true);

        if ("GET".equals(httpMethod)) {
            //开启自动重定向
            connection.setInstanceFollowRedirects(true);
        }
        else {
            connection.setInstanceFollowRedirects(false);
        }

        if ("POST".equals(httpMethod) || "PUT".equals(httpMethod) ||
                "PATCH".equals(httpMethod) || "DELETE".equals(httpMethod)) {
            //开启向连接写入数据
            connection.setDoOutput(true);
        }
        else {
            connection.setDoOutput(false);
        }
        // 设置请求类型
        connection.setRequestMethod(httpMethod);
    }
}

ClientHttpRequestFactory

ClientHttpRequestFactory 是个 request 抽象工厂接口,支持多种实现和自定义实现,如 默认的 SimpleClientHttpRequestFactory,OkHttp3ClientHttpRequestFactory,HttpComponentsClientHttpRequestFactory 等。

  1. RestTemplate 提供了传入 ClientHttpRequestFactory 类型参数的构造方法,为自定义 Request 创造了条件。

    即自定义 ClientHttpRequestFactory 实现,重写 createRequest() 方法。

    public RestTemplate(ClientHttpRequestFactory requestFactory) {
        this();
        //最终调的是父类 HttpAccessor 的 setRequestFactory() 方法,设置了 requestFactory 属性
        //覆盖了默认的 new SimpleClientHttpRequestFactory() 
        setRequestFactory(requestFactory);
    }
    

    RestTemplate 的直接父类是 InterceptingHttpAccessor,为请求拦截器提供支持,顶级父类是 HttpAccessor。

    InterceptingHttpAccessor 重写了 HttpAccessor 的 setRequestFactory() 方法,但仍调的是父类的方法。

    @Override
    public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
        // 调父类
        super.setRequestFactory(requestFactory);
        this.interceptingRequestFactory = null;
    }
    

    HttpAccessor 的 setRequestFactory() 方法来覆盖默认的实现

    public abstract class HttpAccessor {
        protected final Log logger = HttpLogging.forLogName(getClass());
    
        private ClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
    
        public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
            Assert.notNull(requestFactory, "ClientHttpRequestFactory must not be null");
            this.requestFactory = requestFactory;
        }
    
        //........省略.........
    }
    
  2. RestTemplateBuilder 链式调用提供了 requestFactory() 方法来传入特定的 ClientHttpRequestFactory。

    public RestTemplateBuilder requestFactory(Class<? extends ClientHttpRequestFactory> requestFactory) {
        Assert.notNull(requestFactory, "RequestFactory must not be null");
        return requestFactory(() -> createRequestFactory(requestFactory));
    }
    
    /**
     * 获取传入的 ClientHttpRequestFactory 实例
     */
    private ClientHttpRequestFactory createRequestFactory(Class<? extends ClientHttpRequestFactory> requestFactory) {
        try {
            //通过返射调用构造方法创建实例
            Constructor<?> constructor = requestFactory.getDeclaredConstructor();
            constructor.setAccessible(true);
            return (ClientHttpRequestFactory) constructor.newInstance();
        }
        catch (Exception ex) {
            throw new IllegalStateException(ex);
        }
    }
    
    /**
     * 返回设置了属性的 RestTemplateBuilder 实例
     * 注意传入的 requestFactorySupplier
     * (Supplier 函数型接口,供应数据)
     */
    public RestTemplateBuilder requestFactory(Supplier<ClientHttpRequestFactory> requestFactorySupplier) {
        Assert.notNull(requestFactorySupplier, "RequestFactory Supplier must not be null");
        return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, this.messageConverters,
                                       requestFactorySupplier, this.uriTemplateHandler, this.errorHandler, this.basicAuthentication,
                                       this.restTemplateCustomizers, this.requestFactoryCustomizer, this.interceptors);
    }
    
    /**
     * 创建 RestTemplate 实例
     */
    public RestTemplate build() {
        return build(RestTemplate.class);
    }
    /**
     * 创建指定类型的 RestTemplate
     */
    public <T extends RestTemplate> T build(Class<T> restTemplateClass) {
        return configure(BeanUtils.instantiateClass(restTemplateClass));
    }
    
    /**
     * 配置 RestTemplate
     * messageConverters,errorHandler ....
     * 
     */
    public <T extends RestTemplate> T configure(T restTemplate) {
        //配置 ClientHttpRequestFactory
        configureRequestFactory(restTemplate);
        if (!CollectionUtils.isEmpty(this.messageConverters)) {
            restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
        }
        if (this.uriTemplateHandler != null) {
            restTemplate.setUriTemplateHandler(this.uriTemplateHandler);
        }
        if (this.errorHandler != null) {
            restTemplate.setErrorHandler(this.errorHandler);
        }
        if (this.rootUri != null) {
            RootUriTemplateHandler.addTo(restTemplate, this.rootUri);
        }
        if (this.basicAuthentication != null) {
            restTemplate.getInterceptors().add(this.basicAuthentication);
        }
        //添加拦截器
        restTemplate.getInterceptors().addAll(this.interceptors);
        if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) {
            for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) {
                customizer.customize(restTemplate);
            }
        }
        return restTemplate;
    }
    
    /**
     * 配置 ClientHttpRequestFactory
     */
    private void configureRequestFactory(RestTemplate restTemplate) {
        ClientHttpRequestFactory requestFactory = null;
        if (this.requestFactorySupplier != null) {
            //获取传入的 ClientHttpRequestFactory
            requestFactory = this.requestFactorySupplier.get();
        }
        else if (this.detectRequestFactory) {
            requestFactory = new ClientHttpRequestFactorySupplier().get();
        }
        if (requestFactory != null) {
            if (this.requestFactoryCustomizer != null) {
                this.requestFactoryCustomizer.accept(requestFactory);
            }
            restTemplate.setRequestFactory(requestFactory);
        }
    }
    

ClientHttpRequestInterceptor

ClientHttpRequestInterceptor 是函数型接口,用于拦截客户端的 HTTP 请求,此接口的实现可以被 RestTemplate # setInterceptors 方法注册,用于修改输出输出。

@FunctionalInterface
public interface ClientHttpRequestInterceptor {

    /**
     * 拦截请求,并返回响应。
     * ClientHttpRequestExecution 允许拦截器将请求和响应传递给链中的下一个实体
     */
    ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution)
            throws IOException;

}

InterceptingHttpAccessor

RestTemplate 的抽象父类 InterceptingHttpAccessor 提供了需传入 ClientHttpRequestInterceptor 类型参数的 setInterceptors() 方法用于设置自定义的拦截器,即自定义 HTTP 请求拦截器需实现 ClientHttpRequestInterceptor 接口。

public abstract class InterceptingHttpAccessor extends HttpAccessor {

    private final List<ClientHttpRequestInterceptor> interceptors = new ArrayList<>();

    @Nullable
    private volatile ClientHttpRequestFactory interceptingRequestFactory;


    /**
     * 设置 ClientHttpRequestInterceptor 类型的拦载器
     * 在 ClientHttpRequestFactory 的 getRequestFactory() 方法中用到
     */
    public void setInterceptors(List<ClientHttpRequestInterceptor> interceptors) {
        // Take getInterceptors() List as-is when passed in here
        if (this.interceptors != interceptors) {
            this.interceptors.clear();
            this.interceptors.addAll(interceptors);
            AnnotationAwareOrderComparator.sort(this.interceptors);
        }
    }

    /**
     * 返回拦截器
     */
    public List<ClientHttpRequestInterceptor> getInterceptors() {
        return this.interceptors;
    }

    /**
     * 设置请求工厂 RequestFactory
     */
    @Override
    public void setRequestFactory(ClientHttpRequestFactory requestFactory) {
        super.setRequestFactory(requestFactory);
        this.interceptingRequestFactory = null;
    }

    /**
     * 获取 ClientHttpRequestFactory 类型实例
     * 如果存在拦截器,则创建 InterceptingClientHttpRequestFactory 请求工厂实例,
     * 内部实际调的仍是顶级父类 HttpAccessor 的 getRequestFactory() 方法
     */
    @Override
    public ClientHttpRequestFactory getRequestFactory() {
        List<ClientHttpRequestInterceptor> interceptors = getInterceptors();
        if (!CollectionUtils.isEmpty(interceptors)) {
            ClientHttpRequestFactory factory = this.interceptingRequestFactory;
            if (factory == null) {
                //传入 interceptors,调用父类的
                factory = new InterceptingClientHttpRequestFactory(super.getRequestFactory(), interceptors);
                this.interceptingRequestFactory = factory;
            }
            return factory;
        }
        else {
            return super.getRequestFactory();
        }
    }

}

InterceptingClientHttpRequestFactory

带拦截器的 HTTP Request Factory,创建包含拦截器的请求。

public class InterceptingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper {

    private final List<ClientHttpRequestInterceptor> interceptors;

    /**
     * 创建带拦截器的 HTTP 请求工厂 
     */
    public InterceptingClientHttpRequestFactory(ClientHttpRequestFactory requestFactory,
            @Nullable List<ClientHttpRequestInterceptor> interceptors) {

        super(requestFactory);
        this.interceptors = (interceptors != null ? interceptors : Collections.emptyList());
    }

    /**
     * 创建带拦截器的 HTTP 的请求
     */
    @Override
    protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) {
        return new InterceptingClientHttpRequest(requestFactory, this.interceptors, uri, httpMethod);
    }
}

InterceptingClientHttpRequest

这是一个支持请求拦截器的 Request

class InterceptingClientHttpRequest extends AbstractBufferingClientHttpRequest {

    private final ClientHttpRequestFactory requestFactory;
    //拦载
    private final List<ClientHttpRequestInterceptor> interceptors;

    private HttpMethod method;

    private URI uri;


    protected InterceptingClientHttpRequest(ClientHttpRequestFactory requestFactory,
            List<ClientHttpRequestInterceptor> interceptors, URI uri, HttpMethod method) {

        this.requestFactory = requestFactory;
        this.interceptors = interceptors;
        this.method = method;
        this.uri = uri;
    }


    @Override
    public HttpMethod getMethod() {
        return this.method;
    }

    @Override
    public String getMethodValue() {
        return this.method.name();
    }

    @Override
    public URI getURI() {
        return this.uri;
    }

    @Override
    protected final ClientHttpResponse executeInternal(HttpHeaders headers, byte[] bufferedOutput) throws IOException {
        //内部类
        InterceptingRequestExecution requestExecution = new InterceptingRequestExecution();
        //执行请求
        return requestExecution.execute(this, bufferedOutput);
    }

    /**
     * 内部类, execute() 方法执行请求
     */
    private class InterceptingRequestExecution implements ClientHttpRequestExecution {
        //请求拦截器的迭代器
        private final Iterator<ClientHttpRequestInterceptor> iterator;

        public InterceptingRequestExecution() {
            this.iterator = interceptors.iterator();
        }

        @Override
        public ClientHttpResponse execute(HttpRequest request, byte[] body) throws IOException {
            if (this.iterator.hasNext()) {
                // ClientHttpRequestInterceptor 是个接口
                // 如果存在拦截器, 则调用请求拦截器的 intercept() 方法
                ClientHttpRequestInterceptor nextInterceptor = this.iterator.next();
                // 从这里可以看出, 自定义的的请求拦截器还需要实现 intercept() 方法
                // 在自定义重写的 intercept() 方法中, 使用的是 this 对象 调用 execute() 来发送请求
                // 注意:这里传入了一个 this 对象,存在遍历调用直到所有拦截器全部处理完,再执行 else 的操作
                return nextInterceptor.intercept(request, body, this);
            }
            else {
                HttpMethod method = request.getMethod();
                Assert.state(method != null, "No standard HTTP method");
                //创建请求
                ClientHttpRequest delegate = requestFactory.createRequest(request.getURI(), method);
                request.getHeaders().forEach((key, value) -> delegate.getHeaders().addAll(key, value));
                if (body.length > 0) {
                    if (delegate instanceof StreamingHttpOutputMessage) {
                        StreamingHttpOutputMessage streamingOutputMessage = (StreamingHttpOutputMessage) delegate;
                        streamingOutputMessage.setBody(outputStream -> StreamUtils.copy(body, outputStream));
                    }
                    else {
                        StreamUtils.copy(body, delegate.getBody());
                    }
                }
                //执行请求
                return delegate.execute();
            }
        }
    }
}

示例:Spring 提供的 BasicAuthenticationInterceptor 实现认证拦截,BasicAuthenticationInterceptor 实现 ClientHttpRequestInterceptor 接口,重写 intercept() 方法。

public class BasicAuthenticationInterceptor implements ClientHttpRequestInterceptor {

    private final String username;

    private final String password;

    @Nullable
    private final Charset charset;

    public BasicAuthenticationInterceptor(String username, String password) {
        this(username, password, null);
    }

    public BasicAuthenticationInterceptor(String username, String password, @Nullable Charset charset) {
        Assert.doesNotContain(username, ":", "Username must not contain a colon");
        this.username = username;
        this.password = password;
        this.charset = charset;
    }

    /**
     * 实现 intercept() 接口
     */
    @Override
    public ClientHttpResponse intercept(
            HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {

        HttpHeaders headers = request.getHeaders();
        if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) {
            headers.setBasicAuth(this.username, this.password, this.charset);
        }
        // 这里调用 execute() 方法又循环回到了 InterceptingClientHttpRequest 的 execute() 方法
        return execution.execute(request, body);
    }
}

RestTemplateBuilder

RestTemplateBuilder 提供了可传入请求拦截器的构造方法interceptors() 方法,在调用 build() 方法创建 RestTemplate 实例执行配置时将请求拦截器设置到 restTemplate 的 interceptors 属性中(实际是父类 InterceptingHttpAccessor 中的属性)。

public class RestTemplateBuilder {
    //.....省略其它属性......

    private final Set<ClientHttpRequestInterceptor> interceptors;

    //.....省略构造方法......

    /**
     * 指定消息转换器(所以可以自定议消息转换器)
     */
    public RestTemplateBuilder messageConverters(
            HttpMessageConverter<?>... messageConverters) {
        Assert.notNull(messageConverters, "MessageConverters must not be null");
        return messageConverters(Arrays.asList(messageConverters));
    }

    /**
     * 默认的消息转换器
     */
    public RestTemplateBuilder defaultMessageConverters() {
        return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri,
                Collections.unmodifiableSet(
                        new LinkedHashSet<>(new RestTemplate().getMessageConverters())),
                this.requestFactorySupplier, this.uriTemplateHandler, this.errorHandler,
                this.basicAuthentication, this.restTemplateCustomizers,
                this.requestFactoryCustomizer, this.interceptors);
    }

    /**
     * 指定的拦截器(会替换所有先前定义的拦截器)
     */
    public RestTemplateBuilder interceptors(
            ClientHttpRequestInterceptor... interceptors) {
        Assert.notNull(interceptors, "interceptors must not be null");
        return interceptors(Arrays.asList(interceptors));
    }

    /**
     * 被上面方法调用
     */
    public RestTemplateBuilder interceptors(
            Collection<ClientHttpRequestInterceptor> interceptors) {
        Assert.notNull(interceptors, "interceptors must not be null");
        return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri,
                this.messageConverters, this.requestFactorySupplier,
                this.uriTemplateHandler, this.errorHandler, this.basicAuthentication,
                this.restTemplateCustomizers, this.requestFactoryCustomizer,
                Collections.unmodifiableSet(new LinkedHashSet<>(interceptors)));
    }

    /**
     * build() 方法构建指定类型的 RestTemplate ,并进行配置
     */
    public <T extends RestTemplate> T build(Class<T> restTemplateClass) {
        //通过返射创建 RestTemplate 实例
        return configure(BeanUtils.instantiateClass(restTemplateClass));
    }

    /**
     * 被上面方法调用
     */
    public <T extends RestTemplate> T configure(T restTemplate) {
        //配置请求工厂
        configureRequestFactory(restTemplate);
        if (!CollectionUtils.isEmpty(this.messageConverters)) {
            restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
        }
        if (this.uriTemplateHandler != null) {
            restTemplate.setUriTemplateHandler(this.uriTemplateHandler);
        }
        if (this.errorHandler != null) {
            restTemplate.setErrorHandler(this.errorHandler);
        }
        if (this.rootUri != null) {
            RootUriTemplateHandler.addTo(restTemplate, this.rootUri);
        }
        if (this.basicAuthentication != null) {
            restTemplate.getInterceptors().add(this.basicAuthentication);
        }
        // 添加请求拦截器
        restTemplate.getInterceptors().addAll(this.interceptors);
        if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) {
            for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) {
                customizer.customize(restTemplate);
            }
        }
        return restTemplate;
    }
}
  1. 在使用通过 RestTemplateBuilder 创建的 RestTemplate 实例执行 HTTP 请求时,需要先创建请求工厂时。

  2. 在创建请求工厂时会判断是否存在拦截器,如果存在,则创建 InterceptingClientHttpRequestFactory 类型的请求工厂并传入拦截器。

  3. 在使用 InterceptingClientHttpRequestFactory 创建 Request 请求时,将加入传入的拦截器加入实例,在执行 execute() 方法时处理执行拦截器的处理。

    备注:后面两步的处理和 RestTemplate 传入拦截器调的是同样的代码逻辑。

自定义RestTemp late

定义RestTemplate

  1. RestTemplateBuilder

    如果需要调用远程服务,可以使用 Spring Framework 提供的 RestTemplate。RestTemplate 在使用前通常需要自定义,Spring Boot 没有提供自动配置 RestTemplate bean,但自动配置了RestTemplateBuilder bean,用于构建 RestTemplate 实例。自动配置的 RestTemplateBuilder 会确保将合理的 HttpMessageConverters 应用到 RestTemplate 实例中。

    根据上面思路,创建 RestTemplate 对象应该是如下操作:

    RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
    // 或
    RestTemplate template = new RestTemplateBuilder().setConnectTimeout(Duration.ofMillis(1000)).setReadTimeout(Duration.ofMillis(1000)).build();
    

    RestTemplateBuilder 提供了很多有用的方法,可以快速地配置 RestTemplate。例如,添加 BASIC auth 认证支持,可以使用如下方式:

    new RestTemplateBuilder().basicAuthentication("user","123456").build();
    

    另一种创建 RestTemplate 实例是直接 new 这个对象,如下方式。

  2. RestTemplate

    @Bean
    public RestTemplate restTemplate(DiscoveryClient discoveryClient, CloudTemplateProperties properties) {
        // 自定义的 Http Request
        RestTemplate restTemplate = new RestTemplate(new CloudClientHttpRequestFactory(properties));
        // 自定义的 请求拦截器
        CloudHttpRequestInterceptor httpRequestInterceptor = new CloudHttpRequestInterceptor(discoveryClient, properties);
        ArrayList<ClientHttpRequestInterceptor> chRequestInterceptor = new ArrayList<ClientHttpRequestInterceptor>(1);
        // 加入
        chRequestInterceptor.add(httpRequestInterceptor);
        restTemplate.setInterceptors(chRequestInterceptor);
        return restTemplate;
    }
    

定义Http Request

  1. 默认的 ClientHttpRequestFactory

    RestTemplate 默认使用的是 java.net.HttpURLConnection 来执行请求,可以切换成不同的实现了ClientHttpRequestFactory 接口的 HTTP 库,如:Apache HttpComponents,Netty,OkHttp。

    示例:使用 Apache HttpComponents

    RestTemplate template = new RestTemplate(new HttpComponentsClientHttpRequestFactory());
    
  2. 自定义请求工厂示例

    根据业务需要,继承 SimpleClientHttpRequestFactory 或实现 ClientHttpRequestFactory 接口,重写 createRequest()prepareConnection() 方法。

    public class CloudClientHttpRequestFactory extends SimpleClientHttpRequestFactory {
    
        private static final Logger logger = LoggerFactory.getLogger(CloudClientHttpRequestFactory.class);
    
        private CloudTemplateProperties propt;
    
        public CloudClientHttpRequestFactory(CloudTemplateProperties tmpProp) {
            this.propt = tmpProp;
            if(tmpProp != null){
                logger.info("ConnectTimeout = {},ReadTimeout={}",tmpProp.getConnectTimeout(),tmpProp.getReadTimeout());
                if(tmpProp.getConnectTimeout()>0){
                    logger.info("ConnectTimeout setted");
                    super.setConnectTimeout(tmpProp.getConnectTimeout());
                }
                if(tmpProp.getReadTimeout()>0){
                    logger.info("ReadTimeout setted");
                    super.setReadTimeout(tmpProp.getReadTimeout());
                }
            }
        }
    
        @Override
        public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
            URI rsv_uri = uri;
            try {
                rsv_uri = new URI(uri.toASCIIString());
            } catch (URISyntaxException ex) {
                logger.error("system error: ", ex);
            }
            return super.createRequest(rsv_uri, httpMethod);
        }
    
        @Override
        protected void prepareConnection(HttpURLConnection conn, String method) throws IOException {
            super.prepareConnection(conn, method);
            conn.setConnectTimeout(propt.getConnectTimeout());
            conn.setReadTimeout(propt.getReadTimeout());
            conn.setRequestProperty(HttpHeaders.USER_AGENT, propt.getUserAgent());
            if(StringUtils.isNotEmpty(propt.getAccept())) {
                conn.setRequestProperty(HttpHeaders.ACCEPT, propt.getAccept());
            }
            if(StringUtils.isNotEmpty(propt.getContentType()) && HttpMethod.POST.matches(method)) {
                conn.setRequestProperty(HttpHeaders.CONTENT_TYPE, propt.getContentType());
            }
    
            Object xAuth = TemplateXAuthHolder.getXAuth();
            if("notoken".equals(xAuth)) {
                return;
            }
            if(null != xAuth) {
                conn.setRequestProperty("X-Auth-Token", xAuth.toString());
                logger.info(" - request url[X-Auth-Token] from app: {}", conn.getURL());
                return;
            }
    
            HttpServletRequest req = UserContext.getHttpRequest();
            if(null == req) {
                return;
            }
            String xAuthToken = req.getHeader("X-Auth-Token");
            if (StringUtils.isNotEmpty(xAuthToken)) {
                conn.setRequestProperty("X-Auth-Token", xAuthToken);
                logger.info(" - request url[X-Auth-Token]: {}", conn.getURL());
                return;
            }
            String cookie = req.getHeader("Cookie");
            if (StringUtils.isNotEmpty(cookie)) {
                conn.setRequestProperty("Cookie", cookie);
                logger.debug(" - request url[Cookie]: {}", conn.getURL());
                return;
            }
        }
    }
    

定义Request Interceptor

自定义请求拦截器需要实现 ClientHttpRequestInterceptor 接口中的 intercept() 方法。这块在客户端侧的负载均衡会用到,如 Ribbon。

public class CloudHttpRequestInterceptor implements ClientHttpRequestInterceptor {

    private static final Logger logger = LoggerFactory.getLogger(CloudHttpRequestInterceptor.class);

    private AtomicInteger nextServerCyclicCounter = new AtomicInteger(0);;

    private DiscoveryClient discoveryClient;
    private int maxAttempts;
    private Map<String, Object[]> services = new HashMap<String, Object[]>();

    public CloudHttpRequestInterceptor(DiscoveryClient discoveryClient, CloudTemplateProperties propt) {
        this.discoveryClient = discoveryClient;
        this.maxAttempts = propt.getMaxAttempts();
        Map<String, String> svrs = propt.getServices();
        if (null != svrs) {
            for (Entry<String, String> item : svrs.entrySet()) {
                String[] service = item.getValue().split(":");  //host:port
                services.put(item.getKey(), new Object[] { service[0], service.length > 1 ? Integer.parseInt(service[1]) : 8080 });
            }
        }
    }

    @Override
    public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
        HttpRequestWrapper requestWrapper = new HttpRequestWrapper(request) {
            @Override
            public URI getURI() {
                URI uri = request.getURI();
                try {
                    String host = uri.getHost();
                    Object[] service = services.get(host);
                    if (service != null) {
                        return new URI(uri.getScheme(), uri.getUserInfo(), (String) service[0], (int) service[1],
                                uri.getPath(), uri.getQuery(), uri.getFragment());
                    }
                    if (null == discoveryClient) {
                        return uri;
                    }
                    ServiceInstance instance = getInstances(host);
                    if (null == instance) {
                        // 走默认URL方式
                        return uri;
                    } else {
                        // 走注册中心服务调用方式
                        return new URI(uri.getScheme(), uri.getUserInfo(), instance.getHost(), instance.getPort(),
                                uri.getPath(), uri.getQuery(), uri.getFragment());
                    }
                } catch (URISyntaxException e) {
                    logger.error("URISyntaxException", e);
                    return uri;
                }
            }
        };

        ClientHttpResponse resp = null;
        for (int i = 0; i < maxAttempts; i++) {
            try {
                resp = execution.execute(requestWrapper, body);
                if (resp.getStatusCode() == HttpStatus.OK) {
                    break;
                }
            } catch (IOException ex) {
                logger.error(" Http Request error: ", ex);
                if (i == 2) throw new IOException("三次请求服务失败", ex);
            }
        }
        TemplateXAuthHolder.remove();
        return resp;
    }

    /**
     * 根据服务名获取注册中心服务
     * @param serviceId
     * @return ServiceInstance
     */
    private ServiceInstance getInstances(String serviceId) {
        List<ServiceInstance> instances = discoveryClient.getInstances(serviceId);
        if (CollectionUtils.isEmpty(instances)) {
            logger.warn("注册中心获取服务为空! 服务名:{}", serviceId);
            return null;
        }
        logger.info("注册中心获取服务名称:{},在线服务数量:{}", serviceId, instances.size());
        int nextServerIndex = incrementAndGetModulo(instances.size());
        return instances.get(nextServerIndex);
    }

    /**
     * 轮训策略
     * @param modulo
     */
    private int incrementAndGetModulo(int modulo) {
        int current = nextServerCyclicCounter.get();
        int next = (current + 1) % modulo;
        if (nextServerCyclicCounter.compareAndSet(current, next)) {
            return next;
        } else {
            return 0;
        }
    }
}
更多内容请访问:IT源点

全部评论: 0

    我有话说: