Skip to content
ServiceBasicAuthFilter.java 3.43 KiB
Newer Older
package it.inaf.ia2.gms.authn;

import it.inaf.ia2.gms.exception.UnauthorizedException;
import it.inaf.ia2.gms.persistence.ClientsDAO;
import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.bind.DatatypeConverter;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;

public class ServiceBasicAuthFilter implements Filter {

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) req;

        if (request.getServletPath().startsWith("/ws/")) {
            try {
                validateBasicAuth(request);
            } catch (UnauthorizedException ex) {
                ((HttpServletResponse) res).sendError(HttpServletResponse.SC_UNAUTHORIZED, ex.getMessage());
                return;
            }
        }

        chain.doFilter(req, res);
    }

    private void validateBasicAuth(HttpServletRequest request) {

        String token = getBasicAuthToken(request);

        int delim = token.indexOf(":");

        if (delim == -1) {
            throw new BadCredentialsException("Invalid basic authentication token");
        }

        String clientId = token.substring(0, delim);
        String clientSecret = token.substring(delim + 1);

        ClientsDAO clientsDAO = getClientsDAO(request);

        ClientEntity client = clientsDAO.findClientById(clientId)
                .orElseThrow(() -> new BadCredentialsException("Client " + clientId + " not found"));

        String shaSecret = getSha256(clientSecret);
        if (!shaSecret.equals(client.getSecret())) {
            throw new UnauthorizedException("Wrong secret");
        }
    }

    private String getBasicAuthToken(HttpServletRequest request) {

        String header = request.getHeader("Authorization");

        if (header == null || !header.toLowerCase().startsWith("basic ")) {
            throw new UnauthorizedException("Missing Authorization header");
        }

        byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8);
        byte[] decoded = Base64.getDecoder().decode(base64Token);

        return new String(decoded, StandardCharsets.UTF_8);
    }

    protected ClientsDAO getClientsDAO(HttpServletRequest request) {
        WebApplicationContext webApplicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
        return webApplicationContext.getBean(ClientsDAO.class);
    }

    private static String getSha256(String secret) {
        try {
            MessageDigest md = MessageDigest.getInstance("SHA-256");
            byte[] sha = md.digest(secret.getBytes(StandardCharsets.UTF_8));
            return DatatypeConverter.printHexBinary(sha).toLowerCase();
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
    }
}