package it.inaf.ia2.gms.persistence;

import it.inaf.ia2.gms.persistence.model.GroupEntity;
import it.inaf.ia2.gms.persistence.model.MembershipEntity;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Component;

@Component
public class MembershipsDAO {

    private final JdbcTemplate jdbcTemplate;

    @Autowired
    public MembershipsDAO(DataSource dataSource) {
        jdbcTemplate = new JdbcTemplate(dataSource);
    }

    public List<MembershipEntity> findByGroup(String groupId) {

        String sql = "SELECT user_id, group_id, created_by, creation_time FROM gms_membership WHERE group_id = ?";

        return jdbcTemplate.query(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            ps.setString(1, groupId);
            return ps;
        }, resultSet -> {
            return getMembershipsFromResultSet(resultSet);
        });
    }

    public List<MembershipEntity> findByGroups(List<String> groupIds) {

        if (groupIds.isEmpty()) {
            return new ArrayList<>();
        }

        String sql = "SELECT user_id, group_id, created_by, creation_time FROM gms_membership WHERE group_id IN ("
                + String.join(",", Collections.nCopies(groupIds.size(), "?")) + ")";

        return jdbcTemplate.query(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            int i = 0;
            for (String groupId : groupIds) {
                ps.setString(++i, groupId);
            }
            return ps;
        }, resultSet -> {
            return getMembershipsFromResultSet(resultSet);
        });
    }

    private List<MembershipEntity> getMembershipsFromResultSet(ResultSet resultSet) throws SQLException {
        List<MembershipEntity> members = new ArrayList<>();
        while (resultSet.next()) {
            members.add(getMembershipEntityFromResultSet(resultSet));
        }
        return members;
    }

    private MembershipEntity getMembershipEntityFromResultSet(ResultSet resultSet) throws SQLException {
        MembershipEntity membership = new MembershipEntity();
        membership.setGroupId(resultSet.getString("group_id"));
        membership.setUserId(resultSet.getString("user_id"));
        membership.setUserId(resultSet.getString("user_id"));
        membership.setCreatedBy(resultSet.getString("created_by"));
        membership.setCreationTime(new Date(resultSet.getDate("creation_time").getTime()));
        return membership;
    }

    public List<GroupEntity> getUserMemberships(String userId) {
        return getUserMemberships(userId, null);
    }

    public List<GroupEntity> getUserMemberships(String userId, String parentPath) {

        String sql = "SELECT g.id, g.name, g.path, g.is_leaf, g.creation_time, g.created_by "
                + " FROM gms_membership m "
                + " JOIN gms_group g ON g.id = m.group_id"
                + " WHERE m.user_id = ?";
        if (parentPath != null) {
            sql += " AND g.path <@ ? AND g.path <> ? ORDER BY nlevel(g.path) DESC";
        }

        String query = sql;
        return jdbcTemplate.query(conn -> {
            PreparedStatement ps = conn.prepareStatement(query);
            int i = 0;
            ps.setString(++i, userId);
            if (parentPath != null) {
                ps.setObject(++i, parentPath, Types.OTHER);
                ps.setObject(++i, parentPath, Types.OTHER);
            }
            return ps;
        }, resultSet -> {
            List<GroupEntity> memberships = new ArrayList<>();
            while (resultSet.next()) {
                GroupEntity group = new GroupEntity();
                group.setId(resultSet.getString("id"));
                group.setName(resultSet.getString("name"));
                group.setPath(resultSet.getString("path"));
                group.setLeaf(resultSet.getBoolean("is_leaf"));
                group.setCreationTime(new Date(resultSet.getDate("creation_time").getTime()));
                group.setCreatedBy(resultSet.getString("created_by"));
                memberships.add(group);
            }
            return memberships;
        });
    }

    public boolean isMemberOf(String userId, String groupId) {

        String sql = "SELECT COUNT(*) FROM gms_membership "
                + " WHERE user_id = ? AND group_id = ?";

        return jdbcTemplate.query(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            ps.setString(1, userId);
            ps.setString(2, groupId);
            return ps;
        }, resultSet -> {
            resultSet.next();
            int count = resultSet.getInt(1);
            return count == 1;
        });
    }

    public MembershipEntity addMember(MembershipEntity membership) {

        String sql = "INSERT INTO gms_membership (group_id, user_id, created_by) VALUES (?, ?, ?)\n"
                + "ON CONFLICT (group_id, user_id) DO NOTHING";

        jdbcTemplate.update(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            ps.setString(1, membership.getGroupId());
            ps.setString(2, membership.getUserId());
            ps.setString(3, membership.getCreatedBy());
            return ps;
        });

        return membership;
    }

    public void removeMembership(String groupId, String userId) {

        String sql = "DELETE FROM gms_membership WHERE group_id = ? AND user_id = ?";

        jdbcTemplate.update(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            ps.setString(1, groupId);
            ps.setString(2, userId);
            return ps;
        });
    }

    public void deleteAllGroupsMembership(List<String> groupIds) {

        if (groupIds.isEmpty()) {
            return;
        }

        String sql = "DELETE FROM gms_membership WHERE group_id IN ("
                + String.join(",", groupIds.stream().map(g -> "?").collect(Collectors.toList()))
                + ")";

        jdbcTemplate.update(conn -> {
            PreparedStatement ps = conn.prepareStatement(sql);
            int i = 0;
            for (String groupId : groupIds) {
                ps.setString(++i, groupId);
            }
            return ps;
        });
    }
}
