Commit fd37936e authored by Sonia Zorba's avatar Sonia Zorba
Browse files

RapClient error handling improvements

parent dd8cfb0f
package it.inaf.ia2.gms.rap; package it.inaf.ia2.gms.rap;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import it.inaf.ia2.gms.authn.SessionData; import it.inaf.ia2.gms.authn.SessionData;
import it.inaf.ia2.gms.model.RapUser; import it.inaf.ia2.gms.model.RapUser;
import java.util.ArrayList; import java.util.ArrayList;
...@@ -21,6 +23,8 @@ import org.springframework.stereotype.Component; ...@@ -21,6 +23,8 @@ import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap; import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.HttpStatusCodeException;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
@Component @Component
...@@ -51,6 +55,8 @@ public class RapClient { ...@@ -51,6 +55,8 @@ public class RapClient {
private final RestTemplate refreshTokenRestTemplate; private final RestTemplate refreshTokenRestTemplate;
private final ObjectMapper objectMapper = new ObjectMapper();
@Autowired @Autowired
public RapClient(RestTemplate rapRestTemplate) { public RapClient(RestTemplate rapRestTemplate) {
this.rapRestTemplate = rapRestTemplate; this.rapRestTemplate = rapRestTemplate;
...@@ -101,14 +107,30 @@ public class RapClient { ...@@ -101,14 +107,30 @@ public class RapClient {
private <R, T> R httpCall(Function<HttpEntity<?>, R> function, T body) { private <R, T> R httpCall(Function<HttpEntity<?>, R> function, T body) {
try { try {
return function.apply(getEntity(body)); try {
} catch (HttpClientErrorException.Unauthorized ex) { return function.apply(getEntity(body));
if (request.getSession(false) == null) { } catch (HttpClientErrorException.Unauthorized ex) {
// we can't refresh the token without a session if (request.getSession(false) == null || sessionData.getExpiresIn() > 0) {
throw ex; // we can't refresh the token without a session
throw ex;
}
refreshToken();
return function.apply(getEntity(body));
}
} catch (HttpStatusCodeException ex) {
try {
Map<String, String> map = objectMapper.readValue(ex.getResponseBodyAsString(), Map.class);
if (map.containsKey("error")) {
String error = map.get("error");
if (ex instanceof HttpClientErrorException) {
throw new HttpClientErrorException(ex.getStatusCode(), error);
} else if (ex instanceof HttpServerErrorException) {
throw new HttpServerErrorException(ex.getStatusCode(), error);
}
}
} catch (JsonProcessingException ignore) {
} }
refreshToken(); throw ex;
return function.apply(getEntity(body));
} }
} }
......
package it.inaf.ia2.gms.rap;
import it.inaf.ia2.gms.authn.SessionData;
import it.inaf.ia2.gms.model.RapUser;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import org.mockito.Mock;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.mockito.junit.MockitoJUnitRunner;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpClientErrorException.Unauthorized;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.HttpServerErrorException.InternalServerError;
import org.springframework.web.client.RestTemplate;
@RunWith(MockitoJUnitRunner.class)
public class RapClientTest {
@Mock
private HttpServletRequest request;
@Mock
private SessionData sessionData;
@Mock
private RestTemplate restTemplate;
@Mock
private RestTemplate refreshTokenRestTemplate;
private RapClient rapClient;
@Before
public void init() {
rapClient = new RapClient(restTemplate);
ReflectionTestUtils.setField(rapClient, "request", request);
ReflectionTestUtils.setField(rapClient, "refreshTokenRestTemplate", refreshTokenRestTemplate);
ReflectionTestUtils.setField(rapClient, "scope", "openid");
}
@Test
public void testUnauthorizedNoRefreshJsonMsg() {
String jsonError = "{\"error\":\"Unauthorized: foo\"}";
HttpClientErrorException exception = Unauthorized
.create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8);
when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
}))).thenThrow(exception);
try {
rapClient.getUser("123");
} catch (HttpClientErrorException ex) {
assertEquals("401 Unauthorized: foo", ex.getMessage());
}
}
@Test
public void testUnauthorizedNoRefreshNotJsonMsg() {
String errorMessage = "THIS IS NOT A JSON";
HttpClientErrorException exception = Unauthorized
.create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, errorMessage.getBytes(), StandardCharsets.UTF_8);
when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
}))).thenThrow(exception);
try {
rapClient.getUser("123");
} catch (HttpClientErrorException ex) {
assertNotNull(ex.getMessage());
}
}
@Test
public void testServerErrorJsonMsg() {
String jsonError = "{\"error\":\"Fatal error\"}";
HttpServerErrorException exception = InternalServerError
.create(HttpStatus.INTERNAL_SERVER_ERROR, "500", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8);
when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
}))).thenThrow(exception);
try {
rapClient.getUser("123");
} catch (HttpServerErrorException ex) {
assertEquals("500 Fatal error", ex.getMessage());
}
}
@Test
public void testRefreshToken() {
when(request.getSession(eq(false))).thenReturn(mock(HttpSession.class));
when(sessionData.getExpiresIn()).thenReturn(-100l);
ReflectionTestUtils.setField(rapClient, "sessionData", sessionData);
ReflectionTestUtils.setField(rapClient, "clientId", "clientId");
ReflectionTestUtils.setField(rapClient, "clientSecret", "clientSecret");
ReflectionTestUtils.setField(rapClient, "accessTokenUri", "https://sso.ia2.inaf.it");
String jsonError = "{\"error\":\"Unauthorized: token expired\"}";
HttpClientErrorException exception = Unauthorized
.create(HttpStatus.UNAUTHORIZED, "401", HttpHeaders.EMPTY, jsonError.getBytes(), StandardCharsets.UTF_8);
when(restTemplate.exchange(anyString(), eq(HttpMethod.GET), any(HttpEntity.class), eq(new ParameterizedTypeReference<RapUser>() {
}))).thenThrow(exception)
.thenReturn(ResponseEntity.ok(new RapUser()));
ResponseEntity refreshTokenResponse = mock(ResponseEntity.class);
Map<String, Object> mockedBody = new HashMap<>();
mockedBody.put("access_token", "<access_token>");
mockedBody.put("refresh_token", "<refresh_token>");
mockedBody.put("expires_in", 3600);
when(refreshTokenResponse.getBody()).thenReturn(mockedBody);
when(refreshTokenRestTemplate.postForEntity(anyString(), any(HttpEntity.class), any()))
.thenReturn(refreshTokenResponse);
RapUser user = rapClient.getUser("123");
assertNotNull(user);
// verifies that token is refreshed
verify(sessionData, times(1)).setAccessToken(eq("<access_token>"));
verify(sessionData, times(1)).setExpiresIn(eq(3600));
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment