diff --git a/docs/byok_guide.md b/docs/byok_guide.md index 29ac81151..4adac9fd7 100644 --- a/docs/byok_guide.md +++ b/docs/byok_guide.md @@ -16,7 +16,7 @@ The BYOK (Bring Your Own Knowledge) feature in Lightspeed Core enables users to * [Step 2: Create Vector Database](#step-2-create-vector-database) * [Step 3: Configure Embedding Model](#step-3-configure-embedding-model) * [Step 4: Configure Llama Stack](#step-4-configure-llama-stack) - * [Step 5: Enable RAG Tools](#step-5-enable-rag-tools) + * [Step 5: Configure RAG Strategy](#step-5-configure-rag-strategy) * [Supported Vector Database Types](#supported-vector-database-types) * [Configuration Examples](#configuration-examples) * [Conclusion](#conclusion) @@ -34,27 +34,58 @@ BYOK (Bring Your Own Knowledge) is Lightspeed Core's implementation of Retrieval ## How BYOK Works -The BYOK system operates through a sophisticated chain of components: +BYOK knowledge sources can be queried in two complementary modes, configured independently: -1. **Agent Orchestrator**: The AI agent acts as the central coordinator, using the LLM as its reasoning engine -2. **RAG Tool**: When the agent needs external information, it queries your custom vector database -3. **Vector Database**: Your indexed knowledge sources, stored as vector embeddings for semantic search -4. **Embedding Model**: Converts queries and documents into vector representations for similarity matching -5. **Context Integration**: Retrieved knowledge is integrated into the AI's response generation process +### Inline RAG + +Context is fetched from your BYOK vector stores and/or OKP and injected before the LLM request. No tool calls are required. + +```mermaid +graph TD + A[User Query] --> B[Fetch Context] + B --> C[BYOK Vector Stores] + B --> D[OKP Vector Stores] + C --> E[Retrieved Chunks] + D --> E + E --> F[Inject Context into Prompt Context] + F --> G[LLM Generates Response] + G --> H[Response to User] +``` + +### Tool RAG (on-demand retrieval) + +The LLM can call the `file_search` tool during generation when it decides external knowledge is needed. Both BYOK vector stores and OKP are supported in Tool RAG mode. ```mermaid graph TD - A[User Query] --> B[AI Agent] + A[User Query] --> P{Inline RAG enabled?} + P -->|Yes| Q[Fetch Context] + Q --> R[BYOK / OKP Vector Stores] + R --> S[Inject Context into Prompt Context] + S --> B[LLM] + P -->|No| B B --> C{Need External Knowledge?} - C -->|Yes| D[RAG Tool] + C -->|Yes| D[file_search Tool] C -->|No| E[Generate Response] - D --> F[Vector Database] + D --> F[BYOK / OKP Vector Stores] F --> G[Retrieve Relevant Context] - G --> H[Integrate Context] - H --> E - E --> I[Response to User] + G --> B + E --> H[Response to User] ``` +Both modes rely on: +- **Vector Database**: Your indexed knowledge sources stored as vector embeddings +- **Embedding Model**: Converts queries and documents into vector representations for similarity matching + +Inline RAG additionally supports: +- **Score Multiplier**: Optional weight applied per BYOK vector store when mixing multiple sources. Allows custom prioritization of content. + +> [!NOTE] +> OKP and BYOK scores are not directly comparable (different scoring systems), so +> `score_multiplier` does not apply to OKP results. To control the amount of retrieved +> context, set the `BYOK_RAG_MAX_CHUNKS` and `OKP_RAG_MAX_CHUNKS` constants in `src/constants.py` +> (defaults: 10 and 5 respectively). For Tool RAG, use `TOOL_RAG_MAX_CHUNKS` (default: 10). + --- ## Prerequisites @@ -244,12 +275,58 @@ registered_resources: **⚠️ Important**: The `vector_store_id` value must exactly match the ID you provided when creating the vector database using the rag-content tool. This identifier links your Llama Stack configuration to the specific vector database index you created. -### Step 5: Enable RAG Tools +> [!TIP] +> Instead of manually editing `run.yaml`, you can declare your knowledge sources in the `byok_rag` +> section of `lightspeed-stack.yaml`. The lightspeed-stack service automatically generates the required configuration +> at startup. +> +> ```yaml +> byok_rag: +> - rag_id: my-docs # Unique identifier for this knowledge source +> rag_type: inline::faiss +> embedding_model: sentence-transformers/all-mpnet-base-v2 +> embedding_dimension: 768 +> vector_db_id: your-index-id # Llama Stack vector store ID (from index generation) +> db_path: /path/to/vector_db/faiss_store.db +> score_multiplier: 1.0 # Optional: weight results when mixing multiple sources +> ``` +> +> When multiple BYOK sources are configured, `score_multiplier` adjusts the relative importance of +> each store's results during Inline RAG retrieval. Values above 1.0 boost a store; below 1.0 reduce it. + +### Step 5: Configure RAG Strategy + +Add a `rag` section to your `lightspeed-stack.yaml` to choose how BYOK knowledge is used. +Each list entry is a `rag_id` from `byok_rag`, or the special value `okp` for OKP. + +```yaml +rag: + # Inline RAG: inject context before the LLM request (no tool calls needed) + inline: + - my-docs # rag_id from byok_rag + - okp # include OKP context inline + + # Tool RAG: the LLM can call file_search to retrieve context on demand + # Omit to use all registered BYOK stores (backward compatibility) + tool: + - my-docs # expose this BYOK store as the file_search tool + - okp # expose OKP as the file_search tool + +# OKP provider settings (only relevant when okp is listed above) +okp: + offline: true # true = use parent_id for source URLs, false = use reference_url +``` + +Both modes can be enabled simultaneously. Choose based on your latency and control preferences: -The configuration above automatically enables the RAG tools. The system will: +| Mode | When context is fetched | Tool call needed | score_multiplier | +|------|------------------------|------------------|-----------------| +| Inline RAG | With every query | No | Yes (BYOK only) | +| Tool RAG | On LLM demand | Yes | No | -1. **Detect RAG availability**: Automatically identify when RAG is available -2. **Enhance prompts**: Encourage the AI to use RAG tools +> [!TIP] +> A ready-to-use example combining BYOK and OKP is available at +> [`examples/lightspeed-stack-byok-okp-rag.yaml`](../examples/lightspeed-stack-byok-okp-rag.yaml). --- diff --git a/docs/config.md b/docs/config.md index 6d1dde908..8ba10ad7e 100644 --- a/docs/config.md +++ b/docs/config.md @@ -110,15 +110,32 @@ Microsoft Entra ID authentication attributes for Azure. BYOK (Bring Your Own Knowledge) RAG configuration. +Each entry registers a local vector store. The `rag_id` is the +identifier used in `rag.inline` and `rag.tool` to select which stores to use. + +Example: + +```yaml +byok_rag: + - rag_id: my-docs # referenced in rag.inline / rag.tool + rag_type: inline::faiss + embedding_model: sentence-transformers/all-MiniLM-L6-v2 + embedding_dimension: 384 + vector_db_id: vs_abc123 + db_path: /path/to/faiss_store.db + score_multiplier: 1.0 +``` + | Field | Type | Description | |-------|------|-------------| | rag_id | string | Unique RAG ID | -| rag_type | string | Type of RAG database. | +| rag_type | string | Type of RAG database (e.g. `inline::faiss`). | | embedding_model | string | Embedding model identification | | embedding_dimension | integer | Dimensionality of embedding vectors. | | vector_db_id | string | Vector database identification. | | db_path | string | Path to RAG database. | +| score_multiplier | number | Multiplier applied to relevance scores from this vector store when querying multiple sources. Values > 1 boost results; values < 1 reduce them. Default: 1.0. | ## CORSConfiguration @@ -170,7 +187,7 @@ Global service configuration. | azure_entra_id | | | | splunk | | Splunk HEC configuration for sending telemetry events. | | deployment_environment | string | Deployment environment name (e.g., 'development', 'staging', 'production'). Used in telemetry events. | -| solr | | Configuration for Solr vector search operations. | +| rag | | RAG strategy configuration (OKP and BYOK). Controls pre-query (Inline RAG) and tool-based (Tool RAG) retrieval. | ## ConversationHistoryConfiguration @@ -520,19 +537,60 @@ the service can handle requests concurrently. | cors | | Cross-Origin Resource Sharing configuration for cross-domain requests | -## SolrConfiguration +## RagConfiguration + + +Top-level RAG strategy configuration. Controls two complementary retrieval modes: + +- **Inline RAG**: context is fetched from the listed sources and injected before the + LLM request. +- **Tool RAG**: the LLM can call the `file_search` tool during generation to retrieve + context on demand from the listed vector stores. Supports both BYOK and OKP. + +Each strategy is configured as a list of RAG IDs referencing entries in `byok_rag`. +The special ID `okp` activates the OKP provider (no `byok_rag` entry needed). + +**Backward compatibility**: omitting `tool` uses all registered BYOK vector stores +(equivalent to the old `tool.byok.enabled = True`). Omitting `inline` means no +context is injected before the LLM request. + +Example: + +```yaml +rag: + inline: + - my-docs # inject context from my-docs before the LLM request + tool: + - okp # LLM can search OKP as a tool + - my-docs # LLM can also search my-docs as a tool + +okp: + offline: true # use parent_id for OKP URL construction +``` + + +| Field | Type | Description | +|-------|------|-------------| +| inline | list[string] | RAG IDs whose content is injected before the LLM request. Use `okp` for OKP. Empty by default (no inline RAG). | +| tool | list[string] or null | RAG IDs exposed as a `file_search` tool the LLM can invoke. Use `okp` to include OKP. When omitted, all registered BYOK vector stores are used (backward compatibility). | + +## OkpConfiguration -Solr configuration for vector search queries. +OKP (Offline Knowledge Portal) provider settings. Only used when `okp` is listed in `rag.inline` or `rag.tool`. -Controls whether to use offline or online mode when building document URLs -from vector search results, and enables/disables Solr vector IO functionality. +Example: +```yaml +okp: + offline: true # use parent_id for OKP URL construction + chunk_filter_query: "is_chunk:true" +``` | Field | Type | Description | |-------|------|-------------| -| enabled | boolean | When True, enables Solr vector IO functionality for vector search queries. When False, disables Solr vector search processing. | -| offline | boolean | When True, use parent_id for chunk source URLs. When False, use reference_url for chunk source URLs. | +| offline | boolean | When `true` (default), use `parent_id` for OKP chunk source URLs. When `false`, use `reference_url`. | +| chunk_filter_query | string | OKP filter query (`fq`) applied to every OKP search request. Defaults to `"is_chunk:true"`. Extend with `AND` for extra constraints. | ## SplunkConfiguration diff --git a/docs/openapi.json b/docs/openapi.json index 23bac9b99..1f855a6d1 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -5503,6 +5503,13 @@ "format": "file-path", "title": "DB path", "description": "Path to RAG database." + }, + "score_multiplier": { + "type": "number", + "exclusiveMinimum": 0.0, + "title": "Score multiplier", + "description": "Multiplier applied to relevance scores from this vector store. Used to weight results when querying multiple knowledge sources. Values > 1 boost this store's results; values < 1 reduce them.", + "default": 1.0 } }, "additionalProperties": false, @@ -5714,17 +5721,15 @@ "description": "Deployment environment name (e.g., 'development', 'staging', 'production'). Used in telemetry events.", "default": "development" }, - "solr": { - "anyOf": [ - { - "$ref": "#/components/schemas/SolrConfiguration" - }, - { - "type": "null" - } - ], - "title": "Solr configuration", - "description": "Configuration for Solr vector search operations." + "rag": { + "$ref": "#/components/schemas/RagConfiguration", + "title": "RAG configuration", + "description": "Configuration for all RAG strategies (inline and tool-based)." + }, + "okp": { + "$ref": "#/components/schemas/OkpConfiguration", + "title": "OKP configuration", + "description": "OKP provider settings. Only used when 'okp' is listed in rag.inline or rag.tool." } }, "additionalProperties": false, @@ -7575,6 +7580,26 @@ "title": "OAuthFlows", "description": "Defines the configuration for the supported OAuth 2.0 flows." }, + "OkpConfiguration": { + "properties": { + "offline": { + "type": "boolean", + "title": "OKP offline mode", + "description": "When True, use parent_id for OKP chunk source URLs. When False, use reference_url for chunk source URLs.", + "default": true + }, + "chunk_filter_query": { + "type": "string", + "title": "OKP chunk filter query", + "description": "OKP filter query applied to every OKP search request. Defaults to 'is_chunk:true' to restrict results to chunk documents. To add extra constraints, extend the expression using boolean syntax, e.g. 'is_chunk:true AND product:*openshift*'.", + "default": "is_chunk:true" + } + }, + "additionalProperties": false, + "type": "object", + "title": "OkpConfiguration", + "description": "OKP (Offline Knowledge Portal) provider configuration.\n\nControls provider-specific behaviour for the OKP vector store.\nOnly relevant when ``\"okp\"`` is listed in ``rag.inline`` or ``rag.tool``." + }, "OpenIdConnectSecurityScheme": { "properties": { "description": { @@ -8749,6 +8774,37 @@ "title": "RHIdentityConfiguration", "description": "Red Hat Identity authentication configuration." }, + "RagConfiguration": { + "properties": { + "inline": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Inline RAG IDs", + "description": "RAG IDs whose sources are injected as context before the LLM call. Use 'okp' to enable OKP inline RAG. Empty by default (no inline RAG)." + }, + "tool": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Tool RAG IDs", + "description": "RAG IDs made available to the LLM as a file_search tool. Use 'okp' to include the OKP vector store. When omitted, all registered BYOK vector stores are used (backward compatibility)." + } + }, + "additionalProperties": false, + "type": "object", + "title": "RagConfiguration", + "description": "RAG strategy configuration.\n\nControls which RAG sources are used for inline and tool-based retrieval.\n\nEach strategy lists RAG IDs to include. The special ID ``\"okp\"`` defined in constants,\nactivates the OKP provider; all other IDs refer to entries in ``byok_rag``.\n\nBackward compatibility:\n - ``inline`` defaults to ``[]`` (no inline RAG).\n - ``tool`` defaults to ``None`` which means all registered vector stores\n are used (identical to the previous ``tool.byok.enabled = True`` default)." + }, "ReadinessResponse": { "properties": { "ready": { @@ -9260,26 +9316,6 @@ } ] }, - "SolrConfiguration": { - "properties": { - "enabled": { - "type": "boolean", - "title": "Solr enabled", - "description": "When True, enables Solr vector IO functionality for vector search queries. When False, disables Solr vector search processing.", - "default": false - }, - "offline": { - "type": "boolean", - "title": "Offline mode", - "description": "When True, use parent_id for chunk source URLs. When False, use reference_url for chunk source URLs.", - "default": true - } - }, - "additionalProperties": false, - "type": "object", - "title": "SolrConfiguration", - "description": "Solr configuration for vector search queries.\n\nControls whether to use offline or online mode when building document URLs\nfrom vector search results, and enables/disables Solr vector IO functionality." - }, "SplunkConfiguration": { "properties": { "enabled": { diff --git a/docs/rag_guide.md b/docs/rag_guide.md index d07a03b0d..fbf2e1eb2 100644 --- a/docs/rag_guide.md +++ b/docs/rag_guide.md @@ -5,7 +5,7 @@ This document explains how to configure and customize your RAG pipeline using th * Initialize a vector store * Download and point to a local embedding model * Configure an inference provider (LLM) -* Enable Agent-based RAG querying +* Choose a RAG strategy (Inline RAG or Tool RAG) --- @@ -26,12 +26,17 @@ This document explains how to configure and customize your RAG pipeline using th # Introduction -RAG in Lightspeed Core Stack (LCS) is yet only supported via the Agents API. The agent is responsible for planning and deciding when to query the vector index. +Lightspeed Core Stack (LCS) supports two complementary RAG strategies: -The system operates a chain of command. The **Agent** is the orchestrator, using the LLM as its reasoning engine. When a plan requires external information, the Agent queries the **Vector Store**. This is your database of indexed knowledge, which you are responsible for creating before running the stack. The **Embedding Model** is used to convert the queries to vectors. +- **Inline RAG**: context is fetched from BYOK vector stores and/or OKP and injected before the LLM request. No tool calls are required. +- **Tool RAG**: the LLM can call the `file_search` tool during generation to retrieve context on demand from BYOK vector stores and/or OKP. + +Both strategies can be enabled independently via the `rag` section of `lightspeed-stack.yaml`. See [BYOK Feature Documentation](byok_guide.md) for configuration details. + +The **Embedding Model** is used to convert queries and documents into vector representations for similarity matching. > [!NOTE] -> The same Embedding Model should be used to both create the store and to query it. +> The same Embedding Model should be used to both create the vector store and to query it. --- @@ -268,74 +273,36 @@ The OKP (Offline Knowledge Portal) Solr Vector IO is a read-only vector search p #### How to Enable Solr Vector IO -**1. Configure Llama Stack (`run.yaml`):** +**1. Configure Lightspeed Stack (`lightspeed-stack.yaml`):** ```yaml -providers: - vector_io: - - provider_id: solr-vector - provider_type: remote::solr_vector_io - config: - solr_url: http://localhost:8983/solr - collection_name: portal-rag - vector_field: chunk_vector - content_field: chunk - embedding_dimension: 384 - embedding_model: ${env.EMBEDDING_MODEL_DIR} - chunk_window_config: - chunk_parent_id_field: "parent_id" - chunk_content_field: "chunk_field" - chunk_index_field: "chunk_index" - chunk_token_count_field: "num_tokens" - parent_total_chunks_field: "total_chunks" - parent_total_tokens_field: "total_tokens" - chunk_filter_query: "is_chunk:true" - persistence: - namespace: portal-rag - backend: kv_default - -registered_resources: - vector_stores: - - vector_store_id: portal-rag - provider_id: solr-vector - embedding_model: granite-embedding-30m - embedding_dimension: 384 +rag: + inline: + - okp # inject OKP context before the LLM request + tool: + - okp # expose OKP as the file_search tool + +okp: + offline: true # true = use parent_id for source URLs (offline mode) + # false = use reference_url (online mode) ``` -Note: if the vector database (portal-rag) is not in the persistent data store within the vector_io provider -(e.g. after deleting the llama stack cache) you will need to register the vector database under registered resources: - - -```yaml - vector_stores: - - embedding_dimension: 384 - embedding_model: sentence-transformers/${env.EMBEDDING_MODEL_DIR} - provider_id: solr-vector - vector_store_id: portal-rag -``` - - -**2. Configure Lightspeed Stack (`lightspeed-stack.yaml`):** - -```yaml -solr: - enabled: true # Enable Solr vector IO functionality - offline: true # Use parent_id for document URLs (offline mode) - # Set to false to use reference_url (online mode) -``` +> [!NOTE] +> When `okp` is listed in `rag.inline` or `rag.tool`, Lightspeed Stack automatically enriches +> the Llama Stack `run.yaml` at startup with the required `vector_io` provider and `registered_resources` +> entries for the OKP vector store. No manual registration is needed. **Query Request Example:** ``` curl -sX POST http://localhost:8080/v1/query \ -H "Content-Type: application/json" \ - -d '{"query" : "how do I secure a nodejs application with keycloak?", "no_tools":true}' | jq . + -d '{"query" : "how do I secure a nodejs application with keycloak?"}' | jq . ``` -Note: Solr does not currently work with RAG tools. You will need to specify "no_tools": true in request. **Query Processing:** -1. When Solr is enabled, queries use the `portal-rag` vector store +1. When OKP is enabled, queries use the `portal-rag` vector store 2. Vector search is performed with configurable parameters: - `k`: Number of results (default: 5) - `score_threshold`: Minimum similarity score (default: 0.0) @@ -347,11 +314,19 @@ Note: Solr does not currently work with RAG tools. You will need to specify "no_ **Query Filtering:** -To filter the Solr context edit the *chunk_filter_query* field in the -Solr **vector_io** provider in the `run.yaml`. Filters should follow the key:value format: -ex. `"product:*openshift*"` +To filter the Solr context, set the `chunk_filter_query` field in the `okp` section of +`lightspeed-stack.yaml`. Filters follow the Solr key:value format and are applied as a static +`fq` parameter on every OKP search request. The default value `"is_chunk:true"` restricts +results to chunk documents. To add extra constraints, extend the expression using Solr boolean +syntax: -Note: This static filter is a temporary work-around. +```yaml +okp: + chunk_filter_query: "is_chunk:true AND product:*openshift*" +``` + +> [!NOTE] +> This static filter is a temporary work-around until dynamic per-request filtering is supported. **Prerequisites:** @@ -359,6 +334,18 @@ Note: This static filter is a temporary work-around. for instructions on how to pull and run the OKP Solr image visit: https://github.com/lightspeed-core/lightspeed-providers/lightspeed_stack_providers/providers/remote/solr_vector_io/solr_vector_io/README.md +**Chunk volume:** + +OKP and BYOK scores are not directly comparable (different scoring systems), so +`score_multiplier` (a BYOK-only concept) does not apply to OKP results. To control +the number of retrieved chunks, set the constants in `src/constants.py`: + +| Constant | Default | Description | +|----------|---------|-------------| +| `OKP_RAG_MAX_CHUNKS` | 5 | Max chunks retrieved from OKP (Inline RAG) | +| `BYOK_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved from BYOK stores (Inline RAG) | +| `TOOL_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved via Tool RAG (`file_search`) | + **Limitations:** - This is a **read-only** provider - no insert/delete operations diff --git a/examples/lightspeed-stack-byok-okp-rag.yaml b/examples/lightspeed-stack-byok-okp-rag.yaml new file mode 100644 index 000000000..3cd358ff1 --- /dev/null +++ b/examples/lightspeed-stack-byok-okp-rag.yaml @@ -0,0 +1,71 @@ +name: Lightspeed Core Service (LCS) +service: + host: localhost + port: 8080 + auth_enabled: false + workers: 1 + color_log: true + access_log: true +llama_stack: + use_as_library_client: false + url: http://localhost:8321 + api_key: xyzzy +user_data_collection: + feedback_enabled: true + feedback_storage: "/tmp/data/feedback" + transcripts_enabled: true + transcripts_storage: "/tmp/data/transcripts" +authentication: + module: "noop" +quota_handlers: + sqlite: + db_path: quota.sqlite + limiters: + - name: user_monthly_limits + type: user_limiter + initial_quota: 50 + quota_increase: 50 + period: "30 seconds" + - name: cluster_monthly_limits + type: cluster_limiter + initial_quota: 100 + quota_increase: 100 + period: "30 seconds" + scheduler: + # scheduler ticks in seconds + period: 10 +byok_rag: + - rag_id: ocp-docs # referenced in rag.inline / rag.tool + rag_type: inline::faiss + embedding_dimension: 1024 + vector_db_id: vs_123 # Llama-stack vector_store_id + db_path: /tmp/ocp.faiss + score_multiplier: 1.0 # Weight for this vector store's results (Inline RAG only) + - rag_id: knowledge-base # referenced in rag.inline / rag.tool + rag_type: inline::faiss + embedding_dimension: 384 + vector_db_id: vs_456 # Llama-stack vector_store_id + db_path: /tmp/kb.faiss + score_multiplier: 1.2 # Weight for this vector store's results (Inline RAG only) + +# RAG configuration +rag: + # Inline RAG: context injected before the LLM request from the listed sources + # List rag_ids from byok_rag, or 'okp' to include OKP + inline: + - ocp-docs + - knowledge-base + - okp + # Tool RAG: LLM can call file_search on demand to retrieve context + # List rag_ids from byok_rag, or 'okp' to include OKP + # Omit to use all registered BYOK stores (backward compatibility) + tool: + - ocp-docs + - knowledge-base + +# OKP provider settings (only used when 'okp' is listed in rag.inline or rag.tool) +okp: + offline: true # true = use parent_id for source URLs, false = use reference_url + # Solr fq applied to every OKP search request. Combine with AND for extra constraints: + # chunk_filter_query: "is_chunk:true AND product:*openshift*" + chunk_filter_query: "is_chunk:true" diff --git a/examples/lightspeed-stack-byok-rag.yaml b/examples/lightspeed-stack-byok-rag.yaml deleted file mode 100644 index 7780ac21f..000000000 --- a/examples/lightspeed-stack-byok-rag.yaml +++ /dev/null @@ -1,47 +0,0 @@ -name: Lightspeed Core Service (LCS) -service: - host: localhost - port: 8080 - auth_enabled: false - workers: 1 - color_log: true - access_log: true -llama_stack: - use_as_library_client: false - url: http://localhost:8321 - api_key: xyzzy -user_data_collection: - feedback_enabled: true - feedback_storage: "/tmp/data/feedback" - transcripts_enabled: true - transcripts_storage: "/tmp/data/transcripts" -authentication: - module: "noop" -quota_handlers: - sqlite: - db_path: quota.sqlite - limiters: - - name: user_monthly_limits - type: user_limiter - initial_quota: 50 - quota_increase: 50 - period: "30 seconds" - - name: cluster_monthly_limits - type: cluster_limiter - initial_quota: 100 - quota_increase: 100 - period: "30 seconds" - scheduler: - # scheduler ticks in seconds - period: 10 -byok_rag: - - rag_id: ocp_docs - rag_type: inline::faiss - embedding_dimension: 1024 - vector_db_id: vector_byok_1 - db_path: /tmp/ocp.faiss - - rag_id: knowledge_base - rag_type: inline::faiss - embedding_dimension: 384 - vector_db_id: vector_byok_2 - db_path: /tmp/kb.faiss diff --git a/lightspeed-stack.yaml b/lightspeed-stack.yaml index 98b2555a8..fe655a810 100644 --- a/lightspeed-stack.yaml +++ b/lightspeed-stack.yaml @@ -31,8 +31,3 @@ conversation_cache: authentication: module: "noop" - -# OKP Solr for supplementary RAG -solr: - enabled: false - offline: true \ No newline at end of file diff --git a/run.yaml b/run.yaml index 29ce3cae3..b7e56d249 100644 --- a/run.yaml +++ b/run.yaml @@ -24,10 +24,7 @@ providers: config: api_key: ${env.OPENAI_API_KEY} allowed_models: ["${env.E2E_OPENAI_MODEL:=gpt-4o-mini}"] - - config: - allowed_models: - - ${env.EMBEDDING_MODEL_DIR} - provider_id: sentence-transformers + - provider_id: sentence-transformers provider_type: inline::sentence-transformers files: - config: @@ -58,27 +55,7 @@ providers: provider_id: rag-runtime provider_type: inline::rag-runtime vector_io: - - provider_id: solr-vector - provider_type: remote::solr_vector_io - config: - solr_url: http://localhost:8983/solr - collection_name: portal-rag - vector_field: chunk_vector - content_field: chunk - embedding_dimension: 384 - embedding_model: ${env.EMBEDDING_MODEL_DIR} - chunk_window_config: - chunk_parent_id_field: "parent_id" - chunk_content_field: "chunk_field" - chunk_index_field: "chunk_index" - chunk_token_count_field: "num_tokens" - parent_total_chunks_field: "total_chunks" - parent_total_tokens_field: "total_tokens" - chunk_filter_query: "is_chunk:true" - persistence: - namespace: portal-rag - backend: kv_default - - config: # Define the storage backend for RAG + - config: persistence: namespace: vector_io::faiss backend: kv_default @@ -149,28 +126,21 @@ storage: namespace: prompts backend: kv_default registered_resources: - models: - - model_id: granite-embedding-30m - model_type: embedding - provider_id: sentence-transformers - provider_model_id: ${env.EMBEDDING_MODEL_DIR} - metadata: - embedding_dimension: 384 + models: [] shields: - shield_id: llama-guard provider_id: llama-guard provider_shield_id: openai/gpt-4o-mini - vector_stores: - - embedding_dimension: 384 - embedding_model: sentence-transformers/${env.EMBEDDING_MODEL_DIR} - provider_id: solr-vector - vector_store_id: portal-rag + vector_stores: [] datasets: [] scoring_fns: [] benchmarks: [] tool_groups: - toolgroup_id: builtin::rag # Register the RAG tool provider_id: rag-runtime +# REQUIRED: This section is necessary for file_search tool calls to work. +# Without it, llama-stack's rag-runtime silently fails all file_search operations +# with no error logged. vector_stores: default_provider_id: faiss default_embedding_model: # Define the default embedding model for RAG diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index cbd06e0e7..659c55f3a 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -9,8 +9,8 @@ from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import ( APIConnectionError, - AsyncLlamaStackClient, APIStatusError as LLSApiStatusError, + AsyncLlamaStackClient, ) from openai._exceptions import ( APIStatusError as OpenAIAPIStatusError, @@ -22,9 +22,9 @@ from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration +from log import get_logger from models.config import Action from models.requests import QueryRequest - from models.responses import ( ForbiddenResponse, InternalServerErrorResponse, @@ -40,10 +40,11 @@ check_configuration_loaded, validate_and_retrieve_conversation, ) -from utils.mcp_headers import mcp_headers_dependency, McpHeaders +from utils.mcp_headers import McpHeaders, mcp_headers_dependency from utils.query import ( consume_query_tokens, handle_known_apistatus_errors, + prepare_input, store_query_results, update_azure_token, validate_attachments_metadata, @@ -67,8 +68,7 @@ ResponsesApiParams, TurnSummary, ) -from utils.vector_search import perform_vector_search, format_rag_context_for_injection -from log import get_logger +from utils.vector_search import build_rag_context logger = get_logger(__name__) router = APIRouter(tags=["query"]) @@ -155,15 +155,14 @@ async def query_endpoint_handler( client = AsyncLlamaStackClientHolder().get_client() - _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( - client, query_request.query, query_request.solr + # Build RAG context from Inline RAG sources + inline_rag_context = await build_rag_context( + client, query_request.query, query_request.vector_store_ids, query_request.solr ) - rag_context = format_rag_context_for_injection(pre_rag_chunks) - if rag_context: - # safest: mutate a local copy so we don't surprise other logic - query_request = query_request.model_copy(deep=True) # pydantic v2 - query_request.query = query_request.query + rag_context + # Moderation input is the raw user content (query + attachments) without injected RAG + # context, to avoid false positives from retrieved document content. + moderation_input = prepare_input(query_request) # Prepare API request parameters responses_params = await prepare_responses_params( @@ -175,6 +174,7 @@ async def query_endpoint_handler( stream=False, store=True, request_headers=request.headers, + inline_rag_context=inline_rag_context.context_text or None, ) # Handle Azure token refresh if needed @@ -197,15 +197,21 @@ async def query_endpoint_handler( query_request.shield_ids, vector_store_ids, rag_id_mapping, + moderation_input=moderation_input, ) - if pre_rag_chunks: - turn_summary.rag_chunks = pre_rag_chunks + (turn_summary.rag_chunks or []) - - if doc_ids_from_chunks: - turn_summary.referenced_documents = deduplicate_referenced_documents( - doc_ids_from_chunks + turn_summary.referenced_documents - ) + # Combine inline RAG results (BYOK + Solr) with tool-based RAG results for the transcript + rag_chunks = inline_rag_context.rag_chunks + tool_rag_chunks = turn_summary.rag_chunks or [] + logger.info("RAG as a tool retrieved %d chunks", len(tool_rag_chunks)) + turn_summary.rag_chunks = rag_chunks + tool_rag_chunks + + # Add tool-based RAG documents and chunks + rag_documents = inline_rag_context.referenced_documents + tool_rag_documents = turn_summary.referenced_documents or [] + turn_summary.referenced_documents = deduplicate_referenced_documents( + rag_documents + tool_rag_documents + ) # Get topic summary for new conversation if not user_conversation and query_request.generate_topic_summary: @@ -268,6 +274,7 @@ async def retrieve_response( # pylint: disable=too-many-locals shield_ids: Optional[list[str]] = None, vector_store_ids: Optional[list[str]] = None, rag_id_mapping: Optional[dict[str, str]] = None, + moderation_input: Optional[str] = None, ) -> TurnSummary: """ Retrieve response from LLMs and agents. @@ -281,6 +288,9 @@ async def retrieve_response( # pylint: disable=too-many-locals shield_ids: Optional list of shield IDs for moderation. vector_store_ids: Vector store IDs used in the query for source resolution. rag_id_mapping: Mapping from vector_db_id to user-facing rag_id. + moderation_input: Text to moderate. Should be the raw user content (query + + attachments) without injected RAG context to avoid false positives. + Falls back to responses_params.input if not provided. Returns: TurnSummary: Summary of the LLM response content @@ -288,7 +298,9 @@ async def retrieve_response( # pylint: disable=too-many-locals response: Optional[OpenAIResponseObject] = None try: moderation_result = await run_shield_moderation( - client, cast(str, responses_params.input), shield_ids + client, + moderation_input or cast(str, responses_params.input), + shield_ids, ) if moderation_result.decision == "blocked": # Handle shield moderation blocking diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 81f0c1dcf..6c9fe639d 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -3,7 +3,6 @@ import asyncio import datetime import json - from typing import Annotated, Any, AsyncIterator, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request @@ -23,6 +22,7 @@ APIStatusError as LLSApiStatusError, ) from openai._exceptions import APIStatusError as OpenAIAPIStatusError + import metrics from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -40,6 +40,7 @@ MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT, ) +from log import get_logger from models.config import Action from models.context import ResponseGeneratorContext from models.requests import QueryRequest @@ -55,16 +56,16 @@ UnauthorizedResponse, UnprocessableEntityResponse, ) -from utils.types import ReferencedDocument from utils.endpoints import ( check_configuration_loaded, validate_and_retrieve_conversation, ) -from utils.mcp_headers import mcp_headers_dependency, McpHeaders +from utils.mcp_headers import McpHeaders, mcp_headers_dependency from utils.query import ( consume_query_tokens, extract_provider_and_model_from_model_id, handle_known_apistatus_errors, + prepare_input, store_query_results, update_azure_token, validate_attachments_metadata, @@ -90,9 +91,8 @@ from utils.stream_interrupts import get_stream_interrupt_registry from utils.suid import get_suid, normalize_conversation_id from utils.token_counter import TokenCounter -from utils.types import ResponsesApiParams, TurnSummary -from utils.vector_search import format_rag_context_for_injection, perform_vector_search -from log import get_logger +from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary +from utils.vector_search import build_rag_context logger = get_logger(__name__) router = APIRouter(tags=["streaming_query"]) @@ -185,15 +185,10 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals client = AsyncLlamaStackClientHolder().get_client() - _, _, doc_ids_from_chunks, pre_rag_chunks = await perform_vector_search( - client, query_request.query, query_request.solr + # Build RAG context from Inline RAG sources + inline_rag_context = await build_rag_context( + client, query_request.query, query_request.vector_store_ids, query_request.solr ) - - rag_context = format_rag_context_for_injection(pre_rag_chunks) - if rag_context: - query_request = query_request.model_copy(deep=True) - query_request.query = query_request.query + rag_context - # Prepare API request parameters responses_params = await prepare_responses_params( client=client, @@ -204,6 +199,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals stream=True, store=True, request_headers=request.headers, + inline_rag_context=inline_rag_context.context_text or None, ) # Handle Azure token refresh if needed @@ -240,7 +236,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals generator, turn_summary = await retrieve_response_generator( responses_params=responses_params, context=context, - doc_ids_from_chunks=doc_ids_from_chunks, + inline_rag_documents=inline_rag_context.referenced_documents, ) response_media_type = ( @@ -263,7 +259,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, - doc_ids_from_chunks: list[ReferencedDocument], + inline_rag_documents: list[ReferencedDocument], ) -> tuple[AsyncIterator[str], TurnSummary]: """ Retrieve the appropriate response generator. @@ -275,7 +271,7 @@ async def retrieve_response_generator( Args: responses_params: The Responses API parameters context: The response generator context - doc_ids_from_chunks: List of ReferencedDocument objects extracted from static RAG + inline_rag_documents: Referenced documents from inline RAG (BYOK + Solr) Returns: tuple[AsyncIterator[str], TurnSummary]: The response generator and turn summary @@ -285,7 +281,7 @@ async def retrieve_response_generator( try: moderation_result = await run_shield_moderation( context.client, - cast(str, responses_params.input), + prepare_input(context.query_request), context.query_request.shield_ids, ) if moderation_result.decision == "blocked": @@ -305,8 +301,8 @@ async def retrieve_response_generator( response = await context.client.responses.create( **responses_params.model_dump(exclude_none=True) ) - # Store pre-RAG documents for later merging - turn_summary.pre_rag_documents = doc_ids_from_chunks + # Store pre-RAG documents for later merging with tool-based RAG + turn_summary.inline_rag_documents = inline_rag_documents return response_generator(response, context, turn_summary), turn_summary # Handle know LLS client errors only at stream creation time and shield execution @@ -747,8 +743,9 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat rag_id_mapping=context.rag_id_mapping, ) + # Merge pre-RAG documents with tool-based documents and deduplicate turn_summary.referenced_documents = deduplicate_referenced_documents( - tool_based_documents + turn_summary.pre_rag_documents + turn_summary.inline_rag_documents + tool_based_documents ) diff --git a/src/configuration.py b/src/configuration.py index c918be9ea..c9ea8e4af 100644 --- a/src/configuration.py +++ b/src/configuration.py @@ -7,6 +7,7 @@ from llama_stack.core.stack import replace_env_vars import yaml +import constants from models.config import ( A2AStateConfiguration, AuthorizationConfiguration, @@ -14,6 +15,8 @@ Configuration, Customization, LlamaStackConfiguration, + OkpConfiguration, + RagConfiguration, UserDataCollection, ServiceConfiguration, ModelContextProtocolServer, @@ -22,7 +25,6 @@ DatabaseConfiguration, ConversationHistoryConfiguration, QuotaHandlersConfiguration, - SolrConfiguration, SplunkConfiguration, ) @@ -365,11 +367,18 @@ def deployment_environment(self) -> str: return self._configuration.deployment_environment @property - def solr(self) -> Optional[SolrConfiguration]: - """Return Solr configuration, or None if not provided.""" + def rag(self) -> "RagConfiguration": + """Return RAG configuration.""" if self._configuration is None: raise LogicError("logic error: configuration is not loaded") - return self._configuration.solr + return self._configuration.rag + + @property + def okp(self) -> "OkpConfiguration": + """Return OKP configuration.""" + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return self._configuration.okp @property def rag_id_mapping(self) -> dict[str, str]: @@ -386,6 +395,38 @@ def rag_id_mapping(self) -> dict[str, str]: raise LogicError("logic error: configuration is not loaded") return {brag.vector_db_id: brag.rag_id for brag in self._configuration.byok_rag} + @property + def score_multiplier_mapping(self) -> dict[str, float]: + """Return mapping from vector_db_id to score_multiplier from BYOK RAG config. + + Returns: + dict[str, float]: Mapping where keys are llama-stack vector_db_ids + and values are score multipliers from configuration. + + Raises: + LogicError: If the configuration has not been loaded. + """ + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return { + brag.vector_db_id: brag.score_multiplier + for brag in self._configuration.byok_rag + } + + @property + def inline_solr_enabled(self) -> bool: + """Return whether OKP is included in the inline RAG list. + + Returns: + bool: True if 'okp' appears in rag.inline, False otherwise. + + Raises: + LogicError: If the configuration has not been loaded. + """ + if self._configuration is None: + raise LogicError("logic error: configuration is not loaded") + return constants.OKP_RAG_ID in self._configuration.rag.inline + def resolve_index_name( self, vector_store_id: str, rag_id_mapping: Optional[dict[str, str]] = None ) -> str: diff --git a/src/constants.py b/src/constants.py index 902db920c..0c5437fb2 100644 --- a/src/constants.py +++ b/src/constants.py @@ -131,9 +131,6 @@ MCP_AUTH_CLIENT = "client" MCP_AUTH_OAUTH = "oauth" -# default RAG tool value -DEFAULT_RAG_TOOL = "file_search" - # Media type constants for streaming responses MEDIA_TYPE_JSON = "application/json" MEDIA_TYPE_TEXT = "text/plain" @@ -174,14 +171,39 @@ USER_QUOTA_LIMITER = "user_limiter" CLUSTER_QUOTA_LIMITER = "cluster_limiter" -# Vector search constants -VECTOR_SEARCH_DEFAULT_K = 5 -VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD = 0.0 -VECTOR_SEARCH_DEFAULT_MODE = "hybrid" +# RAG as a tool constants +DEFAULT_RAG_TOOL = "file_search" +TOOL_RAG_MAX_CHUNKS = 10 # retrieved from RAG as a tool + +# Inline RAG constants +BYOK_RAG_MAX_CHUNKS = 10 # retrieved from BYOK RAG +OKP_RAG_MAX_CHUNKS = 5 # retrieved from OKP RAG + +# Solr OKP constants +SOLR_VECTOR_SEARCH_DEFAULT_K = 5 +SOLR_VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD = 0.3 +SOLR_VECTOR_SEARCH_DEFAULT_MODE = "hybrid" # SOLR OKP RAG MIMIR_DOC_URL = "https://mimir.corp.redhat.com" +SOLR_PROVIDER_ID = "okp_solr" + +# Solr default configuration values (can be overridden via environment variables) +SOLR_DEFAULT_VECTOR_STORE_ID = "portal-rag" +SOLR_DEFAULT_VECTOR_FIELD = "chunk_vector" +SOLR_DEFAULT_CONTENT_FIELD = "chunk" +SOLR_DEFAULT_EMBEDDING_MODEL = ( + "sentence-transformers/ibm-granite/granite-embedding-30m-english" +) +SOLR_DEFAULT_EMBEDDING_DIMENSION = 384 + +# Default score multiplier for BYOK RAG vector stores +DEFAULT_SCORE_MULTIPLIER = 1.0 + +# Special RAG ID that activates the OKP provider when listed in rag.inline or rag.tool +OKP_RAG_ID = "okp" + # Logging configuration constants # Environment variable name for configurable log level LIGHTSPEED_STACK_LOG_LEVEL_ENV_VAR = "LIGHTSPEED_STACK_LOG_LEVEL" diff --git a/src/llama_stack_configuration.py b/src/llama_stack_configuration.py index b74fbd7f9..fba64ee8b 100644 --- a/src/llama_stack_configuration.py +++ b/src/llama_stack_configuration.py @@ -14,6 +14,8 @@ from azure.core.exceptions import ClientAuthenticationError from azure.identity import ClientSecretCredential, CredentialUnavailableError from llama_stack.core.stack import replace_env_vars + +import constants from log import get_logger logger = get_logger(__name__) @@ -137,11 +139,13 @@ def construct_storage_backends_section( # add new backends for each BYOK RAG for brag in byok_rag: - vector_db_id = brag.get("vector_db_id", "") - backend_name = f"byok_{vector_db_id}_storage" + if not brag.get("rag_id"): + raise ValueError(f"BYOK RAG entry is missing required 'rag_id': {brag}") + rag_id = brag["rag_id"] + backend_name = f"byok_{rag_id}_storage" output[backend_name] = { "type": "kv_sqlite", - "db_path": brag.get("db_path", f".llama/{vector_db_id}.db"), + "db_path": brag.get("db_path", f".llama/{rag_id}.db"), } logger.info( "Added %s backends into storage.backends section, total backends %s", @@ -183,16 +187,24 @@ def construct_vector_stores_section( existing_store_ids = {vs.get("vector_store_id") for vs in output} added = 0 for brag in byok_rag: - vector_db_id = brag.get("vector_db_id", "") + if not brag.get("rag_id"): + raise ValueError(f"BYOK RAG entry is missing required 'rag_id': {brag}") + if not brag.get("vector_db_id"): + raise ValueError( + f"BYOK RAG entry is missing required 'vector_db_id': {brag}" + ) + rag_id = brag["rag_id"] + vector_db_id = brag["vector_db_id"] if vector_db_id in existing_store_ids: continue existing_store_ids.add(vector_db_id) added += 1 + embedding_model = brag.get("embedding_model", constants.DEFAULT_EMBEDDING_MODEL) output.append( { "vector_store_id": vector_db_id, - "provider_id": f"byok_{vector_db_id}", - "embedding_model": brag.get("embedding_model", ""), + "provider_id": f"byok_{rag_id}", + "embedding_model": embedding_model, "embedding_dimension": brag.get("embedding_dimension"), } ) @@ -227,10 +239,16 @@ def construct_models_section( # add embedding models for each BYOK RAG for brag in byok_rag: - embedding_model = brag.get("embedding_model", "") - vector_db_id = brag.get("vector_db_id", "") + if not brag.get("rag_id"): + raise ValueError(f"BYOK RAG entry is missing required 'rag_id': {brag}") + rag_id = brag["rag_id"] + embedding_model = brag.get("embedding_model", constants.DEFAULT_EMBEDDING_MODEL) embedding_dimension = brag.get("embedding_dimension") + # Skip if no embedding model specified + if not embedding_model: + continue + # Strip sentence-transformers/ prefix if present provider_model_id = embedding_model if provider_model_id.startswith("sentence-transformers/"): @@ -243,7 +261,7 @@ def construct_models_section( output.append( { - "model_id": f"byok_{vector_db_id}_embedding", + "model_id": f"byok_{rag_id}_embedding", "model_type": "embedding", "provider_id": "sentence-transformers", "provider_model_id": provider_model_id, @@ -290,9 +308,11 @@ def construct_vector_io_providers_section( # append new vector_io entries for brag in byok_rag: - vector_db_id = brag.get("vector_db_id", "") - backend_name = f"byok_{vector_db_id}_storage" - provider_id = f"byok_{vector_db_id}" + if not brag.get("rag_id"): + raise ValueError(f"BYOK RAG entry is missing required 'rag_id': {brag}") + rag_id = brag["rag_id"] + backend_name = f"byok_{rag_id}_storage" + provider_id = f"byok_{rag_id}" output.append( { "provider_id": provider_id, @@ -353,6 +373,146 @@ def enrich_byok_rag(ls_config: dict[str, Any], byok_rag: list[dict[str, Any]]) - ) +# ============================================================================= +# Enrichment: Solr +# ============================================================================= + + +def enrich_solr(ls_config: dict[str, Any], solr_config: dict[str, Any]) -> None: + """Enrich Llama Stack config with Solr settings. + + Args: + ls_config: Llama Stack configuration dict (modified in place) + solr_config: Solr configuration dict. Expected keys: + - enabled (bool): whether Solr enrichment should run + - chunk_filter_query (str): Solr filter query for chunk retrieval + """ + if not solr_config or not solr_config.get("enabled"): + logger.info("OKP is not enabled: skipping") + return + + logger.info("Enriching Llama Stack config with OKP") + + # Add vector_io provider for Solr + if "providers" not in ls_config: + ls_config["providers"] = {} + if "vector_io" not in ls_config["providers"]: + ls_config["providers"]["vector_io"] = [] + + # Add Solr provider if not already present + existing_providers = [ + p.get("provider_id") for p in ls_config["providers"]["vector_io"] + ] + if constants.SOLR_PROVIDER_ID not in existing_providers: + # Build environment variable expressions + solr_url_env = "${env.SOLR_URL:=http://localhost:8081/solr}" + collection_env = ( + f"${{env.SOLR_COLLECTION:={constants.SOLR_DEFAULT_VECTOR_STORE_ID}}}" + ) + vector_field_env = ( + f"${{env.SOLR_VECTOR_FIELD:={constants.SOLR_DEFAULT_VECTOR_FIELD}}}" + ) + content_field_env = ( + f"${{env.SOLR_CONTENT_FIELD:={constants.SOLR_DEFAULT_CONTENT_FIELD}}}" + ) + embedding_model_env = ( + f"${{env.SOLR_EMBEDDING_MODEL:={constants.SOLR_DEFAULT_EMBEDDING_MODEL}}}" + ) + embedding_dim_env = ( + f"${{env.SOLR_EMBEDDING_DIM:={constants.SOLR_DEFAULT_EMBEDDING_DIMENSION}}}" + ) + + chunk_filter_query = solr_config.get("chunk_filter_query", "is_chunk:true") + + ls_config["providers"]["vector_io"].append( + { + "provider_id": constants.SOLR_PROVIDER_ID, + "provider_type": "remote::solr_vector_io", + "config": { + "solr_url": solr_url_env, + "collection_name": collection_env, + "vector_field": vector_field_env, + "content_field": content_field_env, + "embedding_model": embedding_model_env, + "embedding_dimension": embedding_dim_env, + "chunk_window_config": { + "chunk_parent_id_field": "parent_id", + "chunk_content_field": "chunk_field", + "chunk_index_field": "chunk_index", + "chunk_token_count_field": "num_tokens", + "parent_total_chunks_field": "total_chunks", + "parent_total_tokens_field": "total_tokens", + "chunk_filter_query": chunk_filter_query, + }, + "persistence": { + "namespace": constants.SOLR_DEFAULT_VECTOR_STORE_ID, + "backend": "kv_default", + }, + }, + } + ) + logger.info("Added OKP provider to providers/vector_io") + + # Add vector store registration for Solr + if "registered_resources" not in ls_config: + ls_config["registered_resources"] = {} + if "vector_stores" not in ls_config["registered_resources"]: + ls_config["registered_resources"]["vector_stores"] = [] + + # Add Solr vector store if not already present + existing_stores = [ + vs.get("vector_store_id") + for vs in ls_config["registered_resources"]["vector_stores"] + ] + if constants.SOLR_DEFAULT_VECTOR_STORE_ID not in existing_stores: + # Build environment variable expression + embedding_model_env = ( + f"${{env.SOLR_EMBEDDING_MODEL:={constants.SOLR_DEFAULT_EMBEDDING_MODEL}}}" + ) + + ls_config["registered_resources"]["vector_stores"].append( + { + "vector_store_id": constants.SOLR_DEFAULT_VECTOR_STORE_ID, + "provider_id": constants.SOLR_PROVIDER_ID, + "embedding_model": embedding_model_env, + "embedding_dimension": constants.SOLR_DEFAULT_EMBEDDING_DIMENSION, + } + ) + logger.info( + "Added %s vector store to registered_resources", + constants.SOLR_DEFAULT_VECTOR_STORE_ID, + ) + + # Add Solr embedding model to registered_resources.models if not already present + if "models" not in ls_config["registered_resources"]: + ls_config["registered_resources"]["models"] = [] + + # Strip sentence-transformers/ prefix from constant for provider_model_id + provider_model_id = constants.SOLR_DEFAULT_EMBEDDING_MODEL + if provider_model_id.startswith("sentence-transformers/"): + provider_model_id = provider_model_id[len("sentence-transformers/") :] + + # Check if already registered + registered_models = ls_config["registered_resources"]["models"] + existing_model_ids = [m.get("provider_model_id") for m in registered_models] + if provider_model_id not in existing_model_ids: + # Build environment variable expression + provider_model_env = f"${{env.SOLR_EMBEDDING_MODEL:={provider_model_id}}}" + + ls_config["registered_resources"]["models"].append( + { + "model_id": "solr_embedding", + "model_type": "embedding", + "provider_id": "sentence-transformers", + "provider_model_id": provider_model_env, + "metadata": { + "embedding_dimension": constants.SOLR_DEFAULT_EMBEDDING_DIMENSION, + }, + } + ) + logger.info("Added OKP embedding model to registered_resources.models") + + # ============================================================================= # Main Generation Function (service/container mode only) # ============================================================================= @@ -383,6 +543,17 @@ def generate_configuration( # Enrichment: BYOK RAG enrich_byok_rag(ls_config, config.get("byok_rag", [])) + # Enrichment: Solr - enabled when "okp" appears in either inline or tool list + rag_config = config.get("rag", {}) + inline_ids = rag_config.get("inline") or [] + tool_ids = rag_config.get("tool") or [] + okp_enabled = constants.OKP_RAG_ID in inline_ids or constants.OKP_RAG_ID in tool_ids + okp_config = config.get("okp", {}) + chunk_filter_query = okp_config.get("chunk_filter_query", "is_chunk:true") + enrich_solr( + ls_config, {"enabled": okp_enabled, "chunk_filter_query": chunk_filter_query} + ) + logger.info("Writing Llama Stack configuration into file %s", output_file) with open(output_file, "w", encoding="utf-8") as file: diff --git a/src/models/config.py b/src/models/config.py index 1aca828ad..af29553ce 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -2,37 +2,35 @@ # pylint: disable=too-many-lines -from pathlib import Path -from typing import Optional, Any, Pattern +import re from enum import Enum from functools import cached_property -import re -import yaml +from pathlib import Path +from typing import Any, Optional, Pattern import jsonpath_ng +import yaml from jsonpath_ng.exceptions import JSONPathError from pydantic import ( + AnyHttpUrl, BaseModel, ConfigDict, Field, - field_validator, - model_validator, FilePath, - AnyHttpUrl, - PositiveInt, NonNegativeInt, - SecretStr, + PositiveInt, PrivateAttr, + SecretStr, + field_validator, + model_validator, ) - from pydantic.dataclasses import dataclass -from typing_extensions import Self, Literal +from typing_extensions import Literal, Self import constants - +from log import get_logger from utils import checks from utils.mcp_auth_headers import resolve_authorization_headers -from log import get_logger logger = get_logger(__name__) @@ -1565,6 +1563,15 @@ class ByokRag(ConfigurationBase): description="Path to RAG database.", ) + score_multiplier: float = Field( + constants.DEFAULT_SCORE_MULTIPLIER, + gt=0, + title="Score multiplier", + description="Multiplier applied to relevance scores from this vector store. " + "Used to weight results when querying multiple knowledge sources. " + "Values > 1 boost this store's results; values < 1 reduce them.", + ) + class QuotaLimiterConfiguration(ConfigurationBase): """Configuration for one quota limiter. @@ -1687,27 +1694,59 @@ class QuotaHandlersConfiguration(ConfigurationBase): ) -class SolrConfiguration(ConfigurationBase): - """Solr configuration for vector search queries. +class RagConfiguration(ConfigurationBase): + """RAG strategy configuration. + + Controls which RAG sources are used for inline and tool-based retrieval. - Controls whether to use offline or online mode when building document URLs - from vector search results, and enables/disables Solr vector IO functionality. + Each strategy lists RAG IDs to include. The special ID ``"okp"`` defined in constants, + activates the OKP provider; all other IDs refer to entries in ``byok_rag``. + + Backward compatibility: + - ``inline`` defaults to ``[]`` (no inline RAG). + - ``tool`` defaults to ``None`` which means all registered vector stores + are used (identical to the previous ``tool.byok.enabled = True`` default). """ - enabled: bool = Field( - False, - title="Solr enabled", - description="When True, enables Solr vector IO functionality for vector search queries. " - "When False, disables Solr vector search processing.", + inline: list[str] = Field( + default_factory=list, + title="Inline RAG IDs", + description="RAG IDs whose sources are injected as context before the LLM call. " + f"Use '{constants.OKP_RAG_ID}' to enable OKP inline RAG. Empty by default (no inline RAG).", + ) + + tool: Optional[list[str]] = Field( + default=None, + title="Tool RAG IDs", + description="RAG IDs made available to the LLM as a file_search tool. " + f"Use '{constants.OKP_RAG_ID}' to include the OKP vector store. " + "When omitted, all registered BYOK vector stores are used (backward compatibility).", ) + +class OkpConfiguration(ConfigurationBase): + """OKP (Offline Knowledge Portal) provider configuration. + + Controls provider-specific behaviour for the OKP vector store. + Only relevant when ``"okp"`` is listed in ``rag.inline`` or ``rag.tool``. + """ + offline: bool = Field( - True, - title="Offline mode", - description="When True, use parent_id for chunk source URLs. " + default=True, + title="OKP offline mode", + description="When True, use parent_id for OKP chunk source URLs. " "When False, use reference_url for chunk source URLs.", ) + chunk_filter_query: str = Field( + default="is_chunk:true", + title="OKP chunk filter query", + description="OKP filter query applied to every OKP search request. " + "Defaults to 'is_chunk:true' to restrict results to chunk documents. " + "To add extra constraints, extend the expression using boolean syntax, " + "e.g. 'is_chunk:true AND product:*openshift*'.", + ) + class AzureEntraIdConfiguration(ConfigurationBase): """Microsoft Entra ID authentication attributes for Azure.""" @@ -1847,10 +1886,17 @@ class Configuration(ConfigurationBase): "Used in telemetry events.", ) - solr: Optional[SolrConfiguration] = Field( - default=None, - title="Solr configuration", - description="Configuration for Solr vector search operations.", + rag: RagConfiguration = Field( + default_factory=RagConfiguration, + title="RAG configuration", + description="Configuration for all RAG strategies (inline and tool-based).", + ) + + okp: OkpConfiguration = Field( + default_factory=OkpConfiguration, + title="OKP configuration", + description=f"OKP provider settings. Only used when '{constants.OKP_RAG_ID}' is listed " + "in rag.inline or rag.tool.", ) @model_validator(mode="after") diff --git a/src/utils/query.py b/src/utils/query.py index 9d447ff9c..8d96b5eb6 100644 --- a/src/utils/query.py +++ b/src/utils/query.py @@ -1,23 +1,31 @@ """Utility functions for working with queries.""" +import sqlite3 from datetime import UTC, datetime from typing import Optional +import psycopg2 +from fastapi import HTTPException from llama_stack_client import ( APIConnectionError, APIStatusError as LLSApiStatusError, AsyncLlamaStackClient, ) -from openai._exceptions import APIStatusError as OpenAIAPIStatusError from llama_stack_client.types import Shield - -from fastapi import HTTPException +from openai._exceptions import APIStatusError as OpenAIAPIStatusError from sqlalchemy import func +from sqlalchemy.exc import SQLAlchemyError + +import constants +from app.database import get_session +from authorization.azure_token_manager import AzureEntraIDManager +from cache.cache_error import CacheError +from client import AsyncLlamaStackClientHolder from configuration import configuration +from log import get_logger from models.cache_entry import CacheEntry from models.config import Action from models.database.conversations import UserConversation, UserTurn -import constants from models.requests import Attachment, QueryRequest from models.responses import ( AbstractErrorResponse, @@ -28,23 +36,15 @@ ServiceUnavailableResponse, UnprocessableEntityResponse, ) -from authorization.azure_token_manager import AzureEntraIDManager -from cache.cache_error import CacheError -import psycopg2 -import sqlite3 -from sqlalchemy.exc import SQLAlchemyError -from app.database import get_session -from client import AsyncLlamaStackClientHolder +from utils.quota import consume_tokens +from utils.suid import normalize_conversation_id +from utils.token_counter import TokenCounter from utils.transcripts import ( create_transcript, create_transcript_metadata, store_transcript, ) -from utils.quota import consume_tokens -from utils.suid import normalize_conversation_id -from utils.token_counter import TokenCounter from utils.types import TurnSummary -from log import get_logger logger = get_logger(__name__) @@ -192,19 +192,27 @@ async def update_azure_token( ) -def prepare_input(query_request: QueryRequest) -> str: +def prepare_input( + query_request: QueryRequest, inline_rag_context: Optional[str] = None +) -> str: """ - Prepare input text for Responses API by appending attachments. + Prepare input text for Responses API by appending RAG context and attachments. - Takes the query text and appends any attachment content with type labels. + Takes the query text, appends any inline RAG context for the LLM call, then + appends any attachment content with type labels. Args: query_request: The query request containing the query and optional attachments + inline_rag_context: Optional RAG context to inject into the query before + sending to the LLM. Passed separately to keep QueryRequest a pure public + API model. Returns: - str: The input text with attachments appended (if any) + str: The input text with RAG context and attachments appended (if any) """ input_text = query_request.query + if inline_rag_context: + input_text += f"\n\n{inline_rag_context}" if query_request.attachments: for attachment in query_request.attachments: # Append attachment content with type label diff --git a/src/utils/responses.py b/src/utils/responses.py index 71e8afbe9..b44fb8d28 100644 --- a/src/utils/responses.py +++ b/src/utils/responses.py @@ -11,8 +11,11 @@ OpenAIResponseContentPartRefusal as ContentPartRefusal, OpenAIResponseInputMessageContent as InputMessageContent, OpenAIResponseInputMessageContentText as InputTextPart, + OpenAIResponseInputTool as InputTool, OpenAIResponseInputToolFileSearch as InputToolFileSearch, OpenAIResponseInputToolMCP as InputToolMCP, + OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, + OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, OpenAIResponseMessage as ResponseMessage, OpenAIResponseObject as ResponseObject, OpenAIResponseOutput as ResponseOutput, @@ -23,10 +26,7 @@ OpenAIResponseOutputMessageMCPCall as MCPCall, OpenAIResponseOutputMessageMCPListTools as MCPListTools, OpenAIResponseOutputMessageWebSearchToolCall as WebSearchCall, - OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, - OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, OpenAIResponseUsage as ResponseUsage, - OpenAIResponseInputTool as InputTool, ) from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient @@ -34,6 +34,7 @@ import metrics from configuration import configuration from constants import DEFAULT_RAG_TOOL +from log import get_logger from models.config import ByokRag from models.database.conversations import UserConversation from models.requests import QueryRequest @@ -42,6 +43,7 @@ NotFoundResponse, ServiceUnavailableResponse, ) +from utils.mcp_headers import McpHeaders, extract_propagated_headers from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401 from utils.prompts import get_system_prompt, get_topic_summary_system_prompt from utils.query import ( @@ -49,7 +51,6 @@ handle_known_apistatus_errors, prepare_input, ) -from utils.mcp_headers import McpHeaders, extract_propagated_headers from utils.suid import to_llama_stack_conversation_id from utils.token_counter import TokenCounter from utils.types import ( @@ -61,12 +62,49 @@ ToolResultSummary, TurnSummary, ) -from log import get_logger logger = get_logger(__name__) -async def get_topic_summary( +async def get_vector_store_ids( + client: AsyncLlamaStackClient, + vector_store_ids: Optional[list[str]] = None, +) -> list[str]: + """Get vector store IDs for querying. + + If vector_store_ids are provided, returns them. Otherwise fetches all + available vector stores from Llama Stack. + + Args: + client: The AsyncLlamaStackClient to use for fetching stores + vector_store_ids: Optional list of vector store IDs. If provided, + returns this list. If None, fetches all available vector stores. + + Returns: + List of vector store IDs to query + + Raises: + HTTPException: With ServiceUnavailableResponse if connection fails, + or InternalServerErrorResponse if API returns an error status + """ + if vector_store_ids is not None: + return vector_store_ids + + try: + vector_stores = await client.vector_stores.list() + return [vector_store.id for vector_store in vector_stores.data] + except APIConnectionError as e: + error_response = ServiceUnavailableResponse( + backend_name="Llama Stack", + cause=str(e), + ) + raise HTTPException(**error_response.model_dump()) from e + except APIStatusError as e: + error_response = InternalServerErrorResponse.generic() + raise HTTPException(**error_response.model_dump()) from e + + +async def get_topic_summary( # pylint: disable=too-many-nested-blocks question: str, client: AsyncLlamaStackClient, model_id: str ) -> str: """Get a topic summary for a question using Responses API. @@ -129,28 +167,22 @@ async def prepare_tools( # pylint: disable=too-many-arguments,too-many-position return None toolgroups: list[InputTool] = [] - # Get all vector stores if vector stores are not restricted by request - if vector_store_ids is None: - try: - vector_stores = await client.vector_stores.list() - vector_store_ids = [vector_store.id for vector_store in vector_stores.data] - except APIConnectionError as e: - error_response = ServiceUnavailableResponse( - backend_name="Llama Stack", - cause=str(e), - ) - raise HTTPException(**error_response.model_dump()) from e - except APIStatusError as e: - error_response = InternalServerErrorResponse.generic() - raise HTTPException(**error_response.model_dump()) from e - else: - # Translate customer-facing BYOK rag_ids to llama-stack vector_db_ids - vector_store_ids = resolve_vector_store_ids( - vector_store_ids, configuration.configuration.byok_rag + + # Priority: per-request IDs > rag.tool config > all registered stores. + # In all cases, customer-facing rag_ids are translated to internal vector_db_ids. + # IDs fetched from llama-stack are already internal and need no translation. + byok_rags = configuration.configuration.byok_rag + if vector_store_ids is not None: + effective_ids: list[str] = resolve_vector_store_ids(vector_store_ids, byok_rags) + elif configuration.configuration.rag.tool is not None: + effective_ids = resolve_vector_store_ids( + configuration.configuration.rag.tool, byok_rags ) + else: + effective_ids = await get_vector_store_ids(client, None) # Add RAG tools if vector stores are available - rag_tools = get_rag_tools(vector_store_ids) + rag_tools = get_rag_tools(effective_ids) if rag_tools: toolgroups.extend(rag_tools) @@ -210,6 +242,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma stream: bool = False, store: bool = True, request_headers: Optional[Mapping[str, str]] = None, + inline_rag_context: Optional[str] = None, ) -> ResponsesApiParams: """Prepare API request parameters for Responses API. @@ -222,6 +255,9 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma stream: Whether to stream the response store: Whether to store the response request_headers: Incoming HTTP request headers for allowlist propagation + inline_rag_context: Optional RAG context to inject into the query before + sending to the LLM. Passed separately to keep QueryRequest a pure public + API model. Returns: ResponsesApiParams containing all prepared parameters for the API request @@ -251,7 +287,8 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma ) # Prepare input for Responses API - input_text = prepare_input(query_request) + # Adds inline RAG context and attachments + input_text = prepare_input(query_request, inline_rag_context) # Handle conversation ID for Responses API conversation_id = query_request.conversation_id @@ -318,10 +355,11 @@ def extract_vector_store_ids_from_tools( def resolve_vector_store_ids( vector_store_ids: list[str], byok_rags: list[ByokRag] ) -> list[str]: - """Translate customer-facing BYOK rag_ids to llama-stack vector_db_ids. + """Translate customer-facing rag_ids to llama-stack vector_db_ids. Each ID is looked up against the BYOK RAG configuration. If a matching ``rag_id`` is found, the corresponding ``vector_db_id`` is returned. + The special ``okp`` ID is mapped to the Solr vector store ID. Otherwise the ID is passed through unchanged (assumed to already be a llama-stack vector store ID). @@ -334,6 +372,9 @@ def resolve_vector_store_ids( List of llama-stack vector_db_ids ready for the Llama Stack API. """ rag_id_to_vector_db_id = {brag.rag_id: brag.vector_db_id for brag in byok_rags} + rag_id_to_vector_db_id[constants.OKP_RAG_ID] = ( + constants.SOLR_DEFAULT_VECTOR_STORE_ID + ) return [rag_id_to_vector_db_id.get(vs_id, vs_id) for vs_id in vector_store_ids] @@ -344,16 +385,16 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[InputToolFileSea vector_store_ids: List of vector store identifiers Returns: - List containing file_search tool configuration, or None if no vector stores provided + List containing file_search tool configuration, or empty list if no stores available """ - if not vector_store_ids: - return None + if vector_store_ids == []: + return [] return [ InputToolFileSearch( type="file_search", vector_store_ids=vector_store_ids, - max_num_results=10, + max_num_results=constants.TOOL_RAG_MAX_CHUNKS, ) ] diff --git a/src/utils/types.py b/src/utils/types.py index 220a85239..c3a0c71d3 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -21,8 +21,6 @@ from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.lib.agents.types import ( CompletionMessage as AgentCompletionMessage, -) -from llama_stack_client.lib.agents.types import ( ToolCall as AgentToolCall, ) from pydantic import AnyUrl, BaseModel, Field @@ -285,6 +283,26 @@ class ReferencedDocument(BaseModel): ) +class RAGContext(BaseModel): + """Result of building RAG context from all enabled pre-query RAG sources. + + Attributes: + context_text: Formatted RAG context string for injection into the query. + rag_chunks: RAG chunks from pre-query sources (BYOK + Solr). + referenced_documents: Referenced documents from pre-query sources. + """ + + context_text: str = Field(default="", description="Formatted context for injection") + rag_chunks: list[RAGChunk] = Field( + default_factory=list, + description="RAG chunks from pre-query sources", + ) + referenced_documents: list[ReferencedDocument] = Field( + default_factory=list, + description="Documents from pre-query sources", + ) + + class TurnSummary(BaseModel): """Summary of a turn in llama stack.""" @@ -293,7 +311,7 @@ class TurnSummary(BaseModel): tool_results: list[ToolResultSummary] = Field(default_factory=list) rag_chunks: list[RAGChunk] = Field(default_factory=list) referenced_documents: list[ReferencedDocument] = Field(default_factory=list) - pre_rag_documents: list[ReferencedDocument] = Field(default_factory=list) + inline_rag_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index e39e9ec04..485914e0b 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -4,61 +4,180 @@ and processing RAG chunks that is shared between query_v2.py and streaming_query_v2.py. """ +import asyncio import traceback from typing import Any, Optional from urllib.parse import urljoin from llama_stack_client import AsyncLlamaStackClient -from llama_stack_client.types.query_chunks_response import Chunk from pydantic import AnyUrl import constants from configuration import configuration from log import get_logger from models.responses import ReferencedDocument -from utils.types import RAGChunk +from utils.responses import resolve_vector_store_ids +from utils.types import RAGChunk, RAGContext logger = get_logger(__name__) def _is_solr_enabled() -> bool: - """Check if Solr is enabled in configuration.""" - return bool(configuration.solr and configuration.solr.enabled) + """Check if Solr is enabled for inline RAG in configuration.""" + return configuration.inline_solr_enabled -def _get_vector_store_ids(solr_enabled: bool) -> list[str]: +def _get_solr_vector_store_ids() -> list[str]: """Get vector store IDs based on Solr configuration.""" - if solr_enabled: - vector_store_ids = ["portal-rag"] - logger.info( - "Using portal-rag vector store for Solr query: %s", - vector_store_ids, - ) - return vector_store_ids - return [] + vector_store_ids = [constants.SOLR_DEFAULT_VECTOR_STORE_ID] + logger.info( + "Using %s vector store for OKP query: %s", + constants.SOLR_DEFAULT_VECTOR_STORE_ID, + vector_store_ids, + ) + return vector_store_ids def _build_query_params(solr: Optional[dict[str, Any]] = None) -> dict[str, Any]: """Build query parameters for vector search.""" params = { - "k": constants.VECTOR_SEARCH_DEFAULT_K, - "score_threshold": constants.VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD, - "mode": constants.VECTOR_SEARCH_DEFAULT_MODE, + "k": constants.SOLR_VECTOR_SEARCH_DEFAULT_K, + "score_threshold": constants.SOLR_VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD, + "mode": constants.SOLR_VECTOR_SEARCH_DEFAULT_MODE, } - logger.info("Initial params: %s", params) - logger.info("solr: %s", solr) + logger.debug("Initial params: %s", params) + logger.debug("query_request.solr: %s", solr) if solr: params["solr"] = solr - logger.info("Final params with solr filters: %s", params) + logger.debug("Final params with solr filters: %s", params) else: - logger.info("No solr filters provided") + logger.debug("No solr filters provided") - logger.info("Final params being sent to vector_io.query: %s", params) + logger.debug("Final params being sent to vector_io.query: %s", params) return params -def _extract_document_metadata( +def _extract_byok_rag_chunks( + search_response: Any, vector_store_id: str, weight: float +) -> list[dict[str, Any]]: + """Extract and weight result chunks from vector search for BYOK RAG. + + Args: + search_response: Response from vector_io.query + vector_store_id: ID of the vector store that produced these results + weight: Score multiplier to apply to this store's results + + Returns: + List of result dictionaries with weighted scores + """ + result_chunks = [] + for chunk, score in zip( + search_response.chunks, search_response.scores, strict=True + ): + weighted_score = score * weight + doc_id = ( + chunk.metadata.get("document_id", chunk.chunk_id) + if chunk.metadata + else chunk.chunk_id + ) + logger.debug( + " [%s] score=%.4f weighted=%.4f", + vector_store_id, + score, + weighted_score, + ) + result_chunks.append( + { + "content": chunk.content, + "score": score, + "weighted_score": weighted_score, + "source": vector_store_id, + "doc_id": doc_id, + "metadata": chunk.metadata or {}, + } + ) + return result_chunks + + +def _format_rag_context(rag_chunks: list[RAGChunk], query: str) -> str: + """Format RAG chunks for pre-query context injection. + + This format is used for both BYOK RAG and Solr RAG chunks. + Format is inspired by llama-stack file_search tool implementation. + + Args: + rag_chunks: List of RAG chunks from pre-query sources (BYOK + Solr) + query: The original search query + + Returns: + Formatted string with RAG context metadata attributes + """ + if not rag_chunks: + return "" + + output = f"file_search found {len(rag_chunks)} chunks:\n" + output += "BEGIN of file_search results.\n" + + for i, chunk in enumerate(rag_chunks, 1): + # Build metadata text with source and score + metadata_parts = [] + if chunk.source: + metadata_parts.append(f"document_id: {chunk.source}") + if chunk.score is not None: + metadata_parts.append(f"score: {chunk.score:.4f}") + + metadata_text = ", ".join(metadata_parts) + + # Add additional attributes if present + if chunk.attributes: + metadata_text += f", attributes: {chunk.attributes}" + + # Format chunk with metadata and content + output += f"[{i}] {metadata_text}\n{chunk.content}\n\n" + + output += "END of file_search results.\n" + + output += ( + f'The above results were retrieved to help answer the user\'s query: "{query}". ' + "Use them as supporting information only in answering this query. " + ) + return output + + +async def _query_store_for_byok_rag( + client: AsyncLlamaStackClient, + vector_store_id: str, + query: str, + weight: float, +) -> list[dict[str, Any]]: + """Query a single vector store for BYOK RAG. + + Args: + client: AsyncLlamaStackClient for vector_io queries + vector_store_id: ID of the vector store to query + query: Search query string + weight: Score multiplier to apply + + Returns: + List of weighted result dictionaries, or empty list on error + """ + try: + search_response = await client.vector_io.query( + vector_store_id=vector_store_id, + query=query, + params={ + "max_chunks": constants.BYOK_RAG_MAX_CHUNKS, + "mode": "vector", + }, + ) + return _extract_byok_rag_chunks(search_response, vector_store_id, weight) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to search '%s': %s", vector_store_id, e) + return [] + + +def _extract_solr_document_metadata( chunk: Any, ) -> tuple[Optional[str], Optional[str], Optional[str]]: """Extract document ID, title, and reference URL from chunk metadata.""" @@ -86,17 +205,83 @@ def _extract_document_metadata( return doc_id, title, reference_url -def _process_chunks_for_documents( +def _process_byok_rag_chunks_for_documents( + result_chunks: list[dict[str, Any]], +) -> list[ReferencedDocument]: + """Process BYOK RAG result chunks to extract referenced documents. + + Args: + result_chunks: Processed result dictionaries from BYOK RAG + (output of _extract_byok_rag_chunks) + + Returns: + List of referenced documents extracted from BYOK RAG chunks + """ + referenced_documents = [] + seen_doc_ids = set() + + for result in result_chunks: + metadata = result.get("metadata", {}) + doc_id = result.get("doc_id") or metadata.get("document_id") + title = metadata.get("title") + reference_url = ( + metadata.get("reference_url") + or metadata.get("doc_url") + or metadata.get("docs_url") + ) + + if not doc_id and not reference_url: + continue + + # Use doc_id or reference_url as deduplication key + dedup_key = reference_url or doc_id + if dedup_key and dedup_key not in seen_doc_ids: + seen_doc_ids.add(dedup_key) + + # Build document URL + parsed_url: Optional[AnyUrl] = None + if reference_url: + try: + parsed_url = AnyUrl(reference_url) + except Exception: # pylint: disable=broad-exception-caught + parsed_url = None + + referenced_documents.append( + ReferencedDocument( + doc_title=title, + doc_url=parsed_url, + source=result.get("source"), # Vector store ID + ) + ) + + logger.info( + "Extracted %d unique documents from BYOK RAG", + len(referenced_documents), + ) + return referenced_documents + + +def _process_solr_chunks_for_documents( chunks: list[Any], offline: bool ) -> list[ReferencedDocument]: - """Process chunks to extract referenced documents.""" + """Process Solr chunks to extract referenced documents. + + Args: + chunks: Raw chunks from Solr vector store + offline: Whether to use offline mode for URL construction + + Returns: + List of referenced documents extracted from Solr chunks + """ doc_ids_from_chunks = [] metadata_doc_ids = set() for chunk in chunks: - logger.info("Extract doc ids from chunk: %s", chunk) + logger.debug( + "Extracting doc ids from chunk id: %s", getattr(chunk, "chunk_id", None) + ) - doc_id, title, reference_url = _extract_document_metadata(chunk) + doc_id, title, reference_url = _extract_solr_document_metadata(chunk) if not doc_id and not reference_url: continue @@ -118,23 +303,124 @@ def _process_chunks_for_documents( ReferencedDocument( doc_title=title, doc_url=parsed_url, + source=constants.OKP_RAG_ID, ) ) - logger.info( - "Extracted %d unique document IDs from chunks", + logger.debug( + "Extracted %d unique document IDs from OKP chunks", len(doc_ids_from_chunks), ) return doc_ids_from_chunks -async def perform_vector_search( +async def _fetch_byok_rag( client: AsyncLlamaStackClient, query: str, - solr: Optional[dict[str, Any]] = None, -) -> tuple[list[Any], list[float], list[ReferencedDocument], list[RAGChunk]]: + vector_store_ids: Optional[list[str]] = None, +) -> tuple[list[RAGChunk], list[ReferencedDocument]]: + """Fetch chunks and documents from BYOK RAG sources. + + Args: + client: The AsyncLlamaStackClient to use for the request + query: The search query + configuration: Application configuration + vector_store_ids: Optional list of vector store IDs to query. + If provided, only these stores will be queried. If None, all stores + (excluding Solr) will be queried. + + Returns: + Tuple containing: + - rag_chunks: RAG chunks from BYOK RAG + - referenced_documents: Documents referenced in BYOK RAG results """ - Perform vector search and extract RAG chunks and referenced documents. + rag_chunks: list[RAGChunk] = [] + referenced_documents: list[ReferencedDocument] = [] + + # Determine which BYOK vector stores to query for inline RAG. + # Per-request override takes precedence; otherwise use config-based inline list. + if vector_store_ids is not None: + # Request-level override: filter out Solr store, use the rest + vector_store_ids_to_query = [ + vs_id + for vs_id in vector_store_ids + if vs_id != constants.SOLR_DEFAULT_VECTOR_STORE_ID + ] + else: + inline_rag_ids = [ + rid + for rid in configuration.configuration.rag.inline + if rid != constants.OKP_RAG_ID + ] + vector_store_ids_to_query = resolve_vector_store_ids( + inline_rag_ids, configuration.configuration.byok_rag + ) + + # If inline byok stores are not defined, we disable the inline RAG for backward compatibility + if not vector_store_ids_to_query: + logger.info("No inline BYOK RAG sources configured, skipping BYOK RAG search") + return rag_chunks, referenced_documents + + try: + # Get score multiplier and rag_id mappings + score_multiplier_mapping = configuration.score_multiplier_mapping + rag_id_mapping = configuration.rag_id_mapping + + # Query all vector stores in parallel + results_per_store = await asyncio.gather( + *[ + _query_store_for_byok_rag( + client, + vector_store_id, + query, + score_multiplier_mapping.get(vector_store_id, 1.0), + ) + for vector_store_id in vector_store_ids_to_query + ] + ) + + # Flatten, sort by weighted score, and take top results + all_results: list[dict[str, Any]] = [] + for store_results in results_per_store: + all_results.extend(store_results) + all_results.sort(key=lambda x: x["weighted_score"], reverse=True) + top_results = all_results[: constants.BYOK_RAG_MAX_CHUNKS] + + # Resolve source, log, and convert to RAGChunk in a single pass + logger.info("Filtered top %d chunks from BYOK RAG", len(top_results)) + for result in top_results: + result["source"] = rag_id_mapping.get(result["source"], result["source"]) + logger.debug( + " [%s] score=%.4f weighted=%.4f", + result["source"], + result["score"], + result["weighted_score"], + ) + rag_chunks.append( + RAGChunk( + content=result["content"], + source=result["source"], + score=result["weighted_score"], + attributes=result.get("metadata", {}), + ) + ) + + # Extract referenced documents from BYOK RAG chunks (now with resolved sources) + referenced_documents = _process_byok_rag_chunks_for_documents(top_results) + + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Failed to perform BYOK RAG search: %s", e) + logger.debug("BYOK RAG error details: %s", traceback.format_exc()) + + return rag_chunks, referenced_documents + + +async def _fetch_solr_rag( + client: AsyncLlamaStackClient, + query: str, + solr: Optional[dict[str, Any]] = None, +) -> tuple[list[RAGChunk], list[ReferencedDocument]]: + """Fetch chunks and documents from Solr RAG source. Args: client: The AsyncLlamaStackClient to use for the request @@ -143,28 +429,24 @@ async def perform_vector_search( Returns: Tuple containing: - - retrieved_chunks: Raw chunks from vector store - - retrieved_scores: Scores for each chunk - - doc_ids_from_chunks: Referenced documents extracted from chunks - - rag_chunks: Processed RAG chunks ready for use + - rag_chunks: RAG chunks from Solr + - referenced_documents: Documents referenced in Solr results """ - retrieved_chunks: list[Chunk] = [] - retrieved_scores: list[float] = [] - doc_ids_from_chunks: list[ReferencedDocument] = [] rag_chunks: list[RAGChunk] = [] + referenced_documents: list[ReferencedDocument] = [] - # Check if Solr is enabled in configuration if not _is_solr_enabled(): - logger.info("Solr vector IO is disabled, skipping vector search") - return retrieved_chunks, retrieved_scores, doc_ids_from_chunks, rag_chunks + logger.info("OKP vector IO is disabled, skipping OKP search") + return rag_chunks, referenced_documents # Get offline setting from configuration - offline = configuration.solr.offline if configuration.solr else True + offline = configuration.okp.offline try: - vector_store_ids = _get_vector_store_ids(True) + vector_store_ids = _get_solr_vector_store_ids() if vector_store_ids: + # Assuming only one Solr vector store is registered vector_store_id = vector_store_ids[0] params = _build_query_params(solr) @@ -174,31 +456,86 @@ async def perform_vector_search( params=params, ) - logger.info("The query response total payload: %s", query_response) + logger.debug( + "OKP query returned %d chunks", len(query_response.chunks or []) + ) if query_response.chunks: - retrieved_chunks = query_response.chunks retrieved_scores = ( query_response.scores if hasattr(query_response, "scores") else [] ) - # Extract doc_ids from chunks for referenced_documents - doc_ids_from_chunks = _process_chunks_for_documents( - query_response.chunks, offline + # Limit to top N chunks + top_chunks = query_response.chunks[: constants.OKP_RAG_MAX_CHUNKS] + top_scores = retrieved_scores[: constants.OKP_RAG_MAX_CHUNKS] + + # Extract referenced documents from Solr chunks + referenced_documents = _process_solr_chunks_for_documents( + top_chunks, offline ) # Convert retrieved chunks to RAGChunk format - rag_chunks = _convert_chunks_to_rag_format( - retrieved_chunks, retrieved_scores, offline + rag_chunks = _convert_solr_chunks_to_rag_format( + top_chunks, top_scores, offline + ) + logger.debug( + "Filtered top %d chunks from OKP RAG (%d were retrieved)", + constants.OKP_RAG_MAX_CHUNKS, + len(rag_chunks), ) - logger.info("Retrieved %d chunks from vector DB", len(rag_chunks)) except Exception as e: # pylint: disable=broad-exception-caught - logger.warning("Failed to query vector database for chunks: %s", e) - logger.debug("Vector DB query error details: %s", traceback.format_exc()) - # Continue without RAG chunks + logger.warning("Failed to query OKP for chunks: %s", e) + logger.debug("OKP query error details: %s", traceback.format_exc()) - return retrieved_chunks, retrieved_scores, doc_ids_from_chunks, rag_chunks + return rag_chunks, referenced_documents + + +async def build_rag_context( + client: AsyncLlamaStackClient, + query: str, + vector_store_ids: Optional[list[str]], + solr: Optional[dict[str, Any]] = None, +) -> RAGContext: + """Build RAG context by fetching and merging chunks from all enabled sources. + + Enabled sources can be BYOK and/or Solr OKP. + + Args: + client: The AsyncLlamaStackClient to use for the request + query_request: The user's query request + configuration: Application configuration + + Returns: + RAGContext containing formatted context text and referenced documents + """ + # Fetch from all enabled RAG sources in parallel + byok_chunks_task = _fetch_byok_rag(client, query, vector_store_ids) + solr_chunks_task = _fetch_solr_rag(client, query, solr) + + (byok_chunks, byok_docs), (solr_chunks, solr_docs) = await asyncio.gather( + byok_chunks_task, solr_chunks_task + ) + + # Merge chunks from all sources (BYOK + Solr) + context_chunks = byok_chunks + solr_chunks + + context_text = _format_rag_context(context_chunks, query) + + logger.debug( + "Inline RAG context built: %d chunks, %d characters", + len(context_chunks), + len(context_text), + ) + + # Merge referenced documents from all sources (BYOK + Solr) + top_documents = byok_docs + solr_docs + + return RAGContext( + context_text=context_text, + rag_chunks=context_chunks, + referenced_documents=top_documents, + ) def _build_document_url( @@ -233,13 +570,13 @@ def _build_document_url( return doc_url, reference_doc -def _convert_chunks_to_rag_format( +def _convert_solr_chunks_to_rag_format( retrieved_chunks: list[Any], retrieved_scores: list[float], offline: bool, ) -> list[RAGChunk]: """ - Convert retrieved chunks to RAGChunk format. + Convert retrieved chunks to RAGChunk format for Solr OKP. Args: retrieved_chunks: Raw chunks from vector store @@ -252,15 +589,28 @@ def _convert_chunks_to_rag_format( rag_chunks = [] for i, chunk in enumerate(retrieved_chunks): - # Extract source from chunk metadata based on offline flag - source = None + # Build attributes with document metadata + attributes = {} + + # Legacy logic: extract doc_url from chunk metadata based on offline flag if chunk.metadata: if offline: parent_id = chunk.metadata.get("parent_id") if parent_id: - source = urljoin(constants.MIMIR_DOC_URL, parent_id) + attributes["doc_url"] = urljoin(constants.MIMIR_DOC_URL, parent_id) else: - source = chunk.metadata.get("reference_url") + reference_url = chunk.metadata.get("reference_url") + if reference_url: + attributes["doc_url"] = reference_url + + # For Solr chunks, also extract from chunk_metadata + if hasattr(chunk, "chunk_metadata") and chunk.chunk_metadata: + if hasattr(chunk.chunk_metadata, "document_id"): + doc_id = chunk.chunk_metadata.document_id + attributes["document_id"] = doc_id + # Build URL if not already set + if "doc_url" not in attributes and offline and doc_id: + attributes["doc_url"] = urljoin(constants.MIMIR_DOC_URL, doc_id) # Get score from retrieved_scores list if available score = retrieved_scores[i] if i < len(retrieved_scores) else None @@ -268,36 +618,10 @@ def _convert_chunks_to_rag_format( rag_chunks.append( RAGChunk( content=chunk.content, - source=source, + source=constants.OKP_RAG_ID, score=score, + attributes=attributes if attributes else None, ) ) return rag_chunks - - -def format_rag_context_for_injection( - rag_chunks: list[RAGChunk], max_chunks: int = 5 -) -> str: - """ - Format RAG context for injection into user message. - - Args: - rag_chunks: List of RAG chunks to format - max_chunks: Maximum number of chunks to include (default: 5) - - Returns: - Formatted RAG context string ready for injection - """ - if not rag_chunks: - return "" - - context_chunks = [] - for chunk in rag_chunks[:max_chunks]: # Limit to top chunks - chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}" - context_chunks.append(chunk_text) - - rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks) - logger.info("Injecting %d RAG chunks into user message", len(context_chunks)) - - return rag_context diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 1599c78f2..044fb5bf2 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -712,4 +712,4 @@ async def test_retrieve_response_with_tool_calls( assert result.token_usage.output_tokens == 5 assert result.rag_chunks == [] assert result.referenced_documents == [] - assert result.pre_rag_documents == [] + assert result.inline_rag_documents == [] diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index eebed248f..6f10467ee 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -52,7 +52,7 @@ from models.responses import InternalServerErrorResponse from utils.token_counter import TokenCounter from utils.stream_interrupts import StreamInterruptRegistry -from utils.types import ReferencedDocument, ResponsesApiParams, TurnSummary +from utils.types import RAGContext, ReferencedDocument, ResponsesApiParams, TurnSummary MOCK_AUTH_STREAMING = ( "00000001-0001-0001-0001-000000000001", @@ -330,12 +330,8 @@ async def test_successful_streaming_query( mocker.patch("app.endpoints.streaming_query.check_tokens_available") mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( - "app.endpoints.streaming_query.perform_vector_search", - new=mocker.AsyncMock(return_value=([], [], [], [])), - ) - mocker.patch( - "app.endpoints.streaming_query.perform_vector_search", - new=mocker.AsyncMock(return_value=([], [], [], [])), + "app.endpoints.streaming_query.build_rag_context", + new=mocker.AsyncMock(return_value=RAGContext()), ) mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) @@ -417,8 +413,8 @@ async def test_streaming_query_text_media_type_header( mocker.patch("app.endpoints.streaming_query.check_tokens_available") mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( - "app.endpoints.streaming_query.perform_vector_search", - new=mocker.AsyncMock(return_value=([], [], [], [])), + "app.endpoints.streaming_query.build_rag_context", + new=mocker.AsyncMock(return_value=RAGContext()), ) mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) @@ -503,8 +499,8 @@ async def test_streaming_query_with_conversation( mocker.patch("app.endpoints.streaming_query.check_tokens_available") mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( - "app.endpoints.streaming_query.perform_vector_search", - new=mocker.AsyncMock(return_value=([], [], [], [])), + "app.endpoints.streaming_query.build_rag_context", + new=mocker.AsyncMock(return_value=RAGContext()), ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -600,8 +596,8 @@ async def test_streaming_query_with_attachments( mocker.patch("app.endpoints.streaming_query.check_tokens_available") mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( - "app.endpoints.streaming_query.perform_vector_search", - new=mocker.AsyncMock(return_value=([], [], [], [])), + "app.endpoints.streaming_query.build_rag_context", + new=mocker.AsyncMock(return_value=RAGContext()), ) mock_validate = mocker.patch( "app.endpoints.streaming_query.validate_attachments_metadata" @@ -685,8 +681,8 @@ async def test_streaming_query_azure_token_refresh( mocker.patch("app.endpoints.streaming_query.check_tokens_available") mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") mocker.patch( - "app.endpoints.streaming_query.perform_vector_search", - new=mocker.AsyncMock(return_value=([], [], [], [])), + "app.endpoints.streaming_query.build_rag_context", + new=mocker.AsyncMock(return_value=RAGContext()), ) mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) diff --git a/tests/unit/models/config/test_byok_rag.py b/tests/unit/models/config/test_byok_rag.py index e0cb7a8fb..8cf71ede5 100644 --- a/tests/unit/models/config/test_byok_rag.py +++ b/tests/unit/models/config/test_byok_rag.py @@ -3,16 +3,15 @@ from pathlib import Path import pytest - from pydantic import ValidationError -from models.config import ByokRag - from constants import ( - DEFAULT_RAG_TYPE, - DEFAULT_EMBEDDING_MODEL, DEFAULT_EMBEDDING_DIMENSION, + DEFAULT_EMBEDDING_MODEL, + DEFAULT_RAG_TYPE, + DEFAULT_SCORE_MULTIPLIER, ) +from models.config import ByokRag def test_byok_rag_configuration_default_values() -> None: @@ -30,6 +29,7 @@ def test_byok_rag_configuration_default_values() -> None: assert byok_rag.embedding_dimension == DEFAULT_EMBEDDING_DIMENSION assert byok_rag.vector_db_id == "vector_db_id" assert byok_rag.db_path == "tests/configuration/rag.txt" + assert byok_rag.score_multiplier == DEFAULT_SCORE_MULTIPLIER def test_byok_rag_configuration_nondefault_values() -> None: @@ -142,3 +142,27 @@ def test_byok_rag_configuration_empty_vector_db_id() -> None: vector_db_id="", db_path=Path("tests/configuration/rag.txt"), ) + + +def test_byok_rag_configuration_custom_score_multiplier() -> None: + """Test ByokRag with custom score_multiplier.""" + + byok_rag = ByokRag( + rag_id="rag_id", + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + score_multiplier=2.5, + ) + assert byok_rag.score_multiplier == 2.5 + + +def test_byok_rag_configuration_score_multiplier_must_be_positive() -> None: + """Test that score_multiplier must be greater than 0.""" + + with pytest.raises(ValidationError, match="greater than 0"): + _ = ByokRag( + rag_id="rag_id", + vector_db_id="vector_db_id", + db_path="tests/configuration/rag.txt", + score_multiplier=0.0, + ) diff --git a/tests/unit/models/config/test_dump_configuration.py b/tests/unit/models/config/test_dump_configuration.py index 0867db6b1..29df175cb 100644 --- a/tests/unit/models/config/test_dump_configuration.py +++ b/tests/unit/models/config/test_dump_configuration.py @@ -206,7 +206,14 @@ def test_dump_configuration(tmp_path: Path) -> None: "postgres": None, }, "azure_entra_id": None, - "solr": None, + "rag": { + "inline": [], + "tool": None, + }, + "okp": { + "offline": True, + "chunk_filter_query": "is_chunk:true", + }, "splunk": None, "deployment_environment": "development", } @@ -550,7 +557,14 @@ def test_dump_configuration_with_quota_limiters(tmp_path: Path) -> None: "postgres": None, }, "azure_entra_id": None, - "solr": None, + "rag": { + "inline": [], + "tool": None, + }, + "okp": { + "offline": True, + "chunk_filter_query": "is_chunk:true", + }, "splunk": None, "deployment_environment": "development", } @@ -772,7 +786,14 @@ def test_dump_configuration_with_quota_limiters_different_values( "postgres": None, }, "azure_entra_id": None, - "solr": None, + "rag": { + "inline": [], + "tool": None, + }, + "okp": { + "offline": True, + "chunk_filter_query": "is_chunk:true", + }, "splunk": None, "deployment_environment": "development", } @@ -950,6 +971,7 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "rag_id": "rag_id", "rag_type": "inline::faiss", "vector_db_id": "vector_db_id", + "score_multiplier": 1.0, }, ], "quota_handlers": { @@ -968,7 +990,14 @@ def test_dump_configuration_byok(tmp_path: Path) -> None: "postgres": None, }, "azure_entra_id": None, - "solr": None, + "rag": { + "inline": [], + "tool": None, + }, + "okp": { + "offline": True, + "chunk_filter_query": "is_chunk:true", + }, "splunk": None, "deployment_environment": "development", } @@ -1150,7 +1179,14 @@ def test_dump_configuration_pg_namespace(tmp_path: Path) -> None: "postgres": None, }, "azure_entra_id": None, - "solr": None, + "rag": { + "inline": [], + "tool": None, + }, + "okp": { + "offline": True, + "chunk_filter_query": "is_chunk:true", + }, "splunk": None, "deployment_environment": "development", } diff --git a/tests/unit/models/config/test_rag_configuration.py b/tests/unit/models/config/test_rag_configuration.py new file mode 100644 index 000000000..f13539189 --- /dev/null +++ b/tests/unit/models/config/test_rag_configuration.py @@ -0,0 +1,93 @@ +"""Unit tests for RAG and OKP configuration models.""" + +# pylint: disable=no-member +# Pydantic Field(default_factory=...) pattern confuses pylint's static analysis + +import pytest +from pydantic import ValidationError + +import constants +from models.config import OkpConfiguration, RagConfiguration + + +class TestRagConfiguration: + """Tests for RagConfiguration model.""" + + def test_default_values(self) -> None: + """Test that RagConfiguration has correct default values.""" + config = RagConfiguration() + assert config.inline == [] + assert config.tool is None + + def test_inline_with_byok_ids(self) -> None: + """Test inline list with BYOK rag IDs.""" + config = RagConfiguration(inline=["store-1", "store-2"]) + assert config.inline == ["store-1", "store-2"] + assert config.tool is None + + def test_inline_with_okp_rag(self) -> None: + """Test inline list including the special OKP ID.""" + config = RagConfiguration(inline=[constants.OKP_RAG_ID, "store-1"]) + assert constants.OKP_RAG_ID in config.inline + assert "store-1" in config.inline + + def test_tool_with_okp_rag_and_byok(self) -> None: + """Test tool list with OKP and BYOK IDs.""" + config = RagConfiguration( + inline=["store-1"], + tool=[constants.OKP_RAG_ID, "store-1"], + ) + assert config.inline == ["store-1"] + assert config.tool == [constants.OKP_RAG_ID, "store-1"] + + def test_tool_empty_list(self) -> None: + """Test that an explicit empty tool list disables tool RAG.""" + config = RagConfiguration(tool=[]) + assert config.tool == [] + + def test_tool_none_means_all_stores(self) -> None: + """Test that tool=None (default) means all registered stores are used.""" + config = RagConfiguration() + assert config.tool is None + + def test_no_unknown_fields_allowed(self) -> None: + """Test that RagConfiguration rejects unknown fields.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + RagConfiguration(unknown_field="value") # type: ignore[call-arg] + + def test_fully_custom_config(self) -> None: + """Test RagConfiguration with all fields set.""" + config = RagConfiguration( + inline=[constants.OKP_RAG_ID, "store-1"], + tool=["store-1"], + ) + assert constants.OKP_RAG_ID in config.inline + assert "store-1" in config.inline + assert config.tool == ["store-1"] + + +class TestOkpConfiguration: + """Tests for OkpConfiguration model.""" + + def test_default_values(self) -> None: + """Test that OkpConfiguration has correct default values.""" + config = OkpConfiguration() + assert config.offline is True + assert config.chunk_filter_query == "is_chunk:true" + + def test_offline_false(self) -> None: + """Test offline can be set to False (online mode).""" + config = OkpConfiguration(offline=False) + assert config.offline is False + + def test_custom_chunk_filter_query(self) -> None: + """Test that chunk_filter_query can be customised.""" + config = OkpConfiguration( + chunk_filter_query="is_chunk:true AND product:*openshift*" + ) + assert config.chunk_filter_query == "is_chunk:true AND product:*openshift*" + + def test_no_unknown_fields_allowed(self) -> None: + """Test that OkpConfiguration rejects unknown fields.""" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + OkpConfiguration(unknown_field="value") # type: ignore[call-arg] diff --git a/tests/unit/models/responses/test_rag_chunk.py b/tests/unit/models/responses/test_rag_chunk.py index d9809eade..621e9a4ac 100644 --- a/tests/unit/models/responses/test_rag_chunk.py +++ b/tests/unit/models/responses/test_rag_chunk.py @@ -1,6 +1,7 @@ -"""Unit tests for RAGChunk model.""" +"""Unit tests for RAGChunk and RAGContext models.""" -from utils.types import RAGChunk +from utils.types import RAGChunk, RAGContext +from models.responses import ReferencedDocument class TestRAGChunk: @@ -110,3 +111,84 @@ def test_url_as_source(self) -> None: ) assert chunk.source == url_source assert chunk.score == 0.92 + + def test_attributes_field(self) -> None: + """Test RAGChunk with attributes field.""" + attributes = { + "doc_url": "https://example.com/doc", + "title": "Example Document", + "author": "John Doe", + } + chunk = RAGChunk( + content="Test content", source="test-source", attributes=attributes + ) + assert chunk.attributes == attributes + assert chunk.attributes["doc_url"] == "https://example.com/doc" + + def test_attributes_none(self) -> None: + """Test RAGChunk with attributes=None.""" + chunk = RAGChunk(content="Test content", attributes=None) + assert chunk.attributes is None + + +class TestRAGContext: + """Test cases for the RAGContext model.""" + + def test_default_values(self) -> None: + """Test RAGContext with default values.""" + context = RAGContext() + assert context.context_text == "" + assert context.rag_chunks == [] + assert context.referenced_documents == [] + + def test_with_context_text(self) -> None: + """Test RAGContext with context text.""" + context = RAGContext(context_text="Test context") + assert context.context_text == "Test context" + assert context.rag_chunks == [] + assert context.referenced_documents == [] + + def test_with_rag_chunks(self) -> None: + """Test RAGContext with RAG chunks.""" + chunks = [ + RAGChunk(content="Chunk 1", source="source1", score=0.9), + RAGChunk(content="Chunk 2", source="source2", score=0.8), + ] + context = RAGContext(rag_chunks=chunks) + assert len(context.rag_chunks) == 2 + assert context.rag_chunks[0].content == "Chunk 1" + assert context.rag_chunks[1].content == "Chunk 2" + + def test_with_referenced_documents(self) -> None: + """Test RAGContext with referenced documents.""" + docs = [ + ReferencedDocument( + doc_title="Doc 1", + doc_url="https://example.com/doc1", + source="source1", + ), + ReferencedDocument( + doc_title="Doc 2", + doc_url="https://example.com/doc2", + source="source2", + ), + ] + context = RAGContext(referenced_documents=docs) + assert len(context.referenced_documents) == 2 + assert context.referenced_documents[0].doc_title == "Doc 1" + assert context.referenced_documents[1].doc_title == "Doc 2" + + def test_fully_populated(self) -> None: + """Test RAGContext with all fields populated.""" + chunks = [RAGChunk(content="Test chunk", source="source1", score=0.95)] + docs = [ + ReferencedDocument(doc_title="Test Doc", doc_url="https://example.com/doc") + ] + context = RAGContext( + context_text="Formatted context", + rag_chunks=chunks, + referenced_documents=docs, + ) + assert context.context_text == "Formatted context" + assert len(context.rag_chunks) == 1 + assert len(context.referenced_documents) == 1 diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index b99f68e71..1cd11df86 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -1,5 +1,7 @@ """Unit tests for functions defined in src/configuration.py.""" +# pylint: disable=too-many-lines + from pathlib import Path from typing import Any, Generator from pydantic import ValidationError @@ -994,3 +996,81 @@ def test_rag_id_mapping_not_loaded() -> None: cfg._configuration = None with pytest.raises(LogicError): _ = cfg.rag_id_mapping + + +def test_score_multiplier_mapping_empty_when_no_byok(minimal_config: AppConfig) -> None: + """Test that score_multiplier_mapping returns empty dict when no BYOK RAG configured.""" + assert minimal_config.score_multiplier_mapping == {} + + +def test_score_multiplier_mapping_with_byok_defaults(tmp_path: Path) -> None: + """Test that score_multiplier_mapping uses default multiplier when not specified.""" + db_file = tmp_path / "test.db" + db_file.touch() + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": {"host": "localhost", "port": 8080}, + "llama_stack": { + "api_key": "k", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "authentication": {"module": "noop"}, + "byok_rag": [ + { + "rag_id": "my-kb", + "vector_db_id": "vs-001", + "db_path": str(db_file), + }, + ], + } + ) + assert cfg.score_multiplier_mapping == {"vs-001": 1.0} + + +def test_score_multiplier_mapping_with_custom_values(tmp_path: Path) -> None: + """Test that score_multiplier_mapping builds correct mapping with custom values.""" + db_file1 = tmp_path / "test1.db" + db_file1.touch() + db_file2 = tmp_path / "test2.db" + db_file2.touch() + cfg = AppConfig() + cfg.init_from_dict( + { + "name": "test", + "service": {"host": "localhost", "port": 8080}, + "llama_stack": { + "api_key": "k", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": {}, + "authentication": {"module": "noop"}, + "byok_rag": [ + { + "rag_id": "kb1", + "vector_db_id": "vs-001", + "db_path": str(db_file1), + "score_multiplier": 1.5, + }, + { + "rag_id": "kb2", + "vector_db_id": "vs-002", + "db_path": str(db_file2), + "score_multiplier": 0.75, + }, + ], + } + ) + assert cfg.score_multiplier_mapping == {"vs-001": 1.5, "vs-002": 0.75} + + +def test_score_multiplier_mapping_not_loaded() -> None: + """Test that score_multiplier_mapping raises when config not loaded.""" + cfg = AppConfig() + cfg._configuration = None + with pytest.raises(LogicError): + _ = cfg.score_multiplier_mapping diff --git a/tests/unit/test_llama_stack_configuration.py b/tests/unit/test_llama_stack_configuration.py index aad0c1d72..8d17a96c7 100644 --- a/tests/unit/test_llama_stack_configuration.py +++ b/tests/unit/test_llama_stack_configuration.py @@ -12,6 +12,7 @@ construct_vector_io_providers_section, construct_storage_backends_section, construct_models_section, + enrich_solr, ) from models.config import ( Configuration, @@ -63,7 +64,7 @@ def test_construct_vector_stores_section_adds_new() -> None: output = construct_vector_stores_section(ls_config, byok_rag) assert len(output) == 1 assert output[0]["vector_store_id"] == "store1" - assert output[0]["provider_id"] == "byok_store1" + assert output[0]["provider_id"] == "byok_rag1" assert output[0]["embedding_model"] == "test-model" assert output[0]["embedding_dimension"] == 512 @@ -73,7 +74,7 @@ def test_construct_vector_stores_section_merge() -> None: ls_config = { "registered_resources": {"vector_stores": [{"vector_store_id": "existing"}]} } - byok_rag = [{"vector_db_id": "new_store"}] + byok_rag = [{"rag_id": "rag1", "vector_db_id": "new_store"}] output = construct_vector_stores_section(ls_config, byok_rag) assert len(output) == 2 @@ -89,6 +90,7 @@ def test_construct_vector_stores_section_skips_duplicate_from_existing() -> None } byok_rag = [ { + "rag_id": "rag1", "vector_db_id": "store1", "embedding_model": "test-model", "embedding_dimension": 512, @@ -104,11 +106,13 @@ def test_construct_vector_stores_section_skips_duplicate_within_byok() -> None: ls_config: dict[str, Any] = {} byok_rag = [ { + "rag_id": "rag1", "vector_db_id": "store1", "embedding_model": "model-a", "embedding_dimension": 512, }, { + "rag_id": "rag2", "vector_db_id": "store1", "embedding_model": "model-b", "embedding_dimension": 768, @@ -142,19 +146,20 @@ def test_construct_vector_io_providers_section_preserves_existing() -> None: def test_construct_vector_io_providers_section_adds_new() -> None: - """Test adds new BYOK RAG entries.""" + """Test adds new BYOK RAG entries using rag_id for provider naming.""" ls_config: dict[str, Any] = {"providers": {}} byok_rag = [ { + "rag_id": "rag1", "vector_db_id": "store1", "rag_type": "inline::faiss", }, ] output = construct_vector_io_providers_section(ls_config, byok_rag) assert len(output) == 1 - assert output[0]["provider_id"] == "byok_store1" + assert output[0]["provider_id"] == "byok_rag1" assert output[0]["provider_type"] == "inline::faiss" - assert output[0]["config"]["persistence"]["backend"] == "byok_store1_storage" + assert output[0]["config"]["persistence"]["backend"] == "byok_rag1_storage" assert output[0]["config"]["persistence"]["namespace"] == "vector_io::faiss" @@ -187,19 +192,20 @@ def test_construct_storage_backends_section_preserves_existing() -> None: def test_construct_storage_backends_section_adds_new() -> None: - """Test adds new BYOK RAG backend entries.""" + """Test adds new BYOK RAG backend entries using rag_id for backend naming.""" ls_config: dict[str, Any] = {} byok_rag = [ { + "rag_id": "rag1", "vector_db_id": "store1", "db_path": "/path/to/store1.db", }, ] output = construct_storage_backends_section(ls_config, byok_rag) assert len(output) == 1 - assert "byok_store1_storage" in output - assert output["byok_store1_storage"]["type"] == "kv_sqlite" - assert output["byok_store1_storage"]["db_path"] == "/path/to/store1.db" + assert "byok_rag1_storage" in output + assert output["byok_rag1_storage"]["type"] == "kv_sqlite" + assert output["byok_rag1_storage"]["db_path"] == "/path/to/store1.db" # ============================================================================= @@ -229,10 +235,11 @@ def test_construct_models_section_preserves_existing() -> None: def test_construct_models_section_adds_embedding_model() -> None: - """Test adds embedding model from BYOK RAG.""" + """Test adds embedding model from BYOK RAG using rag_id for model naming.""" ls_config: dict[str, Any] = {} byok_rag = [ { + "rag_id": "rag1", "vector_db_id": "store1", "embedding_model": "sentence-transformers/all-mpnet-base-v2", "embedding_dimension": 768, @@ -240,7 +247,7 @@ def test_construct_models_section_adds_embedding_model() -> None: ] output = construct_models_section(ls_config, byok_rag) assert len(output) == 1 - assert output[0]["model_id"] == "byok_store1_embedding" + assert output[0]["model_id"] == "byok_rag1_embedding" assert output[0]["model_type"] == "embedding" assert output[0]["provider_id"] == "sentence-transformers" assert output[0]["provider_model_id"] == "all-mpnet-base-v2" @@ -252,6 +259,7 @@ def test_construct_models_section_strips_prefix() -> None: ls_config: dict[str, Any] = {} byok_rag = [ { + "rag_id": "rag1", "vector_db_id": "store1", "embedding_model": "sentence-transformers//usr/path/model", "embedding_dimension": 768, @@ -262,6 +270,46 @@ def test_construct_models_section_strips_prefix() -> None: assert output[0]["provider_model_id"] == "/usr/path/model" +def test_construct_storage_backends_section_raises_on_missing_rag_id() -> None: + """Test raises ValueError when rag_id is missing from a BYOK RAG entry.""" + ls_config: dict[str, Any] = {} + byok_rag = [{"vector_db_id": "store1"}] + with pytest.raises(ValueError, match="missing required 'rag_id'"): + construct_storage_backends_section(ls_config, byok_rag) + + +def test_construct_vector_stores_section_raises_on_missing_rag_id() -> None: + """Test raises ValueError when rag_id is missing from a BYOK RAG entry.""" + ls_config: dict[str, Any] = {} + byok_rag = [{"vector_db_id": "store1"}] + with pytest.raises(ValueError, match="missing required 'rag_id'"): + construct_vector_stores_section(ls_config, byok_rag) + + +def test_construct_vector_stores_section_raises_on_missing_vector_db_id() -> None: + """Test raises ValueError when vector_db_id is missing from a BYOK RAG entry.""" + ls_config: dict[str, Any] = {} + byok_rag = [{"rag_id": "rag1"}] + with pytest.raises(ValueError, match="missing required 'vector_db_id'"): + construct_vector_stores_section(ls_config, byok_rag) + + +def test_construct_vector_io_section_raises_on_missing_rag_id() -> None: + """Test raises ValueError when rag_id is missing from a BYOK RAG entry.""" + ls_config: dict[str, Any] = {} + byok_rag = [{"vector_db_id": "store1"}] + with pytest.raises(ValueError, match="missing required 'rag_id'"): + construct_vector_io_providers_section(ls_config, byok_rag) + + +def test_construct_models_section_raises_on_missing_rag_id() -> None: + """Test raises ValueError when rag_id is missing from a BYOK RAG entry.""" + ls_config: dict[str, Any] = {} + byok_rag = [{"vector_db_id": "store1", "embedding_model": "some-model"}] + with pytest.raises(ValueError, match="missing required 'rag_id'"): + construct_models_section(ls_config, byok_rag) + + # ============================================================================= # Test generate_configuration # ============================================================================= @@ -338,13 +386,110 @@ def test_generate_configuration_with_byok(tmp_path: Path) -> None: ] assert "store1" in store_ids - # Check storage.backends - assert "byok_store1_storage" in result["storage"]["backends"] + # Check storage.backends - named after rag_id + assert "byok_rag1_storage" in result["storage"]["backends"] - # Check providers.vector_io + # Check providers.vector_io - named after rag_id provider_ids = [p["provider_id"] for p in result["providers"]["vector_io"]] - assert "byok_store1" in provider_ids + assert "byok_rag1" in provider_ids - # Check registered_resources.models for embedding model + # Check registered_resources.models for embedding model - named after rag_id model_ids = [m["model_id"] for m in result["registered_resources"]["models"]] - assert "byok_store1_embedding" in model_ids + assert "byok_rag1_embedding" in model_ids + + +# ============================================================================= +# Test enrich_solr +# ============================================================================= + + +def test_enrich_solr_skips_when_not_enabled() -> None: + """Test enrich_solr does nothing when Solr is not enabled.""" + ls_config: dict[str, Any] = {} + enrich_solr(ls_config, {"enabled": False}) + assert not ls_config + + +def test_enrich_solr_skips_when_empty_config() -> None: + """Test enrich_solr does nothing with empty config.""" + ls_config: dict[str, Any] = {} + enrich_solr(ls_config, {}) + assert not ls_config + + +def test_enrich_solr_adds_vector_io_provider() -> None: + """Test enrich_solr adds Solr provider to vector_io section.""" + ls_config: dict[str, Any] = {} + enrich_solr(ls_config, {"enabled": True}) + + assert "providers" in ls_config + assert "vector_io" in ls_config["providers"] + provider_ids = [p["provider_id"] for p in ls_config["providers"]["vector_io"]] + assert "okp_solr" in provider_ids + + +def test_enrich_solr_adds_vector_store_registration() -> None: + """Test enrich_solr registers the Solr vector store.""" + ls_config: dict[str, Any] = {} + enrich_solr(ls_config, {"enabled": True}) + + assert "registered_resources" in ls_config + store_ids = [ + s["vector_store_id"] for s in ls_config["registered_resources"]["vector_stores"] + ] + assert "portal-rag" in store_ids + + +def test_enrich_solr_adds_embedding_model() -> None: + """Test enrich_solr registers the Solr embedding model.""" + ls_config: dict[str, Any] = {} + enrich_solr(ls_config, {"enabled": True}) + + model_ids = [m["model_id"] for m in ls_config["registered_resources"]["models"]] + assert "solr_embedding" in model_ids + + +def test_enrich_solr_skips_duplicate_provider() -> None: + """Test enrich_solr does not add duplicate Solr provider.""" + ls_config: dict[str, Any] = { + "providers": {"vector_io": [{"provider_id": "okp_solr"}]} + } + enrich_solr(ls_config, {"enabled": True}) + + provider_ids = [p["provider_id"] for p in ls_config["providers"]["vector_io"]] + assert provider_ids.count("okp_solr") == 1 + + +def test_enrich_solr_skips_duplicate_vector_store() -> None: + """Test enrich_solr does not add duplicate vector store registration.""" + ls_config: dict[str, Any] = { + "registered_resources": {"vector_stores": [{"vector_store_id": "portal-rag"}]} + } + enrich_solr(ls_config, {"enabled": True}) + + store_ids = [ + s["vector_store_id"] for s in ls_config["registered_resources"]["vector_stores"] + ] + assert store_ids.count("portal-rag") == 1 + + +def test_enrich_solr_preserves_existing_config() -> None: + """Test enrich_solr preserves existing providers and resources.""" + ls_config: dict[str, Any] = { + "providers": {"vector_io": [{"provider_id": "existing_provider"}]}, + "registered_resources": { + "vector_stores": [{"vector_store_id": "existing_store"}], + "models": [{"model_id": "existing_model"}], + }, + } + enrich_solr(ls_config, {"enabled": True}) + + provider_ids = [p["provider_id"] for p in ls_config["providers"]["vector_io"]] + assert "existing_provider" in provider_ids + assert "okp_solr" in provider_ids + + store_ids = [ + s["vector_store_id"] for s in ls_config["registered_resources"]["vector_stores"] + ] + assert "existing_store" in store_ids + assert "portal-rag" in store_ids diff --git a/tests/unit/utils/test_responses.py b/tests/unit/utils/test_responses.py index 9f0a1d598..7d84f515c 100644 --- a/tests/unit/utils/test_responses.py +++ b/tests/unit/utils/test_responses.py @@ -11,21 +11,25 @@ from llama_stack_api.openai_responses import ( OpenAIResponseInputToolFileSearch as InputToolFileSearch, OpenAIResponseInputToolMCP as InputToolMCP, + OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, + OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, OpenAIResponseOutputMessageFileSearchToolCall as FileSearchCall, OpenAIResponseOutputMessageFunctionToolCall as FunctionCall, OpenAIResponseOutputMessageMCPCall as MCPCall, OpenAIResponseOutputMessageMCPListTools as MCPListTools, - OpenAIResponseMCPApprovalRequest as MCPApprovalRequest, - OpenAIResponseMCPApprovalResponse as MCPApprovalResponse, OpenAIResponseOutputMessageWebSearchToolCall as WebSearchCall, ) from llama_stack_client import APIConnectionError, APIStatusError, AsyncLlamaStackClient from pydantic import AnyUrl from pytest_mock import MockerFixture +import constants from models.config import ByokRag, ModelContextProtocolServer from models.requests import QueryRequest from utils.responses import ( + _build_chunk_attributes, + _increment_llm_call_metric, + _resolve_source_for_result, build_mcp_tool_call_from_arguments_done, build_tool_call_summary, build_tool_result_from_mcp_output_item_done, @@ -37,14 +41,12 @@ get_mcp_tools, get_rag_tools, get_topic_summary, + get_vector_store_ids, parse_arguments_string, parse_referenced_documents, prepare_responses_params, prepare_tools, resolve_vector_store_ids, - _build_chunk_attributes, - _increment_llm_call_metric, - _resolve_source_for_result, ) from utils.types import RAGChunk @@ -334,8 +336,8 @@ class TestGetRAGTools: """Test cases for get_rag_tools utility function.""" def test_get_rag_tools_empty_list(self) -> None: - """Test get_rag_tools returns None for empty list.""" - assert get_rag_tools([]) is None + """Test get_rag_tools returns empty list for empty vector store IDs.""" + assert not get_rag_tools([]) def test_get_rag_tools_with_vector_stores(self) -> None: """Test get_rag_tools returns correct tool format for vector stores.""" @@ -1119,6 +1121,7 @@ async def test_does_not_translate_when_ids_fetched_from_llama_stack( mock_byok_rag.vector_db_id = "vs-translated" mock_config = mocker.Mock() mock_config.configuration.byok_rag = [mock_byok_rag] + mock_config.configuration.rag.tool = None mocker.patch("utils.responses.configuration", mock_config) result = await prepare_tools(mock_client, None, False, "token") @@ -2354,3 +2357,80 @@ def test_multiple_stores_source_is_none(self, mocker: MockerFixture) -> None: assert len(docs) == 1 assert docs[0].source is None + + +class TestGetVectorStoreIds: + """Tests for get_vector_store_ids utility function.""" + + @pytest.mark.asyncio + async def test_returns_provided_ids_directly(self, mocker: MockerFixture) -> None: + """Test that provided vector_store_ids are returned without fetching.""" + client_mock = mocker.AsyncMock() + result = await get_vector_store_ids(client_mock, ["vs1", "vs2"]) + assert result == ["vs1", "vs2"] + client_mock.vector_stores.list.assert_not_called() + + @pytest.mark.asyncio + async def test_fetches_all_when_no_ids_provided( + self, mocker: MockerFixture + ) -> None: + """Test that all vector stores are fetched when no IDs provided.""" + mock_store1 = mocker.Mock() + mock_store1.id = "vs-fetched-1" + mock_store2 = mocker.Mock() + mock_store2.id = "vs-fetched-2" + + mock_list_result = mocker.Mock() + mock_list_result.data = [mock_store1, mock_store2] + + client_mock = mocker.AsyncMock() + client_mock.vector_stores.list.return_value = mock_list_result + + result = await get_vector_store_ids(client_mock, None) + assert result == ["vs-fetched-1", "vs-fetched-2"] + client_mock.vector_stores.list.assert_called_once() + + @pytest.mark.asyncio + async def test_raises_on_connection_error(self, mocker: MockerFixture) -> None: + """Test that APIConnectionError raises HTTPException 503.""" + client_mock = mocker.AsyncMock() + client_mock.vector_stores.list.side_effect = APIConnectionError.__new__( + APIConnectionError + ) + + with pytest.raises(HTTPException) as exc_info: + await get_vector_store_ids(client_mock, None) + assert exc_info.value.status_code == 503 + + @pytest.mark.asyncio + async def test_raises_on_api_status_error(self, mocker: MockerFixture) -> None: + """Test that APIStatusError raises HTTPException 500.""" + mock_response = mocker.Mock() + mock_response.status_code = 500 + mock_response.headers = {} + mock_response.text = "error" + + client_mock = mocker.AsyncMock() + client_mock.vector_stores.list.side_effect = APIStatusError( + "error", response=mock_response, body=None + ) + + with pytest.raises(HTTPException) as exc_info: + await get_vector_store_ids(client_mock, None) + assert exc_info.value.status_code == 500 + + +class TestGetRAGToolsWithConfig: + """Tests for get_rag_tools with configuration checks.""" + + def test_returns_empty_when_no_vector_store_ids(self) -> None: + """Test get_rag_tools returns empty list when no vector store IDs are provided.""" + # pylint: disable-next=use-implicit-booleaness-not-comparison + assert get_rag_tools([]) == [] + + def test_returns_tools_when_stores_provided(self) -> None: + """Test get_rag_tools returns tools when vector store IDs are provided.""" + tools = get_rag_tools(["vs1"]) + assert tools is not None + assert tools[0].type == constants.DEFAULT_RAG_TOOL + assert tools[0].vector_store_ids == ["vs1"] diff --git a/tests/unit/utils/test_vector_search.py b/tests/unit/utils/test_vector_search.py new file mode 100644 index 000000000..4930cb846 --- /dev/null +++ b/tests/unit/utils/test_vector_search.py @@ -0,0 +1,504 @@ +"""Unit tests for vector search utilities.""" + +import pytest + +import constants +from configuration import AppConfig +from utils.types import RAGChunk +from utils.vector_search import ( + _build_document_url, + _build_query_params, + _convert_solr_chunks_to_rag_format, + _extract_byok_rag_chunks, + _extract_solr_document_metadata, + _fetch_byok_rag, + _fetch_solr_rag, + _format_rag_context, + _get_solr_vector_store_ids, + _is_solr_enabled, + build_rag_context, +) + + +class TestIsSolrEnabled: + """Tests for _is_solr_enabled function.""" + + def test_solr_enabled_true(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test when Solr is enabled in configuration.""" + config_mock = mocker.Mock(spec=AppConfig) + config_mock.inline_solr_enabled = True + mocker.patch("utils.vector_search.configuration", config_mock) + assert _is_solr_enabled() is True + + def test_solr_enabled_false(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test when Solr is disabled in configuration.""" + config_mock = mocker.Mock(spec=AppConfig) + config_mock.inline_solr_enabled = False + mocker.patch("utils.vector_search.configuration", config_mock) + assert _is_solr_enabled() is False + + +class TestGetSolrVectorStoreIds: # pylint: disable=too-few-public-methods + """Tests for _get_solr_vector_store_ids function.""" + + def test_returns_default_vector_store_id(self) -> None: + """Test that function returns the default Solr vector store ID.""" + result = _get_solr_vector_store_ids() + assert result == [constants.SOLR_DEFAULT_VECTOR_STORE_ID] + assert len(result) == 1 + + +class TestBuildQueryParams: + """Tests for _build_query_params function.""" + + def test_default_params(self) -> None: + """Test default parameters when no solr filters provided.""" + params = _build_query_params() + + assert params["k"] == constants.SOLR_VECTOR_SEARCH_DEFAULT_K + assert ( + params["score_threshold"] + == constants.SOLR_VECTOR_SEARCH_DEFAULT_SCORE_THRESHOLD + ) + assert params["mode"] == constants.SOLR_VECTOR_SEARCH_DEFAULT_MODE + assert "solr" not in params + + def test_with_solr_filters(self) -> None: + """Test parameters when solr filters are provided.""" + solr_filters = {"filter": "value"} + params = _build_query_params(solr=solr_filters) + + assert params["solr"] == solr_filters + assert params["k"] == constants.SOLR_VECTOR_SEARCH_DEFAULT_K + + +class TestExtractByokRagChunks: + """Tests for _extract_byok_rag_chunks function.""" + + def test_extract_chunks_with_metadata(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test extraction of chunks with metadata.""" + # Create mock chunks + chunk1 = mocker.Mock() + chunk1.content = "Content 1" + chunk1.chunk_id = "chunk_1" + chunk1.metadata = {"document_id": "doc_1", "title": "Document 1"} + + chunk2 = mocker.Mock() + chunk2.content = "Content 2" + chunk2.chunk_id = "chunk_2" + chunk2.metadata = {"document_id": "doc_2", "title": "Document 2"} + + # Create mock search response + search_response = mocker.Mock() + search_response.chunks = [chunk1, chunk2] + search_response.scores = [0.9, 0.8] + + result = _extract_byok_rag_chunks( + search_response, vector_store_id="test_store", weight=1.5 + ) + + assert len(result) == 2 + assert result[0]["content"] == "Content 1" + assert result[0]["score"] == 0.9 + assert result[0]["weighted_score"] == 0.9 * 1.5 + assert result[0]["source"] == "test_store" + assert result[0]["doc_id"] == "doc_1" + + def test_extract_chunks_without_metadata(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test extraction of chunks without metadata.""" + chunk = mocker.Mock() + chunk.content = "Test content" + chunk.chunk_id = "chunk_id" + chunk.metadata = None + + search_response = mocker.Mock() + search_response.chunks = [chunk] + search_response.scores = [0.75] + + result = _extract_byok_rag_chunks( + search_response, vector_store_id="test_store", weight=1.0 + ) + + assert len(result) == 1 + assert result[0]["doc_id"] == "chunk_id" + assert result[0]["metadata"] == {} + + +class TestFormatRagContext: + """Tests for _format_rag_context function.""" + + def test_empty_chunks(self) -> None: + """Test formatting with empty chunks list.""" + result = _format_rag_context([], "test query") + assert result == "" + + def test_format_single_chunk(self) -> None: + """Test formatting with a single chunk.""" + chunks = [RAGChunk(content="Test content", source="test_source", score=0.95)] + result = _format_rag_context(chunks, "test query") + + assert "file_search found 1 chunks:" in result + assert "BEGIN of file_search results." in result + assert "Test content" in result + assert "document_id: test_source" in result + assert "score: 0.9500" in result + assert "END of file_search results." in result + assert 'answer the user\'s query: "test query"' in result + + def test_format_multiple_chunks(self) -> None: + """Test formatting with multiple chunks.""" + chunks = [ + RAGChunk(content="Content 1", source="source_1", score=0.9), + RAGChunk(content="Content 2", source="source_2", score=0.8), + RAGChunk( + content="Content 3", + source="source_3", + score=0.7, + attributes={"url": "http://example.com"}, + ), + ] + result = _format_rag_context(chunks, "test query") + + assert "file_search found 3 chunks:" in result + assert "Content 1" in result + assert "Content 2" in result + assert "Content 3" in result + assert "document_id: source_1" in result + assert "[1]" in result + assert "[2]" in result + assert "[3]" in result + + def test_format_chunk_with_attributes(self) -> None: + """Test formatting chunk with additional attributes.""" + chunks = [ + RAGChunk( + content="Test content", + source="test_source", + score=0.85, + attributes={"title": "Test Doc", "author": "John Doe"}, + ) + ] + result = _format_rag_context(chunks, "test query") + + assert "attributes:" in result + assert "title" in result or "author" in result + + +class TestExtractSolrDocumentMetadata: + """Tests for _extract_solr_document_metadata function.""" + + def test_extract_from_dict_metadata(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test extraction from dict-based metadata.""" + chunk = mocker.Mock() + chunk.metadata = { + "doc_id": "doc_123", + "title": "Test Document", + "reference_url": "https://example.com/doc", + } + + doc_id, title, reference_url = _extract_solr_document_metadata(chunk) + + assert doc_id == "doc_123" + assert title == "Test Document" + assert reference_url == "https://example.com/doc" + + def test_extract_from_chunk_metadata_object( # type: ignore[no-untyped-def] + self, mocker + ) -> None: + """Test extraction from typed chunk_metadata object.""" + chunk_meta = mocker.Mock() + chunk_meta.doc_id = "doc_456" + chunk_meta.title = "Another Document" + chunk_meta.reference_url = "https://example.com/another" + + chunk = mocker.Mock() + chunk.metadata = {} + chunk.chunk_metadata = chunk_meta + + doc_id, title, reference_url = _extract_solr_document_metadata(chunk) + + assert doc_id == "doc_456" + assert title == "Another Document" + assert reference_url == "https://example.com/another" + + def test_extract_with_missing_fields(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test extraction when some fields are missing.""" + chunk = mocker.Mock() + chunk.metadata = {"doc_id": "doc_789"} + + doc_id, title, reference_url = _extract_solr_document_metadata(chunk) + + assert doc_id == "doc_789" + assert title is None + assert reference_url is None + + +class TestBuildDocumentUrl: + """Tests for _build_document_url function.""" + + def test_offline_mode_with_doc_id(self) -> None: + """Test URL building in offline mode with doc_id.""" + doc_url, reference_doc = _build_document_url( + offline=True, doc_id="doc_123", reference_url=None + ) + + assert doc_url == constants.MIMIR_DOC_URL + "doc_123" + assert reference_doc == "doc_123" + + def test_online_mode_with_reference_url(self) -> None: + """Test URL building in online mode with reference_url.""" + doc_url, reference_doc = _build_document_url( + offline=False, + doc_id="doc_123", + reference_url="https://docs.example.com/page", + ) + + assert doc_url == "https://docs.example.com/page" + assert reference_doc == "https://docs.example.com/page" + + def test_online_mode_without_http(self) -> None: + """Test online mode when reference_url doesn't start with http.""" + doc_url, reference_doc = _build_document_url( + offline=False, doc_id="doc_123", reference_url="relative/path" + ) + + assert doc_url == constants.MIMIR_DOC_URL + "relative/path" + assert reference_doc == "relative/path" + + def test_offline_mode_without_doc_id(self) -> None: + """Test offline mode when doc_id is None.""" + doc_url, reference_doc = _build_document_url( + offline=True, doc_id=None, reference_url="https://example.com" + ) + + assert doc_url == "" + assert reference_doc is None + + +class TestConvertSolrChunksToRagFormat: + """Tests for _convert_solr_chunks_to_rag_format function.""" + + def test_convert_with_metadata_offline(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test conversion with metadata in offline mode.""" + chunk = mocker.Mock() + chunk.content = "Test content" + chunk.metadata = {"parent_id": "parent_123"} + chunk.chunk_metadata = None + + result = _convert_solr_chunks_to_rag_format([chunk], [0.85], offline=True) + + assert len(result) == 1 + assert result[0].content == "Test content" + assert result[0].source == constants.OKP_RAG_ID + assert result[0].score == 0.85 + assert "doc_url" in result[0].attributes + assert "parent_123" in result[0].attributes["doc_url"] + + def test_convert_with_metadata_online(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test conversion with metadata in online mode.""" + chunk = mocker.Mock() + chunk.content = "Test content" + chunk.metadata = {"reference_url": "https://example.com/doc"} + chunk.chunk_metadata = None + + result = _convert_solr_chunks_to_rag_format([chunk], [0.75], offline=False) + + assert len(result) == 1 + assert result[0].attributes["doc_url"] == "https://example.com/doc" + + def test_convert_with_chunk_metadata(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test conversion with chunk_metadata object.""" + chunk_meta = mocker.Mock() + chunk_meta.document_id = "doc_456" + + chunk = mocker.Mock() + chunk.content = "Test content" + chunk.metadata = {} + chunk.chunk_metadata = chunk_meta + + result = _convert_solr_chunks_to_rag_format([chunk], [0.9], offline=True) + + assert len(result) == 1 + assert result[0].attributes["document_id"] == "doc_456" + + def test_convert_multiple_chunks(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test conversion of multiple chunks.""" + chunk1 = mocker.Mock() + chunk1.content = "Content 1" + chunk1.metadata = {"parent_id": "parent_1"} + chunk1.chunk_metadata = None + + chunk2 = mocker.Mock() + chunk2.content = "Content 2" + chunk2.metadata = {"parent_id": "parent_2"} + chunk2.chunk_metadata = None + + result = _convert_solr_chunks_to_rag_format( + [chunk1, chunk2], [0.9, 0.8], offline=True + ) + + assert len(result) == 2 + assert result[0].content == "Content 1" + assert result[1].content == "Content 2" + assert result[0].score == 0.9 + assert result[1].score == 0.8 + + +class TestFetchByokRag: + """Tests for _fetch_byok_rag async function.""" + + @pytest.mark.asyncio + async def test_byok_no_inline_ids(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test when no inline BYOK sources are configured.""" + config_mock = mocker.Mock(spec=AppConfig) + config_mock.configuration.rag.inline = [] + config_mock.configuration.byok_rag = [] + mocker.patch("utils.vector_search.configuration", config_mock) + + client_mock = mocker.AsyncMock() + rag_chunks, referenced_docs = await _fetch_byok_rag(client_mock, "test query") + + assert rag_chunks == [] + assert referenced_docs == [] + client_mock.vector_io.query.assert_not_called() + + @pytest.mark.asyncio + async def test_byok_enabled_success(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test successful BYOK RAG fetch when inline IDs are configured.""" + # Mock configuration + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "rag_1" + byok_rag_mock.vector_db_id = "vs_1" + config_mock.configuration.rag.inline = ["rag_1"] + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.score_multiplier_mapping = {"vs_1": 1.5} + config_mock.rag_id_mapping = {"vs_1": "rag_1"} + mocker.patch("utils.vector_search.configuration", config_mock) + + # Mock search response + chunk_mock = mocker.Mock() + chunk_mock.content = "Test content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = { + "document_id": "doc_1", + "title": "Test Doc", + "reference_url": "https://example.com/doc", + } + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + # Mock client + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + rag_chunks, referenced_docs = await _fetch_byok_rag(client_mock, "test query") + + assert len(rag_chunks) > 0 + assert rag_chunks[0].content == "Test content" + assert len(referenced_docs) > 0 + + +class TestFetchSolrRag: + """Tests for _fetch_solr_rag async function.""" + + @pytest.mark.asyncio + async def test_solr_disabled(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test when Solr is disabled.""" + config_mock = mocker.Mock(spec=AppConfig) + config_mock.inline_solr_enabled = False + mocker.patch("utils.vector_search.configuration", config_mock) + + client_mock = mocker.AsyncMock() + rag_chunks, referenced_docs = await _fetch_solr_rag(client_mock, "test query") + + assert rag_chunks == [] + assert referenced_docs == [] + client_mock.vector_io.query.assert_not_called() + + @pytest.mark.asyncio + async def test_solr_enabled_success(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test successful Solr RAG fetch.""" + # Mock configuration + config_mock = mocker.Mock(spec=AppConfig) + config_mock.inline_solr_enabled = True + config_mock.okp.offline = True + mocker.patch("utils.vector_search.configuration", config_mock) + + # Mock chunk + chunk_mock = mocker.Mock() + chunk_mock.content = "Solr content" + chunk_mock.metadata = {"parent_id": "parent_1", "title": "Solr Doc"} + chunk_mock.chunk_metadata = None + + # Mock query response + query_response = mocker.Mock() + query_response.chunks = [chunk_mock] + query_response.scores = [0.85] + + # Mock client + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = query_response + + rag_chunks, _referenced_docs = await _fetch_solr_rag(client_mock, "test query") + + assert len(rag_chunks) > 0 + assert rag_chunks[0].content == "Solr content" + assert rag_chunks[0].source == constants.OKP_RAG_ID + + +class TestBuildRagContext: + """Tests for build_rag_context async function.""" + + @pytest.mark.asyncio + async def test_both_sources_disabled(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test when both BYOK inline and Solr inline are not configured.""" + config_mock = mocker.Mock(spec=AppConfig) + config_mock.configuration.rag.inline = [] + config_mock.configuration.byok_rag = [] + config_mock.inline_solr_enabled = False + mocker.patch("utils.vector_search.configuration", config_mock) + + client_mock = mocker.AsyncMock() + context = await build_rag_context(client_mock, "test query", None) + + assert context.context_text == "" + assert context.rag_chunks == [] + assert context.referenced_documents == [] + + @pytest.mark.asyncio + async def test_byok_enabled_only(self, mocker) -> None: # type: ignore[no-untyped-def] + """Test when only inline BYOK is configured.""" + # Mock configuration + config_mock = mocker.Mock(spec=AppConfig) + byok_rag_mock = mocker.Mock() + byok_rag_mock.rag_id = "rag_1" + byok_rag_mock.vector_db_id = "vs_1" + config_mock.configuration.rag.inline = ["rag_1"] + config_mock.configuration.byok_rag = [byok_rag_mock] + config_mock.inline_solr_enabled = False + config_mock.score_multiplier_mapping = {"vs_1": 1.0} + config_mock.rag_id_mapping = {"vs_1": "rag_1"} + mocker.patch("utils.vector_search.configuration", config_mock) + + # Mock chunk + chunk_mock = mocker.Mock() + chunk_mock.content = "BYOK content" + chunk_mock.chunk_id = "chunk_1" + chunk_mock.metadata = {"document_id": "doc_1"} + + search_response = mocker.Mock() + search_response.chunks = [chunk_mock] + search_response.scores = [0.9] + + # Mock client + client_mock = mocker.AsyncMock() + client_mock.vector_io.query.return_value = search_response + + context = await build_rag_context(client_mock, "test query", None) + + assert len(context.rag_chunks) > 0 + assert "BYOK content" in context.context_text + assert "file_search found" in context.context_text