Commit 359bf2b9 authored by Sonia Zorba's avatar Sonia Zorba
Browse files

Added BasicAuth/JWT service endpoints

parent fa4bf6f1
Loading
Loading
Loading
Loading
+2 −2
Original line number Original line Diff line number Diff line
@@ -13,8 +13,8 @@ public class CustomIdTokenConverter extends DefaultUserAuthenticationConverter {


    private final JwkTokenStore jwkTokenStore;
    private final JwkTokenStore jwkTokenStore;


    public CustomIdTokenConverter(String keySetUri) {
    public CustomIdTokenConverter(JwkTokenStore jwkTokenStore) {
        this.jwkTokenStore = new JwkTokenStore(keySetUri);
        this.jwkTokenStore = jwkTokenStore;
    }
    }


    @Override
    @Override
+73 −0
Original line number Original line Diff line number Diff line
package it.inaf.ia2.gms.authn;

import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore;

public class JWTFilter implements Filter {

    private final JwkTokenStore jwkTokenStore;

    public JWTFilter(JwkTokenStore jwkTokenStore) {
        this.jwkTokenStore = jwkTokenStore;
    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain fc) throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;

        String authHeader = request.getHeader("Authorization");
        if (authHeader == null) {
            response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Missing Authorization token");
            return;
        }

        authHeader = authHeader.replace("Bearer", "").trim();

        OAuth2AccessToken accessToken = jwkTokenStore.readAccessToken(authHeader);
        if (accessToken.isExpired()) {
            response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Access token is expired");
            return;
        }

        Map<String, Object> claims = accessToken.getAdditionalInformation();

        String principal = (String) claims.get("sub");
        if (principal == null) {
            response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Invalid access token: missing sub claim");
            return;
        }

        ServletRequest wrappedRequest = new ServletRequestWithJWTPrincipal(request, principal);

        fc.doFilter(wrappedRequest, res);
    }

    private static class ServletRequestWithJWTPrincipal extends HttpServletRequestWrapper {

        private final String principal;

        public ServletRequestWithJWTPrincipal(HttpServletRequest request, String principal) {
            super(request);
            this.principal = principal;
        }

        @Override
        public Principal getUserPrincipal() {
            return new UsernamePasswordAuthenticationToken(principal, null);
        }
    }
}
+3 −5
Original line number Original line Diff line number Diff line
@@ -15,6 +15,7 @@ import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.oauth2.provider.client.InMemoryClientDetailsService;
import org.springframework.security.oauth2.provider.client.InMemoryClientDetailsService;
import org.springframework.security.oauth2.provider.token.DefaultAccessTokenConverter;
import org.springframework.security.oauth2.provider.token.DefaultAccessTokenConverter;
import org.springframework.security.oauth2.provider.token.RemoteTokenServices;
import org.springframework.security.oauth2.provider.token.RemoteTokenServices;
import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.RestTemplate;


/**
/**
@@ -24,9 +25,6 @@ import org.springframework.web.client.RestTemplate;
@Configuration
@Configuration
public class OAuth2Config extends AuthorizationServerEndpointsConfiguration {
public class OAuth2Config extends AuthorizationServerEndpointsConfiguration {


    @Value("${security.oauth2.resource.jwk.key-set-uri}")
    private String keySetUri;

    @Value("${security.oauth2.resource.token-info-uri}")
    @Value("${security.oauth2.resource.token-info-uri}")
    private String checkTokenEndpointUrl;
    private String checkTokenEndpointUrl;


@@ -37,11 +35,11 @@ public class OAuth2Config extends AuthorizationServerEndpointsConfiguration {
    private String clientSecret;
    private String clientSecret;


    @Bean
    @Bean
    public RemoteTokenServices resourceServerTokenServices() {
    public RemoteTokenServices resourceServerTokenServices(JwkTokenStore jwkTokenStore) {
        RemoteTokenServices tokenService = new RemoteTokenServices();
        RemoteTokenServices tokenService = new RemoteTokenServices();


        DefaultAccessTokenConverter accessTokenConverter = new DefaultAccessTokenConverter();
        DefaultAccessTokenConverter accessTokenConverter = new DefaultAccessTokenConverter();
        accessTokenConverter.setUserTokenConverter(new CustomIdTokenConverter(keySetUri));
        accessTokenConverter.setUserTokenConverter(new CustomIdTokenConverter(jwkTokenStore));
        tokenService.setAccessTokenConverter(accessTokenConverter);
        tokenService.setAccessTokenConverter(accessTokenConverter);


        tokenService.setCheckTokenEndpointUrl(checkTokenEndpointUrl);
        tokenService.setCheckTokenEndpointUrl(checkTokenEndpointUrl);
+29 −7
Original line number Original line Diff line number Diff line
@@ -14,6 +14,7 @@ import org.springframework.http.HttpMethod;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.builders.WebSecurity;
import org.springframework.security.config.annotation.web.builders.WebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import org.springframework.web.filter.CorsFilter;
import org.springframework.web.filter.CorsFilter;
@@ -28,6 +29,14 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
    @Value("${cors.allowed.origin}")
    @Value("${cors.allowed.origin}")
    private String corsAllowedOrigin;
    private String corsAllowedOrigin;


    @Value("${security.oauth2.resource.jwk.key-set-uri}")
    private String keySetUri;

    @Bean
    public JwkTokenStore jwkTokenStore() {
        return new JwkTokenStore(keySetUri);
    }

    @Override
    @Override
    public void configure(HttpSecurity http) throws Exception {
    public void configure(HttpSecurity http) throws Exception {


@@ -43,23 +52,36 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
    }
    }


    /**
    /**
     * The authentication is ignored for these endpoints. The "/ws" endpoints
     * The authentication is ignored for these endpoints. The "/ws/basic"
     * (web service API for programmatic access) are protected by the custom
     * endpoints (web service API for programmatic access) are protected by the
     * WebServiceAuthorizationFilter that checks BasicAuth for GMS clients.
     * custom ServiceBasicAuthFilter that checks BasicAuth for GMS clients,
     * while the "/ws/jwt" endpoints are protected by the JWTFilter.
     */
     */
    @Override
    @Override
    public void configure(WebSecurity web) throws Exception {
    public void configure(WebSecurity web) throws Exception {
        web.ignoring().antMatchers("/ws/**", "/error");
        web.ignoring().antMatchers("/ws/basic/**", "/ws/jwt/**", "/error");
    }
    }


    /**
    /**
     * Checks the BasicAuth for GMS clients.
     * Checks the BasicAuth for GMS clients.
     */
     */
    @Bean
    @Bean
    public FilterRegistrationBean webServiceAuthorizationFilter() {
    public FilterRegistrationBean serviceBasicAuthFilter() {
        FilterRegistrationBean bean = new FilterRegistrationBean();
        bean.setFilter(new ServiceBasicAuthFilter());
        bean.addUrlPatterns("/ws/basic/*");
        bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
        return bean;
    }

    /**
     * Checks JWT for web services.
     */
    @Bean
    public FilterRegistrationBean serviceJWTFilter(JwkTokenStore jwkTokenStore) {
        FilterRegistrationBean bean = new FilterRegistrationBean();
        FilterRegistrationBean bean = new FilterRegistrationBean();
        bean.setFilter(new WebServiceAuthorizationFilter());
        bean.setFilter(new JWTFilter(jwkTokenStore));
        bean.addUrlPatterns("/ws/*");
        bean.addUrlPatterns("/ws/jwt/*");
        bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
        bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
        return bean;
        return bean;
    }
    }
+1 −1
Original line number Original line Diff line number Diff line
@@ -20,7 +20,7 @@ import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.context.support.WebApplicationContextUtils;


public class WebServiceAuthorizationFilter implements Filter {
public class ServiceBasicAuthFilter implements Filter {


    @Override
    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
Loading