Commit ba6b02eb authored by Sonia Zorba's avatar Sonia Zorba
Browse files

Added persistence layer for invited registration

parent a989ae7c
Loading
Loading
Loading
Loading
+0 −12
Original line number Original line Diff line number Diff line
@@ -67,18 +67,6 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
        web.ignoring().antMatchers("/ws/basic/**", "/ws/jwt/**", "/error", "/logout");
        web.ignoring().antMatchers("/ws/basic/**", "/ws/jwt/**", "/error", "/logout");
    }
    }


    /**
     * Checks the BasicAuth for GMS clients.
     */
    @Bean
    public FilterRegistrationBean serviceBasicAuthFilter(LoggingDAO loggingDAO) {
        FilterRegistrationBean bean = new FilterRegistrationBean();
        bean.setFilter(new ServiceBasicAuthFilter(loggingDAO));
        bean.addUrlPatterns("/ws/basic/*");
        bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
        return bean;
    }

    /**
    /**
     * Checks JWT for web services.
     * Checks JWT for web services.
     */
     */
+0 −104
Original line number Original line Diff line number Diff line
package it.inaf.ia2.gms.authn;

import it.inaf.ia2.gms.exception.UnauthorizedException;
import it.inaf.ia2.gms.persistence.ClientsDAO;
import it.inaf.ia2.gms.persistence.LoggingDAO;
import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.xml.bind.DatatypeConverter;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;

public class ServiceBasicAuthFilter implements Filter {

    private final LoggingDAO loggingDAO;

    public ServiceBasicAuthFilter(LoggingDAO loggingDAO) {
        this.loggingDAO = loggingDAO;
    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) req;

        if (request.getServletPath().startsWith("/ws/")) {
            try {
                validateBasicAuth(request);
            } catch (UnauthorizedException ex) {
                loggingDAO.logAction("Unauthorized BasicAuth WS request");
                ((HttpServletResponse) res).sendError(HttpServletResponse.SC_UNAUTHORIZED, ex.getMessage());
                return;
            }
        }

        loggingDAO.logAction("BasicAuth WS request");

        chain.doFilter(req, res);
    }

    private void validateBasicAuth(HttpServletRequest request) {

        String token = getBasicAuthToken(request);

        int delim = token.indexOf(":");

        if (delim == -1) {
            throw new BadCredentialsException("Invalid basic authentication token");
        }

        String clientId = token.substring(0, delim);
        String clientSecret = token.substring(delim + 1);

        ClientsDAO clientsDAO = getClientsDAO(request);

        ClientEntity client = clientsDAO.findClientById(clientId)
                .orElseThrow(() -> new BadCredentialsException("Client " + clientId + " not found"));

        String shaSecret = getSha256(clientSecret);
        if (!shaSecret.equals(client.getSecret())) {
            throw new UnauthorizedException("Wrong secret");
        }
    }

    private String getBasicAuthToken(HttpServletRequest request) {

        String header = request.getHeader("Authorization");

        if (header == null || !header.toLowerCase().startsWith("basic ")) {
            throw new UnauthorizedException("Missing Authorization header");
        }

        byte[] base64Token = header.substring(6).getBytes(StandardCharsets.UTF_8);
        byte[] decoded = Base64.getDecoder().decode(base64Token);

        return new String(decoded, StandardCharsets.UTF_8);
    }

    protected ClientsDAO getClientsDAO(HttpServletRequest request) {
        WebApplicationContext webApplicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
        return webApplicationContext.getBean(ClientsDAO.class);
    }

    private static String getSha256(String secret) {
        try {
            MessageDigest md = MessageDigest.getInstance("SHA-256");
            byte[] sha = md.digest(secret.getBytes(StandardCharsets.UTF_8));
            return DatatypeConverter.printHexBinary(sha).toLowerCase();
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
    }
}
+0 −58
Original line number Original line Diff line number Diff line
package it.inaf.ia2.gms.persistence;

import it.inaf.ia2.gms.persistence.model.ClientEntity;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.sql.DataSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;

@Component
public class ClientsDAO {

    private final JdbcTemplate jdbcTemplate;

    @Autowired
    public ClientsDAO(DataSource dataSource) {
        jdbcTemplate = new JdbcTemplate(dataSource);
    }

    public Optional<ClientEntity> findClientById(String clientId) {

        String sql = "SELECT client_secret, allowed_actions, ip_filter FROM gms_client WHERE client_id = ?";

        return jdbcTemplate.query(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            ps.setString(1, clientId);
            return ps;
        }, resultSet -> {
            if (resultSet.next()) {
                ClientEntity client = new ClientEntity();
                client.setId(clientId);
                client.setSecret(resultSet.getString("client_secret"));
                client.setAllowedActions(getAllowedActions(resultSet));
                client.setIpFilter(resultSet.getString("ip_filter"));
                return Optional.of(client);
            }
            return Optional.empty();
        });
    }

    private List<String> getAllowedActions(ResultSet resultSet) throws SQLException {

        List<String> actions = new ArrayList<>();

        ResultSet items = resultSet.getArray("allowed_actions").getResultSet();
        while (items.next()) {
            String action = items.getString(1);
            actions.add(action);
        }

        return actions;
    }
}
+99 −0
Original line number Original line Diff line number Diff line
package it.inaf.ia2.gms.persistence;

import it.inaf.ia2.gms.persistence.model.InvitedRegistration;
import java.sql.PreparedStatement;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import javax.sql.DataSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional;

@Component
public class InvitedRegistrationDAO {

    private final JdbcTemplate jdbcTemplate;

    @Autowired
    public InvitedRegistrationDAO(DataSource dataSource) {
        jdbcTemplate = new JdbcTemplate(dataSource);
    }

    @Transactional
    public void addInvitedRegistration(InvitedRegistration invitedRegistration) {

        String sqlReq = "INSERT INTO invited_registration_request (id, token_hash, email) VALUES (?, ?, ?)";

        jdbcTemplate.update(conn -> {
            PreparedStatement ps = conn.prepareStatement(sqlReq);
            ps.setString(1, invitedRegistration.getId());
            ps.setString(2, invitedRegistration.getTokenHash());
            ps.setObject(3, invitedRegistration.getEmail());
            return ps;
        });

        for (String groupId : invitedRegistration.getGroupIds()) {
            String sqlReqGroup = "INSERT INTO invited_registration_request_group (request_id, group_id) VALUES (?, ?)";

            jdbcTemplate.update(conn -> {
                PreparedStatement ps = conn.prepareStatement(sqlReqGroup);
                ps.setString(1, invitedRegistration.getId());
                ps.setString(2, groupId);
                return ps;
            });
        }
    }

    public Optional<InvitedRegistration> getInvitedRegistrationFromToken(String tokenHash) {

        String sqlReq = "SELECT id, email FROM invited_registration_request WHERE token_hash = ? AND !done";

        InvitedRegistration registration = jdbcTemplate.query(conn -> {
            PreparedStatement ps = conn.prepareStatement(sqlReq);
            ps.setString(1, tokenHash);
            return ps;
        }, resultSet -> {
            if (resultSet.next()) {
                InvitedRegistration reg = new InvitedRegistration();
                reg.setId(resultSet.getString("id"));
                reg.setEmail(resultSet.getString("email"));
                return reg;
            }
            return null;
        });

        if (registration != null) {

            String sqlReqGroup = "SELECT group_id FROM invited_registration_request_group WHERE request_id = ?";

            List<String> groupIds = jdbcTemplate.query(conn -> {
                PreparedStatement ps = conn.prepareStatement(sqlReqGroup);
                ps.setString(1, registration.getId());
                return ps;
            }, resultSet -> {
                List<String> groups = new ArrayList<>();
                while (resultSet.next()) {
                    groups.add(resultSet.getString("group_id"));
                }
                return groups;
            });

            registration.setGroupIds(groupIds);
        }

        return Optional.ofNullable(registration);
    }

    public void setRegistrationDone(String tokenHash) {

        String sql = "UPDATE invited_registration_request SET done = true WHERE token_hash = ?";

        jdbcTemplate.update(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            ps.setString(1, tokenHash);
            return ps;
        });
    }
}
+52 −0
Original line number Original line Diff line number Diff line
package it.inaf.ia2.gms.persistence.model;

import java.util.List;

public class InvitedRegistration {

    private String id;
    private String tokenHash;
    private String email;
    private boolean done;
    private List<String> groupIds;

    public String getId() {
        return id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public String getTokenHash() {
        return tokenHash;
    }

    public void setTokenHash(String tokenHash) {
        this.tokenHash = tokenHash;
    }

    public String getEmail() {
        return email;
    }

    public void setEmail(String email) {
        this.email = email;
    }

    public boolean isDone() {
        return done;
    }

    public void setDone(boolean done) {
        this.done = done;
    }

    public List<String> getGroupIds() {
        return groupIds;
    }

    public void setGroupIds(List<String> groupIds) {
        this.groupIds = groupIds;
    }
}
Loading