Skip to content
ClientDbFilter.java 1.34 KiB
Newer Older
Sonia Zorba's avatar
Sonia Zorba committed
package it.inaf.ia2.gms.authn;

import it.inaf.ia2.aa.AuthConfig;
Sonia Zorba's avatar
Sonia Zorba committed
import it.inaf.ia2.aa.UserManager;
Sonia Zorba's avatar
Sonia Zorba committed
import java.io.IOException;
import java.net.URI;
Sonia Zorba's avatar
Sonia Zorba committed
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;

public class ClientDbFilter implements Filter {

Sonia Zorba's avatar
Sonia Zorba committed
    public static final String CLIENT_DB = "client_db";

Sonia Zorba's avatar
Sonia Zorba committed
    private final UserManager userManager;
    private final String defaultJwksUri;
Sonia Zorba's avatar
Sonia Zorba committed
    public ClientDbFilter(AuthConfig authConfig, UserManager userManager) {
        this.userManager = userManager;
        defaultJwksUri = URI.create(authConfig.getRapBaseUri()).resolve(authConfig.getJwksEndpoint()).toString();
Sonia Zorba's avatar
Sonia Zorba committed
    }
Sonia Zorba's avatar
Sonia Zorba committed

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain fc) throws IOException, ServletException {
Sonia Zorba's avatar
Sonia Zorba committed
        HttpServletRequest request = (HttpServletRequest) req;
Sonia Zorba's avatar
Sonia Zorba committed
        String clientDb = request.getParameter(CLIENT_DB);
        if (clientDb != null) {
            request.getSession().setAttribute(CLIENT_DB, clientDb);
Sonia Zorba's avatar
Sonia Zorba committed
            String newUrl = defaultJwksUri.replaceAll("\\?client_name=(.*)", "?client_name=" + clientDb);
Sonia Zorba's avatar
Sonia Zorba committed
            userManager.addJwksUri(URI.create(newUrl));
Sonia Zorba's avatar
Sonia Zorba committed
        }
Sonia Zorba's avatar
Sonia Zorba committed
        fc.doFilter(req, res);
    }
}