package it.inaf.ia2.gms.authn; import it.inaf.ia2.aa.UserManager; import it.inaf.ia2.aa.data.User; import it.inaf.ia2.gms.persistence.LoggingDAO; import it.inaf.ia2.gms.persistence.model.ActionType; import java.io.IOException; import java.security.Principal; import java.util.Map; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequestWrapper; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; public class JWTFilter implements Filter { private final LoggingDAO loggingDAO; private final UserManager userManager; public JWTFilter(LoggingDAO loggingDAO, UserManager userManager) { this.loggingDAO = loggingDAO; this.userManager = userManager; } @Override public void doFilter(ServletRequest req, ServletResponse res, FilterChain fc) throws IOException, ServletException { HttpServletRequest request = (HttpServletRequest) req; HttpServletResponse response = (HttpServletResponse) res; String authHeader = request.getHeader("Authorization"); if (authHeader == null) { if (request.isRequestedSessionIdValid()) { HttpSession session = request.getSession(false); User user = (User) session.getAttribute("user_data"); if (user != null) { ServletRequestWithSessionPrincipal wrappedRequest = new ServletRequestWithSessionPrincipal(request, user); fc.doFilter(wrappedRequest, res); return; } } fc.doFilter(req, res); return; } String token = authHeader.replace("Bearer", "").trim(); Map claims = userManager.parseIdTokenClaims(token); if (claims.get("sub") == null) { loggingDAO.logAction(ActionType.UNAUTHORIZED_ACCESS_ATTEMPT, "Attempt to access API with invalid token " + request.getRequestURI(), request); response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Invalid access token: missing sub claim"); return; } ServletRequestWithJWTPrincipal wrappedRequest = new ServletRequestWithJWTPrincipal(request, token, claims); loggingDAO.logAction(ActionType.API_CALL, request.getRequestURI() + " called by " + wrappedRequest.getUserPrincipal().getName(), request); fc.doFilter(wrappedRequest, res); } private static class ServletRequestWithSessionPrincipal extends HttpServletRequestWrapper { private final User principal; public ServletRequestWithSessionPrincipal(HttpServletRequest request, User user) { super(request); this.principal = user; } @Override public Principal getUserPrincipal() { return principal; } } private static class ServletRequestWithJWTPrincipal extends HttpServletRequestWrapper { private final RapPrincipal principal; public ServletRequestWithJWTPrincipal(HttpServletRequest request, String token, Map jwtClaims) { super(request); this.principal = new RapPrincipal(token, jwtClaims); } @Override public Principal getUserPrincipal() { return principal; } } }