diff --git a/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala b/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala index 0c90a6ce31f..c8061bbff32 100644 --- a/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala +++ b/access-control-service/src/main/scala/org/apache/texera/service/resource/AccessControlResource.scala @@ -60,6 +60,7 @@ object AccessControlResource extends LazyLogging { def authorize( uriInfo: UriInfo, headers: HttpHeaders, + securityContext: SecurityContext, bodyOpt: Option[String] = None ): Response = { val path = uriInfo.getPath @@ -68,7 +69,7 @@ object AccessControlResource extends LazyLogging { path match { case wsapiWorkflowWebsocket() | apiExecutionsStats() | apiExecutionsResultExport() | pveRoute() => - checkComputingUnitAccess(uriInfo, headers, bodyOpt) + checkComputingUnitAccess(uriInfo, headers, securityContext, bodyOpt) case _ => logger.warn(s"No authorization logic for path: $path. Denying access.") Response.status(Response.Status.FORBIDDEN).build() @@ -78,6 +79,7 @@ object AccessControlResource extends LazyLogging { private def checkComputingUnitAccess( uriInfo: UriInfo, headers: HttpHeaders, + securityContext: SecurityContext, bodyOpt: Option[String] ): Response = { val queryParams: Map[String, String] = uriInfo @@ -91,18 +93,6 @@ object AccessControlResource extends LazyLogging { s"Request URI: ${uriInfo.getRequestUri} and headers: ${headers.getRequestHeaders.asScala} and queryParams: $queryParams" ) - val token: String = { - val qToken = queryParams.get("access-token").filter(_.nonEmpty) - val hToken = Option(headers.getRequestHeader("Authorization")) - .flatMap(_.asScala.headOption) - .map(_.replaceFirst("(?i)^Bearer\\s+", "")) // case-insensitive "Bearer " - .map(_.trim) - .filter(_.nonEmpty) - val bToken = bodyOpt.flatMap(extractTokenFromBody) - qToken.orElse(hToken).orElse(bToken).getOrElse("") - } - logger.info(s"token extracted from request $token") - val cuid = queryParams.get("cuid").filter(_.nonEmpty).getOrElse { uriInfo.getPath match { case pvePvesCuidPath(c) => c @@ -121,7 +111,20 @@ object AccessControlResource extends LazyLogging { var cuAccess: PrivilegeEnum = PrivilegeEnum.NONE var userSession: Optional[SessionUser] = Optional.empty() try { - userSession = parseToken(token) + // The Authorization header is parsed once by JwtAuthFilter, which + // installs a SessionUser into the SecurityContext. Reuse it when + // present, and only fall back to parsing the query / body token + // (which the filter does not see) when no principal is available. + val principal = Option(securityContext) + .flatMap(sc => Option(sc.getUserPrincipal)) + .collect { case u: SessionUser => u } + userSession = principal match { + case Some(user) => Optional.of(user) + case None => + val qToken = queryParams.get("access-token").filter(_.nonEmpty) + val bToken = bodyOpt.flatMap(extractTokenFromBody) + parseToken(qToken.orElse(bToken).getOrElse("")) + } if (userSession.isEmpty) return Response.status(Response.Status.FORBIDDEN).build() @@ -217,9 +220,10 @@ class AccessControlResource extends LazyLogging { @Path("/{path:.*}") def authorizeGet( @Context uriInfo: UriInfo, - @Context headers: HttpHeaders + @Context headers: HttpHeaders, + @Context securityContext: SecurityContext ): Response = { - AccessControlResource.authorize(uriInfo, headers) + AccessControlResource.authorize(uriInfo, headers, securityContext) } @POST @@ -227,19 +231,26 @@ class AccessControlResource extends LazyLogging { def authorizePost( @Context uriInfo: UriInfo, @Context headers: HttpHeaders, + @Context securityContext: SecurityContext, body: String ): Response = { logger.info("Request body: " + body) - AccessControlResource.authorize(uriInfo, headers, Option(body).map(_.trim).filter(_.nonEmpty)) + AccessControlResource.authorize( + uriInfo, + headers, + securityContext, + Option(body).map(_.trim).filter(_.nonEmpty) + ) } @DELETE @Path("/{path:.*}") def authorizeDelete( @Context uriInfo: UriInfo, - @Context headers: HttpHeaders + @Context headers: HttpHeaders, + @Context securityContext: SecurityContext ): Response = { - AccessControlResource.authorize(uriInfo, headers) + AccessControlResource.authorize(uriInfo, headers, securityContext) } } diff --git a/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala b/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala index 75f3bacb107..9ab5a5342d3 100644 --- a/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala +++ b/access-control-service/src/test/scala/org/apache/texera/AccessControlResourceSpec.scala @@ -17,7 +17,7 @@ package org.apache.texera -import jakarta.ws.rs.core.{HttpHeaders, MultivaluedHashMap, Response, UriInfo} +import jakarta.ws.rs.core.{HttpHeaders, MultivaluedHashMap, Response, SecurityContext, UriInfo} import org.apache.texera.auth.JwtAuth import org.apache.texera.auth.util.HeaderField import org.apache.texera.dao.MockTexeraDB @@ -36,6 +36,7 @@ import org.apache.texera.dao.jooq.generated.tables.pojos.{ User, WorkflowComputingUnit } +import org.apache.texera.auth.SessionUser import org.apache.texera.service.resource.AccessControlResource import org.mockito.Mockito._ import org.scalatest.flatspec.AnyFlatSpec @@ -86,6 +87,22 @@ class AccessControlResourceSpec private var token: String = _ + // SecurityContext with no principal — simulates a request that did not + // pass the JwtAuthFilter, exercising the query/body token fallback. + private def emptySecurityContext: SecurityContext = { + val sc = mock(classOf[SecurityContext]) + when(sc.getUserPrincipal).thenReturn(null) + sc + } + + // SecurityContext already populated by the filter from a valid + // Authorization header — the production path for Bearer-token requests. + private def securityContextOf(user: User): SecurityContext = { + val sc = mock(classOf[SecurityContext]) + when(sc.getUserPrincipal).thenReturn(new SessionUser(user)) + sc + } + override protected def beforeAll(): Unit = { initializeDBAndReplaceDSLContext() val userDao = new UserDao(getDSLContext.configuration()) @@ -125,7 +142,8 @@ class AccessControlResourceSpec when(mockHttpHeaders.getRequestHeader("Authorization")).thenReturn(new util.ArrayList[String]()) val accessControlResource = new AccessControlResource() - val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders) + val response = + accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders, emptySecurityContext) response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode } @@ -146,7 +164,8 @@ class AccessControlResourceSpec .thenReturn(util.Arrays.asList("Bearer dummy-token")) val accessControlResource = new AccessControlResource() - val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders) + val response = + accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders, emptySecurityContext) response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode } @@ -165,7 +184,8 @@ class AccessControlResourceSpec when(mockHttpHeaders.getRequestHeader("Authorization")).thenReturn(new util.ArrayList[String]()) val accessControlResource = new AccessControlResource() - val response = accessControlResource.authorizePost(mockUriInfo, mockHttpHeaders, null) + val response = + accessControlResource.authorizePost(mockUriInfo, mockHttpHeaders, emptySecurityContext, null) response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode } @@ -193,7 +213,12 @@ class AccessControlResourceSpec // Instantiate the resource and call the method under test val accessControlResource = new AccessControlResource() - val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders) + val response = + accessControlResource.authorizeGet( + mockUriInfo, + mockHttpHeaders, + securityContextOf(testUser1) + ) // Assert that the response status is FORBIDDEN response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode @@ -222,7 +247,12 @@ class AccessControlResourceSpec // Instantiate the resource and call the method under test val accessControlResource = new AccessControlResource() - val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders) + val response = + accessControlResource.authorizeGet( + mockUriInfo, + mockHttpHeaders, + securityContextOf(testUser1) + ) // Assert that the response status is OK and headers are correct response.getStatus shouldBe Response.Status.OK.getStatusCode @@ -259,36 +289,63 @@ class AccessControlResourceSpec it should "return OK for /pve/system with cuid as query parameter" in { val (uri, headers) = mockRequest("/pve/system", Some(testCU.getCuid.toString)) - val response = new AccessControlResource().authorizeGet(uri, headers) + val response = + new AccessControlResource().authorizeGet(uri, headers, securityContextOf(testUser1)) response.getStatus shouldBe Response.Status.OK.getStatusCode } it should "return OK for /pve/pves/{cuid} (cuid extracted from path)" in { val (uri, headers) = mockRequest(s"/pve/pves/${testCU.getCuid}", None) - val response = new AccessControlResource().authorizeDelete(uri, headers) + val response = + new AccessControlResource().authorizeDelete(uri, headers, securityContextOf(testUser1)) response.getStatus shouldBe Response.Status.OK.getStatusCode } it should "return OK for /pve/{cuid}/{pveName}/packages/{packageName} (cuid extracted from path)" in { val (uri, headers) = mockRequest(s"/pve/${testCU.getCuid}/myenv/packages/numpy", None) - val response = new AccessControlResource().authorizeDelete(uri, headers) + val response = + new AccessControlResource().authorizeDelete(uri, headers, securityContextOf(testUser1)) response.getStatus shouldBe Response.Status.OK.getStatusCode } it should "return FORBIDDEN for a PVE path with no cuid in query or path" in { val (uri, headers) = mockRequest("/pve/no-cuid-anywhere", None) - val response = new AccessControlResource().authorizeGet(uri, headers) + val response = + new AccessControlResource().authorizeGet(uri, headers, securityContextOf(testUser1)) response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode } it should "return FORBIDDEN for a non-PVE / non-whitelisted path" in { val (uri, headers) = mockRequest("/random/garbage", Some(testCU.getCuid.toString)) - val response = new AccessControlResource().authorizeGet(uri, headers) + val response = + new AccessControlResource().authorizeGet(uri, headers, securityContextOf(testUser1)) response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode } + + it should "reuse SessionUser from SecurityContext without any token in the request" in { + val mockUriInfo = mock(classOf[UriInfo]) + val mockHttpHeaders = mock(classOf[HttpHeaders]) + val queryParams = new MultivaluedHashMap[String, String]() + queryParams.add("cuid", testCU.getCuid.toString) + val requestHeaders = new MultivaluedHashMap[String, String]() + + when(mockUriInfo.getQueryParameters).thenReturn(queryParams) + when(mockUriInfo.getRequestUri).thenReturn(new URI(testURI)) + when(mockUriInfo.getPath).thenReturn(testPath) + when(mockHttpHeaders.getRequestHeaders).thenReturn(requestHeaders) + when(mockHttpHeaders.getRequestHeader("Authorization")).thenReturn(new util.ArrayList[String]()) + + val sc = mock(classOf[SecurityContext]) + when(sc.getUserPrincipal).thenReturn(new SessionUser(testUser1)) + + val response = new AccessControlResource().authorizeGet(mockUriInfo, mockHttpHeaders, sc) + + response.getStatus shouldBe Response.Status.OK.getStatusCode + response.getHeaderString(HeaderField.UserId) shouldBe testUser1.getUid.toString + } }