Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

class SimpleRetriever(BaseRetriever):
def retrieve(self) -> set[str]:
app.logger.info(
f"Retrieving context for prompt: {self.prompt} with doc_id: {self.doc_id}"
)
context = self._simple_retrieval()
if not context:
# UN-1288 For Pinecone, we are seeing an inconsistent case where
Expand All @@ -21,6 +18,10 @@ def retrieve(self) -> set[str]:
# the following sleep is added
# Note: This will not fix the issue. Since this issue is inconsistent
# and not reproducible easily, this is just a safety net.
app.logger.info(
f"[doc_id: {self.doc_id}] Could not retrieve context, "
"retrying after 2 secs to handle issues due to lag"
)
time.sleep(2)
context = self._simple_retrieval()
return context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,33 @@ def replace_custom_data_variable(

@staticmethod
@lru_cache(maxsize=128)
def _extract_variables_cached(prompt_text: str) -> tuple[str, ...]:
"""Internal cached extraction - returns tuple for lru_cache compatibility."""
return tuple(re.findall(VariableConstants.VARIABLE_REGEX, prompt_text))

@staticmethod
def extract_variables_from_prompt(prompt_text: str) -> list[str]:
variable: list[str] = []
variable = re.findall(VariableConstants.VARIABLE_REGEX, prompt_text)
return variable
"""Extract variables from prompt with caching and stats logging.

Uses lru_cache internally and logs cache statistics periodically
to help determine if caching is beneficial.
"""
result = VariableReplacementHelper._extract_variables_cached(prompt_text)

# Log stats periodically (every 50 calls)
info_after = VariableReplacementHelper._extract_variables_cached.cache_info()
total_calls = info_after.hits + info_after.misses

if total_calls % 50 == 0 and total_calls > 0:
hit_rate = info_after.hits / total_calls * 100
app.logger.info(
f"[VariableCache] total={total_calls} hits={info_after.hits} "
f"misses={info_after.misses} hit_rate={hit_rate:.1f}% "
f"size={info_after.currsize}/{info_after.maxsize} "
f"prompt_chars={len(prompt_text)}"
)

return list(result)

@staticmethod
def fetch_dynamic_variable_value(url: str, data: str) -> Any:
Expand Down
48 changes: 37 additions & 11 deletions prompt-service/src/unstract/prompt_service/services/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import datetime
from typing import Any

from flask import current_app as app

from unstract.prompt_service.constants import PromptServiceConstants as PSKeys
from unstract.prompt_service.constants import RetrievalStrategy
from unstract.prompt_service.core.retrievers.automerging import AutomergingRetriever
Expand Down Expand Up @@ -32,13 +34,23 @@ def perform_retrieval( # type:ignore
file_path: str,
context_retrieval_metrics: dict[str, Any],
) -> tuple[str, list[str]]:
prompt_name = output.get(PSKeys.NAME, "<unknown>")
vector_db_id = (
getattr(vector_db, "_adapter_instance_id", None) if vector_db else None
)
app.logger.info(
f"[Retrieval] prompt='{prompt_name}' doc_id={doc_id} "
f"chunk_size={chunk_size} method={'complete_context' if chunk_size == 0 else 'chunked'}"
+ (f" vector_db={vector_db_id}" if vector_db_id else "")
)

context: list[str]
if chunk_size == 0:
context = RetrievalService.retrieve_complete_context(
execution_source=execution_source,
file_path=file_path,
context_retrieval_metrics=context_retrieval_metrics,
prompt_key=output[PSKeys.PROMPTX],
prompt_key=prompt_name,
)
else:
context = RetrievalService.run_retrieval(
Expand Down Expand Up @@ -101,9 +113,14 @@ def run_retrieval( # type:ignore
llm=llm,
)
context = retriever.retrieve()
context_retrieval_metrics[prompt_key] = {
"time_taken(s)": Metrics.elapsed_time(start_time=retrieval_start_time)
}
elapsed = Metrics.elapsed_time(start_time=retrieval_start_time)
context_retrieval_metrics[prompt_key] = {"time_taken(s)": elapsed}

app.logger.info(
f"[Retrieval] prompt='{prompt_key}' doc_id={doc_id} "
f"strategy='{retrieval_type}' top_k={top_k} chunks={len(context)} time={elapsed:.3f}s"
)

return list(context)

@staticmethod
Expand All @@ -113,18 +130,27 @@ def retrieve_complete_context(
context_retrieval_metrics: dict[str, Any],
prompt_key: str,
) -> list[str]:
"""Loads full context from raw file for zero chunk size retrieval
"""Loads full context from raw file for zero chunk size retrieval.

Args:
execution_source (str): Source of execution.
file_path (str): Path to the directory containing text file.
execution_source: Source of execution (e.g., "api", "workflow").
file_path: Path to the extracted text file.
context_retrieval_metrics: Dict to store retrieval timing metrics
(modified in-place).
prompt_key: Name/identifier of the prompt for metrics tracking.

Returns:
list[str]: context from extracted file.
List containing the complete file content as a single string.
"""
fs_instance = FileUtils.get_fs_instance(execution_source=execution_source)
retrieval_start_time = datetime.datetime.now()
context = fs_instance.read(path=file_path, mode="r")
context_retrieval_metrics[prompt_key] = {
"time_taken(s)": Metrics.elapsed_time(start_time=retrieval_start_time)
}
elapsed = Metrics.elapsed_time(start_time=retrieval_start_time)
context_retrieval_metrics[prompt_key] = {"time_taken(s)": elapsed}

app.logger.info(
f"[Retrieval] prompt='{prompt_key}' complete_context "
f"chars={len(context)} time={elapsed:.3f}s"
)

return [context]