package it.inaf.ia2.gms.authn; import it.inaf.ia2.gms.persistence.LoggingDAO; import java.io.IOException; import java.security.Principal; import java.util.Map; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore; public class JWTFilter implements Filter { private final JwkTokenStore jwkTokenStore; private final LoggingDAO loggingDAO; public JWTFilter(JwkTokenStore jwkTokenStore, LoggingDAO loggingDAO) { this.jwkTokenStore = jwkTokenStore; this.loggingDAO = loggingDAO; } @Override public void doFilter(ServletRequest req, ServletResponse res, FilterChain fc) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) req; HttpServletResponse response = (HttpServletResponse) res; String authHeader = request.getHeader("Authorization"); if (authHeader == null) { loggingDAO.logAction("Attempt to access WS without token", request); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Missing Authorization token"); return; } authHeader = authHeader.replace("Bearer", "").trim(); OAuth2AccessToken accessToken = jwkTokenStore.readAccessToken(authHeader); if (accessToken.isExpired()) { loggingDAO.logAction("Attempt to access WS with expired token", request); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Access token is expired"); return; } Map claims = accessToken.getAdditionalInformation(); if (claims.get("sub") == null) { loggingDAO.logAction("Attempt to access WS with invalid token", request); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Invalid access token: missing sub claim"); return; } ServletRequestWithJWTPrincipal wrappedRequest = new ServletRequestWithJWTPrincipal(request, claims); loggingDAO.logAction("WS access from " + wrappedRequest.getUserPrincipal().getName(), request); fc.doFilter(wrappedRequest, res); } private static class ServletRequestWithJWTPrincipal extends HttpServletRequestWrapper { private final Principal principal; public ServletRequestWithJWTPrincipal(HttpServletRequest request, Map jwtClaims) { super(request); this.principal = new RapPrincipal(jwtClaims); } @Override public Principal getUserPrincipal() { return principal; } } }