Commit 359bf2b9 authored by Sonia Zorba's avatar Sonia Zorba
Browse files

Added BasicAuth/JWT service endpoints

parent fa4bf6f1
......@@ -13,8 +13,8 @@ public class CustomIdTokenConverter extends DefaultUserAuthenticationConverter {
private final JwkTokenStore jwkTokenStore;
public CustomIdTokenConverter(String keySetUri) {
this.jwkTokenStore = new JwkTokenStore(keySetUri);
public CustomIdTokenConverter(JwkTokenStore jwkTokenStore) {
this.jwkTokenStore = jwkTokenStore;
}
@Override
......
package it.inaf.ia2.gms.authn;
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore;
public class JWTFilter implements Filter {
private final JwkTokenStore jwkTokenStore;
public JWTFilter(JwkTokenStore jwkTokenStore) {
this.jwkTokenStore = jwkTokenStore;
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain fc) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
HttpServletResponse response = (HttpServletResponse) res;
String authHeader = request.getHeader("Authorization");
if (authHeader == null) {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Missing Authorization token");
return;
}
authHeader = authHeader.replace("Bearer", "").trim();
OAuth2AccessToken accessToken = jwkTokenStore.readAccessToken(authHeader);
if (accessToken.isExpired()) {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Access token is expired");
return;
}
Map<String, Object> claims = accessToken.getAdditionalInformation();
String principal = (String) claims.get("sub");
if (principal == null) {
response.sendError(HttpServletResponse.SC_UNAUTHORIZED, "Invalid access token: missing sub claim");
return;
}
ServletRequest wrappedRequest = new ServletRequestWithJWTPrincipal(request, principal);
fc.doFilter(wrappedRequest, res);
}
private static class ServletRequestWithJWTPrincipal extends HttpServletRequestWrapper {
private final String principal;
public ServletRequestWithJWTPrincipal(HttpServletRequest request, String principal) {
super(request);
this.principal = principal;
}
@Override
public Principal getUserPrincipal() {
return new UsernamePasswordAuthenticationToken(principal, null);
}
}
}
......@@ -15,6 +15,7 @@ import org.springframework.security.oauth2.provider.ClientDetailsService;
import org.springframework.security.oauth2.provider.client.InMemoryClientDetailsService;
import org.springframework.security.oauth2.provider.token.DefaultAccessTokenConverter;
import org.springframework.security.oauth2.provider.token.RemoteTokenServices;
import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore;
import org.springframework.web.client.RestTemplate;
/**
......@@ -24,9 +25,6 @@ import org.springframework.web.client.RestTemplate;
@Configuration
public class OAuth2Config extends AuthorizationServerEndpointsConfiguration {
@Value("${security.oauth2.resource.jwk.key-set-uri}")
private String keySetUri;
@Value("${security.oauth2.resource.token-info-uri}")
private String checkTokenEndpointUrl;
......@@ -37,11 +35,11 @@ public class OAuth2Config extends AuthorizationServerEndpointsConfiguration {
private String clientSecret;
@Bean
public RemoteTokenServices resourceServerTokenServices() {
public RemoteTokenServices resourceServerTokenServices(JwkTokenStore jwkTokenStore) {
RemoteTokenServices tokenService = new RemoteTokenServices();
DefaultAccessTokenConverter accessTokenConverter = new DefaultAccessTokenConverter();
accessTokenConverter.setUserTokenConverter(new CustomIdTokenConverter(keySetUri));
accessTokenConverter.setUserTokenConverter(new CustomIdTokenConverter(jwkTokenStore));
tokenService.setAccessTokenConverter(accessTokenConverter);
tokenService.setCheckTokenEndpointUrl(checkTokenEndpointUrl);
......
......@@ -14,6 +14,7 @@ import org.springframework.http.HttpMethod;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.builders.WebSecurity;
import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
import org.springframework.security.oauth2.provider.token.store.jwk.JwkTokenStore;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.cors.UrlBasedCorsConfigurationSource;
import org.springframework.web.filter.CorsFilter;
......@@ -28,6 +29,14 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
@Value("${cors.allowed.origin}")
private String corsAllowedOrigin;
@Value("${security.oauth2.resource.jwk.key-set-uri}")
private String keySetUri;
@Bean
public JwkTokenStore jwkTokenStore() {
return new JwkTokenStore(keySetUri);
}
@Override
public void configure(HttpSecurity http) throws Exception {
......@@ -43,23 +52,36 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter {
}
/**
* The authentication is ignored for these endpoints. The "/ws" endpoints
* (web service API for programmatic access) are protected by the custom
* WebServiceAuthorizationFilter that checks BasicAuth for GMS clients.
* The authentication is ignored for these endpoints. The "/ws/basic"
* endpoints (web service API for programmatic access) are protected by the
* custom ServiceBasicAuthFilter that checks BasicAuth for GMS clients,
* while the "/ws/jwt" endpoints are protected by the JWTFilter.
*/
@Override
public void configure(WebSecurity web) throws Exception {
web.ignoring().antMatchers("/ws/**", "/error");
web.ignoring().antMatchers("/ws/basic/**", "/ws/jwt/**", "/error");
}
/**
* Checks the BasicAuth for GMS clients.
*/
@Bean
public FilterRegistrationBean webServiceAuthorizationFilter() {
public FilterRegistrationBean serviceBasicAuthFilter() {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setFilter(new ServiceBasicAuthFilter());
bean.addUrlPatterns("/ws/basic/*");
bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
return bean;
}
/**
* Checks JWT for web services.
*/
@Bean
public FilterRegistrationBean serviceJWTFilter(JwkTokenStore jwkTokenStore) {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setFilter(new WebServiceAuthorizationFilter());
bean.addUrlPatterns("/ws/*");
bean.setFilter(new JWTFilter(jwkTokenStore));
bean.addUrlPatterns("/ws/jwt/*");
bean.setOrder(Ordered.HIGHEST_PRECEDENCE);
return bean;
}
......
......@@ -20,7 +20,7 @@ import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
public class WebServiceAuthorizationFilter implements Filter {
public class ServiceBasicAuthFilter implements Filter {
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
......
......@@ -15,6 +15,7 @@ public class SessionData {
private HttpServletRequest request;
private String userId;
private String userName;
private String accessToken;
private String refreshToken;
......@@ -22,6 +23,7 @@ public class SessionData {
public void init() {
CustomAuthenticationData authn = (CustomAuthenticationData) ((OAuth2Authentication) request.getUserPrincipal()).getUserAuthentication();
userId = (String) authn.getPrincipal();
userName = (String) authn.getAttributes().get("name");
accessToken = (String) authn.getAccessToken().getValue();
refreshToken = authn.getRefreshToken();
}
......@@ -45,4 +47,8 @@ public class SessionData {
public void setRefreshToken(String refreshToken) {
this.refreshToken = refreshToken;
}
public String getUserName() {
return userName;
}
}
......@@ -25,9 +25,12 @@ import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
/**
* Controller for programmatic access using registered clients.
*/
@RestController
@RequestMapping("/ws")
public class WebServiceController {
@RequestMapping("/ws/basic")
public class BasicAuthWebServiceController {
@Autowired
private GroupsService groupsService;
......
......@@ -27,7 +27,7 @@ public class HomePageController {
HomePageResponse response = new HomePageResponse();
response.setUser(session.getUserId());
response.setUser(session.getUserName());
GroupsTabResponse groupsTabResponse = groupsTabResponseBuilder.getGroupsTab(request);
response.setBreadcrumbs(groupsTabResponse.getBreadcrumbs());
......
package it.inaf.ia2.gms.controller;
import it.inaf.ia2.gms.persistence.GroupsDAO;
import it.inaf.ia2.gms.persistence.MembershipsDAO;
import it.inaf.ia2.gms.persistence.model.GroupEntity;
import java.io.IOException;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
/**
* Web service called by other web applications using JWT (delegation).
*/
@RestController
@RequestMapping("/ws/jwt")
public class JWTWebServiceController {
@Autowired
private MembershipsDAO membershipsDAO;
@Autowired
private GroupsDAO groupsDAO;
/**
* This endpoint is compliant with the IVOA GMS standard.
*/
@GetMapping(value = "/search", produces = MediaType.TEXT_PLAIN_VALUE)
public void getGroups(Principal principal, HttpServletResponse response) throws IOException {
List<GroupEntity> memberships = membershipsDAO.getUserMemberships(principal.getName());
// We need to return the complete group name, so it is necessary to load
// all the parents too.
Map<String, String> idNameMap = new HashMap<>();
Set<String> allIdentifiers = getAllIdentifiers(memberships);
for (GroupEntity group : groupsDAO.findGroupsByIds(allIdentifiers)) {
idNameMap.put(group.getId(), group.getName());
}
try (PrintWriter pw = new PrintWriter(response.getOutputStream())) {
for (GroupEntity group : memberships) {
pw.println(getGroupCompleteName(group, idNameMap));
}
}
}
private Set<String> getAllIdentifiers(List<GroupEntity> groups) {
Set<String> allIdentifiers = new HashSet<>();
for (GroupEntity group : groups) {
if (!"".equals(group.getPath())) {
String[] ids = group.getPath().split("\\.");
for (String id : ids) {
allIdentifiers.add(id);
}
}
}
return allIdentifiers;
}
private String getGroupCompleteName(GroupEntity group, Map<String, String> idNameMap) {
List<String> names = new ArrayList<>();
for (String groupId : group.getPath().split("\\.")) {
String groupName = idNameMap.get(groupId);
// Dot inside names is considered a special character (because it is
// used to separate the group from its parents), so we use a
// backslash to escape it (client apps need to be aware of this).
groupName = groupName.replace("\\.", "\\\\.");
names.add(groupName);
}
return String.join(".", names);
}
}
......@@ -103,6 +103,37 @@ public class GroupsDAO {
});
}
public List<GroupEntity> findGroupsByIds(Set<String> identifiers) {
if (identifiers.isEmpty()) {
return new ArrayList<>();
}
return jdbcTemplate.query(conn -> {
String sql = "SELECT id, name, path from gms_group WHERE id IN (";
sql += String.join(",", identifiers.stream().map(p -> "?").collect(Collectors.toList()));
sql += ")";
PreparedStatement ps = conn.prepareStatement(sql);
int i = 0;
for (String id : identifiers) {
ps.setString(++i, id);
}
return ps;
}, resultSet -> {
List<GroupEntity> groups = new ArrayList<>();
while (resultSet.next()) {
GroupEntity group = new GroupEntity();
group.setId(resultSet.getString("id"));
group.setName(resultSet.getString("name"));
group.setPath(resultSet.getString("path"));
groups.add(group);
}
return groups;
});
}
public Optional<GroupEntity> findGroupByParentAndName(String parentPath, String childName) {
String sql = "SELECT id, path from gms_group WHERE name = ? AND path ~ ?";
......
package it.inaf.ia2.gms.persistence;
import it.inaf.ia2.gms.persistence.model.GroupEntity;
import it.inaf.ia2.gms.persistence.model.MembershipEntity;
import java.sql.PreparedStatement;
import java.util.ArrayList;
......@@ -40,6 +41,47 @@ public class MembershipsDAO {
});
}
public List<GroupEntity> getUserMemberships(String userId) {
String sql = "SELECT g.id, g.name, g.path FROM "
+ " gms_membership m "
+ " JOIN gms_group g ON g.id = m.group_id"
+ " WHERE m.user_id = ?";
return jdbcTemplate.query(conn -> {
PreparedStatement ps = conn.prepareStatement(sql);
ps.setString(1, userId);
return ps;
}, resultSet -> {
List<GroupEntity> memberships = new ArrayList<>();
while (resultSet.next()) {
GroupEntity group = new GroupEntity();
group.setId(resultSet.getString("id"));
group.setName(resultSet.getString("name"));
group.setPath(resultSet.getString("path"));
memberships.add(group);
}
return memberships;
});
}
public boolean isMemberOf(String userId, String groupId) {
String sql = "SELECT COUNT(*) FROM gms_membership "
+ " WHERE user_id = ? AND group_id = ?";
return jdbcTemplate.query(conn -> {
PreparedStatement ps = conn.prepareStatement(sql);
ps.setString(1, userId);
ps.setString(2, groupId);
return ps;
}, resultSet -> {
resultSet.next();
int count = resultSet.getInt(1);
return count == 1;
});
}
public MembershipEntity addMember(MembershipEntity membership) {
String sql = "INSERT INTO gms_membership (group_id, user_id) VALUES (?, ?)";
......
......@@ -23,9 +23,9 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@RunWith(JUnit4.class)
public class WebServiceAuthorizationFilterTest {
public class ServiceBasicAuthFilterTest {
private WebServiceAuthorizationFilter filter;
private ServiceBasicAuthFilter filter;
private HttpServletRequest request;
private HttpServletResponse response;
......@@ -43,7 +43,7 @@ public class WebServiceAuthorizationFilterTest {
when(clientsDAO.findClientById("test")).thenReturn(Optional.of(client));
filter = spy(new WebServiceAuthorizationFilter());
filter = spy(new ServiceBasicAuthFilter());
doReturn(clientsDAO).when(filter).getClientsDAO(any());
request = mock(HttpServletRequest.class);
......@@ -54,7 +54,7 @@ public class WebServiceAuthorizationFilterTest {
@Test
public void testValidCredentials() throws Exception {
when(request.getServletPath()).thenReturn("/ws/group");
when(request.getServletPath()).thenReturn("/ws/basic/group");
when(request.getHeader("Authorization")).thenReturn("Basic dGVzdDpwYXNzd29yZA=="); // test:password
filter.doFilter(request, response, chain);
......@@ -65,7 +65,7 @@ public class WebServiceAuthorizationFilterTest {
@Test
public void testInvalidCredentials() throws Exception {
when(request.getServletPath()).thenReturn("/ws/group");
when(request.getServletPath()).thenReturn("/ws/basic/group");
when(request.getHeader("Authorization")).thenReturn("Basic dGVzdDp0ZXN0"); // test:test
filter.doFilter(request, response, chain);
......@@ -77,7 +77,7 @@ public class WebServiceAuthorizationFilterTest {
@Test
public void testMissingHeader() throws Exception {
when(request.getServletPath()).thenReturn("/ws/group");
when(request.getServletPath()).thenReturn("/ws/basic/group");
filter.doFilter(request, response, chain);
......
......@@ -34,7 +34,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
@RunWith(MockitoJUnitRunner.class)
public class WebServiceControllerTest {
public class BasicAuthWebServiceControllerTest {
@Mock
private GroupsService groupsService;
......@@ -46,7 +46,7 @@ public class WebServiceControllerTest {
private PermissionsService permissionsService;
@InjectMocks
private WebServiceController controller;
private BasicAuthWebServiceController controller;
private MockMvc mockMvc;
......@@ -78,7 +78,7 @@ public class WebServiceControllerTest {
List<String> names = Arrays.asList("LBT", "INAF");
mockMvc.perform(post("/ws/group")
mockMvc.perform(post("/ws/basic/group")
.content(mapper.writeValueAsString(names))
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andExpect(status().isCreated())
......@@ -97,7 +97,7 @@ public class WebServiceControllerTest {
when(groupsService.findGroupByNames(names)).thenReturn(Optional.of(inaf));
mockMvc.perform(delete("/ws/group?names=LBT&names=INAF"))
mockMvc.perform(delete("/ws/basic/group?names=LBT&names=INAF"))
.andExpect(status().isNoContent());
verify(groupsService, times(1)).deleteGroup(eq(inaf));
......@@ -123,7 +123,7 @@ public class WebServiceControllerTest {
when(membersService.addMember(eq(inaf.getId()), eq(request.getUserId())))
.thenReturn(membership);
mockMvc.perform(post("/ws/member")
mockMvc.perform(post("/ws/basic/member")
.content(mapper.writeValueAsString(request))
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andExpect(status().isCreated())
......@@ -140,7 +140,7 @@ public class WebServiceControllerTest {
when(groupsService.findGroupByNames(names)).thenReturn(Optional.of(inaf));
mockMvc.perform(delete("/ws/member?names=LBT&names=INAF&userId=user_id"))
mockMvc.perform(delete("/ws/basic/member?names=LBT&names=INAF&userId=user_id"))
.andExpect(status().isNoContent());
verify(membersService, times(1)).removeMember(eq(inaf.getId()), eq("user_id"));
......@@ -168,7 +168,7 @@ public class WebServiceControllerTest {
when(permissionsService.addPermission(eq(inaf), eq(request.getUserId()),
eq(request.getPermission()))).thenReturn(permissionEntity);
mockMvc.perform(post("/ws/permission")
mockMvc.perform(post("/ws/basic/permission")
.content(mapper.writeValueAsString(request))
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andExpect(status().isCreated())
......@@ -187,7 +187,7 @@ public class WebServiceControllerTest {
GroupEntity inaf = getInafGroup();
when(groupsService.findGroupByNames(names)).thenReturn(Optional.of(inaf));
mockMvc.perform(delete("/ws/permission?names=LBT&names=INAF&userId=user_id&permission=ADMIN"))
mockMvc.perform(delete("/ws/basic/permission?names=LBT&names=INAF&userId=user_id&permission=ADMIN"))
.andExpect(status().isNoContent());
verify(permissionsService, times(1)).removePermission(eq(inaf), eq("user_id"));
......@@ -200,7 +200,7 @@ public class WebServiceControllerTest {
request.setFromUserId("from_user");
request.setToUserId("to_user");
mockMvc.perform(post("/ws/prepare-join")
mockMvc.perform(post("/ws/basic/prepare-join")
.content(mapper.writeValueAsString(request))
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andExpect(status().isOk());
......
......@@ -39,7 +39,7 @@ public class HomePageControllerTest {
@Test
public void testGetHomePageModel() throws Exception {
when(session.getUserId()).thenReturn("admin_id");
when(session.getUserName()).thenReturn("Name Surname");
when(groupsTabResponseBuilder.getGroupsTab(any())).thenReturn(new GroupsTabResponse());
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment