diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml
index 10e2d61b..58f9f50c 100644
--- a/.github/workflows/integration-test.yml
+++ b/.github/workflows/integration-test.yml
@@ -113,6 +113,8 @@ jobs:
repository: nextcloud/context_chat
path: apps/context_chat
persist-credentials: false
+ # todo: remove later
+ ref: feat/reverse-content-flow
- name: Checkout backend
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4
@@ -167,6 +169,10 @@ jobs:
cd ..
rm -rf documentation
+ - name: Run files scan
+ run: |
+ ./occ files:scan --all
+
- name: Setup python 3.11
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5
with:
@@ -197,26 +203,101 @@ jobs:
ls -la context_chat_backend/persistent_storage/*
sleep 30 # Wait for the em server to get ready
- - name: Scan files, baseline
- run: |
- ./occ files:scan admin
- ./occ context_chat:scan admin -m text/plain
-
- - name: Check python memory usage
+ - name: Initial memory usage check
run: |
ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem
ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt
- - name: Scan files
+ - name: Run cron jobs
run: |
- ./occ files:scan admin
- ./occ context_chat:scan admin -m text/markdown &
- ./occ context_chat:scan admin -m text/x-rst
+ # every 10 seconds indefinitely
+ while true; do
+ php cron.php
+ sleep 10
+ done &
- - name: Check python memory usage
+ - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files
run: |
- ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem
- ps -p $(cat pid.txt) -o %mem --no-headers > after_scan_mem.txt
+ success=0
+ echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files"
+ for i in {1..90}; do
+ echo "Checking stats, attempt $i..."
+
+ stats_err=$(mktemp)
+ stats=$(timeout 5 ./occ context_chat:stats 2>"$stats_err")
+ stats_exit=$?
+ echo "Stats output:"
+ echo "$stats"
+ if [ -s "$stats_err" ]; then
+ echo "Stderr:"
+ cat "$stats_err"
+ fi
+ echo "---"
+ rm -f "$stats_err"
+
+ # Check for critical errors in output
+ if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then
+ echo "Backend connection error detected (exit=$stats_exit), retrying..."
+ sleep 10
+ continue
+ fi
+
+ # Extract Total eligible files
+ total_files=$(echo "$stats" | grep -oP 'Total eligible files:\s*\K\d+' || echo "")
+
+ # Extract Indexed documents count (files__default)
+ indexed_count=$(echo "$stats" | grep -oP "'files__default'\s*=>\s*\K\d+" || echo "")
+
+ # Validate parsed values
+ if [ -z "$total_files" ] || [ -z "$indexed_count" ]; then
+ echo "Error: Could not parse stats output properly"
+ if echo "$stats" | grep -q "Indexed documents:"; then
+ echo " Indexed documents section found but could not extract count"
+ fi
+ sleep 10
+ continue
+ fi
+
+ echo "Total eligible files: $total_files"
+ echo "Indexed documents (files__default): $indexed_count"
+
+ # Calculate absolute difference
+ diff=$((total_files - indexed_count))
+ if [ $diff -lt 0 ]; then
+ diff=$((-diff))
+ fi
+
+ # Calculate 2% threshold using bc for floating point support
+ threshold=$(echo "scale=4; $total_files * 0.02" | bc)
+
+ # Check if difference is within tolerance
+ if (( $(echo "$diff <= $threshold" | bc -l) )); then
+ echo "Indexing within 2% tolerance (diff=$diff, threshold=$threshold)"
+ success=1
+ break
+ else
+ pct=$(echo "scale=2; ($diff / $total_files) * 100" | bc)
+ echo "Outside 2% tolerance: diff=$diff (${pct}%), threshold=$threshold"
+ fi
+
+ # Check if backend is still alive
+ ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0")
+ if [ "$ccb_alive" -eq 0 ]; then
+ echo "Error: Context Chat Backend process is not running. Exiting."
+ exit 1
+ fi
+
+ sleep 10
+ done
+
+ echo "::endgroup::"
+
+ ./occ context_chat:stats
+
+ if [ $success -ne 1 ]; then
+ echo "Max attempts reached"
+ exit 1
+ fi
- name: Run the prompts
run: |
@@ -250,19 +331,6 @@ jobs:
echo "Memory usage during scan is stable. No memory leak detected."
fi
- - name: Compare memory usage and detect leak
- run: |
- initial_mem=$(cat after_scan_mem.txt | tr -d ' ')
- final_mem=$(cat after_prompt_mem.txt | tr -d ' ')
- echo "Initial Memory Usage: $initial_mem%"
- echo "Memory Usage after prompt: $final_mem%"
-
- if (( $(echo "$final_mem > $initial_mem" | bc -l) )); then
- echo "Memory usage has increased during prompt. Possible memory leak detected!"
- else
- echo "Memory usage during prompt is stable. No memory leak detected."
- fi
-
- name: Show server logs
if: always()
run: |
diff --git a/appinfo/info.xml b/appinfo/info.xml
index 9760cd29..30194baa 100644
--- a/appinfo/info.xml
+++ b/appinfo/info.xml
@@ -82,5 +82,19 @@ Setup background job workers as described here: https://docs.nextcloud.com/serve
Password to be used for authenticating requests to the OpenAI-compatible endpoint set in CC_EM_BASE_URL.
+
+
+ rp
+ Request Processing Mode
+ APP_ROLE=rp
+ true
+
+
+ indexing
+ Indexing Mode
+ APP_ROLE=indexing
+ false
+
+
diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py
index efb81b6d..b6bc17d8 100644
--- a/context_chat_backend/chain/ingest/doc_loader.py
+++ b/context_chat_backend/chain/ingest/doc_loader.py
@@ -3,15 +3,15 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import asyncio
import logging
import re
import tempfile
from collections.abc import Callable
-from typing import BinaryIO
+from io import BytesIO
import docx2txt
from epub2txt import epub2txt
-from fastapi import UploadFile
from langchain_unstructured import UnstructuredLoader
from odfdo import Document
from pandas import read_csv, read_excel
@@ -19,9 +19,12 @@
from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError
from striprtf import striprtf
+from ...types import SourceItem, TaskProcException
+from .task_proc import do_ocr, do_transcription
+
logger = logging.getLogger('ccb.doc_loader')
-def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str:
+def _temp_file_wrapper(file: BytesIO, loader: Callable, sep: str = '\n') -> str:
raw_bytes = file.read()
with tempfile.NamedTemporaryFile(mode='wb') as tmp:
tmp.write(raw_bytes)
@@ -35,46 +38,46 @@ def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str
# -- LOADERS -- #
-def _load_pdf(file: BinaryIO) -> str:
+def _load_pdf(file: BytesIO) -> str:
pdf_reader = PdfReader(file)
return '\n\n'.join([page.extract_text().strip() for page in pdf_reader.pages])
-def _load_csv(file: BinaryIO) -> str:
+def _load_csv(file: BytesIO) -> str:
return read_csv(file).to_string(header=False, na_rep='')
-def _load_epub(file: BinaryIO) -> str:
+def _load_epub(file: BytesIO) -> str:
return _temp_file_wrapper(file, epub2txt).strip()
-def _load_docx(file: BinaryIO) -> str:
+def _load_docx(file: BytesIO) -> str:
return docx2txt.process(file).strip()
-def _load_odt(file: BinaryIO) -> str:
+def _load_odt(file: BytesIO) -> str:
return _temp_file_wrapper(file, lambda fp: Document(fp).get_formatted_text()).strip()
-def _load_ppt_x(file: BinaryIO) -> str:
+def _load_ppt_x(file: BytesIO) -> str:
return _temp_file_wrapper(file, lambda fp: UnstructuredLoader(fp).load()).strip()
-def _load_rtf(file: BinaryIO) -> str:
+def _load_rtf(file: BytesIO) -> str:
return striprtf.rtf_to_text(file.read().decode('utf-8', 'ignore')).strip()
-def _load_xml(file: BinaryIO) -> str:
+def _load_xml(file: BytesIO) -> str:
data = file.read().decode('utf-8', 'ignore')
data = re.sub(r'', '', data)
return data.strip()
-def _load_xlsx(file: BinaryIO) -> str:
+def _load_xlsx(file: BytesIO) -> str:
return read_excel(file, na_filter=False).to_string(header=False, na_rep='')
-def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None:
+def _load_email(file: BytesIO, ext: str = 'eml') -> str | None:
# NOTE: msg format is not tested
if ext not in ['eml', 'msg']:
return None
@@ -115,30 +118,50 @@ def attachment_partitioner(
}
-def decode_source(source: UploadFile) -> str | None:
+def decode_source(source: SourceItem) -> str | None:
+ io_obj: BytesIO | None = None
try:
# .pot files are powerpoint templates but also plain text files,
# so we skip them to prevent decoding errors
- if source.headers['title'].endswith('.pot'):
+ if source.title.endswith('.pot'):
return None
- mimetype = source.headers['type']
+ mimetype = source.type
if mimetype is None:
return None
+ try:
+ if mimetype.startswith('image/'):
+ return asyncio.run(do_ocr(source.userIds[0], source.file_id))
+ if mimetype.startswith('audio/'):
+ return asyncio.run(do_transcription(source.userIds[0], source.file_id))
+ except TaskProcException as e:
+ # todo: convert this to error obj return
+ # todo: short circuit all other ocr/transcription files when a fatal error arrives
+ # todo: maybe with a global ttl, with a retryable tag
+ logger.warning(f'OCR task failed for source file ({source.reference}): {e}')
+ return None
+ except ValueError:
+ # should not happen
+ logger.warning(f'Unexpected ValueError for source file ({source.reference})')
+ return None
+
+ if isinstance(source.content, str):
+ io_obj = BytesIO(source.content.encode('utf-8', 'ignore'))
+ else:
+ io_obj = source.content
+
if _loader_map.get(mimetype):
- result = _loader_map[mimetype](source.file)
- source.file.close()
+ result = _loader_map[mimetype](io_obj)
return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore')
- result = source.file.read().decode('utf-8', 'ignore')
- source.file.close()
- return result
+ return io_obj.read().decode('utf-8', 'ignore')
except PdfFileNotDecryptedError:
- logger.warning(f'PDF file ({source.filename}) is encrypted and cannot be read')
+ logger.warning(f'PDF file ({source.reference}) is encrypted and cannot be read')
return None
except Exception:
- logger.exception(f'Error decoding source file ({source.filename})', stack_info=True)
+ logger.exception(f'Error decoding source file ({source.reference})', stack_info=True)
return None
finally:
- source.file.close() # Ensure file is closed after processing
+ if io_obj is not None:
+ io_obj.close()
diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py
index 5871ebb8..9484ab9f 100644
--- a/context_chat_backend/chain/ingest/injest.py
+++ b/context_chat_backend/chain/ingest/injest.py
@@ -2,32 +2,49 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
+import asyncio
import logging
import re
-from fastapi.datastructures import UploadFile
from langchain.schema import Document
from ...dyn_loader import VectorDBLoader
-from ...types import TConfig
-from ...utils import is_valid_source_id, to_int
+from ...mimetype_list import AUDIO_MIMETYPES, IMAGE_MIMETYPES, SUPPORTED_MIMETYPES
+from ...types import FILES_PROVIDER_ID, IndexingError, SourceItem, TConfig
from ...vectordb.base import BaseVectorDB
from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp
from ..types import InDocument
from .doc_loader import decode_source
from .doc_splitter import get_splitter_for
-from .mimetype_list import SUPPORTED_MIMETYPES
+from .task_proc import OCR_TASK_TYPE, SPEECH_TO_TEXT_TASK_TYPE, is_task_type_available
logger = logging.getLogger('ccb.injest')
-def _allowed_file(file: UploadFile) -> bool:
- return file.headers['type'] in SUPPORTED_MIMETYPES
+
+def _do_extended_mimetype_validation(source: SourceItem, ocr_available: bool, stt_available: bool) -> None:
+ '''
+ Raises
+ ------
+ ValueError
+ '''
+
+ extended_mimetypes = (
+ *SUPPORTED_MIMETYPES,
+ *(([], IMAGE_MIMETYPES)[ocr_available]),
+ *(([], AUDIO_MIMETYPES)[stt_available]),
+ )
+
+ if source.reference.startswith(FILES_PROVIDER_ID) and source.type not in extended_mimetypes:
+ raise ValueError(
+ f'Unsupported file type: {source.type} for reference {source.reference}.'
+ f' OCR available: {ocr_available}, Speech-to-text available: {stt_available}.'
+ )
def _filter_sources(
vectordb: BaseVectorDB,
- sources: list[UploadFile]
-) -> tuple[list[UploadFile], list[UploadFile]]:
+ sources: dict[int, SourceItem]
+) -> tuple[dict[int, SourceItem], dict[int, SourceItem]]:
'''
Returns
-------
@@ -37,30 +54,43 @@ def _filter_sources(
'''
try:
- existing_sources, new_sources = vectordb.check_sources(sources)
+ existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources)
except Exception as e:
- raise DbException('Error: Vectordb sources_to_embed error') from e
+ # todo: maybe handle this and other errors as IndexingErrors
+ raise DbException('Error: Vectordb error while checking existing sources in indexing') from e
+
+ existing_sources = {}
+ to_embed_sources = {}
- return ([
- source for source in sources
- if source.filename in existing_sources
- ], [
- source for source in sources
- if source.filename in new_sources
- ])
+ for db_id, source in sources.items():
+ if source.reference in existing_source_ids:
+ existing_sources[db_id] = source
+ elif source.reference in to_embed_source_ids:
+ to_embed_sources[db_id] = source
+ return existing_sources, to_embed_sources
-def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]:
- indocuments = []
- for source in sources:
- logger.debug('processing source', extra={ 'source_id': source.filename })
+def _sources_to_indocuments(
+ config: TConfig,
+ sources: dict[int, SourceItem]
+) -> tuple[dict[int, InDocument], dict[int, IndexingError]]:
+ indocuments = {}
+ errored_docs = {}
+ for db_id, source in sources.items():
+ logger.debug('processing source', extra={ 'source_id': source.reference })
+
+ # todo: maybe fetch the content of the files here
# transform the source to have text data
content = decode_source(source)
if content is None or (content := content.strip()) == '':
- logger.debug('decoded empty source', extra={ 'source_id': source.filename })
+ logger.debug('decoded empty source', extra={ 'source_id': source.reference })
+ errored_docs[db_id] = IndexingError(
+ error='Decoded content is empty',
+ retryable=False,
+ )
continue
# replace more than two newlines with two newlines (also blank spaces, more than 4)
@@ -71,94 +101,123 @@ def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[
content = content.replace('\0', '')
if content is None or content == '':
- logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.filename })
+ logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.reference })
+ errored_docs[db_id] = IndexingError(
+ error='Decoded content is empty',
+ retryable=False,
+ )
continue
- logger.debug('decoded non empty source', extra={ 'source_id': source.filename })
+ logger.debug('decoded non empty source', extra={ 'source_id': source.reference })
metadata = {
- 'source': source.filename,
- 'title': _decode_latin_1(source.headers['title']),
- 'type': source.headers['type'],
+ 'source': source.reference,
+ 'title': _decode_latin_1(source.title),
+ 'type': source.type,
}
doc = Document(page_content=content, metadata=metadata)
- splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type'])
+ splitter = get_splitter_for(config.embedding_chunk_size, source.type)
split_docs = splitter.split_documents([doc])
logger.debug('split document into chunks', extra={
- 'source_id': source.filename,
+ 'source_id': source.reference,
'len(split_docs)': len(split_docs),
})
- indocuments.append(InDocument(
+ indocuments[db_id] = InDocument(
documents=split_docs,
- userIds=list(map(_decode_latin_1, source.headers['userIds'].split(','))),
- source_id=source.filename, # pyright: ignore[reportArgumentType]
- provider=source.headers['provider'],
- modified=to_int(source.headers['modified']),
- ))
+ userIds=list(map(_decode_latin_1, source.userIds)),
+ source_id=source.reference,
+ provider=source.provider,
+ modified=source.modified, # pyright: ignore[reportArgumentType]
+ )
+
+ return indocuments, errored_docs
- return indocuments
+
+def _increase_access_for_existing_sources(
+ vectordb: BaseVectorDB,
+ existing_sources: dict[int, SourceItem]
+) -> dict[int, IndexingError | None]:
+ '''
+ update userIds for existing sources
+ allow the userIds as additional users, not as the only users
+ '''
+ if len(existing_sources) == 0:
+ return {}
+
+ results = {}
+ logger.debug('Increasing access for existing sources', extra={
+ 'source_ids': [source.reference for source in existing_sources.values()]
+ })
+ for db_id, source in existing_sources.items():
+ try:
+ vectordb.update_access(
+ UpdateAccessOp.ALLOW,
+ list(map(_decode_latin_1, source.userIds)),
+ source.reference,
+ )
+ results[db_id] = None
+ except SafeDbException as e:
+ logger.error(f'Failed to update access for source ({source.reference}): {e.args[0]}')
+ results[db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
+ continue
+ except Exception as e:
+ logger.error(f'Unexpected error while updating access for source ({source.reference}): {e}')
+ results[db_id] = IndexingError(
+ error='Unexpected error while updating access',
+ retryable=True,
+ )
+ continue
+ return results
def _process_sources(
vectordb: BaseVectorDB,
config: TConfig,
- sources: list[UploadFile],
-) -> tuple[list[str],list[str]]:
+ sources: dict[int, SourceItem]
+) -> dict[int, IndexingError | None]:
'''
Processes the sources and adds them to the vectordb.
Returns the list of source ids that were successfully added and those that need to be retried.
'''
- existing_sources, filtered_sources = _filter_sources(vectordb, sources)
+ existing_sources, to_embed_sources = _filter_sources(vectordb, sources)
logger.debug('db filter source results', extra={
'len(existing_sources)': len(existing_sources),
'existing_sources': existing_sources,
- 'len(filtered_sources)': len(filtered_sources),
- 'filtered_sources': filtered_sources,
+ 'len(to_embed_sources)': len(to_embed_sources),
+ 'to_embed_sources': to_embed_sources,
})
- loaded_source_ids = [source.filename for source in existing_sources]
- # update userIds for existing sources
- # allow the userIds as additional users, not as the only users
- if len(existing_sources) > 0:
- logger.debug('Increasing access for existing sources', extra={
- 'source_ids': [source.filename for source in existing_sources]
- })
- for source in existing_sources:
- try:
- vectordb.update_access(
- UpdateAccessOp.allow,
- list(map(_decode_latin_1, source.headers['userIds'].split(','))),
- source.filename, # pyright: ignore[reportArgumentType]
- )
- except SafeDbException as e:
- logger.error(f'Failed to update access for source ({source.filename}): {e.args[0]}')
- continue
-
- if len(filtered_sources) == 0:
+ source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources)
+
+ if len(to_embed_sources) == 0:
# no new sources to embed
logger.debug('Filtered all sources, nothing to embed')
- return loaded_source_ids, [] # pyright: ignore[reportReturnType]
+ return source_proc_results
logger.debug('Filtered sources:', extra={
- 'source_ids': [source.filename for source in filtered_sources]
+ 'source_ids': [source.reference for source in to_embed_sources.values()]
})
# invalid/empty sources are filtered out here and not counted in loaded/retryable
- indocuments = _sources_to_indocuments(config, filtered_sources)
+ indocuments, errored_docs = _sources_to_indocuments(config, to_embed_sources)
- logger.debug('Converted all sources to documents')
+ source_proc_results.update(errored_docs)
+ logger.debug('Converted sources to documents')
if len(indocuments) == 0:
# filtered document(s) were invalid/empty, not an error
logger.debug('All documents were found empty after being processed')
- return loaded_source_ids, [] # pyright: ignore[reportReturnType]
+ return source_proc_results
- added_source_ids, retry_source_ids = vectordb.add_indocuments(indocuments)
- loaded_source_ids.extend(added_source_ids)
+ doc_add_results = vectordb.add_indocuments(indocuments)
+ source_proc_results.update(doc_add_results)
logger.debug('Added documents to vectordb')
- return loaded_source_ids, retry_source_ids # pyright: ignore[reportReturnType]
+ return source_proc_results
def _decode_latin_1(s: str) -> str:
@@ -172,31 +231,34 @@ def _decode_latin_1(s: str) -> str:
def embed_sources(
vectordb_loader: VectorDBLoader,
config: TConfig,
- sources: list[UploadFile],
-) -> tuple[list[str],list[str]]:
- # either not a file or a file that is allowed
- sources_filtered = [
- source for source in sources
- if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
- or _allowed_file(source)
- ]
-
+ sources: dict[int, SourceItem]
+) -> dict[int, IndexingError | None]:
logger.debug('Embedding sources:', extra={
'source_ids': [
- f'{source.filename} ({_decode_latin_1(source.headers["title"])})'
- for source in sources_filtered
- ],
- 'invalid_source_ids': [
- source.filename for source in sources
- if not is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType]
- ],
- 'not_allowed_file_ids': [
- source.filename for source in sources
- if not _allowed_file(source)
+ f'{source.reference} ({_decode_latin_1(source.title)})'
+ for source in sources.values()
],
- 'len(source_ids)': len(sources_filtered),
- 'len(total_source_ids)': len(sources),
+ 'len(source_ids)': len(sources),
})
+ mime_filtered_sources: dict[int, SourceItem] = {}
+ mime_errored_sources: dict[int, IndexingError] = {}
+
+ ocr_available = asyncio.run(is_task_type_available(OCR_TASK_TYPE))
+ stt_available = asyncio.run(is_task_type_available(SPEECH_TO_TEXT_TASK_TYPE))
+ for db_id, source in sources.items():
+ try:
+ _do_extended_mimetype_validation(source, ocr_available, stt_available)
+ mime_filtered_sources[db_id] = source
+ except ValueError as e:
+ mime_errored_sources[db_id] = IndexingError(
+ error=str(e),
+ retryable=False,
+ )
+ continue
+
vectordb = vectordb_loader.load()
- return _process_sources(vectordb, config, sources_filtered)
+ processed_sources = _process_sources(vectordb, config, mime_filtered_sources)
+ processed_sources.update(mime_errored_sources)
+
+ return processed_sources
diff --git a/context_chat_backend/chain/ingest/task_proc.py b/context_chat_backend/chain/ingest/task_proc.py
new file mode 100644
index 00000000..a6f7e4cd
--- /dev/null
+++ b/context_chat_backend/chain/ingest/task_proc.py
@@ -0,0 +1,289 @@
+#
+# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+#
+
+import asyncio
+import json
+import logging
+import os
+from typing import Any, Literal
+
+import niquests
+from nc_py_api import AsyncNextcloudApp, NextcloudException
+from pydantic import BaseModel, ValidationError
+
+from ...types import TaskProcException, TaskProcFatalException
+from ...utils import timed_cache_async
+
+LOGGER = logging.getLogger('ccb.task_proc')
+OCR_TASK_TYPE = 'core:image2text:ocr'
+SPEECH_TO_TEXT_TASK_TYPE = 'core:audio2text'
+CACHE_TTL = 15 * 60 # cache values for 15 minutes
+OCP_TASK_PROC_SCHED_RETRIES = 3
+OCP_TASK_TIMEOUT = 20 * 60 # 20 mins to wait for a task to complete
+
+
+class Task(BaseModel):
+ id: int
+ status: str
+ output: dict[str, Any] | None = None
+
+
+class TaskResponse(BaseModel):
+ task: Task
+
+
+InputShapeType = Literal[
+ 'Number',
+ 'Text',
+ 'Audio',
+ 'Image',
+ 'Video',
+ 'File',
+ 'Enum',
+ 'ListOfNumbers',
+ 'ListOfTexts',
+ 'ListOfImages',
+ 'ListOfAudios',
+ 'ListOfVideos',
+ 'ListOfFiles',
+]
+
+class InputShape(BaseModel):
+ name: str
+ description: str
+ type: InputShapeType
+
+
+class InputShapeEnum(BaseModel):
+ name: str
+ value: str
+
+
+class TaskType(BaseModel):
+ name: str
+ description: str
+ inputShape: dict[str, InputShape]
+ inputShapeEnumValues: dict[str, list[InputShapeEnum]]
+ inputShapeDefaults: dict[str, str | int | float]
+ optionalInputShape: dict[str, InputShape]
+ optionalInputShapeEnumValues: dict[str, list[InputShapeEnum]]
+ optionalInputShapeDefaults: dict[str, str | int | float]
+ outputShape: dict[str, InputShape]
+ outputShapeEnumValues: dict[str, list[InputShapeEnum]]
+ optionalOutputShape: dict[str, InputShape]
+ optionalOutputShapeEnumValues: dict[str, list[InputShapeEnum]]
+
+
+class TaskTypesResponse(BaseModel):
+ types: dict[str, TaskType]
+
+
+
+def __try_parse_ocs_response(response: niquests.Response | None) -> dict | str:
+ if response is None or response.text is None:
+ return 'No response'
+ try:
+ ocs_response = json.loads(response.text)
+ if not (ocs_data := ocs_response.get('ocs', {}).get('data')):
+ return response.text
+ return ocs_data
+ except json.JSONDecodeError:
+ return response.text
+
+
+async def __schedule_task(user_id: str, task_type: str, custom_id: str, task_input: dict) -> Task:
+ '''
+ Raises
+ ------
+ TaskProcException
+ '''
+ nc = AsyncNextcloudApp()
+ await nc.set_user(user_id)
+
+ for sched_tries in range(OCP_TASK_PROC_SCHED_RETRIES):
+ try:
+ response = await nc.ocs(
+ 'POST',
+ '/ocs/v2.php/taskprocessing/schedule',
+ json={
+ 'type': task_type,
+ 'appId': os.getenv('APP_ID', 'context_chat_backend'),
+ 'customId': f'ccb-{custom_id}',
+ 'input': task_input,
+ },
+ )
+ try:
+ task = TaskResponse.model_validate(response).task
+ LOGGER.debug('TaskProcessing task schedule response', extra={
+ 'task': task,
+ })
+ return task
+ except ValidationError as e:
+ raise TaskProcException('Failed to parse TaskProcessing task result') from e
+ except NextcloudException as e:
+ if e.status_code == niquests.codes.precondition_failed: # type: ignore[attr-defined]
+ raise TaskProcFatalException(
+ 'Failed to schedule Nextcloud TaskProcessing task:'
+ f' No provider of {task_type} is installed on this Nextcloud instance.'
+ ' Please install a suitable provider from the AI overview:'
+ ' https://docs.nextcloud.com/server/latest/admin_manual/ai/overview.html.',
+ ) from e
+
+ if e.status_code == niquests.codes.too_many_requests: # type: ignore[attr-defined]
+ LOGGER.warning(
+ 'Rate limited during TaskProcessing task scheduling, waiting 30s before retrying',
+ extra={
+ 'task_type': task_type,
+ 'sched_try': sched_tries,
+ },
+ )
+ await asyncio.sleep(30)
+ continue
+
+ ocs_response = __try_parse_ocs_response(e.response)
+ if e.status_code // 100 == 4:
+ raise TaskProcFatalException(
+ f'Failed to schedule TaskProcessing task due to client error: {ocs_response}',
+ ) from e
+
+ LOGGER.error('NextcloudException during TaskProcessing task scheduling', exc_info=e, extra={
+ 'task_type': task_type,
+ 'sched_try': sched_tries,
+ 'nc_exc_reason': str(e.reason),
+ 'nc_exc_info': str(e.info),
+ 'nc_exc_status_code': str(e.status_code),
+ 'ocs_response': str(ocs_response),
+ })
+ raise TaskProcException(f'Failed to schedule TaskProcessing task: {ocs_response}') from e
+ except TaskProcException:
+ raise
+ except Exception as e:
+ raise TaskProcException(f'Failed to schedule TaskProcessing task: {e}') from e
+
+ raise TaskProcException('Failed to schedule TaskProcessing task, tried 3 times')
+
+
+async def __get_task_result(user_id: str, task: Task) -> Any:
+ nc = AsyncNextcloudApp()
+ await nc.set_user(user_id)
+
+ i = 0
+ now_waiting_for = 0
+
+ while task.status != 'STATUS_SUCCESSFUL' and task.status != 'STATUS_FAILED' and now_waiting_for < OCP_TASK_TIMEOUT:
+ i += 1
+ now_waiting_for += 10
+ await asyncio.sleep(10)
+
+ try:
+ response = await nc.ocs('GET', f'/ocs/v2.php/taskprocessing/task/{task.id}')
+ except NextcloudException as e:
+ if e.status_code == niquests.codes.too_many_requests: # type: ignore[attr-defined]
+ LOGGER.warning(
+ 'Rate limited during TaskProcessing task polling, waiting 10s before retrying',
+ extra={
+ 'task_id': task.id,
+ 'tries_so_far': i,
+ 'waiting_time': now_waiting_for,
+ },
+ )
+ now_waiting_for += 60
+ await asyncio.sleep(60)
+ continue
+ raise TaskProcException('Failed to poll TaskProcessing task') from e
+ except niquests.RequestException as e:
+ LOGGER.warning('Ignored error during TaskProcessing task polling', exc_info=e, extra={
+ 'task_id': task.id,
+ 'tries_so_far': i,
+ 'waiting_time': now_waiting_for,
+ })
+ continue
+
+ try:
+ task = TaskResponse.model_validate(response).task
+ LOGGER.debug(f'TaskProcessing task poll ({now_waiting_for}s) response', extra={
+ 'task_id': task.id,
+ 'tries_so_far': i,
+ 'waiting_time': now_waiting_for,
+ 'task': task,
+ })
+ except ValidationError as e:
+ raise TaskProcException('Failed to parse TaskProcessing task result') from e
+
+ if task.status != 'STATUS_SUCCESSFUL':
+ raise TaskProcException(
+ f'TaskProcessing task id {task.id} failed with status {task.status}'
+ f' after waiting {now_waiting_for} seconds',
+ )
+
+ if not isinstance(task.output, dict) or 'output' not in task.output:
+ raise TaskProcException(f'"output" key not found or invalid in TaskProcessing task result: {task.output}')
+
+ return task.output['output']
+
+
+async def do_ocr(user_id: str, file_id: int) -> str:
+ try:
+ task = await __schedule_task(user_id, OCR_TASK_TYPE, str(file_id), {'input': [file_id]})
+ output = await __get_task_result(user_id, task)
+ if not isinstance(output, list) or len(output) == 0 or not isinstance(output[0], str):
+ raise TaskProcException(f'OCR task returned empty or invalid output: {output}')
+ return output[0]
+ except TaskProcException as e:
+ LOGGER.error(f'Failed to perform OCR for file_id {file_id}', exc_info=e)
+ raise
+
+
+async def do_transcription(user_id: str, file_id: int) -> str:
+ try:
+ task = await __schedule_task(user_id, SPEECH_TO_TEXT_TASK_TYPE, str(file_id), {'input': file_id})
+ output = await __get_task_result(user_id, task)
+ if not isinstance(output, str) or len(output.strip()) == 0:
+ raise TaskProcException(f'Speech-to-text task returned empty or invalid output: {output}')
+ return output
+ except TaskProcException as e:
+ LOGGER.error(f'Failed to perform transcription for file_id {file_id}', exc_info=e)
+ raise
+
+
+@timed_cache_async(CACHE_TTL)
+async def __get_task_types() -> TaskTypesResponse:
+ '''
+ Raises
+ ------
+ TaskProcException
+ '''
+ nc = AsyncNextcloudApp()
+
+ # NC 33 required for this
+ try:
+ response = await nc.ocs(
+ 'GET',
+ '/ocs/v2.php/taskprocessing/tasks_consumer/tasktypes',
+ )
+ except NextcloudException as e:
+ raise TaskProcException('Failed to fetch Nextcloud TaskProcessing types') from e
+
+ try:
+ task_types = TaskTypesResponse.model_validate(response)
+ LOGGER.debug('Fetched task types', extra={
+ 'task_types': task_types,
+ })
+ except (KeyError, TypeError, ValidationError) as e:
+ raise TaskProcException('Failed to parse Nextcloud TaskProcessing types') from e
+
+ return task_types
+
+
+@timed_cache_async(CACHE_TTL)
+async def is_task_type_available(task_type: str) -> bool:
+ try:
+ task_types = await __get_task_types()
+ except Exception as e:
+ LOGGER.warning(f'Failed to fetch task types: {e}', exc_info=e)
+ return False
+ if task_type not in task_types.types:
+ return False
+ return True
diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py
index 0588dabe..797ba201 100644
--- a/context_chat_backend/controller.py
+++ b/context_chat_backend/controller.py
@@ -6,7 +6,7 @@
# isort: off
from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult
from .types import LoaderException, EmbeddingException
-from .vectordb.types import DbException, SafeDbException, UpdateAccessOp
+from .vectordb.types import DbException, SafeDbException
from .setup_functions import ensure_config_file, repair_run, setup_env_vars
# setup env vars before importing other modules
@@ -24,10 +24,9 @@
from contextlib import asynccontextmanager
from functools import wraps
from threading import Event, Thread
-from time import sleep
-from typing import Annotated, Any
+from typing import Any
-from fastapi import Body, FastAPI, Request, UploadFile
+from fastapi import FastAPI, Request
from langchain.llms.base import LLM
from nc_py_api import AsyncNextcloudApp, NextcloudApp
from nc_py_api.ex_app import persistent_storage, set_handlers
@@ -35,27 +34,21 @@
from starlette.responses import FileResponse
from .chain.context import do_doc_search
-from .chain.ingest.injest import embed_sources
from .chain.one_shot import process_context_query, process_query
from .config_parser import get_config
from .dyn_loader import LLMModelLoader, VectorDBLoader
from .models.types import LlmException
from nc_py_api.ex_app import AppAPIAuthMiddleware
-from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of
-from .vectordb.service import (
- count_documents_by_provider,
- decl_update_access,
- delete_by_provider,
- delete_by_source,
- delete_user,
- update_access,
-)
+from .utils import JSONResponse, exec_in_proc, value_of
+from .task_fetcher import start_bg_threads, wait_for_bg_threads
+from .vectordb.service import count_documents_by_provider
# setup
repair_run()
ensure_config_file()
logger = logging.getLogger('ccb.controller')
+app_config = get_config(os.environ['CC_CONFIG_PATH'])
__download_models_from_hf = os.environ.get('CC_DOWNLOAD_MODELS_FROM_HF', 'true').lower() in ('1', 'true', 'yes')
models_to_fetch = {
@@ -73,10 +66,16 @@
app_enabled = Event()
def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str:
- if enabled:
- app_enabled.set()
- else:
- app_enabled.clear()
+ try:
+ if enabled:
+ app_enabled.set()
+ start_bg_threads(app_config, app_enabled)
+ else:
+ app_enabled.clear()
+ wait_for_bg_threads()
+ except Exception as e:
+ logger.exception('Error in enabled handler:', exc_info=e)
+ return f'Error in enabled handler: {e}'
logger.info(f'App {("disabled", "enabled")[enabled]}')
return ''
@@ -88,15 +87,16 @@ async def lifespan(app: FastAPI):
nc = NextcloudApp()
if nc.enabled_state:
app_enabled.set()
+ start_bg_threads(app_config, app_enabled)
logger.info(f'App enable state at startup: {app_enabled.is_set()}')
t = Thread(target=background_thread_task, args=())
t.start()
yield
vectordb_loader.offload()
llm_loader.offload()
+ wait_for_bg_threads()
-app_config = get_config(os.environ['CC_CONFIG_PATH'])
app = FastAPI(debug=app_config.debug, lifespan=lifespan) # pyright: ignore[reportArgumentType]
app.extra['CONFIG'] = app_config
@@ -129,9 +129,11 @@ async def lifespan(app: FastAPI):
# logger background thread
def background_thread_task():
- while(True):
- logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing})
- sleep(10)
+ # todo
+ # while(True):
+ # logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing})
+ # sleep(10)
+ ...
# exception handlers
@@ -213,200 +215,212 @@ def _():
return JSONResponse(content={'enabled': app_enabled.is_set()}, status_code=200)
-@app.post('/updateAccessDeclarative')
+@app.post('/countIndexedDocuments')
@enabled_guard(app)
-def _(
- userIds: Annotated[list[str], Body()],
- sourceId: Annotated[str, Body()],
-):
- logger.debug('Update access declarative request:', extra={
- 'user_ids': userIds,
- 'source_id': sourceId,
- })
+def _():
+ counts = exec_in_proc(target=count_documents_by_provider, args=(vectordb_loader,))
+ return JSONResponse(counts)
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
- if not is_valid_source_id(sourceId):
- return JSONResponse('Invalid source id', 400)
+@app.get('/downloadLogs')
+def download_logs() -> FileResponse:
+ with tempfile.NamedTemporaryFile('wb', delete=False) as tmp:
+ with zipfile.ZipFile(tmp, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_file:
+ files = os.listdir(os.path.join(persistent_storage(), 'logs'))
+ for file in files:
+ file_path = os.path.join(persistent_storage(), 'logs', file)
+ if os.path.isfile(file_path): # Might be a folder (just skip it then)
+ zip_file.write(file_path)
+ return FileResponse(tmp.name, media_type='application/zip', filename='docker_logs.zip')
- exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId))
- return JSONResponse('Access updated')
+# @app.post('/updateAccessDeclarative')
+# @enabled_guard(app)
+# def _(
+# userIds: Annotated[list[str], Body()],
+# sourceId: Annotated[str, Body()],
+# ):
+# logger.debug('Update access declarative request:', extra={
+# 'user_ids': userIds,
+# 'source_id': sourceId,
+# })
+# if len(userIds) == 0:
+# return JSONResponse('Empty list of user ids', 400)
-@app.post('/updateAccess')
-@enabled_guard(app)
-def _(
- op: Annotated[UpdateAccessOp, Body()],
- userIds: Annotated[list[str], Body()],
- sourceId: Annotated[str, Body()],
-):
- logger.debug('Update access request', extra={
- 'op': op,
- 'user_ids': userIds,
- 'source_id': sourceId,
- })
+# if not is_valid_source_id(sourceId):
+# return JSONResponse('Invalid source id', 400)
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
+# exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId))
- if not is_valid_source_id(sourceId):
- return JSONResponse('Invalid source id', 400)
+# return JSONResponse('Access updated')
- exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId))
- return JSONResponse('Access updated')
+# @app.post('/updateAccess')
+# @enabled_guard(app)
+# def _(
+# op: Annotated[UpdateAccessOp, Body()],
+# userIds: Annotated[list[str], Body()],
+# sourceId: Annotated[str, Body()],
+# ):
+# logger.debug('Update access request', extra={
+# 'op': op,
+# 'user_ids': userIds,
+# 'source_id': sourceId,
+# })
+# if len(userIds) == 0:
+# return JSONResponse('Empty list of user ids', 400)
-@app.post('/updateAccessProvider')
-@enabled_guard(app)
-def _(
- op: Annotated[UpdateAccessOp, Body()],
- userIds: Annotated[list[str], Body()],
- providerId: Annotated[str, Body()],
-):
- logger.debug('Update access by provider request', extra={
- 'op': op,
- 'user_ids': userIds,
- 'provider_id': providerId,
- })
+# if not is_valid_source_id(sourceId):
+# return JSONResponse('Invalid source id', 400)
- if len(userIds) == 0:
- return JSONResponse('Empty list of user ids', 400)
+# exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId))
- if not is_valid_provider_id(providerId):
- return JSONResponse('Invalid provider id', 400)
+# return JSONResponse('Access updated')
- exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, providerId))
- return JSONResponse('Access updated')
+# @app.post('/updateAccessProvider')
+# @enabled_guard(app)
+# def _(
+# op: Annotated[UpdateAccessOp, Body()],
+# userIds: Annotated[list[str], Body()],
+# providerId: Annotated[str, Body()],
+# ):
+# logger.debug('Update access by provider request', extra={
+# 'op': op,
+# 'user_ids': userIds,
+# 'provider_id': providerId,
+# })
+# if len(userIds) == 0:
+# return JSONResponse('Empty list of user ids', 400)
-@app.post('/deleteSources')
-@enabled_guard(app)
-def _(sourceIds: Annotated[list[str], Body(embed=True)]):
- logger.debug('Delete sources request', extra={
- 'source_ids': sourceIds,
- })
+# if not is_valid_provider_id(providerId):
+# return JSONResponse('Invalid provider id', 400)
- sourceIds = [source.strip() for source in sourceIds if source.strip() != '']
+# exec_in_proc(target=update_access_provider, args=(vectordb_loader, op, userIds, providerId))
- if len(sourceIds) == 0:
- return JSONResponse('No sources provided', 400)
+# return JSONResponse('Access updated')
- res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds))
- if res is False:
- return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)
- return JSONResponse('All valid sources deleted')
+# @app.post('/deleteSources')
+# @enabled_guard(app)
+# def _(sourceIds: Annotated[list[str], Body(embed=True)]):
+# logger.debug('Delete sources request', extra={
+# 'source_ids': sourceIds,
+# })
+
+# sourceIds = [source.strip() for source in sourceIds if source.strip() != '']
+
+# if len(sourceIds) == 0:
+# return JSONResponse('No sources provided', 400)
+
+# res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds))
+# if res is False:
+# return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400)
+
+# return JSONResponse('All valid sources deleted')
+
+# @app.post('/deleteProvider')
+# @enabled_guard(app)
+# def _(providerKey: str = Body(embed=True)):
+# logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey })
-@app.post('/deleteProvider')
-@enabled_guard(app)
-def _(providerKey: str = Body(embed=True)):
- logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey })
+# if value_of(providerKey) is None:
+# return JSONResponse('Invalid provider key provided', 400)
- if value_of(providerKey) is None:
- return JSONResponse('Invalid provider key provided', 400)
+# exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey))
- exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey))
+# return JSONResponse('All valid sources deleted')
+
+
+# @app.post('/deleteUser')
+# @enabled_guard(app)
+# def _(userId: str = Body(embed=True)):
+# logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId })
- return JSONResponse('All valid sources deleted')
+# if value_of(userId) is None:
+# return JSONResponse('Invalid userId provided', 400)
+# exec_in_proc(target=delete_user, args=(vectordb_loader, userId))
-@app.post('/deleteUser')
-@enabled_guard(app)
-def _(userId: str = Body(embed=True)):
- logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId })
+# return JSONResponse('User deleted')
- if value_of(userId) is None:
- return JSONResponse('Invalid userId provided', 400)
- exec_in_proc(target=delete_user, args=(vectordb_loader, userId))
+# @app.put('/loadSources')
+# @enabled_guard(app)
+# def _(sources: list[UploadFile]):
+# global _indexing
+
+# if len(sources) == 0:
+# return JSONResponse('No sources provided', 400)
+
+# for source in sources:
+# if not value_of(source.filename):
+# return JSONResponse(f'Invalid source filename for: {source.headers.get("title")}', 400)
+
+# with index_lock:
+# if source.filename in _indexing:
+# # this request will be retried by the client
+# return JSONResponse(
+# f'This source ({source.filename}) is already being processed in another request, try again later',
+# 503,
+# headers={'cc-retry': 'true'},
+# )
- return JSONResponse('User deleted')
+# if not (
+# value_of(source.headers.get('userIds'))
+# and source.headers.get('title', None) is not None
+# and value_of(source.headers.get('type'))
+# and value_of(source.headers.get('modified'))
+# and source.headers['modified'].isdigit()
+# and value_of(source.headers.get('provider'))
+# ):
+# logger.error('Invalid/missing headers received', extra={
+# 'source_id': source.filename,
+# 'title': source.headers.get('title'),
+# 'headers': source.headers,
+# })
+# return JSONResponse(f'Invaild/missing headers for: {source.filename}', 400)
+
+# # wait for 10 minutes before failing the request
+# semres = doc_parse_semaphore.acquire(block=True, timeout=10*60)
+# if not semres:
+# return JSONResponse(
+# 'Document parser worker limit reached, try again in some time or consider increasing the limit',
+# 503,
+# headers={'cc-retry': 'true'}
+# )
+# with index_lock:
+# for source in sources:
+# _indexing[source.filename] = source.size
-@app.post('/countIndexedDocuments')
-@enabled_guard(app)
-def _():
- counts = exec_in_proc(target=count_documents_by_provider, args=(vectordb_loader,))
- return JSONResponse(counts)
-
-
-@app.put('/loadSources')
-@enabled_guard(app)
-def _(sources: list[UploadFile]):
- global _indexing
-
- if len(sources) == 0:
- return JSONResponse('No sources provided', 400)
-
- for source in sources:
- if not value_of(source.filename):
- return JSONResponse(f'Invalid source filename for: {source.headers.get("title")}', 400)
-
- with index_lock:
- if source.filename in _indexing:
- # this request will be retried by the client
- return JSONResponse(
- f'This source ({source.filename}) is already being processed in another request, try again later',
- 503,
- headers={'cc-retry': 'true'},
- )
-
- if not (
- value_of(source.headers.get('userIds'))
- and source.headers.get('title', None) is not None
- and value_of(source.headers.get('type'))
- and value_of(source.headers.get('modified'))
- and source.headers['modified'].isdigit()
- and value_of(source.headers.get('provider'))
- ):
- logger.error('Invalid/missing headers received', extra={
- 'source_id': source.filename,
- 'title': source.headers.get('title'),
- 'headers': source.headers,
- })
- return JSONResponse(f'Invaild/missing headers for: {source.filename}', 400)
-
- # wait for 10 minutes before failing the request
- semres = doc_parse_semaphore.acquire(block=True, timeout=10*60)
- if not semres:
- return JSONResponse(
- 'Document parser worker limit reached, try again in some time or consider increasing the limit',
- 503,
- headers={'cc-retry': 'true'}
- )
-
- with index_lock:
- for source in sources:
- _indexing[source.filename] = source.size
-
- try:
- loaded_sources, not_added_sources = exec_in_proc(
- target=embed_sources,
- args=(vectordb_loader, app.extra['CONFIG'], sources)
- )
- except (DbException, EmbeddingException):
- raise
- except Exception as e:
- raise DbException('Error: failed to load sources') from e
- finally:
- with index_lock:
- for source in sources:
- _indexing.pop(source.filename, None)
- doc_parse_semaphore.release()
-
- if len(loaded_sources) != len(sources):
- logger.debug('Some sources were not loaded', extra={
- 'Count of loaded sources': f'{len(loaded_sources)}/{len(sources)}',
- 'source_ids': loaded_sources,
- })
-
- # loaded sources include the existing sources that may only have their access updated
- return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources})
+# try:
+# loaded_sources, not_added_sources = exec_in_proc(
+# target=embed_sources,
+# args=(vectordb_loader, app.extra['CONFIG'], sources)
+# )
+# except (DbException, EmbeddingException):
+# raise
+# except Exception as e:
+# raise DbException('Error: failed to load sources') from e
+# finally:
+# with index_lock:
+# for source in sources:
+# _indexing.pop(source.filename, None)
+# doc_parse_semaphore.release()
+
+# if len(loaded_sources) != len(sources):
+# logger.debug('Some sources were not loaded', extra={
+# 'Count of loaded sources': f'{len(loaded_sources)}/{len(sources)}',
+# 'source_ids': loaded_sources,
+# })
+
+# # loaded sources include the existing sources that may only have their access updated
+# return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources})
class Query(BaseModel):
@@ -496,15 +510,3 @@ def _(query: Query) -> list[SearchResult]:
query.scopeType,
query.scopeList,
))
-
-
-@app.get('/downloadLogs')
-def download_logs() -> FileResponse:
- with tempfile.NamedTemporaryFile('wb', delete=False) as tmp:
- with zipfile.ZipFile(tmp, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_file:
- files = os.listdir(os.path.join(persistent_storage(), 'logs'))
- for file in files:
- file_path = os.path.join(persistent_storage(), 'logs', file)
- if os.path.isfile(file_path): # Might be a folder (just skip it then)
- zip_file.write(file_path)
- return FileResponse(tmp.name, media_type='application/zip', filename='docker_logs.zip')
diff --git a/context_chat_backend/chain/ingest/mimetype_list.py b/context_chat_backend/mimetype_list.py
similarity index 64%
rename from context_chat_backend/chain/ingest/mimetype_list.py
rename to context_chat_backend/mimetype_list.py
index 87f10241..ce21e6ea 100644
--- a/context_chat_backend/chain/ingest/mimetype_list.py
+++ b/context_chat_backend/mimetype_list.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-SUPPORTED_MIMETYPES = [
+SUPPORTED_MIMETYPES = (
'text/plain',
'text/markdown',
'application/json',
@@ -22,4 +22,33 @@
'message/rfc822',
'application/vnd.ms-outlook',
'text/org',
-]
+)
+
+IMAGE_MIMETYPES = (
+ 'image/bmp',
+ 'image/bpg',
+ 'image/emf',
+ 'image/gif',
+ 'image/heic',
+ 'image/heif',
+ 'image/jp2',
+ 'image/jpeg',
+ 'image/png',
+ 'image/svg+xml',
+ 'image/tga',
+ 'image/tiff',
+ 'image/webp',
+ 'image/x-dcraw',
+ 'image/x-icon',
+)
+
+AUDIO_MIMETYPES = (
+ 'audio/aac',
+ 'audio/flac',
+ 'audio/mp4',
+ 'audio/mpeg',
+ 'audio/ogg',
+ 'audio/wav',
+ 'audio/webm',
+ 'audio/x-scpls',
+)
diff --git a/context_chat_backend/network_em.py b/context_chat_backend/network_em.py
index 18bb11f4..d39ea56a 100644
--- a/context_chat_backend/network_em.py
+++ b/context_chat_backend/network_em.py
@@ -79,6 +79,7 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
raise FatalEmbeddingException(response.text)
if response.status_code // 100 != 2:
raise EmbeddingException(response.text)
+ # todo: rework exception handling and their downstream interpretation
except FatalEmbeddingException as e:
logger.error('Fatal error while getting embeddings: %s', str(e), exc_info=e)
raise e
@@ -108,10 +109,14 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
logger.error('Unexpected error while getting embeddings', exc_info=e)
raise EmbeddingException('Error: unexpected error while getting embeddings') from e
- # converts TypedDict to a pydantic model
- resp = CreateEmbeddingResponse(**response.json())
- if isinstance(input_, str):
- return resp['data'][0]['embedding']
+ try:
+ # converts TypedDict to a pydantic model
+ resp = CreateEmbeddingResponse(**response.json())
+ if isinstance(input_, str):
+ return resp['data'][0]['embedding']
+ except Exception as e:
+ logger.error('Error parsing embedding response', exc_info=e)
+ raise EmbeddingException('Error: failed to parse embedding response') from e
# only one embedding in d['embedding'] since truncate is True
return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType]
diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py
new file mode 100644
index 00000000..919cfccd
--- /dev/null
+++ b/context_chat_backend/task_fetcher.py
@@ -0,0 +1,569 @@
+#
+# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors
+# SPDX-License-Identifier: AGPL-3.0-or-later
+#
+
+import asyncio
+import logging
+import os
+from contextlib import suppress
+from enum import Enum
+from io import BytesIO
+from threading import Event, Thread
+from time import sleep
+
+import niquests
+from nc_py_api import AsyncNextcloudApp, NextcloudApp
+from pydantic import ValidationError
+
+from .chain.ingest.injest import embed_sources
+from .dyn_loader import VectorDBLoader
+from .types import (
+ ActionsQueueItems,
+ ActionType,
+ AppRole,
+ EmbeddingException,
+ FilesQueueItems,
+ IndexingError,
+ IndexingException,
+ LoaderException,
+ ReceivedFileItem,
+ SourceItem,
+ TConfig,
+)
+from .utils import exec_in_proc, get_app_role
+from .vectordb.service import (
+ decl_update_access,
+ delete_by_provider,
+ delete_by_source,
+ delete_user,
+ update_access,
+ update_access_provider,
+)
+from .vectordb.types import DbException, SafeDbException
+
+APP_ROLE = get_app_role()
+THREADS = {}
+THREAD_STOP_EVENT = Event()
+LOGGER = logging.getLogger('ccb.task_fetcher')
+FILES_INDEXING_BATCH_SIZE = 64 # todo: config?
+# divides the batch into these many chunks
+PARALLEL_FILE_PARSING = max(1, (os.cpu_count() or 2) - 1) # todo: config?
+# max concurrent fetches to avoid overloading the NC server or hitting rate limits
+CONCURRENT_FILE_FETCHES = 10 # todo: config?
+MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, todo: config?
+ACTIONS_BATCH_SIZE = 512 # todo: config?
+POLLING_COOLDOWN = 30
+
+
+class ThreadType(Enum):
+ FILES_INDEXING = 'files_indexing'
+ UPDATES_PROCESSING = 'updates_processing'
+ REQUEST_PROCESSING = 'request_processing'
+
+
+async def __fetch_file_content(
+ semaphore: asyncio.Semaphore,
+ file_id: int,
+ user_id: str,
+ _rlimit = 3,
+) -> BytesIO:
+ '''
+ Raises
+ ------
+ IndexingException
+ '''
+
+ async with semaphore:
+ nc = AsyncNextcloudApp()
+ try:
+ # a file pointer for storing the stream in memory until it is consumed
+ fp = BytesIO()
+ await nc._session.download2fp(
+ url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}',
+ fp=fp,
+ dav=False,
+ params={ 'userId': user_id },
+ )
+ return fp
+ except niquests.exceptions.RequestException as e:
+ # todo: raise IndexingException with retryable=True for rate limit errors,
+ # todo: and handle it in the caller to not delete the source from the queue and retry later through
+ # todo: the normal lock expiry mechanism
+ if e.response is None:
+ raise
+
+ if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue]
+ # todo: implement rate limits in php CC?
+ wait_for = int(e.response.headers.get('Retry-After', '30'))
+ if _rlimit <= 0:
+ raise IndexingException(
+ f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
+ ' max retries exceeded',
+ retryable=True,
+ ) from e
+ LOGGER.warning(
+ f'Rate limited when fetching content for file id {file_id}, user id {user_id},'
+ f' waiting {wait_for} before retrying',
+ exc_info=e,
+ )
+ await asyncio.sleep(wait_for)
+ return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1)
+
+ raise
+ except IndexingException:
+ raise
+ except Exception as e:
+ LOGGER.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e)
+ raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e
+
+
+async def __fetch_files_content(
+ files: dict[int, ReceivedFileItem]
+) -> dict[int, SourceItem | IndexingError]:
+ source_items = {}
+ semaphore = asyncio.Semaphore(CONCURRENT_FILE_FETCHES)
+ tasks = []
+
+ for db_id, file in files.items():
+ try:
+ # to detect any validation errors but it should not happen since file.reference is validated
+ file.file_id # noqa: B018
+ except ValueError as e:
+ LOGGER.error(
+ f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}',
+ exc_info=e,
+ )
+ source_items[db_id] = IndexingError(
+ error=f'Invalid file reference format: {file.reference}',
+ retryable=False,
+ )
+ continue
+
+ if file.size > MAX_FILE_SIZE:
+ LOGGER.info(
+ f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size'
+ f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB',
+ )
+ source_items[db_id] = IndexingError(
+ error=(
+ f'File size {(file.size/(1024*1024)):.2f} MiB'
+ f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB'
+ ),
+ retryable=False,
+ )
+ continue
+ # todo: perform the existing file check before fetching the content to avoid unnecessary fetches
+ # any user id from the list should have read access to the file
+ tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0])))
+
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ for (db_id, file), result in zip(files.items(), results, strict=True):
+ if isinstance(result, IndexingException):
+ LOGGER.error(
+ f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
+ f': {result}',
+ exc_info=result,
+ )
+ source_items[db_id] = IndexingError(
+ error=str(result),
+ retryable=result.retryable,
+ )
+ elif isinstance(result, str) or isinstance(result, BytesIO):
+ source_items[db_id] = SourceItem(
+ **{
+ **file.model_dump(),
+ 'content': result,
+ }
+ )
+ elif isinstance(result, BaseException):
+ LOGGER.error(
+ f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},'
+ f' reference {file.reference}: {result}',
+ exc_info=result,
+ )
+ source_items[db_id] = IndexingError(
+ error=f'Unexpected error: {result}',
+ retryable=True,
+ )
+ else:
+ LOGGER.error(
+ f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}'
+ f': {result}',
+ exc_info=True,
+ )
+ source_items[db_id] = IndexingError(
+ error='Unknown error',
+ retryable=True,
+ )
+ return source_items
+
+
+def files_indexing_thread(app_config: TConfig, app_enabled: Event) -> None:
+ try:
+ vectordb_loader = VectorDBLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e)
+ return
+
+ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingError | None]:
+ try:
+ return exec_in_proc(
+ target=embed_sources,
+ args=(vectordb_loader, app_config, source_items),
+ )
+ except (DbException, EmbeddingException):
+ raise
+ except Exception as e:
+ raise DbException('Error: failed to load sources') from e
+
+
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Files indexing thread is stopping due to stop event being set')
+ return
+
+ try:
+ nc = NextcloudApp()
+ # todo: add the 'size' param to the return of this call.
+ q_items_res = nc.ocs(
+ 'GET',
+ '/ocs/v2.php/apps/context_chat/queues/documents',
+ params={ 'n': FILES_INDEXING_BATCH_SIZE }
+ )
+
+ try:
+ q_items: FilesQueueItems = FilesQueueItems.model_validate(q_items_res)
+ except ValidationError as e:
+ raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e
+
+ if not q_items.files and not q_items.content_providers:
+ LOGGER.debug('No documents to index')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ LOGGER.debug(f'Fetched {len(q_items.files)} files and {len(q_items.content_providers)} content providers')
+ # populate files content and convert to source items
+ fetched_files = {}
+ source_files = {}
+ # unified error structure for files and content providers
+ source_errors = {}
+
+ if q_items.files:
+ fetched_files = asyncio.run(__fetch_files_content(q_items.files))
+
+ for db_id, result in fetched_files.items():
+ if isinstance(result, SourceItem):
+ source_files[db_id] = result
+ else:
+ source_errors[db_id] = result
+
+ files_result = {}
+ providers_result = {}
+ chunk_size = FILES_INDEXING_BATCH_SIZE // PARALLEL_FILE_PARSING
+
+ # todo: do it in asyncio, it's not truly parallel yet
+ # chunk file parsing for better file operation parallelism
+ for i in range(0, len(source_files), chunk_size):
+ chunk = dict(list(source_files.items())[i:i+chunk_size])
+ files_result.update(_load_sources(chunk))
+
+ for i in range(0, len(q_items.content_providers), chunk_size):
+ chunk = dict(list(q_items.content_providers.items())[i:i+chunk_size])
+ providers_result.update(_load_sources(chunk))
+
+ if (
+ any(isinstance(res, IndexingError) for res in files_result.values())
+ or any(isinstance(res, IndexingError) for res in providers_result.values())
+ ):
+ LOGGER.error('Some sources failed to index', extra={
+ 'file_errors': {
+ db_id: error
+ for db_id, error in files_result.items()
+ if isinstance(error, IndexingError)
+ },
+ 'provider_errors': {
+ provider_id: error
+ for provider_id, error in providers_result.items()
+ if isinstance(error, IndexingError)
+ },
+ })
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error fetching documents to index, will retry:', exc_info=e)
+ sleep(5)
+ continue
+ except Exception as e:
+ LOGGER.exception('Error fetching documents to index:', exc_info=e)
+ sleep(5)
+ continue
+
+ # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable
+ to_delete_files_db_ids = [
+ db_id for db_id, result in files_result.items()
+ if result is None or (isinstance(result, IndexingError) and not result.retryable)
+ ]
+ to_delete_provider_db_ids = [
+ db_id for db_id, result in providers_result.items()
+ if result is None or (isinstance(result, IndexingError) and not result.retryable)
+ ]
+
+ try:
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/documents/',
+ json={
+ 'files': to_delete_files_db_ids,
+ 'content_providers': to_delete_provider_db_ids,
+ },
+ )
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error reporting indexing results, will retry:', exc_info=e)
+ sleep(5)
+ with suppress(Exception):
+ nc = NextcloudApp()
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/documents/',
+ json={
+ 'files': to_delete_files_db_ids,
+ 'content_providers': to_delete_provider_db_ids,
+ },
+ )
+ continue
+ except Exception as e:
+ LOGGER.exception('Error reporting indexing results:', exc_info=e)
+ sleep(5)
+ continue
+
+
+
+def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None:
+ try:
+ vectordb_loader = VectorDBLoader(app_config)
+ except LoaderException as e:
+ LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e)
+ return
+
+ while True:
+ if THREAD_STOP_EVENT.is_set():
+ LOGGER.info('Updates processing thread is stopping due to stop event being set')
+ return
+
+ try:
+ nc = NextcloudApp()
+ q_items_res = nc.ocs(
+ 'GET',
+ '/ocs/v2.php/apps/context_chat/queues/actions',
+ params={ 'n': ACTIONS_BATCH_SIZE }
+ )
+
+ try:
+ q_items: ActionsQueueItems = ActionsQueueItems.model_validate(q_items_res)
+ except ValidationError as e:
+ raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error fetching updates to process, will retry:', exc_info=e)
+ sleep(5)
+ continue
+ except Exception as e:
+ LOGGER.exception('Error fetching updates to process:', exc_info=e)
+ sleep(5)
+ continue
+
+ if not q_items.actions:
+ LOGGER.debug('No updates to process')
+ sleep(POLLING_COOLDOWN)
+ continue
+
+ LOGGER.debug(f'Fetched {len(q_items.actions)} updates')
+ processed_event_ids = []
+ errored_events = {}
+ for i, (db_id, action_item) in enumerate(q_items.actions.items()):
+ try:
+ match action_item.type:
+ case ActionType.DELETE_SOURCE_IDS:
+ exec_in_proc(target=delete_by_source, args=(vectordb_loader, action_item.payload.sourceIds))
+
+ case ActionType.DELETE_PROVIDER_ID:
+ exec_in_proc(target=delete_by_provider, args=(vectordb_loader, action_item.payload.providerId))
+
+ case ActionType.DELETE_USER_ID:
+ exec_in_proc(target=delete_user, args=(vectordb_loader, action_item.payload.userId))
+
+ case ActionType.UPDATE_ACCESS_SOURCE_ID:
+ exec_in_proc(
+ target=update_access,
+ args=(
+ vectordb_loader,
+ action_item.payload.op,
+ action_item.payload.userIds,
+ action_item.payload.sourceId,
+ ),
+ )
+
+ case ActionType.UPDATE_ACCESS_PROVIDER_ID:
+ exec_in_proc(
+ target=update_access_provider,
+ args=(
+ vectordb_loader,
+ action_item.payload.op,
+ action_item.payload.userIds,
+ action_item.payload.providerId,
+ ),
+ )
+
+ case ActionType.UPDATE_ACCESS_DECL_SOURCE_ID:
+ exec_in_proc(
+ target=decl_update_access,
+ args=(
+ vectordb_loader,
+ action_item.payload.userIds,
+ action_item.payload.sourceId,
+ ),
+ )
+
+ case _:
+ LOGGER.warning(
+ f'Unknown action type {action_item.type} for action id {db_id},'
+ f' type {action_item.type}, skipping and marking as processed',
+ extra={ 'action_item': action_item },
+ )
+ continue
+
+ processed_event_ids.append(db_id)
+ except SafeDbException as e:
+ LOGGER.debug(
+ f'Safe DB error thrown while processing action id {db_id}, type {action_item.type},'
+ " it's safe to ignore and mark as processed.",
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ processed_event_ids.append(db_id)
+ continue
+
+ except (LoaderException, DbException) as e:
+ LOGGER.error(
+ f'Error deleting source for action id {db_id}, type {action_item.type}: {e}',
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ errored_events[db_id] = str(e)
+ continue
+
+ except Exception as e:
+ LOGGER.error(
+ f'Unexpected error processing action id {db_id}, type {action_item.type}: {e}',
+ exc_info=e,
+ extra={ 'action_item': action_item },
+ )
+ errored_events[db_id] = f'Unexpected error: {e}'
+ continue
+
+ if (i + 1) % 20 == 0:
+ LOGGER.debug(f'Processed {i + 1} updates, sleeping for a bit to allow other operations to proceed')
+ sleep(2)
+
+ LOGGER.info(f'Processed {len(processed_event_ids)} updates with {len(errored_events)} errors', extra={
+ 'errored_events': errored_events,
+ })
+
+ if len(processed_event_ids) == 0:
+ LOGGER.debug('No updates processed, skipping reporting to the server')
+ continue
+
+ try:
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/actions/',
+ json={ 'actions': processed_event_ids },
+ )
+ except (
+ niquests.exceptions.ConnectionError,
+ niquests.exceptions.Timeout,
+ ) as e:
+ LOGGER.info('Temporary error reporting processed updates, will retry:', exc_info=e)
+ sleep(5)
+ with suppress(Exception):
+ nc = NextcloudApp()
+ nc.ocs(
+ 'DELETE',
+ '/ocs/v2.php/apps/context_chat/queues/actions/',
+ json={ 'ids': processed_event_ids },
+ )
+ continue
+ except Exception as e:
+ LOGGER.exception('Error reporting processed updates:', exc_info=e)
+ sleep(5)
+ continue
+
+
+def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None:
+ ...
+
+
+def start_bg_threads(app_config: TConfig, app_enabled: Event):
+ match APP_ROLE:
+ case AppRole.INDEXING | AppRole.NORMAL:
+ if (
+ ThreadType.FILES_INDEXING in THREADS
+ or ThreadType.UPDATES_PROCESSING in THREADS
+ ):
+ LOGGER.info('Background threads already running, skipping start')
+ return
+
+ THREAD_STOP_EVENT.clear()
+ THREADS[ThreadType.FILES_INDEXING] = Thread(
+ target=files_indexing_thread,
+ args=(app_config, app_enabled),
+ name='FilesIndexingThread',
+ )
+ THREADS[ThreadType.UPDATES_PROCESSING] = Thread(
+ target=updates_processing_thread,
+ args=(app_config, app_enabled),
+ name='UpdatesProcessingThread',
+ )
+ THREADS[ThreadType.FILES_INDEXING].start()
+ THREADS[ThreadType.UPDATES_PROCESSING].start()
+
+ case AppRole.RP | AppRole.NORMAL:
+ if ThreadType.REQUEST_PROCESSING in THREADS:
+ LOGGER.info('Background threads already running, skipping start')
+ return
+
+ THREAD_STOP_EVENT.clear()
+ THREADS[ThreadType.REQUEST_PROCESSING] = Thread(
+ target=request_processing_thread,
+ args=(app_config, app_enabled),
+ name='RequestProcessingThread',
+ )
+ THREADS[ThreadType.REQUEST_PROCESSING].start()
+
+
+def wait_for_bg_threads():
+ match APP_ROLE:
+ case AppRole.INDEXING | AppRole.NORMAL:
+ if (ThreadType.FILES_INDEXING not in THREADS or ThreadType.UPDATES_PROCESSING not in THREADS):
+ return
+
+ THREAD_STOP_EVENT.set()
+ THREADS[ThreadType.FILES_INDEXING].join()
+ THREADS[ThreadType.UPDATES_PROCESSING].join()
+ THREADS.pop(ThreadType.FILES_INDEXING)
+ THREADS.pop(ThreadType.UPDATES_PROCESSING)
+
+ case AppRole.RP | AppRole.NORMAL:
+ if (ThreadType.REQUEST_PROCESSING not in THREADS):
+ return
+
+ THREAD_STOP_EVENT.set()
+ THREADS[ThreadType.REQUEST_PROCESSING].join()
+ THREADS.pop(ThreadType.REQUEST_PROCESSING)
diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py
index 500a97d0..fa710b17 100644
--- a/context_chat_backend/types.py
+++ b/context_chat_backend/types.py
@@ -2,7 +2,14 @@
# SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors
# SPDX-License-Identifier: AGPL-3.0-or-later
#
-from pydantic import BaseModel
+import re
+from enum import Enum
+from io import BytesIO
+from typing import Annotated, Literal
+
+from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator
+
+from .vectordb.types import UpdateAccessOp
__all__ = [
'DEFAULT_EM_MODEL_ALIAS',
@@ -15,6 +22,66 @@
]
DEFAULT_EM_MODEL_ALIAS = 'em_model'
+FILES_PROVIDER_ID = 'files__default'
+
+
+def is_valid_source_id(source_id: str) -> bool:
+ # note the ":" in the item id part
+ return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None
+
+
+def is_valid_provider_id(provider_id: str) -> bool:
+ return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None
+
+
+def _validate_source_ids(source_ids: list[str]) -> list[str]:
+ # todo: use is_valid_source_id()
+ if (
+ not isinstance(source_ids, list)
+ or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids)
+ or len(source_ids) == 0
+ ):
+ raise ValueError('sourceIds must be a non-empty list of non-empty strings')
+ return [sid.strip() for sid in source_ids]
+
+
+def _validate_source_id(source_id: str) -> str:
+ return _validate_source_ids([source_id])[0]
+
+
+def _validate_provider_id(provider_id: str) -> str:
+ if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id):
+ raise ValueError('providerId must be a valid provider ID string')
+ return provider_id
+
+
+def _validate_user_ids(user_ids: list[str]) -> list[str]:
+ if (
+ not isinstance(user_ids, list)
+ or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids)
+ or len(user_ids) == 0
+ ):
+ raise ValueError('userIds must be a non-empty list of non-empty strings')
+ return [uid.strip() for uid in user_ids]
+
+
+def _validate_user_id(user_id: str) -> str:
+ return _validate_user_ids([user_id])[0]
+
+
+def _get_file_id_from_source_ref(source_ref: str) -> int:
+ '''
+ source reference is in the format "FILES_PROVIDER_ID: ".
+ '''
+ if not source_ref.startswith(f'{FILES_PROVIDER_ID}: '):
+ raise ValueError(f'Source reference does not start with expected prefix: {source_ref}')
+
+ try:
+ return int(source_ref[len(f'{FILES_PROVIDER_ID}: '):])
+ except ValueError as e:
+ raise ValueError(
+ f'Invalid source reference format for extracting file_id: {source_ref}'
+ ) from e
class TEmbeddingAuthApiKey(BaseModel):
@@ -71,3 +138,216 @@ class FatalEmbeddingException(EmbeddingException):
Either malformed request, authentication error, or other non-retryable error.
"""
+
+
+class AppRole(str, Enum):
+ NORMAL = 'normal'
+ INDEXING = 'indexing'
+ RP = 'rp'
+
+
+class CommonSourceItem(BaseModel):
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ # source_id of the form "appId__providerId: itemId"
+ reference: Annotated[str, AfterValidator(_validate_source_id)]
+ title: str
+ modified: int
+ type: str
+ provider: Annotated[str, AfterValidator(_validate_provider_id)]
+ size: float
+
+ @field_validator('modified', mode='before')
+ @classmethod
+ def validate_modified(cls, v):
+ if isinstance(v, int):
+ return v
+ if isinstance(v, str):
+ try:
+ return int(v)
+ except ValueError as e:
+ raise ValueError(f'Invalid modified value: {v}') from e
+ raise ValueError(f'Invalid modified type: {type(v)}')
+
+ @field_validator('reference', 'title', 'type', 'provider')
+ @classmethod
+ def validate_strings_non_empty(cls, v):
+ if not isinstance(v, str) or v.strip() == '':
+ raise ValueError('Must be a non-empty string')
+ return v.strip()
+
+ @field_validator('size')
+ @classmethod
+ def validate_size(cls, v):
+ if isinstance(v, int | float) and v >= 0:
+ return float(v)
+ raise ValueError(f'Invalid size value: {v}, must be a non-negative number')
+
+
+class ReceivedFileItem(CommonSourceItem):
+ content: None
+
+ @computed_field
+ @property
+ def file_id(self) -> int:
+ return _get_file_id_from_source_ref(self.reference)
+
+
+class SourceItem(CommonSourceItem):
+ '''
+ Used for the unified queue of items to process, after fetching the content for files
+ and for directly fetched content providers.
+ '''
+ content: str | BytesIO
+
+ @computed_field
+ @property
+ def file_id(self) -> int:
+ return _get_file_id_from_source_ref(self.reference)
+
+ @field_validator('content')
+ @classmethod
+ def validate_content(cls, v):
+ if isinstance(v, str):
+ if v.strip() == '':
+ raise ValueError('Content must be a non-empty string')
+ return v.strip()
+ if isinstance(v, BytesIO):
+ if v.getbuffer().nbytes == 0:
+ raise ValueError('Content must be a non-empty BytesIO')
+ return v
+ raise ValueError('Content must be either a non-empty string or a non-empty BytesIO')
+
+ class Config:
+ # to allow BytesIO in content field
+ arbitrary_types_allowed = True
+
+
+class FilesQueueItems(BaseModel):
+ files: dict[int, ReceivedFileItem] # [db id]: FileItem
+ content_providers: dict[int, SourceItem] # [db id]: SourceItem
+
+
+class IndexingException(Exception):
+ retryable: bool = False
+
+ def __init__(self, message: str, retryable: bool = False):
+ super().__init__(message)
+ self.retryable = retryable
+
+
+class IndexingError(BaseModel):
+ error: str
+ retryable: bool = False
+
+
+# PHP equivalent for reference:
+
+# class ActionType {
+# // { sourceIds: array }
+# public const DELETE_SOURCE_IDS = 'delete_source_ids';
+# // { providerId: string }
+# public const DELETE_PROVIDER_ID = 'delete_provider_id';
+# // { userId: string }
+# public const DELETE_USER_ID = 'delete_user_id';
+# // { op: string, userIds: array, sourceId: string }
+# public const UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id';
+# // { op: string, userIds: array, providerId: string }
+# public const UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id';
+# // { userIds: array, sourceId: string }
+# public const UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id';
+# }
+
+
+class ActionPayloadDeleteSourceIds(BaseModel):
+ sourceIds: Annotated[list[str], AfterValidator(_validate_source_ids)]
+
+
+class ActionPayloadDeleteProviderId(BaseModel):
+ providerId: Annotated[str, AfterValidator(_validate_provider_id)]
+
+
+class ActionPayloadDeleteUserId(BaseModel):
+ userId: Annotated[str, AfterValidator(_validate_user_id)]
+
+
+class ActionPayloadUpdateAccessSourceId(BaseModel):
+ op: UpdateAccessOp
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ sourceId: Annotated[str, AfterValidator(_validate_source_id)]
+
+
+class ActionPayloadUpdateAccessProviderId(BaseModel):
+ op: UpdateAccessOp
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ providerId: Annotated[str, AfterValidator(_validate_provider_id)]
+
+
+class ActionPayloadUpdateAccessDeclSourceId(BaseModel):
+ userIds: Annotated[list[str], AfterValidator(_validate_user_ids)]
+ sourceId: Annotated[str, AfterValidator(_validate_source_id)]
+
+
+class ActionType(str, Enum):
+ DELETE_SOURCE_IDS = 'delete_source_ids'
+ DELETE_PROVIDER_ID = 'delete_provider_id'
+ DELETE_USER_ID = 'delete_user_id'
+ UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id'
+ UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id'
+ UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id'
+
+
+class CommonActionsQueueItem(BaseModel):
+ id: int
+
+
+class ActionsQueueItemDeleteSourceIds(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_SOURCE_IDS]
+ payload: ActionPayloadDeleteSourceIds
+
+
+class ActionsQueueItemDeleteProviderId(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_PROVIDER_ID]
+ payload: ActionPayloadDeleteProviderId
+
+
+class ActionsQueueItemDeleteUserId(CommonActionsQueueItem):
+ type: Literal[ActionType.DELETE_USER_ID]
+ payload: ActionPayloadDeleteUserId
+
+
+class ActionsQueueItemUpdateAccessSourceId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_SOURCE_ID]
+ payload: ActionPayloadUpdateAccessSourceId
+
+
+class ActionsQueueItemUpdateAccessProviderId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_PROVIDER_ID]
+ payload: ActionPayloadUpdateAccessProviderId
+
+
+class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem):
+ type: Literal[ActionType.UPDATE_ACCESS_DECL_SOURCE_ID]
+ payload: ActionPayloadUpdateAccessDeclSourceId
+
+
+ActionsQueueItem = Annotated[
+ ActionsQueueItemDeleteSourceIds
+ | ActionsQueueItemDeleteProviderId
+ | ActionsQueueItemDeleteUserId
+ | ActionsQueueItemUpdateAccessSourceId
+ | ActionsQueueItemUpdateAccessProviderId
+ | ActionsQueueItemUpdateAccessDeclSourceId,
+ Discriminator('type'),
+]
+
+
+class ActionsQueueItems(BaseModel):
+ actions: dict[int, ActionsQueueItem]
+
+
+class TaskProcException(Exception):
+ ...
+
+
+class TaskProcFatalException(TaskProcException):
+ ...
diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py
index f6d6e672..d3a7bfd1 100644
--- a/context_chat_backend/utils.py
+++ b/context_chat_backend/utils.py
@@ -4,17 +4,17 @@
#
import logging
import multiprocessing as mp
-import re
+import os
import traceback
from collections.abc import Callable
from functools import partial, wraps
from multiprocessing.connection import Connection
-from time import perf_counter_ns
+from time import perf_counter_ns, time
from typing import Any, TypeGuard, TypeVar
from fastapi.responses import JSONResponse as FastAPIJSONResponse
-from .types import TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig
+from .types import AppRole, TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig
T = TypeVar('T')
_logger = logging.getLogger('ccb.utils')
@@ -101,15 +101,6 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem
return result['value']
-def is_valid_source_id(source_id: str) -> bool:
- # note the ":" in the item id part
- return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None
-
-
-def is_valid_provider_id(provider_id: str) -> bool:
- return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None
-
-
def timed(func: Callable):
'''
Decorator to time a function
@@ -144,3 +135,30 @@ def redact_config(config: TConfig | TEmbeddingConfig) -> TConfig | TEmbeddingCon
em_conf.auth.password = '***REDACTED***' # noqa: S105
return config_copy
+
+
+def get_app_role() -> AppRole:
+ role = os.getenv('APP_ROLE', '').lower()
+ if role == '':
+ return AppRole.NORMAL
+ if role not in ['indexing', 'rp']:
+ _logger.warning(f'Invalid app role: {role}, defaulting to all roles')
+ return AppRole.NORMAL
+ return AppRole(role)
+
+
+# does not support caching of kwargs for recall
+def timed_cache_async(ttl: int):
+ def decorator(fn: Callable):
+ cached_store: dict[tuple, tuple[float, Any]] = {}
+ @wraps(fn)
+ async def wrapper(*args, **kwargs):
+ if args in cached_store:
+ cached_time, cached_value = cached_store[args]
+ if (time() - cached_time) < ttl:
+ return cached_value
+ new_val = await fn(*args, **kwargs)
+ cached_store[args] = (time(), new_val)
+ return new_val
+ return wrapper
+ return decorator
diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py
index 0bf10200..ebd54075 100644
--- a/context_chat_backend/vectordb/base.py
+++ b/context_chat_backend/vectordb/base.py
@@ -5,12 +5,12 @@
from abc import ABC, abstractmethod
from typing import Any
-from fastapi import UploadFile
from langchain.schema import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from ..chain.types import InDocument, ScopeType
+from ..types import IndexingError, SourceItem
from ..utils import timed
from .types import UpdateAccessOp
@@ -62,7 +62,7 @@ def get_instance(self) -> VectorStore:
'''
@abstractmethod
- def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list[str]]:
+ def add_indocuments(self, indocuments: dict[int, InDocument]) -> dict[int, IndexingError | None]:
'''
Adds the given indocuments to the vectordb and updates the docs + access tables.
@@ -79,10 +79,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list
@timed
@abstractmethod
- def check_sources(
- self,
- sources: list[UploadFile],
- ) -> tuple[list[str], list[str]]:
+ def check_sources(self, sources: dict[int, SourceItem]) -> tuple[list[str], list[str]]:
'''
Checks the sources in the vectordb if they are already embedded
and are up to date.
diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py
index f40390fe..8f67c936 100644
--- a/context_chat_backend/vectordb/pgvector.py
+++ b/context_chat_backend/vectordb/pgvector.py
@@ -10,14 +10,13 @@
import sqlalchemy.dialects.postgresql as postgresql_dialects
import sqlalchemy.orm as orm
from dotenv import load_dotenv
-from fastapi import UploadFile
from langchain.schema import Document
from langchain.vectorstores import VectorStore
from langchain_core.embeddings import Embeddings
from langchain_postgres.vectorstores import Base, PGVector
from ..chain.types import InDocument, ScopeType
-from ..types import EmbeddingException, RetryableEmbeddingException
+from ..types import EmbeddingException, FatalEmbeddingException, IndexingError, RetryableEmbeddingException, SourceItem
from ..utils import timed
from .base import BaseVectorDB
from .types import DbException, SafeDbException, UpdateAccessOp
@@ -129,17 +128,16 @@ def get_users(self) -> list[str]:
except Exception as e:
raise DbException('Error: getting a list of all users from access list') from e
- def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]:
+ def add_indocuments(self, indocuments: dict[int, InDocument]) -> dict[int, IndexingError | None]:
"""
Raises
EmbeddingException: if the embedding request definitively fails
"""
- added_sources = []
- retry_sources = []
+ results = {}
batch_size = PG_BATCH_SIZE // 5
with self.session_maker() as session:
- for indoc in indocuments:
+ for php_db_id, indoc in indocuments.items():
try:
# query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html)
# so we chunk the documents into (5 values * 10k) chunks
@@ -158,7 +156,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis
session.commit()
self.decl_update_access(indoc.userIds, indoc.source_id, session)
- added_sources.append(indoc.source_id)
+ results[php_db_id] = None
session.commit()
except SafeDbException as e:
# for when the source_id is not found. This here can be an error in the DB
@@ -166,51 +164,62 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis
logger.exception('Error adding documents to vectordb', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=True,
+ )
continue
- except RetryableEmbeddingException as e:
+ except FatalEmbeddingException as e:
+ raise EmbeddingException(
+ f'Fatal error while embedding documents for source {indoc.source_id}: {e}'
+ ) from e
+ except (RetryableEmbeddingException, EmbeddingException) as e:
# temporary error, continue with the next document
logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error=str(e),
+ retryable=True,
+ )
continue
- except EmbeddingException as e:
- logger.exception('Error adding documents to vectordb', exc_info=e, extra={
- 'source_id': indoc.source_id,
- })
- raise
except Exception as e:
logger.exception('Error adding documents to vectordb', exc_info=e, extra={
'source_id': indoc.source_id,
})
- retry_sources.append(indoc.source_id)
+ results[php_db_id] = IndexingError(
+ error='An unexpected error occurred while adding documents to the database.',
+ retryable=True,
+ )
continue
- return added_sources, retry_sources
+ return results
@timed
- def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]]:
+ def check_sources(self, sources: dict[int, SourceItem]) -> tuple[list[str], list[str]]:
+ '''
+ returns a tuple of (existing_source_ids, to_embed_source_ids)
+ '''
with self.session_maker() as session:
try:
stmt = (
sa.select(DocumentsStore.source_id)
- .filter(DocumentsStore.source_id.in_([source.filename for source in sources]))
+ .filter(DocumentsStore.source_id.in_([source.reference for source in sources.values()]))
.with_for_update()
)
results = session.execute(stmt).fetchall()
existing_sources = {r.source_id for r in results}
- to_embed = [source.filename for source in sources if source.filename not in existing_sources]
+ to_embed = [source.reference for source in sources.values() if source.reference not in existing_sources]
to_delete = []
- for source in sources:
+ for source in sources.values():
stmt = (
sa.select(DocumentsStore.source_id)
- .filter(DocumentsStore.source_id == source.filename)
+ .filter(DocumentsStore.source_id == source.reference)
.filter(DocumentsStore.modified < sa.cast(
- datetime.fromtimestamp(int(source.headers['modified'])),
+ datetime.fromtimestamp(int(source.modified)),
sa.DateTime,
))
)
@@ -227,14 +236,13 @@ def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]
session.rollback()
raise DbException('Error: checking sources in vectordb') from e
- still_existing_sources = [
- source
- for source in existing_sources
- if source not in to_delete
+ still_existing_source_ids = [
+ source_id
+ for source_id in existing_sources
+ if source_id not in to_delete
]
- # the pyright issue stems from source.filename, which has already been validated
- return list(still_existing_sources), to_embed # pyright: ignore[reportReturnType]
+ return list(still_existing_source_ids), to_embed
def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.Session | None = None):
session = session_ or self.session_maker()
@@ -311,7 +319,7 @@ def update_access(
)
match op:
- case UpdateAccessOp.allow:
+ case UpdateAccessOp.ALLOW:
stmt = (
postgresql_dialects.insert(AccessListStore)
.values([
@@ -326,7 +334,7 @@ def update_access(
session.execute(stmt)
session.commit()
- case UpdateAccessOp.deny:
+ case UpdateAccessOp.DENY:
stmt = (
sa.delete(AccessListStore)
.filter(AccessListStore.uid.in_(user_ids))
@@ -414,11 +422,12 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
stmt_doc = (
sa.delete(DocumentsStore)
.filter(DocumentsStore.source_id.in_(source_ids))
- .returning(DocumentsStore.chunks)
+ .returning(DocumentsStore.chunks, DocumentsStore.source_id)
)
doc_result = session.execute(stmt_doc)
chunks_to_delete = [str(c) for res in doc_result for c in res.chunks]
+ deleted_source_ids = [str(res.source_id) for res in doc_result]
except Exception as e:
session.rollback()
if session_ is None:
@@ -444,6 +453,14 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None
if session_ is None:
session.close()
+ undeleted_source_ids = set(source_ids) - set(deleted_source_ids)
+ if len(undeleted_source_ids) > 0:
+ logger.info(
+ f'Source ids {undeleted_source_ids} were not deleted from documents store.'
+ ' This can be due to the source ids not existing in the documents store due to'
+ ' already being deleted or not having been added yet.'
+ )
+
def delete_provider(self, provider_key: str):
with self.session_maker() as session:
try:
@@ -457,6 +474,10 @@ def delete_provider(self, provider_key: str):
doc_result = session.execute(stmt)
chunks_to_delete = [str(c) for res in doc_result for c in res.chunks]
+
+ if len(chunks_to_delete) == 0:
+ logger.info(f'No documents found for provider {provider_key} when attempting to delete provider.')
+ return
except Exception as e:
session.rollback()
raise DbException('Error: deleting provider from docs store') from e
@@ -490,7 +511,16 @@ def delete_user(self, user_id: str):
session.rollback()
raise DbException('Error: deleting user from access list') from e
- self._cleanup_if_orphaned(list(source_ids), session)
+ try:
+ self._cleanup_if_orphaned(list(source_ids), session)
+ except Exception as e:
+ session.rollback()
+ logger.error(
+ 'Error cleaning up orphaned source ids after deleting user, manual cleanup might be required',
+ exc_info=e,
+ extra={ 'source_ids': list(source_ids) },
+ )
+ raise DbException('Error: cleaning up orphaned source ids after deleting user') from e
def count_documents_by_provider(self) -> dict[str, int]:
try:
diff --git a/context_chat_backend/vectordb/service.py b/context_chat_backend/vectordb/service.py
index 620a0b39..06a8e19e 100644
--- a/context_chat_backend/vectordb/service.py
+++ b/context_chat_backend/vectordb/service.py
@@ -6,27 +6,42 @@
from ..dyn_loader import VectorDBLoader
from .base import BaseVectorDB
-from .types import DbException, UpdateAccessOp
+from .types import UpdateAccessOp
logger = logging.getLogger('ccb.vectordb')
-# todo: return source ids that were successfully deleted
+
def delete_by_source(vectordb_loader: VectorDBLoader, source_ids: list[str]):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('deleting sources by id', extra={ 'source_ids': source_ids })
- try:
- db.delete_source_ids(source_ids)
- except Exception as e:
- raise DbException('Error: Vectordb delete_source_ids error') from e
+ db.delete_source_ids(source_ids)
def delete_by_provider(vectordb_loader: VectorDBLoader, provider_key: str):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug(f'deleting sources by provider: {provider_key}')
db.delete_provider(provider_key)
def delete_user(vectordb_loader: VectorDBLoader, user_id: str):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug(f'deleting user from db: {user_id}')
db.delete_user(user_id)
@@ -38,6 +53,13 @@ def update_access(
user_ids: list[str],
source_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('updating access', extra={ 'op': op, 'user_ids': user_ids, 'source_id': source_id })
db.update_access(op, user_ids, source_id)
@@ -49,6 +71,13 @@ def update_access_provider(
user_ids: list[str],
provider_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('updating access by provider', extra={ 'op': op, 'user_ids': user_ids, 'provider_id': provider_id })
db.update_access_provider(op, user_ids, provider_id)
@@ -59,11 +88,24 @@ def decl_update_access(
user_ids: list[str],
source_id: str,
):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ SafeDbException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('decl update access', extra={ 'user_ids': user_ids, 'source_id': source_id })
db.decl_update_access(user_ids, source_id)
def count_documents_by_provider(vectordb_loader: VectorDBLoader):
+ '''
+ Raises
+ ------
+ DbException
+ LoaderException
+ '''
db: BaseVectorDB = vectordb_loader.load()
logger.debug('counting documents by provider')
return db.count_documents_by_provider()
diff --git a/context_chat_backend/vectordb/types.py b/context_chat_backend/vectordb/types.py
index df5c6dd7..30811797 100644
--- a/context_chat_backend/vectordb/types.py
+++ b/context_chat_backend/vectordb/types.py
@@ -14,5 +14,5 @@ class SafeDbException(Exception):
class UpdateAccessOp(Enum):
- allow = 'allow'
- deny = 'deny'
+ ALLOW = 'allow'
+ DENY = 'deny'