From 8e53243b5ee160be00ac1dcb493e7153cefc8431 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Thu, 5 Mar 2026 14:14:19 +0530 Subject: [PATCH 01/17] feat: add kubernetes app role selection Signed-off-by: Anupam Kumar --- appinfo/info.xml | 14 ++++++++++++++ context_chat_backend/controller.py | 15 ++++++++------- context_chat_backend/task_fetcher.py | 4 ++++ context_chat_backend/types.py | 8 ++++++++ context_chat_backend/utils.py | 13 ++++++++++++- 5 files changed, 46 insertions(+), 8 deletions(-) create mode 100644 context_chat_backend/task_fetcher.py diff --git a/appinfo/info.xml b/appinfo/info.xml index 9760cd29..30194baa 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -82,5 +82,19 @@ Setup background job workers as described here: https://docs.nextcloud.com/serve Password to be used for authenticating requests to the OpenAI-compatible endpoint set in CC_EM_BASE_URL. + + + rp + Request Processing Mode + APP_ROLE=rp + true + + + indexing + Indexing Mode + APP_ROLE=indexing + false + + diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 0588dabe..eddca6ac 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -75,6 +75,7 @@ def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: if enabled: app_enabled.set() + # todo: start bg threads to fetch docs, updates and requests to process else: app_enabled.clear() @@ -213,6 +214,13 @@ def _(): return JSONResponse(content={'enabled': app_enabled.is_set()}, status_code=200) +@app.post('/countIndexedDocuments') +@enabled_guard(app) +def _(): + counts = exec_in_proc(target=count_documents_by_provider, args=(vectordb_loader,)) + return JSONResponse(counts) + + @app.post('/updateAccessDeclarative') @enabled_guard(app) def _( @@ -328,13 +336,6 @@ def _(userId: str = Body(embed=True)): return JSONResponse('User deleted') -@app.post('/countIndexedDocuments') -@enabled_guard(app) -def _(): - counts = exec_in_proc(target=count_documents_by_provider, args=(vectordb_loader,)) - return JSONResponse(counts) - - @app.put('/loadSources') @enabled_guard(app) def _(sources: list[UploadFile]): diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py new file mode 100644 index 00000000..5e2f317f --- /dev/null +++ b/context_chat_backend/task_fetcher.py @@ -0,0 +1,4 @@ +# +# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors +# SPDX-License-Identifier: AGPL-3.0-or-later +# diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 500a97d0..78680866 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +from enum import Enum + from pydantic import BaseModel __all__ = [ @@ -71,3 +73,9 @@ class FatalEmbeddingException(EmbeddingException): Either malformed request, authentication error, or other non-retryable error. """ + + +class AppRole(str, Enum): + NORMAL = 'normal' + INDEXING = 'indexing' + RP = 'rp' diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py index f6d6e672..224f466e 100644 --- a/context_chat_backend/utils.py +++ b/context_chat_backend/utils.py @@ -4,6 +4,7 @@ # import logging import multiprocessing as mp +import os import re import traceback from collections.abc import Callable @@ -14,7 +15,7 @@ from fastapi.responses import JSONResponse as FastAPIJSONResponse -from .types import TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig +from .types import AppRole, TConfig, TEmbeddingAuthApiKey, TEmbeddingAuthBasic, TEmbeddingConfig T = TypeVar('T') _logger = logging.getLogger('ccb.utils') @@ -144,3 +145,13 @@ def redact_config(config: TConfig | TEmbeddingConfig) -> TConfig | TEmbeddingCon em_conf.auth.password = '***REDACTED***' # noqa: S105 return config_copy + + +def get_app_role() -> AppRole: + role = os.getenv('APP_ROLE', '').lower() + if role == '': + return AppRole.NORMAL + if role not in ['indexing', 'rp']: + _logger.warning(f'Invalid app role: {role}, defaulting to all roles') + return AppRole.NORMAL + return AppRole(role) From c282f3d401150fc4983dda37c78661fbc68e6ebd Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Thu, 5 Mar 2026 16:42:41 +0530 Subject: [PATCH 02/17] feat: add thread start and stop logic Signed-off-by: Anupam Kumar --- context_chat_backend/controller.py | 17 ++++-- context_chat_backend/task_fetcher.py | 82 ++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index eddca6ac..4c07a06e 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -42,6 +42,7 @@ from .models.types import LlmException from nc_py_api.ex_app import AppAPIAuthMiddleware from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of +from .task_fetcher import start_bg_threads, stop_bg_threads from .vectordb.service import ( count_documents_by_provider, decl_update_access, @@ -73,11 +74,16 @@ app_enabled = Event() def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: - if enabled: - app_enabled.set() - # todo: start bg threads to fetch docs, updates and requests to process - else: - app_enabled.clear() + try: + if enabled: + app_enabled.set() + start_bg_threads() + else: + app_enabled.clear() + stop_bg_threads() + except Exception as e: + logger.exception('Error in enabled handler:', exc_info=e) + return f'Error in enabled handler: {e}' logger.info(f'App {("disabled", "enabled")[enabled]}') return '' @@ -95,6 +101,7 @@ async def lifespan(app: FastAPI): yield vectordb_loader.offload() llm_loader.offload() + stop_bg_threads() app_config = get_config(os.environ['CC_CONFIG_PATH']) diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 5e2f317f..9660b44c 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -2,3 +2,85 @@ # SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # + +from enum import Enum +from threading import Thread + +from .types import AppRole +from .utils import get_app_role + +APP_ROLE = get_app_role() +THREADS = {} +THREADS_STOP_EVENTS = {} + + +class ThreadType(Enum): + FILES_INDEXING = 'files_indexing' + UPDATES_PROCESSING = 'updates_processing' + REQUEST_PROCESSING = 'request_processing' + + +def files_indexing_thread(): + ... + + +def updates_processing_thread(): + ... + + +def request_processing_thread(): + ... + + +def start_bg_threads(): + match APP_ROLE: + case AppRole.INDEXING | AppRole.NORMAL: + THREADS[ThreadType.FILES_INDEXING] = Thread( + target=files_indexing_thread, + name='FilesIndexingThread', + daemon=True, + ) + THREADS[ThreadType.UPDATES_PROCESSING] = Thread( + target=updates_processing_thread, + name='UpdatesProcessingThread', + daemon=True, + ) + THREADS[ThreadType.FILES_INDEXING].start() + THREADS[ThreadType.UPDATES_PROCESSING].start() + case AppRole.RP | AppRole.NORMAL: + THREADS[ThreadType.REQUEST_PROCESSING] = Thread( + target=request_processing_thread, + name='RequestProcessingThread', + daemon=True, + ) + THREADS[ThreadType.REQUEST_PROCESSING].start() + + +def stop_bg_threads(): + match APP_ROLE: + case AppRole.INDEXING | AppRole.NORMAL: + if ( + ThreadType.FILES_INDEXING not in THREADS + or ThreadType.UPDATES_PROCESSING not in THREADS + or ThreadType.FILES_INDEXING not in THREADS_STOP_EVENTS + or ThreadType.UPDATES_PROCESSING not in THREADS_STOP_EVENTS + ): + return + THREADS_STOP_EVENTS[ThreadType.FILES_INDEXING].set() + THREADS_STOP_EVENTS[ThreadType.UPDATES_PROCESSING].set() + THREADS[ThreadType.FILES_INDEXING].join() + THREADS[ThreadType.UPDATES_PROCESSING].join() + THREADS.pop(ThreadType.FILES_INDEXING) + THREADS.pop(ThreadType.UPDATES_PROCESSING) + THREADS_STOP_EVENTS.pop(ThreadType.FILES_INDEXING) + THREADS_STOP_EVENTS.pop(ThreadType.UPDATES_PROCESSING) + case AppRole.RP | AppRole.NORMAL: + if ( + ThreadType.REQUEST_PROCESSING not in THREADS + or ThreadType.REQUEST_PROCESSING not in THREADS_STOP_EVENTS + ): + return + THREADS_STOP_EVENTS[ThreadType.REQUEST_PROCESSING].set() + THREADS[ThreadType.REQUEST_PROCESSING].join() + THREADS.pop(ThreadType.REQUEST_PROCESSING) + THREADS_STOP_EVENTS.pop(ThreadType.REQUEST_PROCESSING) From ea5208a30b8b1ce64cf85407bf0ceb63c039b0bb Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Mon, 9 Mar 2026 19:22:45 +0530 Subject: [PATCH 03/17] wip: migrate the indexing process Signed-off-by: Anupam Kumar --- .../chain/ingest/doc_loader.py | 53 +-- context_chat_backend/chain/ingest/injest.py | 201 ++++++----- context_chat_backend/controller.py | 157 +++++---- .../{chain/ingest => }/mimetype_list.py | 0 context_chat_backend/task_fetcher.py | 311 ++++++++++++++++-- context_chat_backend/types.py | 121 ++++++- context_chat_backend/vectordb/base.py | 9 +- context_chat_backend/vectordb/pgvector.py | 61 ++-- 8 files changed, 659 insertions(+), 254 deletions(-) rename context_chat_backend/{chain/ingest => }/mimetype_list.py (100%) diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py index efb81b6d..d26f74b1 100644 --- a/context_chat_backend/chain/ingest/doc_loader.py +++ b/context_chat_backend/chain/ingest/doc_loader.py @@ -7,11 +7,10 @@ import re import tempfile from collections.abc import Callable -from typing import BinaryIO +from io import BytesIO import docx2txt from epub2txt import epub2txt -from fastapi import UploadFile from langchain_unstructured import UnstructuredLoader from odfdo import Document from pandas import read_csv, read_excel @@ -19,9 +18,11 @@ from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError from striprtf import striprtf +from ...types import SourceItem + logger = logging.getLogger('ccb.doc_loader') -def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str: +def _temp_file_wrapper(file: BytesIO, loader: Callable, sep: str = '\n') -> str: raw_bytes = file.read() with tempfile.NamedTemporaryFile(mode='wb') as tmp: tmp.write(raw_bytes) @@ -35,46 +36,46 @@ def _temp_file_wrapper(file: BinaryIO, loader: Callable, sep: str = '\n') -> str # -- LOADERS -- # -def _load_pdf(file: BinaryIO) -> str: +def _load_pdf(file: BytesIO) -> str: pdf_reader = PdfReader(file) return '\n\n'.join([page.extract_text().strip() for page in pdf_reader.pages]) -def _load_csv(file: BinaryIO) -> str: +def _load_csv(file: BytesIO) -> str: return read_csv(file).to_string(header=False, na_rep='') -def _load_epub(file: BinaryIO) -> str: +def _load_epub(file: BytesIO) -> str: return _temp_file_wrapper(file, epub2txt).strip() -def _load_docx(file: BinaryIO) -> str: +def _load_docx(file: BytesIO) -> str: return docx2txt.process(file).strip() -def _load_odt(file: BinaryIO) -> str: +def _load_odt(file: BytesIO) -> str: return _temp_file_wrapper(file, lambda fp: Document(fp).get_formatted_text()).strip() -def _load_ppt_x(file: BinaryIO) -> str: +def _load_ppt_x(file: BytesIO) -> str: return _temp_file_wrapper(file, lambda fp: UnstructuredLoader(fp).load()).strip() -def _load_rtf(file: BinaryIO) -> str: +def _load_rtf(file: BytesIO) -> str: return striprtf.rtf_to_text(file.read().decode('utf-8', 'ignore')).strip() -def _load_xml(file: BinaryIO) -> str: +def _load_xml(file: BytesIO) -> str: data = file.read().decode('utf-8', 'ignore') data = re.sub(r'', '', data) return data.strip() -def _load_xlsx(file: BinaryIO) -> str: +def _load_xlsx(file: BytesIO) -> str: return read_excel(file, na_filter=False).to_string(header=False, na_rep='') -def _load_email(file: BinaryIO, ext: str = 'eml') -> str | None: +def _load_email(file: BytesIO, ext: str = 'eml') -> str | None: # NOTE: msg format is not tested if ext not in ['eml', 'msg']: return None @@ -115,30 +116,34 @@ def attachment_partitioner( } -def decode_source(source: UploadFile) -> str | None: +def decode_source(source: SourceItem) -> str | None: + io_obj: BytesIO | None = None try: # .pot files are powerpoint templates but also plain text files, # so we skip them to prevent decoding errors - if source.headers['title'].endswith('.pot'): + if source.title.endswith('.pot'): return None - mimetype = source.headers['type'] + mimetype = source.type if mimetype is None: return None + if isinstance(source.content, str): + io_obj = BytesIO(source.content.encode('utf-8', 'ignore')) + else: + io_obj = source.content + if _loader_map.get(mimetype): - result = _loader_map[mimetype](source.file) - source.file.close() + result = _loader_map[mimetype](io_obj) return result.encode('utf-8', 'ignore').decode('utf-8', 'ignore') - result = source.file.read().decode('utf-8', 'ignore') - source.file.close() - return result + return io_obj.read().decode('utf-8', 'ignore') except PdfFileNotDecryptedError: - logger.warning(f'PDF file ({source.filename}) is encrypted and cannot be read') + logger.warning(f'PDF file ({source.reference}) is encrypted and cannot be read') return None except Exception: - logger.exception(f'Error decoding source file ({source.filename})', stack_info=True) + logger.exception(f'Error decoding source file ({source.reference})', stack_info=True) return None finally: - source.file.close() # Ensure file is closed after processing + if io_obj is not None: + io_obj.close() diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index 5871ebb8..0eb70e0b 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -5,29 +5,23 @@ import logging import re -from fastapi.datastructures import UploadFile from langchain.schema import Document from ...dyn_loader import VectorDBLoader -from ...types import TConfig -from ...utils import is_valid_source_id, to_int +from ...types import IndexingError, SourceItem, TConfig from ...vectordb.base import BaseVectorDB from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp from ..types import InDocument from .doc_loader import decode_source from .doc_splitter import get_splitter_for -from .mimetype_list import SUPPORTED_MIMETYPES logger = logging.getLogger('ccb.injest') -def _allowed_file(file: UploadFile) -> bool: - return file.headers['type'] in SUPPORTED_MIMETYPES - def _filter_sources( vectordb: BaseVectorDB, - sources: list[UploadFile] -) -> tuple[list[UploadFile], list[UploadFile]]: + sources: dict[int, SourceItem] +) -> tuple[dict[int, SourceItem], dict[int, SourceItem]]: ''' Returns ------- @@ -37,30 +31,42 @@ def _filter_sources( ''' try: - existing_sources, new_sources = vectordb.check_sources(sources) + existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources) except Exception as e: - raise DbException('Error: Vectordb sources_to_embed error') from e + raise DbException('Error: Vectordb error while checking existing sources in indexing') from e + + existing_sources = {} + to_embed_sources = {} - return ([ - source for source in sources - if source.filename in existing_sources - ], [ - source for source in sources - if source.filename in new_sources - ]) + for db_id, source in sources.items(): + if source.reference in existing_source_ids: + existing_sources[db_id] = source + elif source.reference in to_embed_source_ids: + to_embed_sources[db_id] = source + return existing_sources, to_embed_sources -def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[InDocument]: - indocuments = [] - for source in sources: - logger.debug('processing source', extra={ 'source_id': source.filename }) +def _sources_to_indocuments( + config: TConfig, + sources: dict[int, SourceItem] +) -> tuple[dict[int, InDocument], dict[int, IndexingError]]: + indocuments = {} + errored_docs = {} + for db_id, source in sources.items(): + logger.debug('processing source', extra={ 'source_id': source.reference }) + + # todo: maybe fetch the content of the files here # transform the source to have text data content = decode_source(source) if content is None or (content := content.strip()) == '': - logger.debug('decoded empty source', extra={ 'source_id': source.filename }) + logger.debug('decoded empty source', extra={ 'source_id': source.reference }) + errored_docs[db_id] = IndexingError( + error='Decoded content is empty', + retryable=False, + ) continue # replace more than two newlines with two newlines (also blank spaces, more than 4) @@ -71,94 +77,123 @@ def _sources_to_indocuments(config: TConfig, sources: list[UploadFile]) -> list[ content = content.replace('\0', '') if content is None or content == '': - logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.filename }) + logger.debug('decoded empty source after cleanup', extra={ 'source_id': source.reference }) + errored_docs[db_id] = IndexingError( + error='Decoded content is empty', + retryable=False, + ) continue - logger.debug('decoded non empty source', extra={ 'source_id': source.filename }) + logger.debug('decoded non empty source', extra={ 'source_id': source.reference }) metadata = { - 'source': source.filename, - 'title': _decode_latin_1(source.headers['title']), - 'type': source.headers['type'], + 'source': source.reference, + 'title': _decode_latin_1(source.title), + 'type': source.type, } doc = Document(page_content=content, metadata=metadata) - splitter = get_splitter_for(config.embedding_chunk_size, source.headers['type']) + splitter = get_splitter_for(config.embedding_chunk_size, source.type) split_docs = splitter.split_documents([doc]) logger.debug('split document into chunks', extra={ - 'source_id': source.filename, + 'source_id': source.reference, 'len(split_docs)': len(split_docs), }) - indocuments.append(InDocument( + indocuments[db_id] = InDocument( documents=split_docs, - userIds=list(map(_decode_latin_1, source.headers['userIds'].split(','))), - source_id=source.filename, # pyright: ignore[reportArgumentType] - provider=source.headers['provider'], - modified=to_int(source.headers['modified']), - )) + userIds=list(map(_decode_latin_1, source.userIds)), + source_id=source.reference, + provider=source.provider, + modified=source.modified, # pyright: ignore[reportArgumentType] + ) + + return indocuments, errored_docs + + +def _increase_access_for_existing_sources( + vectordb: BaseVectorDB, + existing_sources: dict[int, SourceItem] +) -> dict[int, IndexingError | None]: + ''' + update userIds for existing sources + allow the userIds as additional users, not as the only users + ''' + if len(existing_sources) == 0: + return {} - return indocuments + results = {} + logger.debug('Increasing access for existing sources', extra={ + 'source_ids': [source.reference for source in existing_sources.values()] + }) + for db_id, source in existing_sources.items(): + try: + vectordb.update_access( + UpdateAccessOp.allow, + list(map(_decode_latin_1, source.userIds)), + source.reference, + ) + results[db_id] = None + except SafeDbException as e: + logger.error(f'Failed to update access for source ({source.reference}): {e.args[0]}') + results[db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue + except Exception as e: + logger.error(f'Unexpected error while updating access for source ({source.reference}): {e}') + results[db_id] = IndexingError( + error='Unexpected error while updating access', + retryable=True, + ) + continue + return results def _process_sources( vectordb: BaseVectorDB, config: TConfig, - sources: list[UploadFile], -) -> tuple[list[str],list[str]]: + sources: dict[int, SourceItem] +) -> dict[int, IndexingError | None]: ''' Processes the sources and adds them to the vectordb. Returns the list of source ids that were successfully added and those that need to be retried. ''' - existing_sources, filtered_sources = _filter_sources(vectordb, sources) + existing_sources, to_embed_sources = _filter_sources(vectordb, sources) logger.debug('db filter source results', extra={ 'len(existing_sources)': len(existing_sources), 'existing_sources': existing_sources, - 'len(filtered_sources)': len(filtered_sources), - 'filtered_sources': filtered_sources, + 'len(to_embed_sources)': len(to_embed_sources), + 'to_embed_sources': to_embed_sources, }) - loaded_source_ids = [source.filename for source in existing_sources] - # update userIds for existing sources - # allow the userIds as additional users, not as the only users - if len(existing_sources) > 0: - logger.debug('Increasing access for existing sources', extra={ - 'source_ids': [source.filename for source in existing_sources] - }) - for source in existing_sources: - try: - vectordb.update_access( - UpdateAccessOp.allow, - list(map(_decode_latin_1, source.headers['userIds'].split(','))), - source.filename, # pyright: ignore[reportArgumentType] - ) - except SafeDbException as e: - logger.error(f'Failed to update access for source ({source.filename}): {e.args[0]}') - continue - - if len(filtered_sources) == 0: + source_proc_results = _increase_access_for_existing_sources(vectordb, existing_sources) + + if len(to_embed_sources) == 0: # no new sources to embed logger.debug('Filtered all sources, nothing to embed') - return loaded_source_ids, [] # pyright: ignore[reportReturnType] + return source_proc_results logger.debug('Filtered sources:', extra={ - 'source_ids': [source.filename for source in filtered_sources] + 'source_ids': [source.reference for source in to_embed_sources.values()] }) # invalid/empty sources are filtered out here and not counted in loaded/retryable - indocuments = _sources_to_indocuments(config, filtered_sources) + indocuments, errored_docs = _sources_to_indocuments(config, to_embed_sources) - logger.debug('Converted all sources to documents') + source_proc_results.update(errored_docs) + logger.debug('Converted sources to documents') if len(indocuments) == 0: # filtered document(s) were invalid/empty, not an error logger.debug('All documents were found empty after being processed') - return loaded_source_ids, [] # pyright: ignore[reportReturnType] + return source_proc_results - added_source_ids, retry_source_ids = vectordb.add_indocuments(indocuments) - loaded_source_ids.extend(added_source_ids) + doc_add_results = vectordb.add_indocuments(indocuments) + source_proc_results.update(doc_add_results) logger.debug('Added documents to vectordb') - return loaded_source_ids, retry_source_ids # pyright: ignore[reportReturnType] + return source_proc_results def _decode_latin_1(s: str) -> str: @@ -172,31 +207,15 @@ def _decode_latin_1(s: str) -> str: def embed_sources( vectordb_loader: VectorDBLoader, config: TConfig, - sources: list[UploadFile], -) -> tuple[list[str],list[str]]: - # either not a file or a file that is allowed - sources_filtered = [ - source for source in sources - if is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType] - or _allowed_file(source) - ] - + sources: dict[int, SourceItem] +) -> dict[int, IndexingError | None]: logger.debug('Embedding sources:', extra={ 'source_ids': [ - f'{source.filename} ({_decode_latin_1(source.headers["title"])})' - for source in sources_filtered - ], - 'invalid_source_ids': [ - source.filename for source in sources - if not is_valid_source_id(source.filename) # pyright: ignore[reportArgumentType] - ], - 'not_allowed_file_ids': [ - source.filename for source in sources - if not _allowed_file(source) + f'{source.reference} ({_decode_latin_1(source.title)})' + for source in sources.values() ], - 'len(source_ids)': len(sources_filtered), - 'len(total_source_ids)': len(sources), + 'len(source_ids)': len(sources), }) vectordb = vectordb_loader.load() - return _process_sources(vectordb, config, sources_filtered) + return _process_sources(vectordb, config, sources) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 4c07a06e..3e70ee1b 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -27,7 +27,7 @@ from time import sleep from typing import Annotated, Any -from fastapi import Body, FastAPI, Request, UploadFile +from fastapi import Body, FastAPI, Request from langchain.llms.base import LLM from nc_py_api import AsyncNextcloudApp, NextcloudApp from nc_py_api.ex_app import persistent_storage, set_handlers @@ -35,14 +35,13 @@ from starlette.responses import FileResponse from .chain.context import do_doc_search -from .chain.ingest.injest import embed_sources from .chain.one_shot import process_context_query, process_query from .config_parser import get_config from .dyn_loader import LLMModelLoader, VectorDBLoader from .models.types import LlmException from nc_py_api.ex_app import AppAPIAuthMiddleware from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of -from .task_fetcher import start_bg_threads, stop_bg_threads +from .task_fetcher import start_bg_threads, wait_for_bg_threads from .vectordb.service import ( count_documents_by_provider, decl_update_access, @@ -57,6 +56,7 @@ repair_run() ensure_config_file() logger = logging.getLogger('ccb.controller') +app_config = get_config(os.environ['CC_CONFIG_PATH']) __download_models_from_hf = os.environ.get('CC_DOWNLOAD_MODELS_FROM_HF', 'true').lower() in ('1', 'true', 'yes') models_to_fetch = { @@ -77,10 +77,10 @@ def enabled_handler(enabled: bool, _: NextcloudApp | AsyncNextcloudApp) -> str: try: if enabled: app_enabled.set() - start_bg_threads() + start_bg_threads(app_config, app_enabled) else: app_enabled.clear() - stop_bg_threads() + wait_for_bg_threads() except Exception as e: logger.exception('Error in enabled handler:', exc_info=e) return f'Error in enabled handler: {e}' @@ -101,10 +101,9 @@ async def lifespan(app: FastAPI): yield vectordb_loader.offload() llm_loader.offload() - stop_bg_threads() + wait_for_bg_threads() -app_config = get_config(os.environ['CC_CONFIG_PATH']) app = FastAPI(debug=app_config.debug, lifespan=lifespan) # pyright: ignore[reportArgumentType] app.extra['CONFIG'] = app_config @@ -343,78 +342,78 @@ def _(userId: str = Body(embed=True)): return JSONResponse('User deleted') -@app.put('/loadSources') -@enabled_guard(app) -def _(sources: list[UploadFile]): - global _indexing - - if len(sources) == 0: - return JSONResponse('No sources provided', 400) - - for source in sources: - if not value_of(source.filename): - return JSONResponse(f'Invalid source filename for: {source.headers.get("title")}', 400) - - with index_lock: - if source.filename in _indexing: - # this request will be retried by the client - return JSONResponse( - f'This source ({source.filename}) is already being processed in another request, try again later', - 503, - headers={'cc-retry': 'true'}, - ) - - if not ( - value_of(source.headers.get('userIds')) - and source.headers.get('title', None) is not None - and value_of(source.headers.get('type')) - and value_of(source.headers.get('modified')) - and source.headers['modified'].isdigit() - and value_of(source.headers.get('provider')) - ): - logger.error('Invalid/missing headers received', extra={ - 'source_id': source.filename, - 'title': source.headers.get('title'), - 'headers': source.headers, - }) - return JSONResponse(f'Invaild/missing headers for: {source.filename}', 400) - - # wait for 10 minutes before failing the request - semres = doc_parse_semaphore.acquire(block=True, timeout=10*60) - if not semres: - return JSONResponse( - 'Document parser worker limit reached, try again in some time or consider increasing the limit', - 503, - headers={'cc-retry': 'true'} - ) - - with index_lock: - for source in sources: - _indexing[source.filename] = source.size - - try: - loaded_sources, not_added_sources = exec_in_proc( - target=embed_sources, - args=(vectordb_loader, app.extra['CONFIG'], sources) - ) - except (DbException, EmbeddingException): - raise - except Exception as e: - raise DbException('Error: failed to load sources') from e - finally: - with index_lock: - for source in sources: - _indexing.pop(source.filename, None) - doc_parse_semaphore.release() - - if len(loaded_sources) != len(sources): - logger.debug('Some sources were not loaded', extra={ - 'Count of loaded sources': f'{len(loaded_sources)}/{len(sources)}', - 'source_ids': loaded_sources, - }) - - # loaded sources include the existing sources that may only have their access updated - return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources}) +# @app.put('/loadSources') +# @enabled_guard(app) +# def _(sources: list[UploadFile]): +# global _indexing + +# if len(sources) == 0: +# return JSONResponse('No sources provided', 400) + +# for source in sources: +# if not value_of(source.filename): +# return JSONResponse(f'Invalid source filename for: {source.headers.get("title")}', 400) + +# with index_lock: +# if source.filename in _indexing: +# # this request will be retried by the client +# return JSONResponse( +# f'This source ({source.filename}) is already being processed in another request, try again later', +# 503, +# headers={'cc-retry': 'true'}, +# ) + +# if not ( +# value_of(source.headers.get('userIds')) +# and source.headers.get('title', None) is not None +# and value_of(source.headers.get('type')) +# and value_of(source.headers.get('modified')) +# and source.headers['modified'].isdigit() +# and value_of(source.headers.get('provider')) +# ): +# logger.error('Invalid/missing headers received', extra={ +# 'source_id': source.filename, +# 'title': source.headers.get('title'), +# 'headers': source.headers, +# }) +# return JSONResponse(f'Invaild/missing headers for: {source.filename}', 400) + +# # wait for 10 minutes before failing the request +# semres = doc_parse_semaphore.acquire(block=True, timeout=10*60) +# if not semres: +# return JSONResponse( +# 'Document parser worker limit reached, try again in some time or consider increasing the limit', +# 503, +# headers={'cc-retry': 'true'} +# ) + +# with index_lock: +# for source in sources: +# _indexing[source.filename] = source.size + +# try: +# loaded_sources, not_added_sources = exec_in_proc( +# target=embed_sources, +# args=(vectordb_loader, app.extra['CONFIG'], sources) +# ) +# except (DbException, EmbeddingException): +# raise +# except Exception as e: +# raise DbException('Error: failed to load sources') from e +# finally: +# with index_lock: +# for source in sources: +# _indexing.pop(source.filename, None) +# doc_parse_semaphore.release() + +# if len(loaded_sources) != len(sources): +# logger.debug('Some sources were not loaded', extra={ +# 'Count of loaded sources': f'{len(loaded_sources)}/{len(sources)}', +# 'source_ids': loaded_sources, +# }) + +# # loaded sources include the existing sources that may only have their access updated +# return JSONResponse({'loaded_sources': loaded_sources, 'sources_to_retry': not_added_sources}) class Query(BaseModel): diff --git a/context_chat_backend/chain/ingest/mimetype_list.py b/context_chat_backend/mimetype_list.py similarity index 100% rename from context_chat_backend/chain/ingest/mimetype_list.py rename to context_chat_backend/mimetype_list.py diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 9660b44c..a548bcfd 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -3,15 +3,41 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # +import asyncio +import logging +from contextlib import suppress from enum import Enum -from threading import Thread +from io import BytesIO +from threading import Event, Thread +from time import sleep -from .types import AppRole -from .utils import get_app_role +import niquests +from nc_py_api import AsyncNextcloudApp, NextcloudApp +from pydantic import ValidationError + +from .chain.ingest.injest import embed_sources +from .dyn_loader import VectorDBLoader +from .types import ( + AppRole, + EmbeddingException, + FilesQueueItem, + IndexingError, + IndexingException, + LoaderException, + ReceivedFileItem, + SourceItem, + TConfig, +) +from .utils import exec_in_proc, get_app_role +from .vectordb.types import DbException APP_ROLE = get_app_role() THREADS = {} -THREADS_STOP_EVENTS = {} +LOGGER = logging.getLogger('ccb.task_fetcher') +FILES_INDEXING_BATCH_SIZE = 64 # todo: config? +# max concurrent fetches to avoid overloading the NC server or hitting rate limits +CONCURRENT_FILE_FETCHES = 10 # todo: config? +MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, todo: config? class ThreadType(Enum): @@ -20,67 +46,294 @@ class ThreadType(Enum): REQUEST_PROCESSING = 'request_processing' -def files_indexing_thread(): - ... +async def __fetch_file_content( + semaphore: asyncio.Semaphore, + file_id: int, + user_id: str, + _rlimit = 3, +) -> BytesIO: + ''' + Raises + ------ + IndexingException + ''' + + async with semaphore: + nc = AsyncNextcloudApp() + try: + # a file pointer for storing the stream in memory until it is consumed + fp = BytesIO() + await nc._session.download2fp( + url_path=f'/apps/context_chat/files/{file_id}', + fp=fp, + dav=False, + params={ 'userId': user_id }, + ) + return fp + except niquests.exceptions.RequestException as e: + # todo: raise IndexingException with retryable=True for rate limit errors, + # todo: and handle it in the caller to not delete the source from the queue and retry later through + # todo: the normal lock expiry mechanism + if e.response is None: + raise + + if e.response.status_code == niquests.codes.too_many_requests: # pyright: ignore[reportAttributeAccessIssue] + # todo: implement rate limits in php CC? + wait_for = int(e.response.headers.get('Retry-After', '30')) + if _rlimit <= 0: + raise IndexingException( + f'Rate limited when fetching content for file id {file_id}, user id {user_id},' + ' max retries exceeded', + retryable=True, + ) from e + LOGGER.warning( + f'Rate limited when fetching content for file id {file_id}, user id {user_id},' + f' waiting {wait_for} before retrying', + exc_info=e, + ) + await asyncio.sleep(wait_for) + return await __fetch_file_content(semaphore, file_id, user_id, _rlimit - 1) + + raise + except IndexingException: + raise + except Exception as e: + LOGGER.error(f'Error fetching content for file id {file_id}, user id {user_id}: {e}', exc_info=e) + raise IndexingException(f'Error fetching content for file id {file_id}, user id {user_id}: {e}') from e + + +async def __fetch_files_content( + files: dict[int, ReceivedFileItem] +) -> dict[int, SourceItem | IndexingError]: + source_items = {} + semaphore = asyncio.Semaphore(CONCURRENT_FILE_FETCHES) + tasks = [] + + for file_id, file_item in files.items(): + if file_item.size > MAX_FILE_SIZE: + LOGGER.info( + f'Skipping file id {file_id}, source id {file_item.reference} due to size' + f' {(file_item.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB', + ) + source_items[file_id] = IndexingError( + error=( + f'File size {(file_item.size/(1024*1024)):.2f} MiB' + f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB' + ), + retryable=False, + ) + continue + # todo: perform the existing file check before fetching the content to avoid unnecessary fetches + # any user id from the list should have read access to the file + tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file_id, file_item.userIds[0]))) + results = await asyncio.gather(*tasks, return_exceptions=True) + for (file_id, file_item), result in zip(files.items(), results, strict=True): + if isinstance(result, IndexingException): + LOGGER.error( + f'Error fetching content for file id {file_id}, reference {file_item.reference}: {result}', + exc_info=result, + ) + source_items[file_id] = IndexingError( + error=str(result), + retryable=result.retryable, + ) + elif isinstance(result, str) or isinstance(result, BytesIO): + source_items[file_id] = SourceItem( + **file_item.model_dump(), + content=result, + ) + elif isinstance(result, BaseException): + LOGGER.error( + f'Unexpected error fetching content for file id {file_id}, reference {file_item.reference}: {result}', + exc_info=result, + ) + source_items[file_id] = IndexingError( + error=f'Unexpected error: {result}', + retryable=True, + ) + else: + LOGGER.error( + f'Unknown error fetching content for file id {file_id}, reference {file_item.reference}: {result}', + exc_info=True, + ) + source_items[file_id] = IndexingError( + error='Unknown error', + retryable=True, + ) + return source_items + + +def files_indexing_thread(app_config: TConfig, app_enabled: Event) -> None: + try: + vectordb_loader = VectorDBLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingError | None]: + try: + return exec_in_proc( + target=embed_sources, + args=(vectordb_loader, app_config, source_items), + ) + except (DbException, EmbeddingException): + raise + except Exception as e: + raise DbException('Error: failed to load sources') from e -def updates_processing_thread(): + + while True: + if not app_enabled.is_set(): + LOGGER.info('Files indexing thread is stopping as the app is disabled') + return + + try: + nc = NextcloudApp() + # todo: add the 'size' param to the return of this call. + q_items_res = nc.ocs( + 'GET', + '/apps/context_chat/queues/documents', + params={ 'n': FILES_INDEXING_BATCH_SIZE } + ) + + try: + q_items = FilesQueueItem.model_validate(q_items_res) + except ValidationError as e: + raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + + # populate files content and convert to source items + fetched_files = {} + source_files = {} + # unified error structure for files and content providers + source_errors = {} + + if q_items.files: + fetched_files = asyncio.run(__fetch_files_content(q_items.files)) + + for file_id, result in fetched_files.items(): + if isinstance(result, SourceItem): + source_files[file_id] = result + else: + source_errors[file_id] = result + + files_result = _load_sources(source_files) + providers_result = _load_sources(q_items.content_providers) + + if ( + any(isinstance(res, IndexingError) for res in files_result.values()) + or any(isinstance(res, IndexingError) for res in providers_result.values()) + ): + LOGGER.error('Some sources failed to index', extra={ + 'file_errors': { + file_id: error + for file_id, error in files_result.items() + if isinstance(error, IndexingError) + }, + 'provider_errors': { + provider_id: error + for provider_id, error in providers_result.items() + if isinstance(error, IndexingError) + }, + }) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error fetching documents to index, will retry:', exc_info=e) + sleep(5) + continue + except Exception as e: + LOGGER.exception('Error fetching documents to index:', exc_info=e) + sleep(5) + continue + + # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable + to_delete_file_ids = [ + file_id for file_id, result in files_result.items() + if result is None or (isinstance(result, IndexingError) and not result.retryable) + ] + to_delete_provider_ids = [ + provider_id for provider_id, result in providers_result.items() + if result is None or (isinstance(result, IndexingError) and not result.retryable) + ] + + try: + nc.ocs( + 'DELETE', + '/apps/context_chat/queues/documents/', + json={ + 'files': to_delete_file_ids, + 'content_providers': to_delete_provider_ids, + }, + ) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error reporting indexing results, will retry:', exc_info=e) + sleep(5) + with suppress(Exception): + nc = NextcloudApp() + nc.ocs( + 'DELETE', + '/apps/context_chat/queues/documents/', + json={ + 'files': to_delete_file_ids, + 'content_providers': to_delete_provider_ids, + }, + ) + continue + except Exception as e: + LOGGER.exception('Error reporting indexing results:', exc_info=e) + sleep(5) + continue + + + +def updates_processing_thread(app_config: TConfig): ... -def request_processing_thread(): +def request_processing_thread(app_config: TConfig): ... -def start_bg_threads(): +def start_bg_threads(app_config: TConfig, app_enabled: Event): match APP_ROLE: case AppRole.INDEXING | AppRole.NORMAL: THREADS[ThreadType.FILES_INDEXING] = Thread( target=files_indexing_thread, + args=(app_config, Event), name='FilesIndexingThread', - daemon=True, ) THREADS[ThreadType.UPDATES_PROCESSING] = Thread( target=updates_processing_thread, + args=(app_config, Event), name='UpdatesProcessingThread', - daemon=True, ) THREADS[ThreadType.FILES_INDEXING].start() THREADS[ThreadType.UPDATES_PROCESSING].start() case AppRole.RP | AppRole.NORMAL: THREADS[ThreadType.REQUEST_PROCESSING] = Thread( target=request_processing_thread, + args=(app_config, Event), name='RequestProcessingThread', - daemon=True, ) THREADS[ThreadType.REQUEST_PROCESSING].start() -def stop_bg_threads(): +def wait_for_bg_threads(): match APP_ROLE: case AppRole.INDEXING | AppRole.NORMAL: - if ( - ThreadType.FILES_INDEXING not in THREADS - or ThreadType.UPDATES_PROCESSING not in THREADS - or ThreadType.FILES_INDEXING not in THREADS_STOP_EVENTS - or ThreadType.UPDATES_PROCESSING not in THREADS_STOP_EVENTS - ): + if (ThreadType.FILES_INDEXING not in THREADS or ThreadType.UPDATES_PROCESSING not in THREADS): return - THREADS_STOP_EVENTS[ThreadType.FILES_INDEXING].set() - THREADS_STOP_EVENTS[ThreadType.UPDATES_PROCESSING].set() THREADS[ThreadType.FILES_INDEXING].join() THREADS[ThreadType.UPDATES_PROCESSING].join() THREADS.pop(ThreadType.FILES_INDEXING) THREADS.pop(ThreadType.UPDATES_PROCESSING) - THREADS_STOP_EVENTS.pop(ThreadType.FILES_INDEXING) - THREADS_STOP_EVENTS.pop(ThreadType.UPDATES_PROCESSING) case AppRole.RP | AppRole.NORMAL: - if ( - ThreadType.REQUEST_PROCESSING not in THREADS - or ThreadType.REQUEST_PROCESSING not in THREADS_STOP_EVENTS - ): + if (ThreadType.REQUEST_PROCESSING not in THREADS): return - THREADS_STOP_EVENTS[ThreadType.REQUEST_PROCESSING].set() THREADS[ThreadType.REQUEST_PROCESSING].join() THREADS.pop(ThreadType.REQUEST_PROCESSING) - THREADS_STOP_EVENTS.pop(ThreadType.REQUEST_PROCESSING) diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 78680866..97d48ce6 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -3,8 +3,13 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # from enum import Enum +from io import BytesIO +from typing import Self -from pydantic import BaseModel +from pydantic import BaseModel, field_validator + +from .mimetype_list import SUPPORTED_MIMETYPES +from .utils import is_valid_provider_id, is_valid_source_id __all__ = [ 'DEFAULT_EM_MODEL_ALIAS', @@ -17,6 +22,7 @@ ] DEFAULT_EM_MODEL_ALIAS = 'em_model' +FILES_PROVIDER_ID = 'files__default' class TEmbeddingAuthApiKey(BaseModel): @@ -79,3 +85,116 @@ class AppRole(str, Enum): NORMAL = 'normal' INDEXING = 'indexing' RP = 'rp' + + +class CommonSourceItem(BaseModel): + userIds: list[str] + reference: str # source_id of the form "appId__providerId: itemId" + title: str + modified: int | str # todo: int/string? + type: str + provider: str + size: int + + @field_validator('modified', mode='before') + @classmethod + def validate_modified(cls, v): + if isinstance(v, int): + return v + if isinstance(v, str): + try: + return int(v) + except ValueError as e: + raise ValueError(f'Invalid modified value: {v}') from e + raise ValueError(f'Invalid modified type: {type(v)}') + + @field_validator('reference', 'title', 'type', 'provider') + @classmethod + def validate_strings_non_empty(cls, v): + if not isinstance(v, str) or v.strip() == '': + raise ValueError('Must be a non-empty string') + return v.strip() + + @field_validator('userIds', mode='after') + def validate_user_ids(self) -> Self: + if ( + not isinstance(self.userIds, list) + or not all( + isinstance(uid, str) + and uid.strip() != '' + for uid in self.userIds + ) + or len(self.userIds) == 0 + ): + raise ValueError('userIds must be a non-empty list of non-empty strings') + self.userIds = [uid.strip() for uid in self.userIds] + return self + + @field_validator('reference', mode='after') + def validate_reference_format(self) -> Self: + # validate reference format: "appId__providerId: itemId" + if not is_valid_source_id(self.reference): + raise ValueError('Invalid reference format, must be "appId__providerId: itemId"') + return self + + @field_validator('provider', mode='after') + def validate_provider_format(self) -> Self: + # validate provider format: "appId__providerId" + if not is_valid_provider_id(self.provider): + raise ValueError('Invalid provider format, must be "appId__providerId"') + return self + + @field_validator('type', mode='after') + def validate_type(self) -> Self: + if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES: + raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}') + return self + + @field_validator('size', mode='after') + def validate_size(self) -> Self: + if not isinstance(self.size, int) or self.size < 0: + raise ValueError(f'Invalid size value: {self.size}, must be a non-negative integer') + return self + + +class ReceivedFileItem(CommonSourceItem): + content: None + + +class SourceItem(CommonSourceItem): + ''' + Used for the unified queue of items to process, after fetching the content for files + and for directly fetched content providers. + ''' + content: str | BytesIO + + @field_validator('content') + @classmethod + def validate_content(cls, v): + if isinstance(v, str): + if v.strip() == '': + raise ValueError('Content must be a non-empty string') + return v.strip() + if isinstance(v, BytesIO): + if v.getbuffer().nbytes == 0: + raise ValueError('Content must be a non-empty BytesIO') + return v + raise ValueError('Content must be either a non-empty string or a non-empty BytesIO') + + +class FilesQueueItem(BaseModel): + files: dict[int, ReceivedFileItem] # [db id]: FileItem + content_providers: dict[int, SourceItem] # [db id]: SourceItem + + +class IndexingException(Exception): + retryable: bool = False + + def __init__(self, message: str, retryable: bool = False): + super().__init__(message) + self.retryable = retryable + + +class IndexingError(BaseModel): + error: str + retryable: bool = False diff --git a/context_chat_backend/vectordb/base.py b/context_chat_backend/vectordb/base.py index 0bf10200..ebd54075 100644 --- a/context_chat_backend/vectordb/base.py +++ b/context_chat_backend/vectordb/base.py @@ -5,12 +5,12 @@ from abc import ABC, abstractmethod from typing import Any -from fastapi import UploadFile from langchain.schema import Document from langchain.schema.embeddings import Embeddings from langchain.schema.vectorstore import VectorStore from ..chain.types import InDocument, ScopeType +from ..types import IndexingError, SourceItem from ..utils import timed from .types import UpdateAccessOp @@ -62,7 +62,7 @@ def get_instance(self) -> VectorStore: ''' @abstractmethod - def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list[str]]: + def add_indocuments(self, indocuments: dict[int, InDocument]) -> dict[int, IndexingError | None]: ''' Adds the given indocuments to the vectordb and updates the docs + access tables. @@ -79,10 +79,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str],list @timed @abstractmethod - def check_sources( - self, - sources: list[UploadFile], - ) -> tuple[list[str], list[str]]: + def check_sources(self, sources: dict[int, SourceItem]) -> tuple[list[str], list[str]]: ''' Checks the sources in the vectordb if they are already embedded and are up to date. diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py index f40390fe..5a5ded35 100644 --- a/context_chat_backend/vectordb/pgvector.py +++ b/context_chat_backend/vectordb/pgvector.py @@ -10,14 +10,13 @@ import sqlalchemy.dialects.postgresql as postgresql_dialects import sqlalchemy.orm as orm from dotenv import load_dotenv -from fastapi import UploadFile from langchain.schema import Document from langchain.vectorstores import VectorStore from langchain_core.embeddings import Embeddings from langchain_postgres.vectorstores import Base, PGVector from ..chain.types import InDocument, ScopeType -from ..types import EmbeddingException, RetryableEmbeddingException +from ..types import EmbeddingException, IndexingError, RetryableEmbeddingException, SourceItem from ..utils import timed from .base import BaseVectorDB from .types import DbException, SafeDbException, UpdateAccessOp @@ -129,17 +128,16 @@ def get_users(self) -> list[str]: except Exception as e: raise DbException('Error: getting a list of all users from access list') from e - def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], list[str]]: + def add_indocuments(self, indocuments: dict[int, InDocument]) -> dict[int, IndexingError | None]: """ Raises EmbeddingException: if the embedding request definitively fails """ - added_sources = [] - retry_sources = [] + results = {} batch_size = PG_BATCH_SIZE // 5 with self.session_maker() as session: - for indoc in indocuments: + for php_db_id, indoc in indocuments.items(): try: # query paramerters limitation in postgres is 65535 (https://www.postgresql.org/docs/current/limits.html) # so we chunk the documents into (5 values * 10k) chunks @@ -158,7 +156,7 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis session.commit() self.decl_update_access(indoc.userIds, indoc.source_id, session) - added_sources.append(indoc.source_id) + results[php_db_id] = None session.commit() except SafeDbException as e: # for when the source_id is not found. This here can be an error in the DB @@ -166,51 +164,67 @@ def add_indocuments(self, indocuments: list[InDocument]) -> tuple[list[str], lis logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error=str(e), + retryable=True, + ) continue except RetryableEmbeddingException as e: # temporary error, continue with the next document logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error=str(e), + retryable=True, + ) continue except EmbeddingException as e: logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - raise + results[php_db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue except Exception as e: logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, }) - retry_sources.append(indoc.source_id) + results[php_db_id] = IndexingError( + error='An unexpected error occurred while adding documents to the database.', + retryable=True, + ) continue - return added_sources, retry_sources + return results @timed - def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str]]: + def check_sources(self, sources: dict[int, SourceItem]) -> tuple[list[str], list[str]]: + ''' + returns a tuple of (existing_source_ids, to_embed_source_ids) + ''' with self.session_maker() as session: try: stmt = ( sa.select(DocumentsStore.source_id) - .filter(DocumentsStore.source_id.in_([source.filename for source in sources])) + .filter(DocumentsStore.source_id.in_([source.reference for source in sources.values()])) .with_for_update() ) results = session.execute(stmt).fetchall() existing_sources = {r.source_id for r in results} - to_embed = [source.filename for source in sources if source.filename not in existing_sources] + to_embed = [source.reference for source in sources.values() if source.reference not in existing_sources] to_delete = [] - for source in sources: + for source in sources.values(): stmt = ( sa.select(DocumentsStore.source_id) - .filter(DocumentsStore.source_id == source.filename) + .filter(DocumentsStore.source_id == source.reference) .filter(DocumentsStore.modified < sa.cast( - datetime.fromtimestamp(int(source.headers['modified'])), + datetime.fromtimestamp(int(source.modified)), sa.DateTime, )) ) @@ -227,14 +241,13 @@ def check_sources(self, sources: list[UploadFile]) -> tuple[list[str], list[str] session.rollback() raise DbException('Error: checking sources in vectordb') from e - still_existing_sources = [ - source - for source in existing_sources - if source not in to_delete + still_existing_source_ids = [ + source_id + for source_id in existing_sources + if source_id not in to_delete ] - # the pyright issue stems from source.filename, which has already been validated - return list(still_existing_sources), to_embed # pyright: ignore[reportReturnType] + return list(still_existing_source_ids), to_embed def decl_update_access(self, user_ids: list[str], source_id: str, session_: orm.Session | None = None): session = session_ or self.session_maker() From b3d0c22032ad6b1bf4d055bb97904b4afb9dcce5 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Mon, 9 Mar 2026 19:42:21 +0530 Subject: [PATCH 04/17] wip: parallelize file parsing and processing based on cpu count Signed-off-by: Anupam Kumar --- context_chat_backend/task_fetcher.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index a548bcfd..853a68c8 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -5,6 +5,7 @@ import asyncio import logging +import os from contextlib import suppress from enum import Enum from io import BytesIO @@ -35,6 +36,8 @@ THREADS = {} LOGGER = logging.getLogger('ccb.task_fetcher') FILES_INDEXING_BATCH_SIZE = 64 # todo: config? +# divides the batch into these many chunks +PARALLEL_FILE_PARSING = max(1, (os.cpu_count() or 2) - 1) # todo: config? # max concurrent fetches to avoid overloading the NC server or hitting rate limits CONCURRENT_FILE_FETCHES = 10 # todo: config? MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, todo: config? @@ -217,8 +220,18 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro else: source_errors[file_id] = result - files_result = _load_sources(source_files) - providers_result = _load_sources(q_items.content_providers) + files_result = {} + providers_result = {} + chunk_size = FILES_INDEXING_BATCH_SIZE // PARALLEL_FILE_PARSING + + # chunk file parsing for better file operation parallelism + for i in range(0, len(source_files), chunk_size): + chunk = dict(list(source_files.items())[i:i+chunk_size]) + files_result.update(_load_sources(chunk)) + + for i in range(0, len(q_items.content_providers), chunk_size): + chunk = dict(list(q_items.content_providers.items())[i:i+chunk_size]) + providers_result.update(_load_sources(chunk)) if ( any(isinstance(res, IndexingError) for res in files_result.values()) From a9b8c8f1003fcf330de87b746c918a932646c789 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Tue, 10 Mar 2026 17:36:03 +0530 Subject: [PATCH 05/17] ci: use the kubernetes branch of context_chat Signed-off-by: Anupam Kumar --- .github/workflows/integration-test.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 10e2d61b..fb06bafa 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -113,6 +113,8 @@ jobs: repository: nextcloud/context_chat path: apps/context_chat persist-credentials: false + # todo: remove later + ref: feat/reverse-content-flow - name: Checkout backend uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4 From 8867c1b2550fe1496a64ccd74be8593f802aef51 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Tue, 10 Mar 2026 17:43:27 +0530 Subject: [PATCH 06/17] fix typo Signed-off-by: Anupam Kumar --- context_chat_backend/task_fetcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 853a68c8..cfa9293c 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -304,11 +304,11 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro -def updates_processing_thread(app_config: TConfig): +def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: ... -def request_processing_thread(app_config: TConfig): +def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None: ... @@ -317,12 +317,12 @@ def start_bg_threads(app_config: TConfig, app_enabled: Event): case AppRole.INDEXING | AppRole.NORMAL: THREADS[ThreadType.FILES_INDEXING] = Thread( target=files_indexing_thread, - args=(app_config, Event), + args=(app_config, app_enabled), name='FilesIndexingThread', ) THREADS[ThreadType.UPDATES_PROCESSING] = Thread( target=updates_processing_thread, - args=(app_config, Event), + args=(app_config, app_enabled), name='UpdatesProcessingThread', ) THREADS[ThreadType.FILES_INDEXING].start() @@ -330,7 +330,7 @@ def start_bg_threads(app_config: TConfig, app_enabled: Event): case AppRole.RP | AppRole.NORMAL: THREADS[ThreadType.REQUEST_PROCESSING] = Thread( target=request_processing_thread, - args=(app_config, Event), + args=(app_config, app_enabled), name='RequestProcessingThread', ) THREADS[ThreadType.REQUEST_PROCESSING].start() From 3c4d698afe881efd1f122dc922275cd7517300cc Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 11:58:50 +0530 Subject: [PATCH 07/17] migrate the update process to be thread based Signed-off-by: Anupam Kumar --- context_chat_backend/chain/ingest/injest.py | 2 +- context_chat_backend/controller.py | 203 ++++++++++---------- context_chat_backend/task_fetcher.py | 183 +++++++++++++++++- context_chat_backend/types.py | 183 +++++++++++++++++- context_chat_backend/vectordb/pgvector.py | 30 ++- context_chat_backend/vectordb/service.py | 54 +++++- context_chat_backend/vectordb/types.py | 4 +- 7 files changed, 534 insertions(+), 125 deletions(-) diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index 0eb70e0b..7369f452 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -129,7 +129,7 @@ def _increase_access_for_existing_sources( for db_id, source in existing_sources.items(): try: vectordb.update_access( - UpdateAccessOp.allow, + UpdateAccessOp.ALLOW, list(map(_decode_latin_1, source.userIds)), source.reference, ) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 3e70ee1b..580416f7 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -6,7 +6,7 @@ # isort: off from .chain.types import ContextException, LLMOutput, ScopeType, SearchResult from .types import LoaderException, EmbeddingException -from .vectordb.types import DbException, SafeDbException, UpdateAccessOp +from .vectordb.types import DbException, SafeDbException from .setup_functions import ensure_config_file, repair_run, setup_env_vars # setup env vars before importing other modules @@ -25,9 +25,9 @@ from functools import wraps from threading import Event, Thread from time import sleep -from typing import Annotated, Any +from typing import Any -from fastapi import Body, FastAPI, Request +from fastapi import FastAPI, Request from langchain.llms.base import LLM from nc_py_api import AsyncNextcloudApp, NextcloudApp from nc_py_api.ex_app import persistent_storage, set_handlers @@ -40,16 +40,9 @@ from .dyn_loader import LLMModelLoader, VectorDBLoader from .models.types import LlmException from nc_py_api.ex_app import AppAPIAuthMiddleware -from .utils import JSONResponse, exec_in_proc, is_valid_provider_id, is_valid_source_id, value_of +from .utils import JSONResponse, exec_in_proc, value_of from .task_fetcher import start_bg_threads, wait_for_bg_threads -from .vectordb.service import ( - count_documents_by_provider, - decl_update_access, - delete_by_provider, - delete_by_source, - delete_user, - update_access, -) +from .vectordb.service import count_documents_by_provider # setup @@ -227,119 +220,131 @@ def _(): return JSONResponse(counts) -@app.post('/updateAccessDeclarative') -@enabled_guard(app) -def _( - userIds: Annotated[list[str], Body()], - sourceId: Annotated[str, Body()], -): - logger.debug('Update access declarative request:', extra={ - 'user_ids': userIds, - 'source_id': sourceId, - }) +@app.get('/downloadLogs') +def download_logs() -> FileResponse: + with tempfile.NamedTemporaryFile('wb', delete=False) as tmp: + with zipfile.ZipFile(tmp, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_file: + files = os.listdir(os.path.join(persistent_storage(), 'logs')) + for file in files: + file_path = os.path.join(persistent_storage(), 'logs', file) + if os.path.isfile(file_path): # Might be a folder (just skip it then) + zip_file.write(file_path) + return FileResponse(tmp.name, media_type='application/zip', filename='docker_logs.zip') - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - if not is_valid_source_id(sourceId): - return JSONResponse('Invalid source id', 400) +# @app.post('/updateAccessDeclarative') +# @enabled_guard(app) +# def _( +# userIds: Annotated[list[str], Body()], +# sourceId: Annotated[str, Body()], +# ): +# logger.debug('Update access declarative request:', extra={ +# 'user_ids': userIds, +# 'source_id': sourceId, +# }) - exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId)) +# if len(userIds) == 0: +# return JSONResponse('Empty list of user ids', 400) - return JSONResponse('Access updated') +# if not is_valid_source_id(sourceId): +# return JSONResponse('Invalid source id', 400) +# exec_in_proc(target=decl_update_access, args=(vectordb_loader, userIds, sourceId)) -@app.post('/updateAccess') -@enabled_guard(app) -def _( - op: Annotated[UpdateAccessOp, Body()], - userIds: Annotated[list[str], Body()], - sourceId: Annotated[str, Body()], -): - logger.debug('Update access request', extra={ - 'op': op, - 'user_ids': userIds, - 'source_id': sourceId, - }) +# return JSONResponse('Access updated') - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - if not is_valid_source_id(sourceId): - return JSONResponse('Invalid source id', 400) +# @app.post('/updateAccess') +# @enabled_guard(app) +# def _( +# op: Annotated[UpdateAccessOp, Body()], +# userIds: Annotated[list[str], Body()], +# sourceId: Annotated[str, Body()], +# ): +# logger.debug('Update access request', extra={ +# 'op': op, +# 'user_ids': userIds, +# 'source_id': sourceId, +# }) - exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId)) +# if len(userIds) == 0: +# return JSONResponse('Empty list of user ids', 400) - return JSONResponse('Access updated') +# if not is_valid_source_id(sourceId): +# return JSONResponse('Invalid source id', 400) +# exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, sourceId)) -@app.post('/updateAccessProvider') -@enabled_guard(app) -def _( - op: Annotated[UpdateAccessOp, Body()], - userIds: Annotated[list[str], Body()], - providerId: Annotated[str, Body()], -): - logger.debug('Update access by provider request', extra={ - 'op': op, - 'user_ids': userIds, - 'provider_id': providerId, - }) +# return JSONResponse('Access updated') - if len(userIds) == 0: - return JSONResponse('Empty list of user ids', 400) - if not is_valid_provider_id(providerId): - return JSONResponse('Invalid provider id', 400) +# @app.post('/updateAccessProvider') +# @enabled_guard(app) +# def _( +# op: Annotated[UpdateAccessOp, Body()], +# userIds: Annotated[list[str], Body()], +# providerId: Annotated[str, Body()], +# ): +# logger.debug('Update access by provider request', extra={ +# 'op': op, +# 'user_ids': userIds, +# 'provider_id': providerId, +# }) - exec_in_proc(target=update_access, args=(vectordb_loader, op, userIds, providerId)) +# if len(userIds) == 0: +# return JSONResponse('Empty list of user ids', 400) - return JSONResponse('Access updated') +# if not is_valid_provider_id(providerId): +# return JSONResponse('Invalid provider id', 400) +# exec_in_proc(target=update_access_provider, args=(vectordb_loader, op, userIds, providerId)) -@app.post('/deleteSources') -@enabled_guard(app) -def _(sourceIds: Annotated[list[str], Body(embed=True)]): - logger.debug('Delete sources request', extra={ - 'source_ids': sourceIds, - }) +# return JSONResponse('Access updated') - sourceIds = [source.strip() for source in sourceIds if source.strip() != ''] - if len(sourceIds) == 0: - return JSONResponse('No sources provided', 400) +# @app.post('/deleteSources') +# @enabled_guard(app) +# def _(sourceIds: Annotated[list[str], Body(embed=True)]): +# logger.debug('Delete sources request', extra={ +# 'source_ids': sourceIds, +# }) - res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds)) - if res is False: - return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400) +# sourceIds = [source.strip() for source in sourceIds if source.strip() != ''] - return JSONResponse('All valid sources deleted') +# if len(sourceIds) == 0: +# return JSONResponse('No sources provided', 400) +# res = exec_in_proc(target=delete_by_source, args=(vectordb_loader, sourceIds)) +# if res is False: +# return JSONResponse('Error: VectorDB delete failed, check vectordb logs for more info.', 400) -@app.post('/deleteProvider') -@enabled_guard(app) -def _(providerKey: str = Body(embed=True)): - logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey }) +# return JSONResponse('All valid sources deleted') - if value_of(providerKey) is None: - return JSONResponse('Invalid provider key provided', 400) - exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey)) +# @app.post('/deleteProvider') +# @enabled_guard(app) +# def _(providerKey: str = Body(embed=True)): +# logger.debug('Delete sources by provider for all users request', extra={ 'provider_key': providerKey }) - return JSONResponse('All valid sources deleted') +# if value_of(providerKey) is None: +# return JSONResponse('Invalid provider key provided', 400) +# exec_in_proc(target=delete_by_provider, args=(vectordb_loader, providerKey)) -@app.post('/deleteUser') -@enabled_guard(app) -def _(userId: str = Body(embed=True)): - logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId }) +# return JSONResponse('All valid sources deleted') - if value_of(userId) is None: - return JSONResponse('Invalid userId provided', 400) - exec_in_proc(target=delete_user, args=(vectordb_loader, userId)) +# @app.post('/deleteUser') +# @enabled_guard(app) +# def _(userId: str = Body(embed=True)): +# logger.debug('Remove access list for user, and orphaned sources', extra={ 'user_id': userId }) + +# if value_of(userId) is None: +# return JSONResponse('Invalid userId provided', 400) - return JSONResponse('User deleted') +# exec_in_proc(target=delete_user, args=(vectordb_loader, userId)) + +# return JSONResponse('User deleted') # @app.put('/loadSources') @@ -503,15 +508,3 @@ def _(query: Query) -> list[SearchResult]: query.scopeType, query.scopeList, )) - - -@app.get('/downloadLogs') -def download_logs() -> FileResponse: - with tempfile.NamedTemporaryFile('wb', delete=False) as tmp: - with zipfile.ZipFile(tmp, mode='w', compression=zipfile.ZIP_DEFLATED) as zip_file: - files = os.listdir(os.path.join(persistent_storage(), 'logs')) - for file in files: - file_path = os.path.join(persistent_storage(), 'logs', file) - if os.path.isfile(file_path): # Might be a folder (just skip it then) - zip_file.write(file_path) - return FileResponse(tmp.name, media_type='application/zip', filename='docker_logs.zip') diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index cfa9293c..84b974b2 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -19,9 +19,11 @@ from .chain.ingest.injest import embed_sources from .dyn_loader import VectorDBLoader from .types import ( + ActionsQueueItems, + ActionType, AppRole, EmbeddingException, - FilesQueueItem, + FilesQueueItems, IndexingError, IndexingException, LoaderException, @@ -30,7 +32,15 @@ TConfig, ) from .utils import exec_in_proc, get_app_role -from .vectordb.types import DbException +from .vectordb.service import ( + decl_update_access, + delete_by_provider, + delete_by_source, + delete_user, + update_access, + update_access_provider, +) +from .vectordb.types import DbException, SafeDbException APP_ROLE = get_app_role() THREADS = {} @@ -41,6 +51,8 @@ # max concurrent fetches to avoid overloading the NC server or hitting rate limits CONCURRENT_FILE_FETCHES = 10 # todo: config? MAX_FILE_SIZE = 100 * 1024 * 1024 # 100 MB, todo: config? +ACTIONS_BATCH_SIZE = 512 # todo: config? +POLLING_COOLDOWN = 30 class ThreadType(Enum): @@ -201,10 +213,15 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro ) try: - q_items = FilesQueueItem.model_validate(q_items_res) + q_items: FilesQueueItems = FilesQueueItems.model_validate(q_items_res) except ValidationError as e: raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + if not q_items.files and not q_items.content_providers: + LOGGER.debug('No documents to index') + sleep(POLLING_COOLDOWN) + continue + # populate files content and convert to source items fetched_files = {} source_files = {} @@ -305,7 +322,165 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: - ... + try: + vectordb_loader = VectorDBLoader(app_config) + except LoaderException as e: + LOGGER.error('Error initializing vector DB loader, files indexing thread will not start:', exc_info=e) + return + + while True: + if not app_enabled.is_set(): + LOGGER.info('Files indexing thread is stopping as the app is disabled') + return + + try: + nc = NextcloudApp() + q_items_res = nc.ocs( + 'GET', + '/apps/context_chat/queues/actions', + params={ 'n': ACTIONS_BATCH_SIZE } + ) + + try: + q_items: ActionsQueueItems = ActionsQueueItems.model_validate(q_items_res) + except ValidationError as e: + raise Exception(f'Error validating queue items response: {e}\nResponse content: {q_items_res}') from e + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error fetching updates to process, will retry:', exc_info=e) + sleep(5) + continue + except Exception as e: + LOGGER.exception('Error fetching updates to process:', exc_info=e) + sleep(5) + continue + + if not q_items.actions: + LOGGER.debug('No updates to process') + sleep(POLLING_COOLDOWN) + continue + + processed_event_ids = [] + errored_events = {} + for i, (db_id, action_item) in enumerate(q_items.actions.items()): + try: + match action_item.type: + case ActionType.DELETE_SOURCE_IDS: + exec_in_proc(target=delete_by_source, args=(vectordb_loader, action_item.payload.sourceIds)) + + case ActionType.DELETE_PROVIDER_ID: + exec_in_proc(target=delete_by_provider, args=(vectordb_loader, action_item.payload.providerId)) + + case ActionType.DELETE_USER_ID: + exec_in_proc(target=delete_user, args=(vectordb_loader, action_item.payload.userId)) + + case ActionType.UPDATE_ACCESS_SOURCE_ID: + exec_in_proc( + target=update_access, + args=( + vectordb_loader, + action_item.payload.op, + action_item.payload.userIds, + action_item.payload.sourceId, + ), + ) + + case ActionType.UPDATE_ACCESS_PROVIDER_ID: + exec_in_proc( + target=update_access_provider, + args=( + vectordb_loader, + action_item.payload.op, + action_item.payload.userIds, + action_item.payload.providerId, + ), + ) + + case ActionType.UPDATE_ACCESS_DECL_SOURCE_ID: + exec_in_proc( + target=decl_update_access, + args=( + vectordb_loader, + action_item.payload.userIds, + action_item.payload.sourceId, + ), + ) + + case _: + LOGGER.warning( + f'Unknown action type {action_item.type} for action id {db_id},' + f' type {action_item.type}, skipping and marking as processed', + extra={ 'action_item': action_item }, + ) + continue + + processed_event_ids.append(db_id) + except SafeDbException as e: + LOGGER.debug( + f'Safe DB error thrown while processing action id {db_id}, type {action_item.type},' + " it's safe to ignore and mark as processed.", + exc_info=e, + extra={ 'action_item': action_item }, + ) + processed_event_ids.append(db_id) + continue + + except (LoaderException, DbException) as e: + LOGGER.error( + f'Error deleting source for action id {db_id}, type {action_item.type}: {e}', + exc_info=e, + extra={ 'action_item': action_item }, + ) + errored_events[db_id] = str(e) + continue + + except Exception as e: + LOGGER.error( + f'Unexpected error processing action id {db_id}, type {action_item.type}: {e}', + exc_info=e, + extra={ 'action_item': action_item }, + ) + errored_events[db_id] = f'Unexpected error: {e}' + continue + + if (i + 1) % 20 == 0: + LOGGER.debug(f'Processed {i + 1} updates, sleeping for a bit to allow other operations to proceed') + sleep(2) + + LOGGER.info(f'Processed {len(processed_event_ids)} updates with {len(errored_events)} errors', extra={ + 'errored_events': errored_events, + }) + + if len(processed_event_ids) == 0: + LOGGER.debug('No updates processed, skipping reporting to the server') + continue + + try: + nc.ocs( + 'DELETE', + '/apps/context_chat/queues/actions/', + json={ 'actions': processed_event_ids }, + ) + except ( + niquests.exceptions.ConnectionError, + niquests.exceptions.Timeout, + ) as e: + LOGGER.info('Temporary error reporting processed updates, will retry:', exc_info=e) + sleep(5) + with suppress(Exception): + nc = NextcloudApp() + nc.ocs( + 'DELETE', + '/apps/context_chat/queues/actions/', + json={ 'ids': processed_event_ids }, + ) + continue + except Exception as e: + LOGGER.exception('Error reporting processed updates:', exc_info=e) + sleep(5) + continue def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None: diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 97d48ce6..849c2e31 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -4,12 +4,13 @@ # from enum import Enum from io import BytesIO -from typing import Self +from typing import Annotated, Literal, Self -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, Discriminator, field_validator from .mimetype_list import SUPPORTED_MIMETYPES from .utils import is_valid_provider_id, is_valid_source_id +from .vectordb.types import UpdateAccessOp __all__ = [ 'DEFAULT_EM_MODEL_ALIAS', @@ -182,7 +183,7 @@ def validate_content(cls, v): raise ValueError('Content must be either a non-empty string or a non-empty BytesIO') -class FilesQueueItem(BaseModel): +class FilesQueueItems(BaseModel): files: dict[int, ReceivedFileItem] # [db id]: FileItem content_providers: dict[int, SourceItem] # [db id]: SourceItem @@ -198,3 +199,179 @@ def __init__(self, message: str, retryable: bool = False): class IndexingError(BaseModel): error: str retryable: bool = False + + +# PHP equivalent for reference: + +# class ActionType { +# // { sourceIds: array } +# public const DELETE_SOURCE_IDS = 'delete_source_ids'; +# // { providerId: string } +# public const DELETE_PROVIDER_ID = 'delete_provider_id'; +# // { userId: string } +# public const DELETE_USER_ID = 'delete_user_id'; +# // { op: string, userIds: array, sourceId: string } +# public const UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id'; +# // { op: string, userIds: array, providerId: string } +# public const UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id'; +# // { userIds: array, sourceId: string } +# public const UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id'; +# } + + +def _validate_source_ids(source_ids: list[str]) -> list[str]: + if ( + not isinstance(source_ids, list) + or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids) + or len(source_ids) == 0 + ): + raise ValueError('sourceIds must be a non-empty list of non-empty strings') + return [sid.strip() for sid in source_ids] + + +def _validate_provider_id(provider_id: str) -> str: + if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id): + raise ValueError('providerId must be a valid provider ID string') + return provider_id + + +def _validate_user_ids(user_ids: list[str]) -> list[str]: + if ( + not isinstance(user_ids, list) + or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids) + or len(user_ids) == 0 + ): + raise ValueError('userIds must be a non-empty list of non-empty strings') + return [uid.strip() for uid in user_ids] + + +class ActionPayloadDeleteSourceIds(BaseModel): + sourceIds: list[str] + + @field_validator('sourceIds', mode='after') + def validate_source_ids(self) -> Self: + self.sourceIds = _validate_source_ids(self.sourceIds) + return self + + +class ActionPayloadDeleteProviderId(BaseModel): + providerId: str + + @field_validator('providerId') + def validate_provider_id(self) -> Self: + self.providerId = _validate_provider_id(self.providerId) + return self + + +class ActionPayloadDeleteUserId(BaseModel): + userId: str + + @field_validator('userId') + def validate_user_id(self) -> Self: + self.userId = _validate_user_ids([self.userId])[0] + return self + + +class ActionPayloadUpdateAccessSourceId(BaseModel): + op: UpdateAccessOp + userIds: list[str] + sourceId: str + + @field_validator('userIds', mode='after') + def validate_user_ids(self) -> Self: + self.userIds = _validate_user_ids(self.userIds) + return self + + @field_validator('sourceId') + def validate_source_id(self) -> Self: + self.sourceId = _validate_source_ids([self.sourceId])[0] + return self + + +class ActionPayloadUpdateAccessProviderId(BaseModel): + op: UpdateAccessOp + userIds: list[str] + providerId: str + + @field_validator('userIds', mode='after') + def validate_user_ids(self) -> Self: + self.userIds = _validate_user_ids(self.userIds) + return self + + @field_validator('providerId') + def validate_provider_id(self) -> Self: + self.providerId = _validate_provider_id(self.providerId) + return self + + +class ActionPayloadUpdateAccessDeclSourceId(BaseModel): + userIds: list[str] + sourceId: str + + @field_validator('userIds', mode='after') + def validate_user_ids(self) -> Self: + self.userIds = _validate_user_ids(self.userIds) + return self + + @field_validator('sourceId') + def validate_source_id(self) -> Self: + self.sourceId = _validate_source_ids([self.sourceId])[0] + return self + + +class ActionType(str, Enum): + DELETE_SOURCE_IDS = 'delete_source_ids' + DELETE_PROVIDER_ID = 'delete_provider_id' + DELETE_USER_ID = 'delete_user_id' + UPDATE_ACCESS_SOURCE_ID = 'update_access_source_id' + UPDATE_ACCESS_PROVIDER_ID = 'update_access_provider_id' + UPDATE_ACCESS_DECL_SOURCE_ID = 'update_access_decl_source_id' + + +class CommonActionsQueueItem(BaseModel): + id: int + + +class ActionsQueueItemDeleteSourceIds(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_SOURCE_IDS] + payload: ActionPayloadDeleteSourceIds + + +class ActionsQueueItemDeleteProviderId(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_PROVIDER_ID] + payload: ActionPayloadDeleteProviderId + + +class ActionsQueueItemDeleteUserId(CommonActionsQueueItem): + type: Literal[ActionType.DELETE_USER_ID] + payload: ActionPayloadDeleteUserId + + +class ActionsQueueItemUpdateAccessSourceId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_SOURCE_ID] + payload: ActionPayloadUpdateAccessSourceId + + +class ActionsQueueItemUpdateAccessProviderId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_PROVIDER_ID] + payload: ActionPayloadUpdateAccessProviderId + + +class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem): + type: Literal[ActionType.UPDATE_ACCESS_DECL_SOURCE_ID] + payload: ActionPayloadUpdateAccessDeclSourceId + + +ActionsQueueItem = Annotated[ + ActionsQueueItemDeleteSourceIds + | ActionsQueueItemDeleteProviderId + | ActionsQueueItemDeleteUserId + | ActionsQueueItemUpdateAccessSourceId + | ActionsQueueItemUpdateAccessProviderId + | ActionsQueueItemUpdateAccessDeclSourceId, + Discriminator('type'), +] + + +class ActionsQueueItems(BaseModel): + actions: dict[int, ActionsQueueItem] diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py index 5a5ded35..8b0c864b 100644 --- a/context_chat_backend/vectordb/pgvector.py +++ b/context_chat_backend/vectordb/pgvector.py @@ -324,7 +324,7 @@ def update_access( ) match op: - case UpdateAccessOp.allow: + case UpdateAccessOp.ALLOW: stmt = ( postgresql_dialects.insert(AccessListStore) .values([ @@ -339,7 +339,7 @@ def update_access( session.execute(stmt) session.commit() - case UpdateAccessOp.deny: + case UpdateAccessOp.DENY: stmt = ( sa.delete(AccessListStore) .filter(AccessListStore.uid.in_(user_ids)) @@ -427,11 +427,12 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None stmt_doc = ( sa.delete(DocumentsStore) .filter(DocumentsStore.source_id.in_(source_ids)) - .returning(DocumentsStore.chunks) + .returning(DocumentsStore.chunks, DocumentsStore.source_id) ) doc_result = session.execute(stmt_doc) chunks_to_delete = [str(c) for res in doc_result for c in res.chunks] + deleted_source_ids = [str(res.source_id) for res in doc_result] except Exception as e: session.rollback() if session_ is None: @@ -457,6 +458,14 @@ def delete_source_ids(self, source_ids: list[str], session_: orm.Session | None if session_ is None: session.close() + undeleted_source_ids = set(source_ids) - set(deleted_source_ids) + if len(undeleted_source_ids) > 0: + logger.info( + f'Source ids {undeleted_source_ids} were not deleted from documents store.' + ' This can be due to the source ids not existing in the documents store due to' + ' already being deleted or not having been added yet.' + ) + def delete_provider(self, provider_key: str): with self.session_maker() as session: try: @@ -470,6 +479,10 @@ def delete_provider(self, provider_key: str): doc_result = session.execute(stmt) chunks_to_delete = [str(c) for res in doc_result for c in res.chunks] + + if len(chunks_to_delete) == 0: + logger.info(f'No documents found for provider {provider_key} when attempting to delete provider.') + return except Exception as e: session.rollback() raise DbException('Error: deleting provider from docs store') from e @@ -503,7 +516,16 @@ def delete_user(self, user_id: str): session.rollback() raise DbException('Error: deleting user from access list') from e - self._cleanup_if_orphaned(list(source_ids), session) + try: + self._cleanup_if_orphaned(list(source_ids), session) + except Exception as e: + session.rollback() + logger.error( + 'Error cleaning up orphaned source ids after deleting user, manual cleanup might be required', + exc_info=e, + extra={ 'source_ids': list(source_ids) }, + ) + raise DbException('Error: cleaning up orphaned source ids after deleting user') from e def count_documents_by_provider(self) -> dict[str, int]: try: diff --git a/context_chat_backend/vectordb/service.py b/context_chat_backend/vectordb/service.py index 620a0b39..06a8e19e 100644 --- a/context_chat_backend/vectordb/service.py +++ b/context_chat_backend/vectordb/service.py @@ -6,27 +6,42 @@ from ..dyn_loader import VectorDBLoader from .base import BaseVectorDB -from .types import DbException, UpdateAccessOp +from .types import UpdateAccessOp logger = logging.getLogger('ccb.vectordb') -# todo: return source ids that were successfully deleted + def delete_by_source(vectordb_loader: VectorDBLoader, source_ids: list[str]): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('deleting sources by id', extra={ 'source_ids': source_ids }) - try: - db.delete_source_ids(source_ids) - except Exception as e: - raise DbException('Error: Vectordb delete_source_ids error') from e + db.delete_source_ids(source_ids) def delete_by_provider(vectordb_loader: VectorDBLoader, provider_key: str): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug(f'deleting sources by provider: {provider_key}') db.delete_provider(provider_key) def delete_user(vectordb_loader: VectorDBLoader, user_id: str): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug(f'deleting user from db: {user_id}') db.delete_user(user_id) @@ -38,6 +53,13 @@ def update_access( user_ids: list[str], source_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('updating access', extra={ 'op': op, 'user_ids': user_ids, 'source_id': source_id }) db.update_access(op, user_ids, source_id) @@ -49,6 +71,13 @@ def update_access_provider( user_ids: list[str], provider_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('updating access by provider', extra={ 'op': op, 'user_ids': user_ids, 'provider_id': provider_id }) db.update_access_provider(op, user_ids, provider_id) @@ -59,11 +88,24 @@ def decl_update_access( user_ids: list[str], source_id: str, ): + ''' + Raises + ------ + DbException + LoaderException + SafeDbException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('decl update access', extra={ 'user_ids': user_ids, 'source_id': source_id }) db.decl_update_access(user_ids, source_id) def count_documents_by_provider(vectordb_loader: VectorDBLoader): + ''' + Raises + ------ + DbException + LoaderException + ''' db: BaseVectorDB = vectordb_loader.load() logger.debug('counting documents by provider') return db.count_documents_by_provider() diff --git a/context_chat_backend/vectordb/types.py b/context_chat_backend/vectordb/types.py index df5c6dd7..30811797 100644 --- a/context_chat_backend/vectordb/types.py +++ b/context_chat_backend/vectordb/types.py @@ -14,5 +14,5 @@ class SafeDbException(Exception): class UpdateAccessOp(Enum): - allow = 'allow' - deny = 'deny' + ALLOW = 'allow' + DENY = 'deny' From 3f2d57d3d277c6e7ab487f9c03ffc98bf4350ba0 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 14:33:39 +0530 Subject: [PATCH 08/17] fix pydantic types Signed-off-by: Anupam Kumar --- context_chat_backend/types.py | 180 ++++++++++++---------------------- context_chat_backend/utils.py | 10 -- 2 files changed, 64 insertions(+), 126 deletions(-) diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 849c2e31..8577c931 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -2,14 +2,14 @@ # SPDX-FileCopyrightText: 2024 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +import re from enum import Enum from io import BytesIO from typing import Annotated, Literal, Self -from pydantic import BaseModel, Discriminator, field_validator +from pydantic import AfterValidator, BaseModel, Discriminator, field_validator, model_validator from .mimetype_list import SUPPORTED_MIMETYPES -from .utils import is_valid_provider_id, is_valid_source_id from .vectordb.types import UpdateAccessOp __all__ = [ @@ -26,6 +26,49 @@ FILES_PROVIDER_ID = 'files__default' +def is_valid_source_id(source_id: str) -> bool: + # note the ":" in the item id part + return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None + + +def is_valid_provider_id(provider_id: str) -> bool: + return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None + + +def _validate_source_ids(source_ids: list[str]) -> list[str]: + if ( + not isinstance(source_ids, list) + or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids) + or len(source_ids) == 0 + ): + raise ValueError('sourceIds must be a non-empty list of non-empty strings') + return [sid.strip() for sid in source_ids] + + +def _validate_source_id(source_id: str) -> str: + return _validate_source_ids([source_id])[0] + + +def _validate_provider_id(provider_id: str) -> str: + if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id): + raise ValueError('providerId must be a valid provider ID string') + return provider_id + + +def _validate_user_ids(user_ids: list[str]) -> list[str]: + if ( + not isinstance(user_ids, list) + or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids) + or len(user_ids) == 0 + ): + raise ValueError('userIds must be a non-empty list of non-empty strings') + return [uid.strip() for uid in user_ids] + + +def _validate_user_id(user_id: str) -> str: + return _validate_user_ids([user_id])[0] + + class TEmbeddingAuthApiKey(BaseModel): apikey: str @@ -89,12 +132,13 @@ class AppRole(str, Enum): class CommonSourceItem(BaseModel): - userIds: list[str] - reference: str # source_id of the form "appId__providerId: itemId" + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + # source_id of the form "appId__providerId: itemId" + reference: Annotated[str, AfterValidator(_validate_source_id)] title: str modified: int | str # todo: int/string? type: str - provider: str + provider: Annotated[str, AfterValidator(_validate_provider_id)] size: int @field_validator('modified', mode='before') @@ -116,42 +160,13 @@ def validate_strings_non_empty(cls, v): raise ValueError('Must be a non-empty string') return v.strip() - @field_validator('userIds', mode='after') - def validate_user_ids(self) -> Self: - if ( - not isinstance(self.userIds, list) - or not all( - isinstance(uid, str) - and uid.strip() != '' - for uid in self.userIds - ) - or len(self.userIds) == 0 - ): - raise ValueError('userIds must be a non-empty list of non-empty strings') - self.userIds = [uid.strip() for uid in self.userIds] - return self - - @field_validator('reference', mode='after') - def validate_reference_format(self) -> Self: - # validate reference format: "appId__providerId: itemId" - if not is_valid_source_id(self.reference): - raise ValueError('Invalid reference format, must be "appId__providerId: itemId"') - return self - - @field_validator('provider', mode='after') - def validate_provider_format(self) -> Self: - # validate provider format: "appId__providerId" - if not is_valid_provider_id(self.provider): - raise ValueError('Invalid provider format, must be "appId__providerId"') - return self - - @field_validator('type', mode='after') + @model_validator(mode='after') def validate_type(self) -> Self: if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES: raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}') return self - @field_validator('size', mode='after') + @model_validator(mode='after') def validate_size(self) -> Self: if not isinstance(self.size, int) or self.size < 0: raise ValueError(f'Invalid size value: {self.size}, must be a non-negative integer') @@ -182,6 +197,10 @@ def validate_content(cls, v): return v raise ValueError('Content must be either a non-empty string or a non-empty BytesIO') + class Config: + # to allow BytesIO in content field + arbitrary_types_allowed = True + class FilesQueueItems(BaseModel): files: dict[int, ReceivedFileItem] # [db id]: FileItem @@ -219,104 +238,33 @@ class IndexingError(BaseModel): # } -def _validate_source_ids(source_ids: list[str]) -> list[str]: - if ( - not isinstance(source_ids, list) - or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids) - or len(source_ids) == 0 - ): - raise ValueError('sourceIds must be a non-empty list of non-empty strings') - return [sid.strip() for sid in source_ids] - - -def _validate_provider_id(provider_id: str) -> str: - if not isinstance(provider_id, str) or not is_valid_provider_id(provider_id): - raise ValueError('providerId must be a valid provider ID string') - return provider_id - - -def _validate_user_ids(user_ids: list[str]) -> list[str]: - if ( - not isinstance(user_ids, list) - or not all(isinstance(uid, str) and uid.strip() != '' for uid in user_ids) - or len(user_ids) == 0 - ): - raise ValueError('userIds must be a non-empty list of non-empty strings') - return [uid.strip() for uid in user_ids] - - class ActionPayloadDeleteSourceIds(BaseModel): - sourceIds: list[str] - - @field_validator('sourceIds', mode='after') - def validate_source_ids(self) -> Self: - self.sourceIds = _validate_source_ids(self.sourceIds) - return self + sourceIds: Annotated[list[str], AfterValidator(_validate_source_ids)] class ActionPayloadDeleteProviderId(BaseModel): - providerId: str - - @field_validator('providerId') - def validate_provider_id(self) -> Self: - self.providerId = _validate_provider_id(self.providerId) - return self + providerId: Annotated[str, AfterValidator(_validate_provider_id)] class ActionPayloadDeleteUserId(BaseModel): - userId: str - - @field_validator('userId') - def validate_user_id(self) -> Self: - self.userId = _validate_user_ids([self.userId])[0] - return self + userId: Annotated[str, AfterValidator(_validate_user_id)] class ActionPayloadUpdateAccessSourceId(BaseModel): op: UpdateAccessOp - userIds: list[str] - sourceId: str - - @field_validator('userIds', mode='after') - def validate_user_ids(self) -> Self: - self.userIds = _validate_user_ids(self.userIds) - return self - - @field_validator('sourceId') - def validate_source_id(self) -> Self: - self.sourceId = _validate_source_ids([self.sourceId])[0] - return self + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + sourceId: Annotated[str, AfterValidator(_validate_source_id)] class ActionPayloadUpdateAccessProviderId(BaseModel): op: UpdateAccessOp - userIds: list[str] - providerId: str - - @field_validator('userIds', mode='after') - def validate_user_ids(self) -> Self: - self.userIds = _validate_user_ids(self.userIds) - return self - - @field_validator('providerId') - def validate_provider_id(self) -> Self: - self.providerId = _validate_provider_id(self.providerId) - return self + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + providerId: Annotated[str, AfterValidator(_validate_provider_id)] class ActionPayloadUpdateAccessDeclSourceId(BaseModel): - userIds: list[str] - sourceId: str - - @field_validator('userIds', mode='after') - def validate_user_ids(self) -> Self: - self.userIds = _validate_user_ids(self.userIds) - return self - - @field_validator('sourceId') - def validate_source_id(self) -> Self: - self.sourceId = _validate_source_ids([self.sourceId])[0] - return self + userIds: Annotated[list[str], AfterValidator(_validate_user_ids)] + sourceId: Annotated[str, AfterValidator(_validate_source_id)] class ActionType(str, Enum): diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py index 224f466e..c7e588b3 100644 --- a/context_chat_backend/utils.py +++ b/context_chat_backend/utils.py @@ -5,7 +5,6 @@ import logging import multiprocessing as mp import os -import re import traceback from collections.abc import Callable from functools import partial, wraps @@ -102,15 +101,6 @@ def exec_in_proc(group=None, target=None, name=None, args=(), kwargs={}, *, daem return result['value'] -def is_valid_source_id(source_id: str) -> bool: - # note the ":" in the item id part - return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+: [a-zA-Z0-9:-]+$', source_id) is not None - - -def is_valid_provider_id(provider_id: str) -> bool: - return re.match(r'^[a-zA-Z0-9_-]+__[a-zA-Z0-9_-]+$', provider_id) is not None - - def timed(func: Callable): ''' Decorator to time a function From fc5ca61a9a0be4e84584aea96ddca4fbdc7d7289 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 14:34:55 +0530 Subject: [PATCH 09/17] fix: use a dedicated event to allow app halt without app being disabled Signed-off-by: Anupam Kumar --- context_chat_backend/controller.py | 1 + context_chat_backend/task_fetcher.py | 28 ++++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 580416f7..55206ca0 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -88,6 +88,7 @@ async def lifespan(app: FastAPI): nc = NextcloudApp() if nc.enabled_state: app_enabled.set() + start_bg_threads(app_config, app_enabled) logger.info(f'App enable state at startup: {app_enabled.is_set()}') t = Thread(target=background_thread_task, args=()) t.start() diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 84b974b2..e93eac34 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -44,6 +44,7 @@ APP_ROLE = get_app_role() THREADS = {} +THREAD_STOP_EVENT = Event() LOGGER = logging.getLogger('ccb.task_fetcher') FILES_INDEXING_BATCH_SIZE = 64 # todo: config? # divides the batch into these many chunks @@ -199,8 +200,8 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro while True: - if not app_enabled.is_set(): - LOGGER.info('Files indexing thread is stopping as the app is disabled') + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Files indexing thread is stopping due to stop event being set') return try: @@ -329,8 +330,8 @@ def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: return while True: - if not app_enabled.is_set(): - LOGGER.info('Files indexing thread is stopping as the app is disabled') + if THREAD_STOP_EVENT.is_set(): + LOGGER.info('Updates processing thread is stopping due to stop event being set') return try: @@ -490,6 +491,14 @@ def request_processing_thread(app_config: TConfig, app_enabled: Event) -> None: def start_bg_threads(app_config: TConfig, app_enabled: Event): match APP_ROLE: case AppRole.INDEXING | AppRole.NORMAL: + if ( + ThreadType.FILES_INDEXING in THREADS + or ThreadType.UPDATES_PROCESSING in THREADS + ): + LOGGER.info('Background threads already running, skipping start') + return + + THREAD_STOP_EVENT.clear() THREADS[ThreadType.FILES_INDEXING] = Thread( target=files_indexing_thread, args=(app_config, app_enabled), @@ -502,7 +511,13 @@ def start_bg_threads(app_config: TConfig, app_enabled: Event): ) THREADS[ThreadType.FILES_INDEXING].start() THREADS[ThreadType.UPDATES_PROCESSING].start() + case AppRole.RP | AppRole.NORMAL: + if ThreadType.REQUEST_PROCESSING in THREADS: + LOGGER.info('Background threads already running, skipping start') + return + + THREAD_STOP_EVENT.clear() THREADS[ThreadType.REQUEST_PROCESSING] = Thread( target=request_processing_thread, args=(app_config, app_enabled), @@ -516,12 +531,17 @@ def wait_for_bg_threads(): case AppRole.INDEXING | AppRole.NORMAL: if (ThreadType.FILES_INDEXING not in THREADS or ThreadType.UPDATES_PROCESSING not in THREADS): return + + THREAD_STOP_EVENT.set() THREADS[ThreadType.FILES_INDEXING].join() THREADS[ThreadType.UPDATES_PROCESSING].join() THREADS.pop(ThreadType.FILES_INDEXING) THREADS.pop(ThreadType.UPDATES_PROCESSING) + case AppRole.RP | AppRole.NORMAL: if (ThreadType.REQUEST_PROCESSING not in THREADS): return + + THREAD_STOP_EVENT.set() THREADS[ThreadType.REQUEST_PROCESSING].join() THREADS.pop(ThreadType.REQUEST_PROCESSING) From 23d89b980e1e4a2ac205ce88dfde255f1fb41f3e Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 17:54:48 +0530 Subject: [PATCH 10/17] fix fetch url and pydantic types Signed-off-by: Anupam Kumar --- context_chat_backend/task_fetcher.py | 14 +++++++------- context_chat_backend/types.py | 17 +++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index e93eac34..5784d12b 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -80,7 +80,7 @@ async def __fetch_file_content( # a file pointer for storing the stream in memory until it is consumed fp = BytesIO() await nc._session.download2fp( - url_path=f'/apps/context_chat/files/{file_id}', + url_path=f'/ocs/v2.php/apps/context_chat/files/{file_id}', fp=fp, dav=False, params={ 'userId': user_id }, @@ -209,7 +209,7 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro # todo: add the 'size' param to the return of this call. q_items_res = nc.ocs( 'GET', - '/apps/context_chat/queues/documents', + '/ocs/v2.php/apps/context_chat/queues/documents', params={ 'n': FILES_INDEXING_BATCH_SIZE } ) @@ -292,7 +292,7 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro try: nc.ocs( 'DELETE', - '/apps/context_chat/queues/documents/', + '/ocs/v2.php/apps/context_chat/queues/documents/', json={ 'files': to_delete_file_ids, 'content_providers': to_delete_provider_ids, @@ -308,7 +308,7 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro nc = NextcloudApp() nc.ocs( 'DELETE', - '/apps/context_chat/queues/documents/', + '/ocs/v2.php/apps/context_chat/queues/documents/', json={ 'files': to_delete_file_ids, 'content_providers': to_delete_provider_ids, @@ -338,7 +338,7 @@ def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: nc = NextcloudApp() q_items_res = nc.ocs( 'GET', - '/apps/context_chat/queues/actions', + '/ocs/v2.php/apps/context_chat/queues/actions', params={ 'n': ACTIONS_BATCH_SIZE } ) @@ -461,7 +461,7 @@ def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: try: nc.ocs( 'DELETE', - '/apps/context_chat/queues/actions/', + '/ocs/v2.php/apps/context_chat/queues/actions/', json={ 'actions': processed_event_ids }, ) except ( @@ -474,7 +474,7 @@ def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: nc = NextcloudApp() nc.ocs( 'DELETE', - '/apps/context_chat/queues/actions/', + '/ocs/v2.php/apps/context_chat/queues/actions/', json={ 'ids': processed_event_ids }, ) continue diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 8577c931..972756fa 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -136,10 +136,10 @@ class CommonSourceItem(BaseModel): # source_id of the form "appId__providerId: itemId" reference: Annotated[str, AfterValidator(_validate_source_id)] title: str - modified: int | str # todo: int/string? + modified: int type: str provider: Annotated[str, AfterValidator(_validate_provider_id)] - size: int + size: float @field_validator('modified', mode='before') @classmethod @@ -160,18 +160,19 @@ def validate_strings_non_empty(cls, v): raise ValueError('Must be a non-empty string') return v.strip() + @field_validator('size') + @classmethod + def validate_size(cls, v): + if isinstance(v, int | float) and v >= 0: + return float(v) + raise ValueError(f'Invalid size value: {v}, must be a non-negative number') + @model_validator(mode='after') def validate_type(self) -> Self: if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES: raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}') return self - @model_validator(mode='after') - def validate_size(self) -> Self: - if not isinstance(self.size, int) or self.size < 0: - raise ValueError(f'Invalid size value: {self.size}, must be a non-negative integer') - return self - class ReceivedFileItem(CommonSourceItem): content: None From c8399b581812fd89723fed2a0358ba3a7eb04118 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 18:52:35 +0530 Subject: [PATCH 11/17] fix: use the correct file id Signed-off-by: Anupam Kumar --- context_chat_backend/controller.py | 9 ++-- context_chat_backend/task_fetcher.py | 79 +++++++++++++++++----------- context_chat_backend/types.py | 22 +++++++- 3 files changed, 75 insertions(+), 35 deletions(-) diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 55206ca0..797ba201 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -24,7 +24,6 @@ from contextlib import asynccontextmanager from functools import wraps from threading import Event, Thread -from time import sleep from typing import Any from fastapi import FastAPI, Request @@ -130,9 +129,11 @@ async def lifespan(app: FastAPI): # logger background thread def background_thread_task(): - while(True): - logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing}) - sleep(10) + # todo + # while(True): + # logger.info(f'Currently indexing {len(_indexing)} documents (filename, size): ', extra={'_indexing': _indexing}) + # sleep(10) + ... # exception handlers diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 5784d12b..0442cd53 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -125,15 +125,29 @@ async def __fetch_files_content( semaphore = asyncio.Semaphore(CONCURRENT_FILE_FETCHES) tasks = [] - for file_id, file_item in files.items(): - if file_item.size > MAX_FILE_SIZE: + for db_id, file in files.items(): + try: + # to detect any validation errors but it should not happen since file.reference is validated + file.file_id # noqa: B018 + except ValueError as e: + LOGGER.error( + f'Invalid file reference format for db id {db_id}, file reference {file.reference}: {e}', + exc_info=e, + ) + source_items[db_id] = IndexingError( + error=f'Invalid file reference format: {file.reference}', + retryable=False, + ) + continue + + if file.size > MAX_FILE_SIZE: LOGGER.info( - f'Skipping file id {file_id}, source id {file_item.reference} due to size' - f' {(file_item.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB', + f'Skipping db id {db_id}, file id {file.file_id}, source id {file.reference} due to size' + f' {(file.size/(1024*1024)):.2f} MiB exceeding the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB', ) - source_items[file_id] = IndexingError( + source_items[db_id] = IndexingError( error=( - f'File size {(file_item.size/(1024*1024)):.2f} MiB' + f'File size {(file.size/(1024*1024)):.2f} MiB' f' exceeds the limit {(MAX_FILE_SIZE/(1024*1024)):.2f} MiB' ), retryable=False, @@ -141,39 +155,44 @@ async def __fetch_files_content( continue # todo: perform the existing file check before fetching the content to avoid unnecessary fetches # any user id from the list should have read access to the file - tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file_id, file_item.userIds[0]))) + tasks.append(asyncio.ensure_future(__fetch_file_content(semaphore, file.file_id, file.userIds[0]))) results = await asyncio.gather(*tasks, return_exceptions=True) - for (file_id, file_item), result in zip(files.items(), results, strict=True): + for (db_id, file), result in zip(files.items(), results, strict=True): if isinstance(result, IndexingException): LOGGER.error( - f'Error fetching content for file id {file_id}, reference {file_item.reference}: {result}', + f'Error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}' + f': {result}', exc_info=result, ) - source_items[file_id] = IndexingError( + source_items[db_id] = IndexingError( error=str(result), retryable=result.retryable, ) elif isinstance(result, str) or isinstance(result, BytesIO): - source_items[file_id] = SourceItem( - **file_item.model_dump(), - content=result, + source_items[db_id] = SourceItem( + **{ + **file.model_dump(), + 'content': result, + } ) elif isinstance(result, BaseException): LOGGER.error( - f'Unexpected error fetching content for file id {file_id}, reference {file_item.reference}: {result}', + f'Unexpected error fetching content for db id {db_id}, file id {file.file_id},' + f' reference {file.reference}: {result}', exc_info=result, ) - source_items[file_id] = IndexingError( + source_items[db_id] = IndexingError( error=f'Unexpected error: {result}', retryable=True, ) else: LOGGER.error( - f'Unknown error fetching content for file id {file_id}, reference {file_item.reference}: {result}', + f'Unknown error fetching content for db id {db_id}, file id {file.file_id}, reference {file.reference}' + f': {result}', exc_info=True, ) - source_items[file_id] = IndexingError( + source_items[db_id] = IndexingError( error='Unknown error', retryable=True, ) @@ -232,11 +251,11 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro if q_items.files: fetched_files = asyncio.run(__fetch_files_content(q_items.files)) - for file_id, result in fetched_files.items(): + for db_id, result in fetched_files.items(): if isinstance(result, SourceItem): - source_files[file_id] = result + source_files[db_id] = result else: - source_errors[file_id] = result + source_errors[db_id] = result files_result = {} providers_result = {} @@ -257,8 +276,8 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro ): LOGGER.error('Some sources failed to index', extra={ 'file_errors': { - file_id: error - for file_id, error in files_result.items() + db_id: error + for db_id, error in files_result.items() if isinstance(error, IndexingError) }, 'provider_errors': { @@ -280,12 +299,12 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro continue # delete the entries from the PHP side queue where indexing succeeded or the error is not retryable - to_delete_file_ids = [ - file_id for file_id, result in files_result.items() + to_delete_files_db_ids = [ + db_id for db_id, result in files_result.items() if result is None or (isinstance(result, IndexingError) and not result.retryable) ] - to_delete_provider_ids = [ - provider_id for provider_id, result in providers_result.items() + to_delete_provider_db_ids = [ + db_id for db_id, result in providers_result.items() if result is None or (isinstance(result, IndexingError) and not result.retryable) ] @@ -294,8 +313,8 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro 'DELETE', '/ocs/v2.php/apps/context_chat/queues/documents/', json={ - 'files': to_delete_file_ids, - 'content_providers': to_delete_provider_ids, + 'files': to_delete_files_db_ids, + 'content_providers': to_delete_provider_db_ids, }, ) except ( @@ -310,8 +329,8 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro 'DELETE', '/ocs/v2.php/apps/context_chat/queues/documents/', json={ - 'files': to_delete_file_ids, - 'content_providers': to_delete_provider_ids, + 'files': to_delete_files_db_ids, + 'content_providers': to_delete_provider_db_ids, }, ) continue diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 972756fa..9f23e14f 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -7,7 +7,7 @@ from io import BytesIO from typing import Annotated, Literal, Self -from pydantic import AfterValidator, BaseModel, Discriminator, field_validator, model_validator +from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator, model_validator from .mimetype_list import SUPPORTED_MIMETYPES from .vectordb.types import UpdateAccessOp @@ -69,6 +69,21 @@ def _validate_user_id(user_id: str) -> str: return _validate_user_ids([user_id])[0] +def _get_file_id_from_source_ref(source_ref: str) -> int: + ''' + source reference is in the format "FILES_PROVIDER_ID: ". + ''' + if not source_ref.startswith(f'{FILES_PROVIDER_ID}: '): + raise ValueError(f'Source reference does not start with expected prefix: {source_ref}') + + try: + return int(source_ref[len(f'{FILES_PROVIDER_ID}: '):]) + except ValueError as e: + raise ValueError( + f'Invalid source reference format for extracting file_id: {source_ref}' + ) from e + + class TEmbeddingAuthApiKey(BaseModel): apikey: str @@ -177,6 +192,11 @@ def validate_type(self) -> Self: class ReceivedFileItem(CommonSourceItem): content: None + @computed_field + @property + def file_id(self) -> int: + return _get_file_id_from_source_ref(self.reference) + class SourceItem(CommonSourceItem): ''' From 17b32b66b68f61fee050bd200e5ecdc896dbed47 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 19:24:51 +0530 Subject: [PATCH 12/17] fix: wip: improve embeddings exception handling Signed-off-by: Anupam Kumar --- context_chat_backend/network_em.py | 13 +++++++++---- context_chat_backend/task_fetcher.py | 1 + context_chat_backend/vectordb/pgvector.py | 17 ++++++----------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/context_chat_backend/network_em.py b/context_chat_backend/network_em.py index 18bb11f4..d39ea56a 100644 --- a/context_chat_backend/network_em.py +++ b/context_chat_backend/network_em.py @@ -79,6 +79,7 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] raise FatalEmbeddingException(response.text) if response.status_code // 100 != 2: raise EmbeddingException(response.text) + # todo: rework exception handling and their downstream interpretation except FatalEmbeddingException as e: logger.error('Fatal error while getting embeddings: %s', str(e), exc_info=e) raise e @@ -108,10 +109,14 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float] logger.error('Unexpected error while getting embeddings', exc_info=e) raise EmbeddingException('Error: unexpected error while getting embeddings') from e - # converts TypedDict to a pydantic model - resp = CreateEmbeddingResponse(**response.json()) - if isinstance(input_, str): - return resp['data'][0]['embedding'] + try: + # converts TypedDict to a pydantic model + resp = CreateEmbeddingResponse(**response.json()) + if isinstance(input_, str): + return resp['data'][0]['embedding'] + except Exception as e: + logger.error('Error parsing embedding response', exc_info=e) + raise EmbeddingException('Error: failed to parse embedding response') from e # only one embedding in d['embedding'] since truncate is True return [d['embedding'] for d in resp['data']] # pyright: ignore[reportReturnType] diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 0442cd53..51f98e7d 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -261,6 +261,7 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro providers_result = {} chunk_size = FILES_INDEXING_BATCH_SIZE // PARALLEL_FILE_PARSING + # todo: do it in asyncio, it's not truly parallel yet # chunk file parsing for better file operation parallelism for i in range(0, len(source_files), chunk_size): chunk = dict(list(source_files.items())[i:i+chunk_size]) diff --git a/context_chat_backend/vectordb/pgvector.py b/context_chat_backend/vectordb/pgvector.py index 8b0c864b..8f67c936 100644 --- a/context_chat_backend/vectordb/pgvector.py +++ b/context_chat_backend/vectordb/pgvector.py @@ -16,7 +16,7 @@ from langchain_postgres.vectorstores import Base, PGVector from ..chain.types import InDocument, ScopeType -from ..types import EmbeddingException, IndexingError, RetryableEmbeddingException, SourceItem +from ..types import EmbeddingException, FatalEmbeddingException, IndexingError, RetryableEmbeddingException, SourceItem from ..utils import timed from .base import BaseVectorDB from .types import DbException, SafeDbException, UpdateAccessOp @@ -169,7 +169,11 @@ def add_indocuments(self, indocuments: dict[int, InDocument]) -> dict[int, Index retryable=True, ) continue - except RetryableEmbeddingException as e: + except FatalEmbeddingException as e: + raise EmbeddingException( + f'Fatal error while embedding documents for source {indoc.source_id}: {e}' + ) from e + except (RetryableEmbeddingException, EmbeddingException) as e: # temporary error, continue with the next document logger.exception('Error adding documents to vectordb, should be retried later.', exc_info=e, extra={ 'source_id': indoc.source_id, @@ -179,15 +183,6 @@ def add_indocuments(self, indocuments: dict[int, InDocument]) -> dict[int, Index retryable=True, ) continue - except EmbeddingException as e: - logger.exception('Error adding documents to vectordb', exc_info=e, extra={ - 'source_id': indoc.source_id, - }) - results[php_db_id] = IndexingError( - error=str(e), - retryable=False, - ) - continue except Exception as e: logger.exception('Error adding documents to vectordb', exc_info=e, extra={ 'source_id': indoc.source_id, From 759f2c2d4c6f16c20a9548d6aeb45bde823bc14d Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Wed, 11 Mar 2026 19:44:06 +0530 Subject: [PATCH 13/17] fix(ci): update to the latest changes Signed-off-by: Anupam Kumar --- .github/workflows/integration-test.yml | 104 ++++++++++++++++++------- 1 file changed, 76 insertions(+), 28 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index fb06bafa..9563bcdd 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -199,26 +199,87 @@ jobs: ls -la context_chat_backend/persistent_storage/* sleep 30 # Wait for the em server to get ready - - name: Scan files, baseline - run: | - ./occ files:scan admin - ./occ context_chat:scan admin -m text/plain - - - name: Check python memory usage + - name: Initial memory usage check run: | ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt - - name: Scan files - run: | - ./occ files:scan admin - ./occ context_chat:scan admin -m text/markdown & - ./occ context_chat:scan admin -m text/x-rst - - - name: Check python memory usage + - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files run: | - ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem - ps -p $(cat pid.txt) -o %mem --no-headers > after_scan_mem.txt + success=0 + for i in {1..90}; do + echo "Checking stats, attempt $i..." + + mkfifo error_pipe + stats=$(timeout 5 ./occ context_chat:stats 2>error_pipe) + echo "Stats output:" + echo "$stats" + echo "---" + + # Check for critical errors in output + if echo "$stats" | grep -q "Error during request"; then + echo "Backend connection error detected, retrying..." + rm -f error_pipe + sleep 10 + continue + fi + + # Extract Total eligible files + total_files=$(echo "$stats" | grep -oP 'Total eligible files:\s*\K\d+' || echo "") + + # Extract Indexed documents count (files__default) + indexed_count=$(echo "$stats" | grep -oP "'files__default'\s*=>\s*\K\d+" || echo "") + + # Validate parsed values + if [ -z "$total_files" ] || [ -z "$indexed_count" ]; then + echo "Error: Could not parse stats output properly" + if echo "$stats" | grep -q "Indexed documents:"; then + echo " Indexed documents section found but could not extract count" + fi + rm -f error_pipe + sleep 10 + continue + fi + + echo "Total eligible files: $total_files" + echo "Indexed documents (files__default): $indexed_count" + + # Calculate absolute difference + diff=$((total_files - indexed_count)) + if [ $diff -lt 0 ]; then + diff=$((-diff)) + fi + + # Calculate 2% threshold using bc for floating point support + threshold=$(echo "scale=4; $total_files * 0.02" | bc) + + # Check if difference is within tolerance + if (( $(echo "$diff <= $threshold" | bc -l) )); then + echo "Indexing within 2% tolerance (diff=$diff, threshold=$threshold)" + rm -f error_pipe + success=1 + break + else + pct=$(echo "scale=2; ($diff / $total_files) * 100" | bc) + echo "Outside 2% tolerance: diff=$diff (${pct}%), threshold=$threshold" + fi + + # Check if backend is still alive + ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0") + if [ "$ccb_alive" -eq 0 ]; then + echo "Error: Context Chat Backend process is not running. Exiting." + rm -f error_pipe + exit 1 + fi + + rm -f error_pipe + sleep 10 + done + + if [ $success -ne 1 ]; then + echo "Max attempts reached" + exit 1 + fi - name: Run the prompts run: | @@ -252,19 +313,6 @@ jobs: echo "Memory usage during scan is stable. No memory leak detected." fi - - name: Compare memory usage and detect leak - run: | - initial_mem=$(cat after_scan_mem.txt | tr -d ' ') - final_mem=$(cat after_prompt_mem.txt | tr -d ' ') - echo "Initial Memory Usage: $initial_mem%" - echo "Memory Usage after prompt: $final_mem%" - - if (( $(echo "$final_mem > $initial_mem" | bc -l) )); then - echo "Memory usage has increased during prompt. Possible memory leak detected!" - else - echo "Memory usage during prompt is stable. No memory leak detected." - fi - - name: Show server logs if: always() run: | From dda28b2d848b2306b43510c14146a17ae214763b Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Thu, 12 Mar 2026 16:10:58 +0530 Subject: [PATCH 14/17] fix(ci): use file to store stderr Signed-off-by: Anupam Kumar --- .github/workflows/integration-test.yml | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 9563bcdd..de0f4659 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -210,16 +210,21 @@ jobs: for i in {1..90}; do echo "Checking stats, attempt $i..." - mkfifo error_pipe - stats=$(timeout 5 ./occ context_chat:stats 2>error_pipe) + stats_err=$(mktemp) + stats=$(timeout 5 ./occ context_chat:stats 2>"$stats_err") + stats_exit=$? echo "Stats output:" echo "$stats" + if [ -s "$stats_err" ]; then + echo "Stderr:" + cat "$stats_err" + fi echo "---" + rm -f "$stats_err" # Check for critical errors in output - if echo "$stats" | grep -q "Error during request"; then - echo "Backend connection error detected, retrying..." - rm -f error_pipe + if [ $stats_exit -ne 0 ] || echo "$stats" | grep -q "Error during request"; then + echo "Backend connection error detected (exit=$stats_exit), retrying..." sleep 10 continue fi @@ -236,7 +241,6 @@ jobs: if echo "$stats" | grep -q "Indexed documents:"; then echo " Indexed documents section found but could not extract count" fi - rm -f error_pipe sleep 10 continue fi @@ -256,7 +260,6 @@ jobs: # Check if difference is within tolerance if (( $(echo "$diff <= $threshold" | bc -l) )); then echo "Indexing within 2% tolerance (diff=$diff, threshold=$threshold)" - rm -f error_pipe success=1 break else @@ -268,11 +271,9 @@ jobs: ccb_alive=$(ps -p $(cat pid.txt) -o cmd= | grep -c "main.py" || echo "0") if [ "$ccb_alive" -eq 0 ]; then echo "Error: Context Chat Backend process is not running. Exiting." - rm -f error_pipe exit 1 fi - rm -f error_pipe sleep 10 done From 9806225e28ecff14a035b8fff5bcb7111b82f6e1 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Thu, 12 Mar 2026 17:17:38 +0530 Subject: [PATCH 15/17] fix(ci): add cron jobs Signed-off-by: Anupam Kumar --- .github/workflows/integration-test.yml | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index de0f4659..0d8e4229 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -204,9 +204,18 @@ jobs: ps -p $(cat pid.txt) -o pid,cmd,%mem,rss --sort=-%mem ps -p $(cat pid.txt) -o %mem --no-headers > initial_mem.txt + - name: Run cron jobs + run: | + # every 10 seconds indefinitely + while true; do + php cron.php + sleep 10 + done & + - name: Periodically check context_chat stats for 15 minutes to allow the backend to index the files run: | success=0 + echo "::group::Checking stats periodically for 15 minutes to allow the backend to index the files" for i in {1..90}; do echo "Checking stats, attempt $i..." @@ -277,6 +286,10 @@ jobs: sleep 10 done + echo "::endgroup::" + + ./occ context_chat:stats + if [ $success -ne 1 ]; then echo "Max attempts reached" exit 1 From 84f689601e544295be180ac9baa966efef70d027 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Thu, 12 Mar 2026 17:35:47 +0530 Subject: [PATCH 16/17] fix(ci): do a occ files scan before cron jobs Signed-off-by: Anupam Kumar --- .github/workflows/integration-test.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml index 0d8e4229..58f9f50c 100644 --- a/.github/workflows/integration-test.yml +++ b/.github/workflows/integration-test.yml @@ -169,6 +169,10 @@ jobs: cd .. rm -rf documentation + - name: Run files scan + run: | + ./occ files:scan --all + - name: Setup python 3.11 uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5 with: From 634db900bab129847db20b1202d662fbaa5f89a1 Mon Sep 17 00:00:00 2001 From: Anupam Kumar Date: Fri, 13 Mar 2026 18:05:02 +0530 Subject: [PATCH 17/17] feat: add support for multimodal indexing for images (OCR) and audio (Speech-to-text) files Signed-off-by: Anupam Kumar --- .../chain/ingest/doc_loader.py | 20 +- context_chat_backend/chain/ingest/injest.py | 47 ++- .../chain/ingest/task_proc.py | 289 ++++++++++++++++++ context_chat_backend/mimetype_list.py | 33 +- context_chat_backend/task_fetcher.py | 2 + context_chat_backend/types.py | 25 +- context_chat_backend/utils.py | 19 +- 7 files changed, 420 insertions(+), 15 deletions(-) create mode 100644 context_chat_backend/chain/ingest/task_proc.py diff --git a/context_chat_backend/chain/ingest/doc_loader.py b/context_chat_backend/chain/ingest/doc_loader.py index d26f74b1..b6bc17d8 100644 --- a/context_chat_backend/chain/ingest/doc_loader.py +++ b/context_chat_backend/chain/ingest/doc_loader.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: AGPL-3.0-or-later # +import asyncio import logging import re import tempfile @@ -18,7 +19,8 @@ from pypdf.errors import FileNotDecryptedError as PdfFileNotDecryptedError from striprtf import striprtf -from ...types import SourceItem +from ...types import SourceItem, TaskProcException +from .task_proc import do_ocr, do_transcription logger = logging.getLogger('ccb.doc_loader') @@ -128,6 +130,22 @@ def decode_source(source: SourceItem) -> str | None: if mimetype is None: return None + try: + if mimetype.startswith('image/'): + return asyncio.run(do_ocr(source.userIds[0], source.file_id)) + if mimetype.startswith('audio/'): + return asyncio.run(do_transcription(source.userIds[0], source.file_id)) + except TaskProcException as e: + # todo: convert this to error obj return + # todo: short circuit all other ocr/transcription files when a fatal error arrives + # todo: maybe with a global ttl, with a retryable tag + logger.warning(f'OCR task failed for source file ({source.reference}): {e}') + return None + except ValueError: + # should not happen + logger.warning(f'Unexpected ValueError for source file ({source.reference})') + return None + if isinstance(source.content, str): io_obj = BytesIO(source.content.encode('utf-8', 'ignore')) else: diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index 7369f452..9484ab9f 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -2,22 +2,45 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # +import asyncio import logging import re from langchain.schema import Document from ...dyn_loader import VectorDBLoader -from ...types import IndexingError, SourceItem, TConfig +from ...mimetype_list import AUDIO_MIMETYPES, IMAGE_MIMETYPES, SUPPORTED_MIMETYPES +from ...types import FILES_PROVIDER_ID, IndexingError, SourceItem, TConfig from ...vectordb.base import BaseVectorDB from ...vectordb.types import DbException, SafeDbException, UpdateAccessOp from ..types import InDocument from .doc_loader import decode_source from .doc_splitter import get_splitter_for +from .task_proc import OCR_TASK_TYPE, SPEECH_TO_TEXT_TASK_TYPE, is_task_type_available logger = logging.getLogger('ccb.injest') +def _do_extended_mimetype_validation(source: SourceItem, ocr_available: bool, stt_available: bool) -> None: + ''' + Raises + ------ + ValueError + ''' + + extended_mimetypes = ( + *SUPPORTED_MIMETYPES, + *(([], IMAGE_MIMETYPES)[ocr_available]), + *(([], AUDIO_MIMETYPES)[stt_available]), + ) + + if source.reference.startswith(FILES_PROVIDER_ID) and source.type not in extended_mimetypes: + raise ValueError( + f'Unsupported file type: {source.type} for reference {source.reference}.' + f' OCR available: {ocr_available}, Speech-to-text available: {stt_available}.' + ) + + def _filter_sources( vectordb: BaseVectorDB, sources: dict[int, SourceItem] @@ -33,6 +56,7 @@ def _filter_sources( try: existing_source_ids, to_embed_source_ids = vectordb.check_sources(sources) except Exception as e: + # todo: maybe handle this and other errors as IndexingErrors raise DbException('Error: Vectordb error while checking existing sources in indexing') from e existing_sources = {} @@ -217,5 +241,24 @@ def embed_sources( 'len(source_ids)': len(sources), }) + mime_filtered_sources: dict[int, SourceItem] = {} + mime_errored_sources: dict[int, IndexingError] = {} + + ocr_available = asyncio.run(is_task_type_available(OCR_TASK_TYPE)) + stt_available = asyncio.run(is_task_type_available(SPEECH_TO_TEXT_TASK_TYPE)) + for db_id, source in sources.items(): + try: + _do_extended_mimetype_validation(source, ocr_available, stt_available) + mime_filtered_sources[db_id] = source + except ValueError as e: + mime_errored_sources[db_id] = IndexingError( + error=str(e), + retryable=False, + ) + continue + vectordb = vectordb_loader.load() - return _process_sources(vectordb, config, sources) + processed_sources = _process_sources(vectordb, config, mime_filtered_sources) + processed_sources.update(mime_errored_sources) + + return processed_sources diff --git a/context_chat_backend/chain/ingest/task_proc.py b/context_chat_backend/chain/ingest/task_proc.py new file mode 100644 index 00000000..a6f7e4cd --- /dev/null +++ b/context_chat_backend/chain/ingest/task_proc.py @@ -0,0 +1,289 @@ +# +# SPDX-FileCopyrightText: 2026 Nextcloud GmbH and Nextcloud contributors +# SPDX-License-Identifier: AGPL-3.0-or-later +# + +import asyncio +import json +import logging +import os +from typing import Any, Literal + +import niquests +from nc_py_api import AsyncNextcloudApp, NextcloudException +from pydantic import BaseModel, ValidationError + +from ...types import TaskProcException, TaskProcFatalException +from ...utils import timed_cache_async + +LOGGER = logging.getLogger('ccb.task_proc') +OCR_TASK_TYPE = 'core:image2text:ocr' +SPEECH_TO_TEXT_TASK_TYPE = 'core:audio2text' +CACHE_TTL = 15 * 60 # cache values for 15 minutes +OCP_TASK_PROC_SCHED_RETRIES = 3 +OCP_TASK_TIMEOUT = 20 * 60 # 20 mins to wait for a task to complete + + +class Task(BaseModel): + id: int + status: str + output: dict[str, Any] | None = None + + +class TaskResponse(BaseModel): + task: Task + + +InputShapeType = Literal[ + 'Number', + 'Text', + 'Audio', + 'Image', + 'Video', + 'File', + 'Enum', + 'ListOfNumbers', + 'ListOfTexts', + 'ListOfImages', + 'ListOfAudios', + 'ListOfVideos', + 'ListOfFiles', +] + +class InputShape(BaseModel): + name: str + description: str + type: InputShapeType + + +class InputShapeEnum(BaseModel): + name: str + value: str + + +class TaskType(BaseModel): + name: str + description: str + inputShape: dict[str, InputShape] + inputShapeEnumValues: dict[str, list[InputShapeEnum]] + inputShapeDefaults: dict[str, str | int | float] + optionalInputShape: dict[str, InputShape] + optionalInputShapeEnumValues: dict[str, list[InputShapeEnum]] + optionalInputShapeDefaults: dict[str, str | int | float] + outputShape: dict[str, InputShape] + outputShapeEnumValues: dict[str, list[InputShapeEnum]] + optionalOutputShape: dict[str, InputShape] + optionalOutputShapeEnumValues: dict[str, list[InputShapeEnum]] + + +class TaskTypesResponse(BaseModel): + types: dict[str, TaskType] + + + +def __try_parse_ocs_response(response: niquests.Response | None) -> dict | str: + if response is None or response.text is None: + return 'No response' + try: + ocs_response = json.loads(response.text) + if not (ocs_data := ocs_response.get('ocs', {}).get('data')): + return response.text + return ocs_data + except json.JSONDecodeError: + return response.text + + +async def __schedule_task(user_id: str, task_type: str, custom_id: str, task_input: dict) -> Task: + ''' + Raises + ------ + TaskProcException + ''' + nc = AsyncNextcloudApp() + await nc.set_user(user_id) + + for sched_tries in range(OCP_TASK_PROC_SCHED_RETRIES): + try: + response = await nc.ocs( + 'POST', + '/ocs/v2.php/taskprocessing/schedule', + json={ + 'type': task_type, + 'appId': os.getenv('APP_ID', 'context_chat_backend'), + 'customId': f'ccb-{custom_id}', + 'input': task_input, + }, + ) + try: + task = TaskResponse.model_validate(response).task + LOGGER.debug('TaskProcessing task schedule response', extra={ + 'task': task, + }) + return task + except ValidationError as e: + raise TaskProcException('Failed to parse TaskProcessing task result') from e + except NextcloudException as e: + if e.status_code == niquests.codes.precondition_failed: # type: ignore[attr-defined] + raise TaskProcFatalException( + 'Failed to schedule Nextcloud TaskProcessing task:' + f' No provider of {task_type} is installed on this Nextcloud instance.' + ' Please install a suitable provider from the AI overview:' + ' https://docs.nextcloud.com/server/latest/admin_manual/ai/overview.html.', + ) from e + + if e.status_code == niquests.codes.too_many_requests: # type: ignore[attr-defined] + LOGGER.warning( + 'Rate limited during TaskProcessing task scheduling, waiting 30s before retrying', + extra={ + 'task_type': task_type, + 'sched_try': sched_tries, + }, + ) + await asyncio.sleep(30) + continue + + ocs_response = __try_parse_ocs_response(e.response) + if e.status_code // 100 == 4: + raise TaskProcFatalException( + f'Failed to schedule TaskProcessing task due to client error: {ocs_response}', + ) from e + + LOGGER.error('NextcloudException during TaskProcessing task scheduling', exc_info=e, extra={ + 'task_type': task_type, + 'sched_try': sched_tries, + 'nc_exc_reason': str(e.reason), + 'nc_exc_info': str(e.info), + 'nc_exc_status_code': str(e.status_code), + 'ocs_response': str(ocs_response), + }) + raise TaskProcException(f'Failed to schedule TaskProcessing task: {ocs_response}') from e + except TaskProcException: + raise + except Exception as e: + raise TaskProcException(f'Failed to schedule TaskProcessing task: {e}') from e + + raise TaskProcException('Failed to schedule TaskProcessing task, tried 3 times') + + +async def __get_task_result(user_id: str, task: Task) -> Any: + nc = AsyncNextcloudApp() + await nc.set_user(user_id) + + i = 0 + now_waiting_for = 0 + + while task.status != 'STATUS_SUCCESSFUL' and task.status != 'STATUS_FAILED' and now_waiting_for < OCP_TASK_TIMEOUT: + i += 1 + now_waiting_for += 10 + await asyncio.sleep(10) + + try: + response = await nc.ocs('GET', f'/ocs/v2.php/taskprocessing/task/{task.id}') + except NextcloudException as e: + if e.status_code == niquests.codes.too_many_requests: # type: ignore[attr-defined] + LOGGER.warning( + 'Rate limited during TaskProcessing task polling, waiting 10s before retrying', + extra={ + 'task_id': task.id, + 'tries_so_far': i, + 'waiting_time': now_waiting_for, + }, + ) + now_waiting_for += 60 + await asyncio.sleep(60) + continue + raise TaskProcException('Failed to poll TaskProcessing task') from e + except niquests.RequestException as e: + LOGGER.warning('Ignored error during TaskProcessing task polling', exc_info=e, extra={ + 'task_id': task.id, + 'tries_so_far': i, + 'waiting_time': now_waiting_for, + }) + continue + + try: + task = TaskResponse.model_validate(response).task + LOGGER.debug(f'TaskProcessing task poll ({now_waiting_for}s) response', extra={ + 'task_id': task.id, + 'tries_so_far': i, + 'waiting_time': now_waiting_for, + 'task': task, + }) + except ValidationError as e: + raise TaskProcException('Failed to parse TaskProcessing task result') from e + + if task.status != 'STATUS_SUCCESSFUL': + raise TaskProcException( + f'TaskProcessing task id {task.id} failed with status {task.status}' + f' after waiting {now_waiting_for} seconds', + ) + + if not isinstance(task.output, dict) or 'output' not in task.output: + raise TaskProcException(f'"output" key not found or invalid in TaskProcessing task result: {task.output}') + + return task.output['output'] + + +async def do_ocr(user_id: str, file_id: int) -> str: + try: + task = await __schedule_task(user_id, OCR_TASK_TYPE, str(file_id), {'input': [file_id]}) + output = await __get_task_result(user_id, task) + if not isinstance(output, list) or len(output) == 0 or not isinstance(output[0], str): + raise TaskProcException(f'OCR task returned empty or invalid output: {output}') + return output[0] + except TaskProcException as e: + LOGGER.error(f'Failed to perform OCR for file_id {file_id}', exc_info=e) + raise + + +async def do_transcription(user_id: str, file_id: int) -> str: + try: + task = await __schedule_task(user_id, SPEECH_TO_TEXT_TASK_TYPE, str(file_id), {'input': file_id}) + output = await __get_task_result(user_id, task) + if not isinstance(output, str) or len(output.strip()) == 0: + raise TaskProcException(f'Speech-to-text task returned empty or invalid output: {output}') + return output + except TaskProcException as e: + LOGGER.error(f'Failed to perform transcription for file_id {file_id}', exc_info=e) + raise + + +@timed_cache_async(CACHE_TTL) +async def __get_task_types() -> TaskTypesResponse: + ''' + Raises + ------ + TaskProcException + ''' + nc = AsyncNextcloudApp() + + # NC 33 required for this + try: + response = await nc.ocs( + 'GET', + '/ocs/v2.php/taskprocessing/tasks_consumer/tasktypes', + ) + except NextcloudException as e: + raise TaskProcException('Failed to fetch Nextcloud TaskProcessing types') from e + + try: + task_types = TaskTypesResponse.model_validate(response) + LOGGER.debug('Fetched task types', extra={ + 'task_types': task_types, + }) + except (KeyError, TypeError, ValidationError) as e: + raise TaskProcException('Failed to parse Nextcloud TaskProcessing types') from e + + return task_types + + +@timed_cache_async(CACHE_TTL) +async def is_task_type_available(task_type: str) -> bool: + try: + task_types = await __get_task_types() + except Exception as e: + LOGGER.warning(f'Failed to fetch task types: {e}', exc_info=e) + return False + if task_type not in task_types.types: + return False + return True diff --git a/context_chat_backend/mimetype_list.py b/context_chat_backend/mimetype_list.py index 87f10241..ce21e6ea 100644 --- a/context_chat_backend/mimetype_list.py +++ b/context_chat_backend/mimetype_list.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: 2023 Nextcloud GmbH and Nextcloud contributors # SPDX-License-Identifier: AGPL-3.0-or-later # -SUPPORTED_MIMETYPES = [ +SUPPORTED_MIMETYPES = ( 'text/plain', 'text/markdown', 'application/json', @@ -22,4 +22,33 @@ 'message/rfc822', 'application/vnd.ms-outlook', 'text/org', -] +) + +IMAGE_MIMETYPES = ( + 'image/bmp', + 'image/bpg', + 'image/emf', + 'image/gif', + 'image/heic', + 'image/heif', + 'image/jp2', + 'image/jpeg', + 'image/png', + 'image/svg+xml', + 'image/tga', + 'image/tiff', + 'image/webp', + 'image/x-dcraw', + 'image/x-icon', +) + +AUDIO_MIMETYPES = ( + 'audio/aac', + 'audio/flac', + 'audio/mp4', + 'audio/mpeg', + 'audio/ogg', + 'audio/wav', + 'audio/webm', + 'audio/x-scpls', +) diff --git a/context_chat_backend/task_fetcher.py b/context_chat_backend/task_fetcher.py index 51f98e7d..919cfccd 100644 --- a/context_chat_backend/task_fetcher.py +++ b/context_chat_backend/task_fetcher.py @@ -242,6 +242,7 @@ def _load_sources(source_items: dict[int, SourceItem]) -> dict[int, IndexingErro sleep(POLLING_COOLDOWN) continue + LOGGER.debug(f'Fetched {len(q_items.files)} files and {len(q_items.content_providers)} content providers') # populate files content and convert to source items fetched_files = {} source_files = {} @@ -383,6 +384,7 @@ def updates_processing_thread(app_config: TConfig, app_enabled: Event) -> None: sleep(POLLING_COOLDOWN) continue + LOGGER.debug(f'Fetched {len(q_items.actions)} updates') processed_event_ids = [] errored_events = {} for i, (db_id, action_item) in enumerate(q_items.actions.items()): diff --git a/context_chat_backend/types.py b/context_chat_backend/types.py index 9f23e14f..fa710b17 100644 --- a/context_chat_backend/types.py +++ b/context_chat_backend/types.py @@ -5,11 +5,10 @@ import re from enum import Enum from io import BytesIO -from typing import Annotated, Literal, Self +from typing import Annotated, Literal -from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator, model_validator +from pydantic import AfterValidator, BaseModel, Discriminator, computed_field, field_validator -from .mimetype_list import SUPPORTED_MIMETYPES from .vectordb.types import UpdateAccessOp __all__ = [ @@ -36,6 +35,7 @@ def is_valid_provider_id(provider_id: str) -> bool: def _validate_source_ids(source_ids: list[str]) -> list[str]: + # todo: use is_valid_source_id() if ( not isinstance(source_ids, list) or not all(isinstance(sid, str) and sid.strip() != '' for sid in source_ids) @@ -182,12 +182,6 @@ def validate_size(cls, v): return float(v) raise ValueError(f'Invalid size value: {v}, must be a non-negative number') - @model_validator(mode='after') - def validate_type(self) -> Self: - if self.reference.startswith(FILES_PROVIDER_ID) and self.type not in SUPPORTED_MIMETYPES: - raise ValueError(f'Unsupported file type: {self.type} for reference {self.reference}') - return self - class ReceivedFileItem(CommonSourceItem): content: None @@ -205,6 +199,11 @@ class SourceItem(CommonSourceItem): ''' content: str | BytesIO + @computed_field + @property + def file_id(self) -> int: + return _get_file_id_from_source_ref(self.reference) + @field_validator('content') @classmethod def validate_content(cls, v): @@ -344,3 +343,11 @@ class ActionsQueueItemUpdateAccessDeclSourceId(CommonActionsQueueItem): class ActionsQueueItems(BaseModel): actions: dict[int, ActionsQueueItem] + + +class TaskProcException(Exception): + ... + + +class TaskProcFatalException(TaskProcException): + ... diff --git a/context_chat_backend/utils.py b/context_chat_backend/utils.py index c7e588b3..d3a7bfd1 100644 --- a/context_chat_backend/utils.py +++ b/context_chat_backend/utils.py @@ -9,7 +9,7 @@ from collections.abc import Callable from functools import partial, wraps from multiprocessing.connection import Connection -from time import perf_counter_ns +from time import perf_counter_ns, time from typing import Any, TypeGuard, TypeVar from fastapi.responses import JSONResponse as FastAPIJSONResponse @@ -145,3 +145,20 @@ def get_app_role() -> AppRole: _logger.warning(f'Invalid app role: {role}, defaulting to all roles') return AppRole.NORMAL return AppRole(role) + + +# does not support caching of kwargs for recall +def timed_cache_async(ttl: int): + def decorator(fn: Callable): + cached_store: dict[tuple, tuple[float, Any]] = {} + @wraps(fn) + async def wrapper(*args, **kwargs): + if args in cached_store: + cached_time, cached_value = cached_store[args] + if (time() - cached_time) < ttl: + return cached_value + new_val = await fn(*args, **kwargs) + cached_store[args] = (time(), new_val) + return new_val + return wrapper + return decorator