Commit ba6b02eb authored by Sonia Zorba's avatar Sonia Zorba
Browse files

Added persistence layer for invited registration

parent a989ae7c
......@@ -67,18 +67,6 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
web.ignoring().antMatchers("/ws/basic/**", "/ws/jwt/**", "/error", "/logout");
}
/**
* Checks the BasicAuth for GMS clients.
*/
@Bean
public FilterRegistrationBean serviceBasicAuthFilter(LoggingDAO loggingDAO) {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setFilter(new ServiceBasicAuthFilter(loggingDAO));
bean.addUrlPatterns("/ws/basic/*");
bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
return bean;
}
/**
* Checks JWT for web services.
*/
......
package it.inaf.ia2.gms.authn;
import it.inaf.ia2.gms.exception.UnauthorizedException;
import it.inaf.ia2.gms.persistence.ClientsDAO;
import it.inaf.ia2.gms.persistence.LoggingDAO;
import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.bind.DatatypeConverter;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
public class ServiceBasicAuthFilter implements Filter {
private final LoggingDAO loggingDAO;
public ServiceBasicAuthFilter(LoggingDAO loggingDAO) {
this.loggingDAO = loggingDAO;
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
if (request.getServletPath().startsWith("/ws/")) {
try {
validateBasicAuth(request);
} catch (UnauthorizedException ex) {
loggingDAO.logAction("Unauthorized BasicAuth WS request");
((HttpServletResponse) res).sendError(HttpServletResponse.SC_UNAUTHORIZED, ex.getMessage());
return;
}
}
loggingDAO.logAction("BasicAuth WS request");
chain.doFilter(req, res);
}
private void validateBasicAuth(HttpServletRequest request) {
String token = getBasicAuthToken(request);
int delim = token.indexOf(":");
if (delim == -1) {
throw new BadCredentialsException("Invalid basic authentication token");
}
String clientId = token.substring(0, delim);
String clientSecret = token.substring(delim + 1);
ClientsDAO clientsDAO = getClientsDAO(request);
ClientEntity client = clientsDAO.findClientById(clientId)
.orElseThrow(() -> new BadCredentialsException("Client " + clientId + " not found"));
String shaSecret = getSha256(clientSecret);
if (!shaSecret.equals(client.getSecret())) {
throw new UnauthorizedException("Wrong secret");
}
}
private String getBasicAuthToken(HttpServletRequest request) {
String header = request.getHeader("Authorization");
if (header == null || !header.toLowerCase().startsWith("basic ")) {
throw new UnauthorizedException("Missing Authorization header");
}
byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8);
byte[] decoded = Base64.getDecoder().decode(base64Token);
return new String(decoded, StandardCharsets.UTF_8);
}
protected ClientsDAO getClientsDAO(HttpServletRequest request) {
WebApplicationContext webApplicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
return webApplicationContext.getBean(ClientsDAO.class);
}
private static String getSha256(String secret) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-256");
byte[] sha = md.digest(secret.getBytes(StandardCharsets.UTF_8));
return DatatypeConverter.printHexBinary(sha).toLowerCase();
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException(e);
}
}
}
package it.inaf.ia2.gms.persistence;
import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.sql.DataSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
@Component
public class ClientsDAO {
private final JdbcTemplate jdbcTemplate;
@Autowired
public ClientsDAO(DataSource dataSource) {
jdbcTemplate = new JdbcTemplate(dataSource);
}
public Optional<ClientEntity> findClientById(String clientId) {
String sql = "SELECT client_secret, allowed_actions, ip_filter FROM gms_client WHERE client_id = ?";
return jdbcTemplate.query(conn -> {
PreparedStatement ps = conn.prepareStatement(sql);
ps.setString(1, clientId);
return ps;
}, resultSet -> {
if (resultSet.next()) {
ClientEntity client = new ClientEntity();
client.setId(clientId);
client.setSecret(resultSet.getString("client_secret"));
client.setAllowedActions(getAllowedActions(resultSet));
client.setIpFilter(resultSet.getString("ip_filter"));
return Optional.of(client);
}
return Optional.empty();
});
}
private List<String> getAllowedActions(ResultSet resultSet) throws SQLException {
List<String> actions = new ArrayList<>();
ResultSet items = resultSet.getArray("allowed_actions").getResultSet();
while (items.next()) {
String action = items.getString(1);
actions.add(action);
}
return actions;
}
}
package it.inaf.ia2.gms.persistence;
import it.inaf.ia2.gms.persistence.model.InvitedRegistration;
import java.sql.PreparedStatement;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.sql.DataSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;
@Component
public class InvitedRegistrationDAO {
private final JdbcTemplate jdbcTemplate;
@Autowired
public InvitedRegistrationDAO(DataSource dataSource) {
jdbcTemplate = new JdbcTemplate(dataSource);
}
@Transactional
public void addInvitedRegistration(InvitedRegistration invitedRegistration) {
String sqlReq = "INSERT INTO invited_registration_request (id, token_hash, email) VALUES (?, ?, ?)";
jdbcTemplate.update(conn -> {
PreparedStatement ps = conn.prepareStatement(sqlReq);
ps.setString(1, invitedRegistration.getId());
ps.setString(2, invitedRegistration.getTokenHash());
ps.setObject(3, invitedRegistration.getEmail());
return ps;
});
for (String groupId : invitedRegistration.getGroupIds()) {
String sqlReqGroup = "INSERT INTO invited_registration_request_group (request_id, group_id) VALUES (?, ?)";
jdbcTemplate.update(conn -> {
PreparedStatement ps = conn.prepareStatement(sqlReqGroup);
ps.setString(1, invitedRegistration.getId());
ps.setString(2, groupId);
return ps;
});
}
}
public Optional<InvitedRegistration> getInvitedRegistrationFromToken(String tokenHash) {
String sqlReq = "SELECT id, email FROM invited_registration_request WHERE token_hash = ? AND !done";
InvitedRegistration registration = jdbcTemplate.query(conn -> {
PreparedStatement ps = conn.prepareStatement(sqlReq);
ps.setString(1, tokenHash);
return ps;
}, resultSet -> {
if (resultSet.next()) {
InvitedRegistration reg = new InvitedRegistration();
reg.setId(resultSet.getString("id"));
reg.setEmail(resultSet.getString("email"));
return reg;
}
return null;
});
if (registration != null) {
String sqlReqGroup = "SELECT group_id FROM invited_registration_request_group WHERE request_id = ?";
List<String> groupIds = jdbcTemplate.query(conn -> {
PreparedStatement ps = conn.prepareStatement(sqlReqGroup);
ps.setString(1, registration.getId());
return ps;
}, resultSet -> {
List<String> groups = new ArrayList<>();
while (resultSet.next()) {
groups.add(resultSet.getString("group_id"));
}
return groups;
});
registration.setGroupIds(groupIds);
}
return Optional.ofNullable(registration);
}
public void setRegistrationDone(String tokenHash) {
String sql = "UPDATE invited_registration_request SET done = true WHERE token_hash = ?";
jdbcTemplate.update(conn -> {
PreparedStatement ps = conn.prepareStatement(sql);
ps.setString(1, tokenHash);
return ps;
});
}
}
package it.inaf.ia2.gms.persistence.model;
import java.util.List;
public class InvitedRegistration {
private String id;
private String tokenHash;
private String email;
private boolean done;
private List<String> groupIds;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getTokenHash() {
return tokenHash;
}
public void setTokenHash(String tokenHash) {
this.tokenHash = tokenHash;
}
public String getEmail() {
return email;
}
public void setEmail(String email) {
this.email = email;
}
public boolean isDone() {
return done;
}
public void setDone(boolean done) {
this.done = done;
}
public List<String> getGroupIds() {
return groupIds;
}
public void setGroupIds(List<String> groupIds) {
this.groupIds = groupIds;
}
}
......@@ -31,14 +31,6 @@ CREATE TABLE gms_permission (
foreign key (group_path) references gms_group(path)
);
CREATE TABLE gms_client (
client_id varchar NOT NULL,
client_secret varchar NOT NULL,
allowed_actions text[] NOT NULL,
ip_filter text NULL,
primary key (client_id)
);
CREATE TABLE error_log (
"date" timestamp DEFAULT NOW(),
"exception_class" varchar,
......@@ -51,3 +43,19 @@ CREATE TABLE audit_log (
"ip_address" varchar,
"action" TEXT
);
CREATE TABLE invited_registration_request (
id varchar NOT NULL,
token_hash varchar NOT NULL,
email varchar NOT NULL,
creation_time timestamp DEFAULT NOW(),
done boolean,
PRIMARY KEY(id)
);
CREATE TABLE invited_registration_request_group (
request_id varchar NOT NULL,
group_id varchar NOT NULL,
PRIMARY KEY (request_id, group_id),
FOREIGN KEY (request_id) REFERENCES invited_registration_request(id)
);
package it.inaf.ia2.gms.authn;
import it.inaf.ia2.gms.persistence.ClientsDAO;
import it.inaf.ia2.gms.persistence.LoggingDAO;
import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.util.Collections;
import java.util.Optional;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(JUnit4.class)
public class ServiceBasicAuthFilterTest {
private ServiceBasicAuthFilter filter;
private HttpServletRequest request;
private HttpServletResponse response;
private FilterChain chain;
@Before
public void setUp() {
ClientsDAO clientsDAO = mock(ClientsDAO.class);
ClientEntity client = new ClientEntity();
client.setId("test");
client.setSecret("5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8"); // sha256 of "password"
client.setAllowedActions(Collections.singletonList("*"));
when(clientsDAO.findClientById("test")).thenReturn(Optional.of(client));
LoggingDAO loggingDAO = mock(LoggingDAO.class);
filter = spy(new ServiceBasicAuthFilter(loggingDAO));
doReturn(clientsDAO).when(filter).getClientsDAO(any());
request = mock(HttpServletRequest.class);
response = mock(HttpServletResponse.class);
chain = mock(FilterChain.class);
}
@Test
public void testValidCredentials() throws Exception {
when(request.getServletPath()).thenReturn("/ws/basic/group");
when(request.getHeader("Authorization")).thenReturn("Basic dGVzdDpwYXNzd29yZA=="); // test:password
filter.doFilter(request, response, chain);
verify(chain, times(1)).doFilter(any(), any());
}
@Test
public void testInvalidCredentials() throws Exception {
when(request.getServletPath()).thenReturn("/ws/basic/group");
when(request.getHeader("Authorization")).thenReturn("Basic dGVzdDp0ZXN0"); // test:test
filter.doFilter(request, response, chain);
verify(response, times(1)).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), any());
verify(chain, never()).doFilter(any(), any());
}
@Test
public void testMissingHeader() throws Exception {
when(request.getServletPath()).thenReturn("/ws/basic/group");
filter.doFilter(request, response, chain);
verify(response, times(1)).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), any());
verify(chain, never()).doFilter(any(), any());
}
@Test
public void testOutsidePath() throws Exception {
when(request.getServletPath()).thenReturn("/other/path");
filter.doFilter(request, response, chain);
verify(response, never()).sendError(anyInt(), any());
verify(chain, times(1)).doFilter(any(), any());
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment