Newer
Older
package it.inaf.ia2.gms.authn;
import it.inaf.ia2.gms.exception.UnauthorizedException;
import it.inaf.ia2.gms.persistence.ClientsDAO;
import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.bind.DatatypeConverter;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
public class ServiceBasicAuthFilter implements Filter {
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
if (request.getServletPath().startsWith("/ws/")) {
try {
validateBasicAuth(request);
} catch (UnauthorizedException ex) {
((HttpServletResponse) res).sendError(HttpServletResponse.SC_UNAUTHORIZED, ex.getMessage());
return;
}
}
chain.doFilter(req, res);
}
private void validateBasicAuth(HttpServletRequest request) {
String token = getBasicAuthToken(request);
int delim = token.indexOf(":");
if (delim == -1) {
throw new BadCredentialsException("Invalid basic authentication token");
}
String clientId = token.substring(0, delim);
String clientSecret = token.substring(delim + 1);
ClientsDAO clientsDAO = getClientsDAO(request);
ClientEntity client = clientsDAO.findClientById(clientId)
.orElseThrow(() -> new BadCredentialsException("Client " + clientId + " not found"));
String shaSecret = getSha256(clientSecret);
if (!shaSecret.equals(client.getSecret())) {
throw new UnauthorizedException("Wrong secret");
}
}
private String getBasicAuthToken(HttpServletRequest request) {
String header = request.getHeader("Authorization");
if (header == null || !header.toLowerCase().startsWith("basic ")) {
throw new UnauthorizedException("Missing Authorization header");
}
byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8);
byte[] decoded = Base64.getDecoder().decode(base64Token);
return new String(decoded, StandardCharsets.UTF_8);
}
protected ClientsDAO getClientsDAO(HttpServletRequest request) {
WebApplicationContext webApplicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
return webApplicationContext.getBean(ClientsDAO.class);
}
private static String getSha256(String secret) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] sha = md.digest(secret.getBytes(StandardCharsets.UTF_8));
return DatatypeConverter.printHexBinary(sha).toLowerCase();
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e);
}
}
}