Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ object AccessControlResource extends LazyLogging {
def authorize(
uriInfo: UriInfo,
headers: HttpHeaders,
securityContext: SecurityContext,
bodyOpt: Option[String] = None
): Response = {
val path = uriInfo.getPath
Expand All @@ -68,7 +69,7 @@ object AccessControlResource extends LazyLogging {
path match {
case wsapiWorkflowWebsocket() | apiExecutionsStats() | apiExecutionsResultExport() |
pveRoute() =>
checkComputingUnitAccess(uriInfo, headers, bodyOpt)
checkComputingUnitAccess(uriInfo, headers, securityContext, bodyOpt)
case _ =>
logger.warn(s"No authorization logic for path: $path. Denying access.")
Response.status(Response.Status.FORBIDDEN).build()
Expand All @@ -78,6 +79,7 @@ object AccessControlResource extends LazyLogging {
private def checkComputingUnitAccess(
uriInfo: UriInfo,
headers: HttpHeaders,
securityContext: SecurityContext,
bodyOpt: Option[String]
): Response = {
val queryParams: Map[String, String] = uriInfo
Expand All @@ -91,18 +93,6 @@ object AccessControlResource extends LazyLogging {
s"Request URI: ${uriInfo.getRequestUri} and headers: ${headers.getRequestHeaders.asScala} and queryParams: $queryParams"
)

val token: String = {
val qToken = queryParams.get("access-token").filter(_.nonEmpty)
val hToken = Option(headers.getRequestHeader("Authorization"))
.flatMap(_.asScala.headOption)
.map(_.replaceFirst("(?i)^Bearer\\s+", "")) // case-insensitive "Bearer "
.map(_.trim)
.filter(_.nonEmpty)
val bToken = bodyOpt.flatMap(extractTokenFromBody)
qToken.orElse(hToken).orElse(bToken).getOrElse("")
}
logger.info(s"token extracted from request $token")

val cuid = queryParams.get("cuid").filter(_.nonEmpty).getOrElse {
uriInfo.getPath match {
case pvePvesCuidPath(c) => c
Expand All @@ -121,7 +111,20 @@ object AccessControlResource extends LazyLogging {
var cuAccess: PrivilegeEnum = PrivilegeEnum.NONE
var userSession: Optional[SessionUser] = Optional.empty()
try {
userSession = parseToken(token)
// The Authorization header is parsed once by JwtAuthFilter, which
// installs a SessionUser into the SecurityContext. Reuse it when
// present, and only fall back to parsing the query / body token
// (which the filter does not see) when no principal is available.
val principal = Option(securityContext)
.flatMap(sc => Option(sc.getUserPrincipal))
.collect { case u: SessionUser => u }
userSession = principal match {
case Some(user) => Optional.of(user)
case None =>
val qToken = queryParams.get("access-token").filter(_.nonEmpty)
val bToken = bodyOpt.flatMap(extractTokenFromBody)
parseToken(qToken.orElse(bToken).getOrElse(""))
}
if (userSession.isEmpty)
return Response.status(Response.Status.FORBIDDEN).build()

Expand Down Expand Up @@ -217,29 +220,37 @@ class AccessControlResource extends LazyLogging {
@Path("/{path:.*}")
def authorizeGet(
@Context uriInfo: UriInfo,
@Context headers: HttpHeaders
@Context headers: HttpHeaders,
@Context securityContext: SecurityContext
): Response = {
AccessControlResource.authorize(uriInfo, headers)
AccessControlResource.authorize(uriInfo, headers, securityContext)
}

@POST
@Path("/{path:.*}")
def authorizePost(
@Context uriInfo: UriInfo,
@Context headers: HttpHeaders,
@Context securityContext: SecurityContext,
body: String
): Response = {
logger.info("Request body: " + body)
AccessControlResource.authorize(uriInfo, headers, Option(body).map(_.trim).filter(_.nonEmpty))
AccessControlResource.authorize(
uriInfo,
headers,
securityContext,
Option(body).map(_.trim).filter(_.nonEmpty)
)
}

@DELETE
@Path("/{path:.*}")
def authorizeDelete(
@Context uriInfo: UriInfo,
@Context headers: HttpHeaders
@Context headers: HttpHeaders,
@Context securityContext: SecurityContext
): Response = {
AccessControlResource.authorize(uriInfo, headers)
AccessControlResource.authorize(uriInfo, headers, securityContext)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.texera

import jakarta.ws.rs.core.{HttpHeaders, MultivaluedHashMap, Response, UriInfo}
import jakarta.ws.rs.core.{HttpHeaders, MultivaluedHashMap, Response, SecurityContext, UriInfo}
import org.apache.texera.auth.JwtAuth
import org.apache.texera.auth.util.HeaderField
import org.apache.texera.dao.MockTexeraDB
Expand All @@ -36,6 +36,7 @@ import org.apache.texera.dao.jooq.generated.tables.pojos.{
User,
WorkflowComputingUnit
}
import org.apache.texera.auth.SessionUser
import org.apache.texera.service.resource.AccessControlResource
import org.mockito.Mockito._
import org.scalatest.flatspec.AnyFlatSpec
Expand Down Expand Up @@ -86,6 +87,22 @@ class AccessControlResourceSpec

private var token: String = _

// SecurityContext with no principal — simulates a request that did not
// pass the JwtAuthFilter, exercising the query/body token fallback.
private def emptySecurityContext: SecurityContext = {
val sc = mock(classOf[SecurityContext])
when(sc.getUserPrincipal).thenReturn(null)
sc
}

// SecurityContext already populated by the filter from a valid
// Authorization header — the production path for Bearer-token requests.
private def securityContextOf(user: User): SecurityContext = {
val sc = mock(classOf[SecurityContext])
when(sc.getUserPrincipal).thenReturn(new SessionUser(user))
sc
}

override protected def beforeAll(): Unit = {
initializeDBAndReplaceDSLContext()
val userDao = new UserDao(getDSLContext.configuration())
Expand Down Expand Up @@ -125,7 +142,8 @@ class AccessControlResourceSpec
when(mockHttpHeaders.getRequestHeader("Authorization")).thenReturn(new util.ArrayList[String]())

val accessControlResource = new AccessControlResource()
val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders)
val response =
accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders, emptySecurityContext)

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}
Expand All @@ -146,7 +164,8 @@ class AccessControlResourceSpec
.thenReturn(util.Arrays.asList("Bearer dummy-token"))

val accessControlResource = new AccessControlResource()
val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders)
val response =
accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders, emptySecurityContext)

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}
Expand All @@ -165,7 +184,8 @@ class AccessControlResourceSpec
when(mockHttpHeaders.getRequestHeader("Authorization")).thenReturn(new util.ArrayList[String]())

val accessControlResource = new AccessControlResource()
val response = accessControlResource.authorizePost(mockUriInfo, mockHttpHeaders, null)
val response =
accessControlResource.authorizePost(mockUriInfo, mockHttpHeaders, emptySecurityContext, null)

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}
Expand Down Expand Up @@ -193,7 +213,12 @@ class AccessControlResourceSpec

// Instantiate the resource and call the method under test
val accessControlResource = new AccessControlResource()
val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders)
val response =
accessControlResource.authorizeGet(
mockUriInfo,
mockHttpHeaders,
securityContextOf(testUser1)
)

// Assert that the response status is FORBIDDEN
response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
Expand Down Expand Up @@ -222,7 +247,12 @@ class AccessControlResourceSpec

// Instantiate the resource and call the method under test
val accessControlResource = new AccessControlResource()
val response = accessControlResource.authorizeGet(mockUriInfo, mockHttpHeaders)
val response =
accessControlResource.authorizeGet(
mockUriInfo,
mockHttpHeaders,
securityContextOf(testUser1)
)

// Assert that the response status is OK and headers are correct
response.getStatus shouldBe Response.Status.OK.getStatusCode
Expand Down Expand Up @@ -259,36 +289,63 @@ class AccessControlResourceSpec

it should "return OK for /pve/system with cuid as query parameter" in {
val (uri, headers) = mockRequest("/pve/system", Some(testCU.getCuid.toString))
val response = new AccessControlResource().authorizeGet(uri, headers)
val response =
new AccessControlResource().authorizeGet(uri, headers, securityContextOf(testUser1))

response.getStatus shouldBe Response.Status.OK.getStatusCode
}

it should "return OK for /pve/pves/{cuid} (cuid extracted from path)" in {
val (uri, headers) = mockRequest(s"/pve/pves/${testCU.getCuid}", None)
val response = new AccessControlResource().authorizeDelete(uri, headers)
val response =
new AccessControlResource().authorizeDelete(uri, headers, securityContextOf(testUser1))

response.getStatus shouldBe Response.Status.OK.getStatusCode
}

it should "return OK for /pve/{cuid}/{pveName}/packages/{packageName} (cuid extracted from path)" in {
val (uri, headers) = mockRequest(s"/pve/${testCU.getCuid}/myenv/packages/numpy", None)
val response = new AccessControlResource().authorizeDelete(uri, headers)
val response =
new AccessControlResource().authorizeDelete(uri, headers, securityContextOf(testUser1))

response.getStatus shouldBe Response.Status.OK.getStatusCode
}

it should "return FORBIDDEN for a PVE path with no cuid in query or path" in {
val (uri, headers) = mockRequest("/pve/no-cuid-anywhere", None)
val response = new AccessControlResource().authorizeGet(uri, headers)
val response =
new AccessControlResource().authorizeGet(uri, headers, securityContextOf(testUser1))

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}

it should "return FORBIDDEN for a non-PVE / non-whitelisted path" in {
val (uri, headers) = mockRequest("/random/garbage", Some(testCU.getCuid.toString))
val response = new AccessControlResource().authorizeGet(uri, headers)
val response =
new AccessControlResource().authorizeGet(uri, headers, securityContextOf(testUser1))

response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode
}

it should "reuse SessionUser from SecurityContext without any token in the request" in {
val mockUriInfo = mock(classOf[UriInfo])
val mockHttpHeaders = mock(classOf[HttpHeaders])
val queryParams = new MultivaluedHashMap[String, String]()
queryParams.add("cuid", testCU.getCuid.toString)
val requestHeaders = new MultivaluedHashMap[String, String]()

when(mockUriInfo.getQueryParameters).thenReturn(queryParams)
when(mockUriInfo.getRequestUri).thenReturn(new URI(testURI))
when(mockUriInfo.getPath).thenReturn(testPath)
when(mockHttpHeaders.getRequestHeaders).thenReturn(requestHeaders)
when(mockHttpHeaders.getRequestHeader("Authorization")).thenReturn(new util.ArrayList[String]())

val sc = mock(classOf[SecurityContext])
when(sc.getUserPrincipal).thenReturn(new SessionUser(testUser1))

val response = new AccessControlResource().authorizeGet(mockUriInfo, mockHttpHeaders, sc)

response.getStatus shouldBe Response.Status.OK.getStatusCode
response.getHeaderString(HeaderField.UserId) shouldBe testUser1.getUid.toString
}
}
Loading