上篇回顾

我们已经知道了Spring Security的核心过滤器的创建和原理,本文主要介绍核心过滤器FilterChainProxy是如何在tomcat的ServletContext中生效的。

ServletContext如何拿到FilterChainProxy的过滤器对象

我们都知道,Bean都是存在Spring的Bean工厂里的,
而且在Web项目中ServletFilterListener都要放入ServletContext中。

看下面这张图,ServletContainerInitializer接口提供了一个onStartup()方法,用于在Servlet容器启动时动态注册一些对象到ServletContext中。

ServletContainerInitializer

官方的解释是:为了支持可以不使用web.xml。提供了ServletContainerInitializer,它可以通过SPI机制,当启动web容器的时候,会自动到添加的相应jar包下找到META-INF/services下以ServletContainerInitializer的全路径名称命名的文件,它的内容为ServletContainerInitializer实现类的全路径,将它们实例化。

通过下图可知,Spring框架通过META-INF配置了SpringServletContainerInitializer

META-INF配置

SpringServletContainerInitializer实现了ServletContainerInitializer接口。

请注意该类上的@HandlesTypes(WebApplicationInitializer.class)注解.

根据Sevlet3.0规范,Servlet容器在调用onStartup()方法时,会以Set集合的方式注入WebApplicationInitializer的子类(包括接口,抽象类)。然后会依次调用WebApplicationInitializer的实现类的onStartup方法,从而起到启动web.xml相同的作用(添加servletlistener实例到ServletContext中)。

SpringServletContainerInitializer

Spring Security中的AbstractSecurityWebApplicationInitializer就是WebApplicationInitializer的抽象子类.

当执行到下面的onStartup()方法时,会调用insertSpringSecurityFilterChain()

将类型为FilterChainProxy名称为springSecurityFilterChain的过滤器对象用DelegatingFilterProxy包装,然后注入ServletContext

AbstractSecurityWebApplicationInitializer

运行过程

请求到达的时候,FilterChainProxy的dofilter()方法内部,会遍历所有的SecurityFilterChain,对匹配到的url,则一一调用SecurityFilterChain中的filter做认证或授权。

public class FilterChainProxy extends GenericFilterBean {

    private final static String FILTER_APPLIED = FilterChainProxy.class.getName().concat(".APPLIED");
    private List<SecurityFilterChain> filterChains;
    private FilterChainValidator filterChainValidator = new NullFilterChainValidator();
    private HttpFirewall firewall = new StrictHttpFirewall();

    public FilterChainProxy() {
    }

    public FilterChainProxy(SecurityFilterChain chain) {
        this(Arrays.asList(chain));
    }

    public FilterChainProxy(List<SecurityFilterChain> filterChains) {
        this.filterChains = filterChains;
    }

    @Override
    public void afterPropertiesSet() {
        filterChainValidator.validate(this);
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response,
            FilterChain chain) throws IOException, ServletException {
        boolean clearContext = request.getAttribute(FILTER_APPLIED) == null;
        if (clearContext) {
            try {
                request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
                doFilterInternal(request, response, chain);
            }
            finally {
                SecurityContextHolder.clearContext();
                request.removeAttribute(FILTER_APPLIED);
            }
        }
        else {
            doFilterInternal(request, response, chain);
        }
    }

    private void doFilterInternal(ServletRequest request, ServletResponse response,
            FilterChain chain) throws IOException, ServletException {

        FirewalledRequest fwRequest = firewall.getFirewalledRequest((HttpServletRequest) request);
        HttpServletResponse fwResponse = firewall.getFirewalledResponse((HttpServletResponse) response);

        // 根据当前请求,获得一组过滤器链
        List<Filter> filters = getFilters(fwRequest);

        if (filters == null || filters.size() == 0) {
            if (logger.isDebugEnabled()) {
                logger.debug(UrlUtils.buildRequestUrl(fwRequest)
                        + (filters == null ? " has no matching filters": " has an empty filter list"));
            }
            fwRequest.reset();
            chain.doFilter(fwRequest, fwResponse);
            return;
        }

        VirtualFilterChain vfc = new VirtualFilterChain(fwRequest, chain, filters);
        // 请求依次经过这组滤器链
        vfc.doFilter(fwRequest, fwResponse);
    }

    /**
     * 根据Request请求获得一个过滤器链
     */
    private List<Filter> getFilters(HttpServletRequest request) {
        for (SecurityFilterChain chain : filterChains) {
            if (chain.matches(request)) {
                return chain.getFilters();
            }
        }
        return null;
    }

    /**
     * 根据URL获得一个过滤器链
     */
    public List<Filter> getFilters(String url) {
        return getFilters(firewall.getFirewalledRequest((new FilterInvocation(url, null)
                .getRequest())));
    }

    /**
     * 返回一个过滤器链
     */
    public List<SecurityFilterChain> getFilterChains() {
        return Collections.unmodifiableList(filterChains);
    }

    // 过滤器链内部类
    private static class VirtualFilterChain implements FilterChain {
        private final FilterChain originalChain;
        private final List<Filter> additionalFilters;
        private final FirewalledRequest firewalledRequest;
        private final int size;
        private int currentPosition = 0;

        private VirtualFilterChain(FirewalledRequest firewalledRequest,
                FilterChain chain, List<Filter> additionalFilters) {
            this.originalChain = chain;
            this.additionalFilters = additionalFilters;
            this.size = additionalFilters.size();
            this.firewalledRequest = firewalledRequest;
        }

        @Override
        public void doFilter(ServletRequest request, ServletResponse response)
                throws IOException, ServletException {
            if (currentPosition == size) {
                if (logger.isDebugEnabled()) {
                    logger.debug(UrlUtils.buildRequestUrl(firewalledRequest)
                            + " reached end of additional filter chain; proceeding with original chain");
                }

                // Deactivate path stripping as we exit the security filter chain
                this.firewalledRequest.reset();

                originalChain.doFilter(request, response);
            }
            else {
                currentPosition++;

                Filter nextFilter = additionalFilters.get(currentPosition - 1);

                if (logger.isDebugEnabled()) {
                    logger.debug(UrlUtils.buildRequestUrl(firewalledRequest)
                            + " at position " + currentPosition + " of " + size
                            + " in additional filter chain; firing Filter: '"
                            + nextFilter.getClass().getSimpleName() + "'");
                }

                nextFilter.doFilter(request, response, this);
            }
        }
    }

    public interface FilterChainValidator {
        void validate(FilterChainProxy filterChainProxy);
    }

    private static class NullFilterChainValidator implements FilterChainValidator {
        @Override
        public void validate(FilterChainProxy filterChainProxy) {
        }
    }

}