Newer
Older
import it.inaf.ia2.gms.authn.SessionData;
import it.inaf.ia2.gms.model.RapUser;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.apache.commons.codec.binary.Base64;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;
@Value("${rap.ws-url}")
private String rapBaseUrl;
@Value("${security.oauth2.client.access-token-uri}")
private String accessTokenUri;
@Value("${security.oauth2.client.client-id}")
private String clientId;
@Value("${security.oauth2.client.client-secret}")
private String clientSecret;
@Value("${security.oauth2.client.scope}")
private String scope;
/* Use basic auth instead of JWT when asking for users */
@Value("${rap.ws.basic-auth}")
private boolean basicAuth;
private final SessionData sessionData;
private final RestTemplate rapRestTemplate;
private final RestTemplate refreshTokenRestTemplate;
@Autowired
public RapClient(SessionData sessionData, RestTemplate rapRestTemplate) {
this.sessionData = sessionData;
this.rapRestTemplate = rapRestTemplate;
this.refreshTokenRestTemplate = new RestTemplate();
}
public List<RapUser> getUsers(Set<String> identifiers) {
if (identifiers.isEmpty()) {
return new ArrayList<>();
}
String url = rapBaseUrl + "/user?identifiers=" + String.join(",", identifiers);
return httpCall(entity -> {
return rapRestTemplate.exchange(url, HttpMethod.GET, entity, new ParameterizedTypeReference<List<RapUser>>() {
}).getBody();
});
public List<RapUser> searchUsers(String searchText) {
if (searchText == null || searchText.trim().isEmpty()) {
return new ArrayList<>();
}
String url = rapBaseUrl + "/user?search=" + searchText;
return httpCall(entity -> {
return rapRestTemplate.exchange(url, HttpMethod.GET, entity, new ParameterizedTypeReference<List<RapUser>>() {
}).getBody();
});
}
private <R> R httpCall(Function<HttpEntity<?>, R> function) {
return httpCall(function, null);
private <R, T> R httpCall(Function<HttpEntity<?>, R> function, T body) {
try {
return function.apply(getEntity(body));
} catch (HttpClientErrorException.Unauthorized ex) {
refreshToken();
return function.apply(getEntity(body));
}
}
private <T> HttpEntity<T> getEntity(T body) {
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
if (basicAuth) {
String auth = clientId + ":" + clientSecret;
String encodedAuth = Base64.encodeBase64String(auth.getBytes());
headers.add("Authorization", "Basic " + encodedAuth);
} else {
headers.add("Authorization", "Bearer " + sessionData.getAccessToken());
}
return new HttpEntity<>(body, headers);
}
public void refreshToken() {
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
headers.setBasicAuth(clientId, clientSecret);
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
MultiValueMap<String, String> map = new LinkedMultiValueMap<>();
map.add("grant_type", "refresh_token");
map.add("refresh_token", sessionData.getRefreshToken());
map.add("scope", scope.replace(",", " "));
HttpEntity<MultiValueMap<String, String>> request = new HttpEntity<>(map, headers);
ResponseEntity<Map> response = refreshTokenRestTemplate.postForEntity(accessTokenUri, request, Map.class);
Map<String, Object> values = response.getBody();
sessionData.setAccessToken((String) values.get("access_token"));
sessionData.setRefreshToken((String) values.get("refresh_token"));
sessionData.setExpiresIn((int) values.get("expires_in"));