package it.inaf.ia2.gms.authn; import it.inaf.ia2.gms.persistence.ClientsDAO; import it.inaf.ia2.gms.persistence.LoggingDAO; import it.inaf.ia2.gms.persistence.model.ClientEntity; import java.util.Collections; import java.util.Optional; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @RunWith(JUnit4.class) public class ServiceBasicAuthFilterTest { private ServiceBasicAuthFilter filter; private HttpServletRequest request; private HttpServletResponse response; private FilterChain chain; @Before public void setUp() { ClientsDAO clientsDAO = mock(ClientsDAO.class); ClientEntity client = new ClientEntity(); client.setId("test"); client.setSecret("5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8"); // sha256 of "password" client.setAllowedActions(Collections.singletonList("*")); when(clientsDAO.findClientById("test")).thenReturn(Optional.of(client)); LoggingDAO loggingDAO = mock(LoggingDAO.class); filter = spy(new ServiceBasicAuthFilter(loggingDAO)); doReturn(clientsDAO).when(filter).getClientsDAO(any()); request = mock(HttpServletRequest.class); response = mock(HttpServletResponse.class); chain = mock(FilterChain.class); } @Test public void testValidCredentials() throws Exception { when(request.getServletPath()).thenReturn("/ws/basic/group"); when(request.getHeader("Authorization")).thenReturn("Basic dGVzdDpwYXNzd29yZA=="); // test:password filter.doFilter(request, response, chain); verify(chain, times(1)).doFilter(any(), any()); } @Test public void testInvalidCredentials() throws Exception { when(request.getServletPath()).thenReturn("/ws/basic/group"); when(request.getHeader("Authorization")).thenReturn("Basic dGVzdDp0ZXN0"); // test:test filter.doFilter(request, response, chain); verify(response, times(1)).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), any()); verify(chain, never()).doFilter(any(), any()); } @Test public void testMissingHeader() throws Exception { when(request.getServletPath()).thenReturn("/ws/basic/group"); filter.doFilter(request, response, chain); verify(response, times(1)).sendError(eq(HttpServletResponse.SC_UNAUTHORIZED), any()); verify(chain, never()).doFilter(any(), any()); } @Test public void testOutsidePath() throws Exception { when(request.getServletPath()).thenReturn("/other/path"); filter.doFilter(request, response, chain); verify(response, never()).sendError(anyInt(), any()); verify(chain, times(1)).doFilter(any(), any()); } }