Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,8 @@ import org.apache.texera.auth.{
}
import org.apache.texera.dao.SqlServer
import org.apache.texera.service.activity.UserActivityEventListener
import org.apache.texera.service.resource.{
AccessControlResource,
HealthCheckResource,
LiteLLMModelsResource,
LiteLLMProxyResource
}
import org.apache.texera.service.resource.{AccessControlResource, HealthCheckResource}
import org.eclipse.jetty.server.session.SessionHandler
import org.glassfish.jersey.server.filter.RolesAllowedDynamicFeature
import java.nio.file.Path

class AccessControlService extends Application[AccessControlServiceConfiguration] with LazyLogging {
Expand Down Expand Up @@ -73,8 +67,6 @@ class AccessControlService extends Application[AccessControlServiceConfiguration

environment.jersey.register(classOf[HealthCheckResource])
environment.jersey.register(classOf[AccessControlResource])
environment.jersey.register(classOf[LiteLLMProxyResource])
environment.jersey.register(classOf[LiteLLMModelsResource])

// Register JWT authentication filter
environment.jersey.register(new AuthDynamicFeature(classOf[JwtAuthFilter]))
Expand All @@ -85,9 +77,6 @@ class AccessControlService extends Application[AccessControlServiceConfiguration
new io.dropwizard.auth.AuthValueFactoryProvider.Binder(classOf[SessionUser])
)

// Required for @RolesAllowed on resources to be enforced.
environment.jersey.register(classOf[RolesAllowedDynamicFeature])

// Record USER_LAST_ACTIVE_TIME on every matched, completed request.
// Lives only in this service because authenticated client sessions
// contact access-control-service often enough to capture activity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import org.glassfish.jersey.server.monitoring.{
* Lives in access-control-service because USER_LAST_ACTIVE_TIME is a
* user-management concern; the assumption is that any authenticated
* client session contacts this service often enough (UI navigation,
* permission checks, LiteLLM proxy) to capture activity with high
* recall, so other services do not need to mirror this listener.
* permission checks) to capture activity with high recall, so other
* services do not need to mirror this listener.
*/
@Provider
class UserActivityEventListener(track: Integer => Unit = UserActivityTracker.markActive)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ package org.apache.texera.service.resource
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.typesafe.scalalogging.LazyLogging
import jakarta.annotation.security.{PermitAll, RolesAllowed}
import jakarta.ws.rs.client.{Client, ClientBuilder, Entity}
import jakarta.annotation.security.PermitAll
import jakarta.ws.rs.core._
import jakarta.ws.rs.{Consumes, DELETE, GET, POST, Path, Produces}
import jakarta.ws.rs.{DELETE, GET, POST, Path, Produces}
import org.apache.texera.auth.JwtParser.parseToken
import org.apache.texera.auth.SessionUser
import org.apache.texera.auth.util.{ComputingUnitAccess, HeaderField}
import org.apache.texera.config.{GuiConfig, KubernetesConfig, LLMConfig}
import org.apache.texera.dao.jooq.generated.enums.PrivilegeEnum
import org.apache.texera.config.KubernetesConfig
import org.apache.texera.dao.jooq.generated.enums.{PrivilegeEnum, UserRoleEnum}

import java.net.URLDecoder
import java.nio.charset.StandardCharsets
Expand All @@ -45,6 +44,8 @@ object AccessControlResource extends LazyLogging {
private val apiExecutionsStats: Regex = """.*/api/executions/[0-9]+/stats/[0-9]+.*""".r
private val apiExecutionsResultExport: Regex = """.*/api/executions/result/export.*""".r
private val pveRoute: Regex = """^/?(?:auth/)?(?:api/|wsapi/)?pve(?:/.*)?$""".r
// Agent service: authenticate any /api/agents request (Phase 1 — see #5561).
private val apiAgents: Regex = """.*/api/agents.*""".r
// Path patterns whose cuid lives in the URL path rather than the query string.
private val pvePvesCuidPath: Regex = """^/?(?:auth/)?(?:api/|wsapi/)?pve/pves/([0-9]+)$""".r
private val pvePackagesCuidPath: Regex =
Expand All @@ -69,12 +70,68 @@ object AccessControlResource extends LazyLogging {
case wsapiWorkflowWebsocket() | apiExecutionsStats() | apiExecutionsResultExport() |
pveRoute() =>
checkComputingUnitAccess(uriInfo, headers, bodyOpt)
case apiAgents() =>
checkAgentAccess(uriInfo, headers, bodyOpt)
case _ =>
logger.warn(s"No authorization logic for path: $path. Denying access.")
Response.status(Response.Status.FORBIDDEN).build()
}
}

// Extract the bearer token from the access-token query param, the
// Authorization header, or a "token" field in the body (in that order).
private def extractBearerToken(
uriInfo: UriInfo,
headers: HttpHeaders,
bodyOpt: Option[String]
): String = {
val qToken = Option(uriInfo.getQueryParameters().getFirst("access-token"))
.map(_.trim)
.filter(_.nonEmpty)
val hToken = Option(headers.getRequestHeader("Authorization"))
.flatMap(_.asScala.headOption)
.map(_.replaceFirst("(?i)^Bearer\\s+", "")) // case-insensitive "Bearer "
.map(_.trim)
.filter(_.nonEmpty)
val bToken = bodyOpt.flatMap(extractTokenFromBody)
qToken.orElse(hToken).orElse(bToken).getOrElse("")
}

// Phase 1 agent authorization: authenticate the JWT and require a
// REGULAR/ADMIN role. Any such user may reach any agent for now; per-agent
// ownership is deferred (see #5302 / #5561). On success, forward the user
// identity headers so the agent service can trust them.
private def checkAgentAccess(
uriInfo: UriInfo,
headers: HttpHeaders,
bodyOpt: Option[String]
): Response = {
val token = extractBearerToken(uriInfo, headers, bodyOpt)
val userSession: Optional[SessionUser] =
try parseToken(token)
catch {
case e: Exception =>
logger.error(s"Failed parsing token for agent request: $e")
Optional.empty()
}

if (userSession.isEmpty) {
return Response.status(Response.Status.UNAUTHORIZED).build()
}

val user = userSession.get()
if (!(user.isRoleOf(UserRoleEnum.REGULAR) || user.isRoleOf(UserRoleEnum.ADMIN))) {
return Response.status(Response.Status.FORBIDDEN).build()
}

Response
.ok()
.header(HeaderField.UserId, user.getUid.toString)
.header(HeaderField.UserName, user.getName)
.header(HeaderField.UserEmail, user.getEmail)
.build()
}

private def checkComputingUnitAccess(
uriInfo: UriInfo,
headers: HttpHeaders,
Expand Down Expand Up @@ -242,155 +299,3 @@ class AccessControlResource extends LazyLogging {
AccessControlResource.authorize(uriInfo, headers)
}
}

// Forwards chat completions to LiteLLM with the server's master key, so
// only authenticated users may call it.
@Path("/chat")
@RolesAllowed(Array("REGULAR", "ADMIN"))
@Produces(Array(MediaType.APPLICATION_JSON))
@Consumes(Array(MediaType.APPLICATION_JSON))
class LiteLLMProxyResource(
copilotEnabled: Boolean,
litellmBaseUrl: String,
litellmApiKey: String
) extends LazyLogging {

// No-arg constructor for Jersey reflection. Tests use the param-ful form.
def this() =
this(
GuiConfig.guiWorkflowWorkspaceCopilotEnabled,
LLMConfig.baseUrl,
LLMConfig.masterKey
)

private val client: Client = ClientBuilder.newClient()

@POST
@Path("/{path:.*}")
def proxyPost(
@Context uriInfo: UriInfo,
@Context headers: HttpHeaders,
body: String
): Response = {
if (!copilotEnabled) {
return Response
.status(Response.Status.FORBIDDEN)
.entity(LiteLLMProxyResource.CopilotDisabledBody)
.build()
}

// uriInfo.getPath returns "chat/completions" for /api/chat/completions
// We want to forward as "/chat/completions" to LiteLLM
val fullPath = uriInfo.getPath
val targetUrl = s"$litellmBaseUrl/$fullPath"

logger.info(s"Proxying POST request to LiteLLM: $targetUrl")

try {
val requestBuilder = client
.target(targetUrl)
.request(MediaType.APPLICATION_JSON)
.header("Authorization", s"Bearer $litellmApiKey")

// Forward other relevant headers from the original request
headers.getRequestHeaders.asScala.foreach {
case (key, values)
if !key.equalsIgnoreCase("Authorization") &&
!key.equalsIgnoreCase("Host") &&
!key.equalsIgnoreCase("Content-Length") =>
values.asScala.foreach(value => requestBuilder.header(key, value))
case _ => // Skip Authorization, Host, and Content-Length headers
}

val response = requestBuilder.post(Entity.json(body))

// Build response with same status and body from LiteLLM
val responseBody = response.readEntity(classOf[String])
val responseBuilder = Response
.status(response.getStatus)
.entity(responseBody)

// Forward response headers
response.getHeaders.asScala.foreach {
case (key, values) =>
values.asScala.foreach(value => responseBuilder.header(key, value))
}

responseBuilder.build()
} catch {
case e: Exception =>
logger.error(s"Error proxying request to LiteLLM: ${e.getMessage}", e)
Response
.status(Response.Status.BAD_GATEWAY)
.entity(s"""{"error": "Failed to proxy request to LiteLLM: ${e.getMessage}"}""")
.build()
}
}
}

object LiteLLMProxyResource {
val CopilotDisabledBody: String = """{"error": "Copilot feature is disabled"}"""
}

@Path("/models")
@RolesAllowed(Array("REGULAR", "ADMIN"))
@Produces(Array(MediaType.APPLICATION_JSON))
class LiteLLMModelsResource(
copilotEnabled: Boolean,
litellmBaseUrl: String,
litellmApiKey: String
) extends LazyLogging {

// No-arg constructor for Jersey reflection. Tests use the param-ful form.
def this() =
this(
GuiConfig.guiWorkflowWorkspaceCopilotEnabled,
LLMConfig.baseUrl,
LLMConfig.masterKey
)

private val client: Client = ClientBuilder.newClient()

@GET
def getModels: Response = {
if (!copilotEnabled) {
return Response
.status(Response.Status.FORBIDDEN)
.entity(LiteLLMProxyResource.CopilotDisabledBody)
.build()
}

val targetUrl = s"$litellmBaseUrl/models"

logger.info(s"Fetching models from LiteLLM: $targetUrl")

try {
val response = client
.target(targetUrl)
.request(MediaType.APPLICATION_JSON)
.header("Authorization", s"Bearer $litellmApiKey")
.get()

// Build response with same status and body from LiteLLM
val responseBody = response.readEntity(classOf[String])
val responseBuilder = Response
.status(response.getStatus)
.entity(responseBody)

// Forward response headers
response.getHeaders.asScala.foreach {
case (key, values) =>
values.asScala.foreach(value => responseBuilder.header(key, value))
}

responseBuilder.build()
} catch {
case e: Exception =>
logger.error(s"Error fetching models from LiteLLM: ${e.getMessage}", e)
Response
.status(Response.Status.BAD_GATEWAY)
.entity(s"""{"error": "Failed to fetch models from LiteLLM: ${e.getMessage}"}""")
.build()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import io.dropwizard.jetty.MutableServletContextHandler
import io.dropwizard.jetty.setup.ServletEnvironment
import org.apache.texera.auth.UnauthorizedExceptionMapper
import org.apache.texera.service.activity.UserActivityEventListener
import org.glassfish.jersey.server.filter.RolesAllowedDynamicFeature
import org.mockito.ArgumentMatchers.isA
import org.mockito.Mockito.{mock, verify, when}
import org.scalatest.flatspec.AnyFlatSpec
Expand All @@ -47,8 +46,6 @@ class AccessControlServiceRunSpec extends AnyFlatSpec with Matchers {

verify(jersey).register(isA(classOf[UserActivityEventListener]))
verify(jersey).register(classOf[UnauthorizedExceptionMapper])
// Without this feature Jersey ignores @RolesAllowed on the LiteLLM proxies.
verify(jersey).register(classOf[RolesAllowedDynamicFeature])
verify(jersey).setUrlPattern("/api/*")
}
}
Loading
Loading