Commit fd37936e authored by Sonia Zorba's avatar Sonia Zorba
Browse files

RapClient error handling improvements

parent dd8cfb0f
Loading
Loading
Loading
Loading
+29 −7
Original line number Diff line number Diff line
package it.inaf.ia2.gms.rap;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import it.inaf.ia2.gms.authn.SessionData;
import it.inaf.ia2.gms.model.RapUser;
import java.util.ArrayList;
@@ -21,6 +23,8 @@ import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.HttpStatusCodeException;
import org.springframework.web.client.RestTemplate;

@Component
@@ -51,6 +55,8 @@ public class RapClient {

    private final RestTemplate refreshTokenRestTemplate;

    private final ObjectMapper objectMapper = new ObjectMapper();

    @Autowired
    public RapClient(RestTemplate rapRestTemplate) {
        this.rapRestTemplate = rapRestTemplate;
@@ -100,16 +106,32 @@ public class RapClient {
    }

    private <R, T> R httpCall(Function<HttpEntity<?>, R> function, T body) {
        try {
            try {
                return function.apply(getEntity(body));
            } catch (HttpClientErrorException.Unauthorized ex) {
            if (request.getSession(false) == null) {
                if (request.getSession(false) == null || sessionData.getExpiresIn() > 0) {
                    // we can't refresh the token without a session
                    throw ex;
                }
                refreshToken();
                return function.apply(getEntity(body));
            }
        } catch (HttpStatusCodeException ex) {
            try {
                Map<String, String> map = objectMapper.readValue(ex.getResponseBodyAsString(), Map.class);
                if (map.containsKey("error")) {
                    String error = map.get("error");
                    if (ex instanceof HttpClientErrorException) {
                        throw new HttpClientErrorException(ex.getStatusCode(), error);
                    } else if (ex instanceof HttpServerErrorException) {
                        throw new HttpServerErrorException(ex.getStatusCode(), error);
                    }
                }
            } catch (JsonProcessingException ignore) {
            }
            throw ex;
        }
    }

    private <T> HttpEntity<T> getEntity(T body) {
+153 −0
Original line number Diff line number Diff line
package it.inaf.ia2.gms.rap;

import it.inaf.ia2.gms.authn.SessionData;
import it.inaf.ia2.gms.model.RapUser;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import org.mockito.Mock;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpClientErrorException.Unauthorized;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.HttpServerErrorException.InternalServerError;
import org.springframework.web.client.RestTemplate;

@RunWith(MockitoJUnitRunner.class)
public class RapClientTest {

    @Mock
    private HttpServletRequest request;

    @Mock
    private SessionData sessionData;

    @Mock
    private RestTemplate restTemplate;

    @Mock
    private RestTemplate refreshTokenRestTemplate;

    private RapClient rapClient;

    @Before
    public void init() {
        rapClient = new RapClient(restTemplate);
        ReflectionTestUtils.setField(rapClient, "request", request);
        ReflectionTestUtils.setField(rapClient, "refreshTokenRestTemplate", refreshTokenRestTemplate);
        ReflectionTestUtils.setField(rapClient, "scope", "openid");
    }

    @Test
    public void testUnauthorizedNoRefreshJsonMsg() {

        String jsonError = "{\"error\":\"Unauthorized: foo\"}";

        HttpClientErrorException exception = Unauthorized
                .create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8);

        when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
        }))).thenThrow(exception);

        try {
            rapClient.getUser("123");
        } catch (HttpClientErrorException ex) {
            assertEquals("401 Unauthorized: foo", ex.getMessage());
        }
    }

    @Test
    public void testUnauthorizedNoRefreshNotJsonMsg() {

        String errorMessage = "THIS IS NOT A JSON";

        HttpClientErrorException exception = Unauthorized
                .create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, errorMessage.getBytes(), StandardCharsets.UTF_8);

        when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
        }))).thenThrow(exception);

        try {
            rapClient.getUser("123");
        } catch (HttpClientErrorException ex) {
            assertNotNull(ex.getMessage());
        }
    }

    @Test
    public void testServerErrorJsonMsg() {

        String jsonError = "{\"error\":\"Fatal error\"}";

        HttpServerErrorException exception = InternalServerError
                .create(HttpStatus.INTERNAL_SERVER_ERROR, "500", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8);

        when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
        }))).thenThrow(exception);

        try {
            rapClient.getUser("123");
        } catch (HttpServerErrorException ex) {
            assertEquals("500 Fatal error", ex.getMessage());
        }
    }

    @Test
    public void testRefreshToken() {

        when(request.getSession(eq(false))).thenReturn(mock(HttpSession.class));
        when(sessionData.getExpiresIn()).thenReturn(-100l);

        ReflectionTestUtils.setField(rapClient, "sessionData", sessionData);
        ReflectionTestUtils.setField(rapClient, "clientId", "clientId");
        ReflectionTestUtils.setField(rapClient, "clientSecret", "clientSecret");
        ReflectionTestUtils.setField(rapClient, "accessTokenUri", "https://sso.ia2.inaf.it");

        String jsonError = "{\"error\":\"Unauthorized: token expired\"}";

        HttpClientErrorException exception = Unauthorized
                .create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8);

        when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
        }))).thenThrow(exception)
                .thenReturn(ResponseEntity.ok(new RapUser()));

        ResponseEntity refreshTokenResponse = mock(ResponseEntity.class);
        Map<String, Object> mockedBody = new HashMap<>();
        mockedBody.put("access_token", "<access_token>");
        mockedBody.put("refresh_token", "<refresh_token>");
        mockedBody.put("expires_in", 3600);
        when(refreshTokenResponse.getBody()).thenReturn(mockedBody);

        when(refreshTokenRestTemplate.postForEntity(anyString(), any(HttpEntity.class), any()))
                .thenReturn(refreshTokenResponse);

        RapUser user = rapClient.getUser("123");
        assertNotNull(user);

        // verifies that token is refreshed
        verify(sessionData, times(1)).setAccessToken(eq("<access_token>"));
        verify(sessionData, times(1)).setExpiresIn(eq(3600));
    }
}