diff --git a/src/datajoint/admin.py b/src/datajoint/admin.py index 64a91bb48..51ca42fc2 100644 --- a/src/datajoint/admin.py +++ b/src/datajoint/admin.py @@ -1,26 +1,42 @@ +""" +Administrative utilities for managing database connections. + +This module provides functions for viewing and terminating database connections +through the MySQL processlist interface. +""" + +from __future__ import annotations + import logging import pymysql -from .connection import conn +from .connection import Connection, conn logger = logging.getLogger(__name__.split(".")[0]) -def kill(restriction=None, connection=None, order_by=None): +def kill( + restriction: str | None = None, + connection: Connection | None = None, + order_by: str | list[str] | None = None, +) -> None: """ - view and kill database connections. + View and interactively kill database connections. - :param restriction: restriction to be applied to processlist - :param connection: a datajoint.Connection object. Default calls datajoint.conn() - :param order_by: order by a single attribute or the list of attributes. defaults to 'id'. + Displays active database connections matching the optional restriction and + prompts the user to select connections to terminate. - Restrictions are specified as strings and can involve any of the attributes of - information_schema.processlist: ID, USER, HOST, DB, COMMAND, TIME, STATE, INFO. + Args: + restriction: SQL WHERE clause condition to filter the processlist. + Can reference any column from information_schema.processlist: + ID, USER, HOST, DB, COMMAND, TIME, STATE, INFO. + connection: A datajoint.Connection object. If None, uses datajoint.conn(). + order_by: Column name(s) to sort results by. Defaults to 'id'. Examples: - dj.kill('HOST LIKE "%compute%"') lists only connections from hosts containing "compute". - dj.kill('TIME > 600') lists only connections in their current state for more than 10 minutes + >>> dj.kill('HOST LIKE "%compute%"') # connections from hosts containing "compute" + >>> dj.kill('TIME > 600') # connections idle for more than 10 minutes """ if connection is None: @@ -59,18 +75,28 @@ def kill(restriction=None, connection=None, order_by=None): logger.warn("Process not found") -def kill_quick(restriction=None, connection=None): +def kill_quick( + restriction: str | None = None, + connection: Connection | None = None, +) -> int: """ - Kill database connections without prompting. Returns number of terminated connections. + Kill database connections without prompting. + + Terminates all database connections matching the optional restriction + without user confirmation. - :param restriction: restriction to be applied to processlist - :param connection: a datajoint.Connection object. Default calls datajoint.conn() + Args: + restriction: SQL WHERE clause condition to filter the processlist. + Can reference any column from information_schema.processlist: + ID, USER, HOST, DB, COMMAND, TIME, STATE, INFO. + connection: A datajoint.Connection object. If None, uses datajoint.conn(). - Restrictions are specified as strings and can involve any of the attributes of - information_schema.processlist: ID, USER, HOST, DB, COMMAND, TIME, STATE, INFO. + Returns: + Number of connections terminated. Examples: - dj.kill('HOST LIKE "%compute%"') terminates connections from hosts containing "compute". + >>> dj.kill_quick('HOST LIKE "%compute%"') # kill connections from "compute" hosts + >>> dj.kill_quick('TIME > 600') # kill connections idle for more than 10 minutes """ if connection is None: connection = conn() diff --git a/src/datajoint/cli.py b/src/datajoint/cli.py index 6437ebbc5..8e153b02b 100644 --- a/src/datajoint/cli.py +++ b/src/datajoint/cli.py @@ -1,16 +1,41 @@ +""" +Command-line interface for DataJoint Python. + +This module provides a console interface for interacting with DataJoint databases, +allowing users to connect to servers and work with virtual modules from the command line. + +Usage: + datajoint [-u USER] [-p PASSWORD] [-h HOST] [-s SCHEMA:MODULE ...] + +Example: + datajoint -u root -h localhost -s mydb:experiment mydb:subject +""" + +from __future__ import annotations + import argparse from code import interact from collections import ChainMap +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Sequence import datajoint as dj -def cli(args: list = None): +def cli(args: Sequence[str] | None = None) -> None: """ - Console interface for DataJoint Python + Console interface for DataJoint Python. + + Launches an interactive Python shell with DataJoint configured and optional + virtual modules loaded for database schemas. + + Args: + args: List of command-line arguments. If None, reads from sys.argv. - :param args: List of arguments to be passed in, defaults to reading stdin - :type args: list, optional + Raises: + SystemExit: Always raised when the interactive session ends. """ parser = argparse.ArgumentParser( prog="datajoint", diff --git a/src/datajoint/connection.py b/src/datajoint/connection.py index 66d926694..2bb1d8642 100644 --- a/src/datajoint/connection.py +++ b/src/datajoint/connection.py @@ -1,14 +1,19 @@ """ -This module contains the Connection class that manages the connection to the database, and -the ``conn`` function that provides access to a persistent connection in datajoint. +Database connection management for DataJoint. + +This module contains the Connection class that manages the connection to the database, +and the ``conn`` function that provides access to a persistent connection in datajoint. """ +from __future__ import annotations + import logging import pathlib import re import warnings from contextlib import contextmanager from getpass import getpass +from typing import TYPE_CHECKING, Any import pymysql as client @@ -19,6 +24,10 @@ from .settings import config from .version import __version__ +if TYPE_CHECKING: + from collections.abc import Generator + + logger = logging.getLogger(__name__.split(".")[0]) query_log_max_length = 300 @@ -26,13 +35,17 @@ cache_key = "query_cache" # the key to lookup the query_cache folder in dj.config -def translate_query_error(client_error, query): +def translate_query_error(client_error: Exception, query: str) -> Exception: """ - Take client error and original query and return the corresponding DataJoint exception. + Translate a database client error into the corresponding DataJoint exception. - :param client_error: the exception raised by the client interface - :param query: sql query with placeholders - :return: an instance of the corresponding subclass of datajoint.errors.DataJointError + Args: + client_error: The exception raised by the pymysql client interface. + query: The SQL query that caused the error (with placeholders). + + Returns: + An instance of the appropriate DataJointError subclass, or the original + error if no specific translation is available. """ logger.debug("type: {}, args: {}".format(type(client_error), client_error.args)) @@ -71,22 +84,37 @@ def translate_query_error(client_error, query): return client_error -def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None): +def conn( + host: str | None = None, + user: str | None = None, + password: str | None = None, + *, + init_fun: str | None = None, + reset: bool = False, + use_tls: bool | dict[str, Any] | None = None, +) -> Connection: """ - Returns a persistent connection object to be shared by multiple modules. + Return a persistent connection object to be shared by multiple modules. + If the connection is not yet established or reset=True, a new connection is set up. If connection information is not provided, it is taken from config which takes the - information from dj_local_conf.json. If the password is not specified in that file + information from dj_local_conf.json. If the password is not specified in that file, datajoint prompts for the password. - :param host: hostname - :param user: mysql user - :param password: mysql password - :param init_fun: initialization function - :param reset: whether the connection should be reset or not - :param use_tls: TLS encryption option. Valid options are: True (required), False - (required no TLS), None (TLS preferred, default), dict (Manually specify values per - https://dev.mysql.com/doc/refman/8.0/en/connection-options.html#encrypted-connection-options). + Args: + host: Database hostname, optionally with port (host:port). + user: MySQL username. + password: MySQL password. + init_fun: SQL initialization statement to execute on connection. + reset: If True, close existing connection and create a new one. + use_tls: TLS encryption option: + - True: Require TLS + - False: Require no TLS + - None: TLS preferred (default) + - dict: Manual SSL configuration options + + Returns: + A shared Connection object. """ if not hasattr(conn, "connection") or reset: host = host if host is not None else config["database.host"] @@ -103,45 +131,71 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use class EmulatedCursor: - """acts like a cursor""" + """ + A cursor-like object that wraps pre-fetched query results. - def __init__(self, data): + Used when query caching is enabled to provide a cursor interface + over cached data. + """ + + def __init__(self, data: list[tuple | dict]) -> None: self._data = data self._iter = iter(self._data) - def __iter__(self): + def __iter__(self) -> EmulatedCursor: return self - def __next__(self): + def __next__(self) -> tuple | dict: return next(self._iter) - def fetchall(self): + def fetchall(self) -> list[tuple | dict]: + """Return all remaining rows.""" return self._data - def fetchone(self): + def fetchone(self) -> tuple | dict: + """Return the next row.""" return next(self._iter) @property - def rowcount(self): + def rowcount(self) -> int: + """Return the total number of rows.""" return len(self._data) class Connection: """ - A dj.Connection object manages a connection to a database server. - It also catalogues modules, schemas, tables, and their dependencies (foreign keys). - - Most of the parameters below should be set in the local configuration file. - - :param host: host name, may include port number as hostname:port, in which case it overrides the value in port - :param user: user name - :param password: password - :param port: port number - :param init_fun: connection initialization function (SQL) - :param use_tls: TLS encryption option + Manage a connection to a DataJoint database server. + + This class handles database connectivity, query execution, transaction management, + and maintains references to schemas and their dependencies (foreign keys). + + Most parameters should be configured in the local configuration file rather than + passed directly. + + Args: + host: Hostname, may include port as hostname:port. + user: Database username. + password: Database password. + port: Port number (overridden if included in host). + init_fun: SQL initialization statement to execute on connection. + use_tls: TLS encryption option (True/False/None/dict). + + Attributes: + conn_info: Dictionary of connection parameters. + schemas: Dictionary mapping database names to Schema objects. + dependencies: Dependencies graph for foreign key relationships. + connection_id: MySQL connection ID for this session. """ - def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None): + def __init__( + self, + host: str, + user: str, + password: str, + port: int | None = None, + init_fun: str | None = None, + use_tls: bool | dict[str, Any] | None = None, + ) -> None: if ":" in host: # the port in the hostname overrides the port argument host, port = host.split(":") @@ -165,15 +219,17 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None) self.schemas = dict() self.dependencies = Dependencies(self) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Connection): + return NotImplemented return self.conn_info == other.conn_info - def __repr__(self): + def __repr__(self) -> str: connected = "connected" if self.is_connected else "disconnected" return "DataJoint connection ({connected}) {user}@{host}:{port}".format(connected=connected, **self.conn_info) - def connect(self): - """Connect to the database server.""" + def connect(self) -> None: + """Establish connection to the database server.""" with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*deprecated.*") try: @@ -198,38 +254,44 @@ def connect(self): ) self._conn.autocommit(True) - def set_query_cache(self, query_cache=None): + def set_query_cache(self, query_cache: str | None = None) -> None: """ - When query_cache is not None, the connection switches into the query caching mode, which entails: - 1. Only SELECT queries are allowed. - 2. The results of queries are cached under the path indicated by dj.config['query_cache'] - 3. query_cache is a string that differentiates different cache states. + Enable or disable query caching mode. + + When query_cache is not None, the connection switches into caching mode: + 1. Only SELECT queries are allowed + 2. Results are cached under dj.config['query_cache'] + 3. The query_cache string differentiates cache states - :param query_cache: a string to initialize the hash for query results + Args: + query_cache: String to initialize the hash for query results, + or None to disable caching. """ self._query_cache = query_cache - def purge_query_cache(self): - """Purges all query cache.""" + def purge_query_cache(self) -> None: + """Remove all cached query results from the cache directory.""" if isinstance(config.get(cache_key), str) and pathlib.Path(config[cache_key]).is_dir(): for path in pathlib.Path(config[cache_key]).iterdir(): if not path.is_dir(): path.unlink() - def close(self): + def close(self) -> None: + """Close the database connection.""" self._conn.close() - def register(self, schema): + def register(self, schema: Any) -> None: + """Register a schema with this connection.""" self.schemas[schema.database] = schema self.dependencies.clear() - def ping(self): - """Ping the connection or raises an exception if the connection is closed.""" + def ping(self) -> None: + """Ping the connection; raises an exception if disconnected.""" self._conn.ping(reconnect=False) @property - def is_connected(self): - """Return true if the object is connected to the database server.""" + def is_connected(self) -> bool: + """Return True if connected to the database server.""" try: self.ping() except: @@ -237,7 +299,7 @@ def is_connected(self): return True @staticmethod - def _execute_query(cursor, query, args, suppress_warnings): + def _execute_query(cursor: Any, query: str, args: tuple, suppress_warnings: bool) -> None: try: with warnings.catch_warnings(): if suppress_warnings: @@ -247,16 +309,31 @@ def _execute_query(cursor, query, args, suppress_warnings): except client.err.Error as err: raise translate_query_error(err, query) - def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None): + def query( + self, + query: str, + args: tuple = (), + *, + as_dict: bool = False, + suppress_warnings: bool = True, + reconnect: bool | None = None, + ) -> Any: """ - Execute the specified query and return the tuple generator (cursor). - - :param query: SQL query - :param args: additional arguments for the client.cursor - :param as_dict: If as_dict is set to True, the returned cursor objects returns - query results as dictionary. - :param suppress_warnings: If True, suppress all warnings arising from underlying query library - :param reconnect: when None, get from config, when True, attempt to reconnect if disconnected + Execute an SQL query and return a cursor with the results. + + Args: + query: The SQL query string. + args: Parameters to substitute into the query. + as_dict: If True, return results as dictionaries instead of tuples. + suppress_warnings: If True, suppress warnings from the database driver. + reconnect: If True, reconnect on connection loss. If None, use config setting. + + Returns: + A cursor object (or EmulatedCursor when caching is enabled). + + Raises: + DataJointError: If caching is enabled and query is not SELECT/SHOW. + LostConnectionError: If connection is lost and reconnect fails. """ # check cache first: use_query_cache = bool(self._query_cache) @@ -300,24 +377,28 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn return cursor - def get_user(self): + def get_user(self) -> str: """ - :return: the user name and host name provided by the client to the server. + Return the current database user. + + Returns: + The username and host in 'user@host' format. """ return self.query("SELECT user()").fetchone()[0] # ---------- transaction processing @property - def in_transaction(self): - """ - :return: True if there is an open transaction. - """ + def in_transaction(self) -> bool: + """Return True if there is an open transaction.""" self._in_transaction = self._in_transaction and self.is_connected return self._in_transaction - def start_transaction(self): + def start_transaction(self) -> None: """ - Starts a transaction error. + Start a new database transaction. + + Raises: + DataJointError: If already in a transaction (nesting not supported). """ if self.in_transaction: raise errors.DataJointError("Nested connections are not supported.") @@ -325,19 +406,14 @@ def start_transaction(self): self._in_transaction = True logger.debug("Transaction started") - def cancel_transaction(self): - """ - Cancels the current transaction and rolls back all changes made during the transaction. - """ + def cancel_transaction(self) -> None: + """Cancel the current transaction and roll back all changes.""" self.query("ROLLBACK") self._in_transaction = False logger.debug("Transaction cancelled. Rolling back ...") - def commit_transaction(self): - """ - Commit all changes made during the transaction and close it. - - """ + def commit_transaction(self) -> None: + """Commit all changes made during the transaction and close it.""" self.query("COMMIT") self._in_transaction = False logger.debug("Transaction committed and closed.") @@ -345,16 +421,22 @@ def commit_transaction(self): # -------- context manager for transactions @property @contextmanager - def transaction(self): + def transaction(self) -> Generator[Connection, None, None]: """ - Context manager for transactions. Opens an transaction and closes it after the with statement. - If an error is caught during the transaction, the commits are automatically rolled back. - All errors are raised again. + Context manager for database transactions. + + Opens a transaction and commits it after the with block completes successfully. + If an exception is raised, the transaction is rolled back automatically. + + Yields: + This Connection object. Example: - >>> import datajoint as dj - >>> with dj.conn().transaction as conn: - >>> # transaction is open here + >>> import datajoint as dj + >>> with dj.conn().transaction as conn: + ... # transaction is open here + ... table.insert(data) + # transaction is committed on successful exit """ try: self.start_transaction() diff --git a/src/datajoint/expression.py b/src/datajoint/expression.py index 62359be94..7fa3db770 100644 --- a/src/datajoint/expression.py +++ b/src/datajoint/expression.py @@ -1,8 +1,19 @@ +""" +Query expression classes for building DataJoint queries. + +This module provides the QueryExpression class and related classes (Aggregation, Union, U) +for constructing SQL queries through a Pythonic interface. Query expressions support +restriction, projection, joining, aggregation, and union operations. +""" + +from __future__ import annotations + import copy import inspect import logging import re from itertools import count +from typing import TYPE_CHECKING, Any from .condition import ( AndList, @@ -20,6 +31,12 @@ from .preview import preview, repr_html from .settings import config +if TYPE_CHECKING: + from collections.abc import Iterator + + from .connection import Connection + from .heading import Heading + logger = logging.getLogger(__name__.split(".")[0]) @@ -60,48 +77,50 @@ class QueryExpression: _distinct = False @property - def connection(self): - """a dj.Connection object""" + def connection(self) -> Connection: + """The database connection for this expression.""" assert self._connection is not None return self._connection @property - def support(self): - """A list of table names or subqueries to from the FROM clause""" + def support(self) -> list: + """List of table names or subqueries forming the FROM clause.""" assert self._support is not None return self._support @property - def heading(self): - """a dj.Heading object, reflects the effects of the projection operator .proj""" + def heading(self) -> Heading: + """The Heading object reflecting projection effects.""" return self._heading @property - def original_heading(self): - """a dj.Heading object reflecting the attributes before projection""" + def original_heading(self) -> Heading: + """The Heading object before any projection was applied.""" return self._original_heading or self.heading @property - def restriction(self): - """a AndList object of restrictions applied to input to produce the result""" + def restriction(self) -> AndList: + """AndList of restrictions applied to produce the result.""" if self._restriction is None: self._restriction = AndList() return self._restriction @property - def restriction_attributes(self): - """the set of attribute names invoked in the WHERE clause""" + def restriction_attributes(self) -> set[str]: + """Set of attribute names used in the WHERE clause.""" if self._restriction_attributes is None: self._restriction_attributes = set() return self._restriction_attributes @property - def primary_key(self): + def primary_key(self) -> list[str]: + """List of primary key attribute names.""" return self.heading.primary_key _subquery_alias_count = count() # count for alias names used in the FROM clause - def from_clause(self): + def from_clause(self) -> str: + """Generate the FROM clause for the SQL query.""" support = ( ( "(" + src.make_sql() + ") as `$%x`" % next(self._subquery_alias_count) @@ -115,10 +134,12 @@ def from_clause(self): clause += " NATURAL{left} JOIN {clause}".format(left=" LEFT" if left else "", clause=s) return clause - def where_clause(self): + def where_clause(self) -> str: + """Generate the WHERE clause for the SQL query.""" return "" if not self.restriction else " WHERE (%s)" % ")AND(".join(str(s) for s in self.restriction) - def sorting_clauses(self): + def sorting_clauses(self) -> str: + """Generate ORDER BY and LIMIT clauses for the SQL query.""" if not self._top: return "" clause = ", ".join(_wrap_attributes(_flatten_attribute_list(self.primary_key, self._top.order_by))) @@ -129,11 +150,15 @@ def sorting_clauses(self): return clause - def make_sql(self, fields=None): + def make_sql(self, fields: list[str] | None = None) -> str: """ - Make the SQL SELECT statement. + Generate the complete SQL SELECT statement. - :param fields: used to explicitly set the select attributes + Args: + fields: Attribute names to select. If None, uses heading names. + + Returns: + The complete SQL SELECT query string. """ return "SELECT {distinct}{fields} FROM {from_}{where}{sorting}".format( distinct="DISTINCT " if self._distinct else "", @@ -144,8 +169,13 @@ def make_sql(self, fields=None): ) # --------- query operators ----------- - def make_subquery(self): - """create a new SELECT statement where self is the FROM clause""" + def make_subquery(self) -> QueryExpression: + """ + Create a new query expression with this expression as a subquery. + + Returns: + A new QueryExpression with self in the FROM clause. + """ result = QueryExpression() result._connection = self.connection result._support = [self] @@ -520,8 +550,8 @@ def tail(self, limit=25, **fetch_kwargs): """ return self.fetch(order_by="KEY DESC", limit=limit, **fetch_kwargs)[::-1] - def __len__(self): - """:return: number of elements in the result set e.g. ``len(q1)``.""" + def __len__(self) -> int: + """Return the number of rows in the result set.""" result = self.make_subquery() if self._top else copy.copy(self) return result.connection.query( "SELECT {select_} FROM {from_}{where}".format( @@ -537,46 +567,42 @@ def __len__(self): ) ).fetchone()[0] - def __bool__(self): - """ - :return: True if the result is not empty. Equivalent to len(self) > 0 but often - faster e.g. ``bool(q1)``. - """ + def __bool__(self) -> bool: + """Return True if the result set is not empty.""" return bool( self.connection.query( "SELECT EXISTS(SELECT 1 FROM {from_}{where})".format(from_=self.from_clause(), where=self.where_clause()) ).fetchone()[0] ) - def __contains__(self, item): + def __contains__(self, item: Any) -> bool: """ - returns True if the restriction in item matches any entries in self - e.g. ``restriction in q1``. + Check if any entries match the given restriction. - :param item: any restriction - (item in query_expression) is equivalent to bool(query_expression & item) but may be - executed more efficiently. - """ - return bool(self & item) # May be optimized e.g. using an EXISTS query + Args: + item: Any valid restriction. - def __iter__(self): + Returns: + True if at least one entry matches the restriction. """ - returns an iterator-compatible QueryExpression object e.g. ``iter(q1)``. + return bool(self & item) # May be optimized e.g. using an EXISTS query - :param self: iterator-compatible QueryExpression object - """ + def __iter__(self) -> QueryExpression: + """Return an iterator over the query results.""" self._iter_only_key = all(v.in_key for v in self.heading.attributes.values()) self._iter_keys = self.fetch("KEY") return self - def __next__(self): + def __next__(self) -> dict: """ - returns the next record on an iterator-compatible QueryExpression object - e.g. ``next(q1)``. + Return the next record from the iterator. - :param self: A query expression - :type self: :class:`QueryExpression` - :rtype: dict + Returns: + Dictionary containing the next row's attribute values. + + Raises: + TypeError: If __iter__ was not called first. + StopIteration: When no more rows are available. """ try: key = self._iter_keys.pop(0) @@ -596,31 +622,39 @@ def __next__(self): # -- move on to next entry. return next(self) - def cursor(self, as_dict=False): + def cursor(self, as_dict: bool = False) -> Any: """ - See expression.fetch() for input description. - :return: query cursor + Execute the query and return a database cursor. + + Args: + as_dict: If True, return rows as dictionaries. + + Returns: + A database cursor object. """ sql = self.make_sql() logger.debug(sql) return self.connection.query(sql, as_dict=as_dict) - def __repr__(self): - """ - returns the string representation of a QueryExpression object e.g. ``str(q1)``. + def __repr__(self) -> str: + """Return string representation of the query expression.""" + return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview() - :param self: A query expression - :type self: :class:`QueryExpression` - :rtype: str + def preview(self, limit: int | None = None, width: int | None = None) -> str: """ - return super().__repr__() if config["loglevel"].lower() == "debug" else self.preview() + Return a formatted preview of the query results. - def preview(self, limit=None, width=None): - """:return: a string of preview of the contents of the query.""" + Args: + limit: Maximum number of rows to show. + width: Maximum display width. + + Returns: + Formatted string representation of query results. + """ return preview(self, limit, width) - def _repr_html_(self): - """:return: HTML to display table in Jupyter notebook.""" + def _repr_html_(self) -> str: + """Return HTML representation for Jupyter notebook display.""" return repr_html(self) @@ -919,11 +953,16 @@ def aggr(self, group, **named_attributes): aggregate = aggr # alias for aggr -def _flatten_attribute_list(primary_key, attrs): +def _flatten_attribute_list(primary_key: list[str], attrs: list[str]) -> Iterator[str]: """ - :param primary_key: list of attributes in primary key - :param attrs: list of attribute names, which may include "KEY", "KEY DESC" or "KEY ASC" - :return: generator of attributes where "KEY" is replaced with its component attributes + Expand "KEY" placeholders in attribute lists to actual primary key attributes. + + Args: + primary_key: List of primary key attribute names. + attrs: List of attribute names, which may include "KEY", "KEY DESC", or "KEY ASC". + + Yields: + Attribute names with "KEY" expanded to primary key components. """ for a in attrs: if re.match(r"^\s*KEY(\s+[aA][Ss][Cc])?\s*$", a): @@ -936,6 +975,15 @@ def _flatten_attribute_list(primary_key, attrs): yield a -def _wrap_attributes(attr): +def _wrap_attributes(attr: Iterator[str]) -> Iterator[str]: + """ + Wrap attribute names in SQL backquotes for safe identifier usage. + + Args: + attr: Iterator of attribute names/expressions. + + Yields: + Attribute expressions with identifiers wrapped in backquotes. + """ for entry in attr: # wrap attribute names in backquotes yield re.sub(r"\b((?!asc|desc)\w+)\b", r"`\1`", entry, flags=re.IGNORECASE) diff --git a/src/datajoint/fetch.py b/src/datajoint/fetch.py index 0029a898f..67a2e7c24 100644 --- a/src/datajoint/fetch.py +++ b/src/datajoint/fetch.py @@ -1,9 +1,19 @@ +""" +Data fetching utilities for DataJoint query expressions. + +This module provides the Fetch and Fetch1 classes that handle retrieving +data from the database, unpacking blobs, and downloading external files. +""" + +from __future__ import annotations + import itertools import json import numbers import uuid from functools import partial from pathlib import Path +from typing import TYPE_CHECKING, Any import numpy as np import pandas @@ -17,36 +27,69 @@ from .storage import StorageBackend from .utils import safe_write +if TYPE_CHECKING: + from collections.abc import Generator + + from .connection import Connection + from .expression import QueryExpression + from .heading import Attribute + class key: """ - object that allows requesting the primary key as an argument in expression.fetch() - The string "KEY" can be used instead of the class key + Sentinel object for requesting primary key in expression.fetch(). + + The string "KEY" can be used interchangeably with this class. + + Example: + >>> table.fetch('attribute', dj.key) # fetch attribute values and keys + >>> table.fetch('KEY') # equivalent using string """ pass -def is_key(attr): +def is_key(attr: Any) -> bool: + """Check if an attribute reference represents the primary key.""" return attr is key or attr == "KEY" -def to_dicts(recarray): - """convert record array to a dictionaries""" +def to_dicts(recarray: np.ndarray) -> Generator[dict, None, None]: + """ + Convert a numpy record array to a generator of dictionaries. + + Args: + recarray: A numpy structured/record array. + + Yields: + Dictionary for each row with field names as keys. + """ for rec in recarray: yield dict(zip(recarray.dtype.names, rec.tolist())) -def _get(connection, attr, data, squeeze, download_path): +def _get( + connection: Connection, + attr: Attribute, + data: Any, + squeeze: bool, + download_path: str | Path, +) -> Any: """ - This function is called for every attribute - - :param connection: a dj.Connection object - :param attr: attribute name from the table's heading - :param data: literal value fetched from the table - :param squeeze: if True squeeze blobs - :param download_path: for fetches that download data, e.g. attachments - :return: unpacked data + Process and unpack a single attribute value from the database. + + Handles special attribute types including blobs, attachments, external storage, + UUIDs, JSON, and object references. + + Args: + connection: The database connection for accessing external stores. + attr: Attribute metadata from the table's heading. + data: Raw value fetched from the database. + squeeze: If True, remove extra dimensions from blob arrays. + download_path: Directory for downloading attachments. + + Returns: + The unpacked/processed attribute value. """ if data is None: return @@ -114,43 +157,53 @@ def adapt(x): class Fetch: """ - A fetch object that handles retrieving elements from the table expression. + Handler for retrieving multiple rows from a query expression. + + Provides flexible data retrieval with support for various output formats, + attribute selection, ordering, and pagination. - :param expression: the QueryExpression object to fetch from. + Args: + expression: The QueryExpression to fetch data from. """ - def __init__(self, expression): + def __init__(self, expression: QueryExpression) -> None: self._expression = expression def __call__( self, - *attrs, - offset=None, - limit=None, - order_by=None, - format=None, - as_dict=None, - squeeze=False, - download_path=".", - ): + *attrs: str, + offset: int | None = None, + limit: int | None = None, + order_by: str | list[str] | None = None, + format: str | None = None, + as_dict: bool | None = None, + squeeze: bool = False, + download_path: str | Path = ".", + ) -> np.ndarray | list[dict] | pandas.DataFrame | list | tuple: """ - Fetches the expression results from the database into an np.array or list of dictionaries and - unpacks blob attributes. - - :param attrs: zero or more attributes to fetch. If not provided, the call will return all attributes of this - table. If provided, returns tuples with an entry for each attribute. - :param offset: the number of tuples to skip in the returned result - :param limit: the maximum number of tuples to return - :param order_by: a single attribute or the list of attributes to order the results. No ordering should be assumed - if order_by=None. To reverse the order, add DESC to the attribute name or names: e.g. ("age DESC", - "frequency") To order by primary key, use "KEY" or "KEY DESC" - :param format: Effective when as_dict=None and when attrs is empty None: default from config['fetch_format'] or - 'array' if not configured "array": use numpy.key_array "frame": output pandas.DataFrame. . - :param as_dict: returns a list of dictionaries instead of a record array. Defaults to False for .fetch() and to - True for .fetch('KEY') - :param squeeze: if True, remove extra dimensions from arrays - :param download_path: for fetches that download data, e.g. attachments - :return: the contents of the table in the form of a structured numpy.array or a dict list + Fetch results from the database into various output formats. + + Args: + *attrs: Attribute names to fetch. If empty, fetches all attributes. + Use "KEY" or dj.key to include primary key values. + offset: Number of rows to skip before returning results. + limit: Maximum number of rows to return. + order_by: Attribute(s) for sorting. Use "KEY" for primary key ordering, + append " DESC" for descending order (e.g., "timestamp DESC"). + format: Output format when fetching all attributes: + - None: Use config['fetch_format'] default + - "array": Return numpy structured array + - "frame": Return pandas DataFrame + as_dict: If True, return list of dictionaries. Defaults to True for + "KEY" fetches, False otherwise. + squeeze: If True, remove extra dimensions from blob arrays. + download_path: Directory for downloading attachments. + + Returns: + Data in the requested format: + - Single attr: array of values + - Multiple attrs: tuple of arrays + - No attrs: structured array, DataFrame, or list of dicts """ if offset or order_by or limit: self._expression = self._expression.restrict( @@ -251,30 +304,44 @@ def __call__( class Fetch1: """ - Fetch object for fetching the result of a query yielding one row. + Handler for fetching exactly one row from a query expression. - :param expression: a query expression to fetch from. + Raises an error if the query returns zero or more than one row. + + Args: + expression: The QueryExpression to fetch from. """ - def __init__(self, expression): + def __init__(self, expression: QueryExpression) -> None: self._expression = expression - def __call__(self, *attrs, squeeze=False, download_path="."): + def __call__( + self, + *attrs: str, + squeeze: bool = False, + download_path: str | Path = ".", + ) -> dict | Any | tuple: """ - Fetches the result of a query expression that yields one entry. + Fetch exactly one row from the query expression. + + Args: + *attrs: Attribute names to fetch. If empty, returns all attributes + as a dictionary. If specified, returns values as a tuple. + squeeze: If True, remove extra dimensions from blob arrays. + download_path: Directory for downloading attachments. + + Returns: + - No attrs: Dictionary with all attribute values + - One attr: Single value + - Multiple attrs: Tuple of values - If no attributes are specified, returns the result as a dict. - If attributes are specified returns the corresponding results as a tuple. + Raises: + DataJointError: If the query returns zero or more than one row. Examples: - d = rel.fetch1() # as a dictionary - a, b = rel.fetch1('a', 'b') # as a tuple - - :params *attrs: attributes to return when expanding into a tuple. - If attrs is empty, the return result is a dict - :param squeeze: When true, remove extra dimensions from arrays in attributes - :param download_path: for fetches that download data, e.g. attachments - :return: the one tuple in the table in the form of a dict + >>> d = rel.fetch1() # returns dict + >>> a, b = rel.fetch1('a', 'b') # returns tuple + >>> val = rel.fetch1('value') # returns single value """ heading = self._expression.heading diff --git a/src/datajoint/hash.py b/src/datajoint/hash.py index f58c65732..b5256bc6c 100644 --- a/src/datajoint/hash.py +++ b/src/datajoint/hash.py @@ -1,14 +1,37 @@ +""" +Hashing utilities for DataJoint. + +This module provides functions for computing hashes of data, streams, and files. +These are used for checksums, content-addressable storage, and primary key hashing. +""" + +from __future__ import annotations + import hashlib import io import uuid +from collections.abc import Mapping from pathlib import Path +from typing import IO -def key_hash(mapping): +def key_hash(mapping: Mapping) -> str: """ - 32-byte hash of the mapping's key values sorted by the key name. - This is often used to convert a long primary key value into a shorter hash. - For example, the JobTable in datajoint.jobs uses this function to hash the primary key of autopopulated tables. + Compute a 32-character hex hash of a mapping's values, sorted by key name. + + This is commonly used to convert long primary key values into shorter hashes. + For example, the JobTable in datajoint.jobs uses this function to hash the + primary keys of autopopulated tables. + + Args: + mapping: A dict-like object whose values will be hashed. + + Returns: + A 32-character hexadecimal MD5 hash string. + + Example: + >>> key_hash({'subject_id': 1, 'session': 5}) + 'a1b2c3d4e5f6...' """ hashed = hashlib.md5() for k, v in sorted(mapping.items()): @@ -16,11 +39,18 @@ def key_hash(mapping): return hashed.hexdigest() -def uuid_from_stream(stream, *, init_string=""): +def uuid_from_stream(stream: IO[bytes], *, init_string: str = "") -> uuid.UUID: """ - :return: 16-byte digest of stream data - :stream: stream object or open file handle - :init_string: string to initialize the checksum + Compute a UUID from the contents of a binary stream. + + Reads the stream in chunks and computes an MD5 hash, returning it as a UUID. + + Args: + stream: A binary stream object (file handle opened in 'rb' mode or BytesIO). + init_string: Optional string to initialize the hash (acts as a salt). + + Returns: + A UUID object derived from the MD5 hash of the stream contents. """ hashed = hashlib.md5(init_string.encode()) chunk = True @@ -31,9 +61,29 @@ def uuid_from_stream(stream, *, init_string=""): return uuid.UUID(bytes=hashed.digest()) -def uuid_from_buffer(buffer=b"", *, init_string=""): +def uuid_from_buffer(buffer: bytes = b"", *, init_string: str = "") -> uuid.UUID: + """ + Compute a UUID from a bytes buffer. + + Args: + buffer: The binary data to hash. + init_string: Optional string to initialize the hash (acts as a salt). + + Returns: + A UUID object derived from the MD5 hash of the buffer. + """ return uuid_from_stream(io.BytesIO(buffer), init_string=init_string) -def uuid_from_file(filepath, *, init_string=""): +def uuid_from_file(filepath: str | Path, *, init_string: str = "") -> uuid.UUID: + """ + Compute a UUID from the contents of a file. + + Args: + filepath: Path to the file to hash. + init_string: Optional string to initialize the hash (acts as a salt). + + Returns: + A UUID object derived from the MD5 hash of the file contents. + """ return uuid_from_stream(Path(filepath).open("rb"), init_string=init_string) diff --git a/src/datajoint/heading.py b/src/datajoint/heading.py index dc305db71..2a4bc0b12 100644 --- a/src/datajoint/heading.py +++ b/src/datajoint/heading.py @@ -243,9 +243,6 @@ def _init_from_database(self): as_dict=True, ).fetchone() if info is None: - if table_name == "~log": - logger.warning("Could not create the ~log table") - return raise DataJointError( "The table `{database}`.`{table_name}` is not defined.".format(table_name=table_name, database=database) ) diff --git a/src/datajoint/logging.py b/src/datajoint/logging.py index b432e1a4b..10e8dd5af 100644 --- a/src/datajoint/logging.py +++ b/src/datajoint/logging.py @@ -1,6 +1,18 @@ +""" +Logging configuration for the DataJoint package. + +This module sets up the default logging handler and format for DataJoint, +and provides a custom exception hook to log uncaught exceptions. + +The log level can be configured via the DJ_LOG_LEVEL environment variable. +""" + +from __future__ import annotations + import logging import os import sys +from types import TracebackType logger = logging.getLogger(__name__.split(".")[0]) @@ -15,7 +27,22 @@ logger.handlers = [stream_handler] -def excepthook(exc_type, exc_value, exc_traceback): +def excepthook( + exc_type: type[BaseException], + exc_value: BaseException, + exc_traceback: TracebackType | None, +) -> None: + """ + Custom exception hook that logs uncaught exceptions. + + Keyboard interrupts are passed to the default handler; all other exceptions + are logged as errors with full traceback information. + + Args: + exc_type: The exception class. + exc_value: The exception instance. + exc_traceback: The traceback object. + """ if issubclass(exc_type, KeyboardInterrupt): sys.__excepthook__(exc_type, exc_value, exc_traceback) return diff --git a/src/datajoint/schemas.py b/src/datajoint/schemas.py index 9df3ba34d..5ce898394 100644 --- a/src/datajoint/schemas.py +++ b/src/datajoint/schemas.py @@ -11,7 +11,7 @@ from .external import ExternalMapping from .heading import Heading from .settings import config -from .table import FreeTable, Log, lookup_class_name +from .table import FreeTable, lookup_class_name from .user_tables import Computed, Imported, Lookup, Manual, Part, _get_tier from .utils import to_camel_case, user_choice @@ -63,7 +63,6 @@ def __init__( :param add_objects: a mapping with additional objects to make available to the context in which table classes are declared. """ - self._log = None self.connection = connection self.database = None self.context = context @@ -136,7 +135,7 @@ def activate( "Schema `{name}` does not exist and could not be created. Check permissions.".format(name=schema_name) ) else: - self.log("created") + logger.info("Created schema `%s`", schema_name) self.connection.register(self) # decorate all tables already decorated @@ -231,13 +230,6 @@ def _decorate_table(self, table_class, context, assert_declared=False): if table_class not in self._auto_populated_tables: self._auto_populated_tables.append(table_class) - @property - def log(self): - self._assert_exists() - if self._log is None: - self._log = Log(self.connection, self.database) - return self._log - def __repr__(self): return "Schema `{name}`\n".format(name=self.database) @@ -420,7 +412,7 @@ def replace(s): def list_tables(self): """ Return a list of all tables in the schema except tables with ~ in first character such - as ~logs and ~job + as ~job :return: A list of table names from the database schema. """ diff --git a/src/datajoint/table.py b/src/datajoint/table.py index 5cedacfdc..ce961af4c 100644 --- a/src/datajoint/table.py +++ b/src/datajoint/table.py @@ -5,7 +5,6 @@ import json import logging import mimetypes -import platform import re import uuid from datetime import datetime, timezone @@ -34,7 +33,6 @@ from .staged_insert import staged_insert1 as _staged_insert1 from .storage import StorageBackend, build_object_path, verify_or_create_store_metadata from .utils import get_master, is_camel_case, user_choice -from .version import __version__ as version logger = logging.getLogger(__name__.split(".")[0]) @@ -73,7 +71,6 @@ class Table(QueryExpression): """ _table_name = None # must be defined in subclass - _log_ = None # placeholder for the Log table object # These properties must be set by the schema decorator (schemas.py) at class level # or by FreeTable at instance level @@ -118,7 +115,7 @@ def declare(self, context=None): # skip if no create privilege pass else: - self._log("Declared " + self.full_table_name) + logger.info("Declared %s", self.full_table_name) # Populate lineage entries for semantic matching self._populate_lineage() @@ -153,7 +150,7 @@ def alter(self, prompt=True, context=None): self.__class__._heading = Heading(table_info=self.heading.table_info) if prompt: logger.info("Table altered") - self._log("Altered " + self.full_table_name) + logger.info("Altered %s", self.full_table_name) def _populate_lineage(self): """ @@ -293,16 +290,6 @@ def full_table_name(self): """ return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name) - @property - def _log(self): - if self._log_ is None: - self._log_ = Log( - self.connection, - database=self.database, - skip_logging=self.table_name.startswith("~"), - ) - return self._log_ - @property def external(self): return self.connection.schemas[self.database].external @@ -609,7 +596,7 @@ def delete_quick(self, get_count=False): query = "DELETE FROM " + self.full_table_name + self.where_clause() self.connection.query(query) count = self.connection.query("SELECT ROW_COUNT()").fetchone()[0] if get_count else None - self._log(query[:255]) + logger.debug("Deleted from %s", self.full_table_name) return count def delete( @@ -787,10 +774,9 @@ def drop_quick(self): self.connection.query(query) # Clean up lineage entries delete_lineage_entries(self.connection, self.database, self.table_name) - logger.info("Dropped table %s" % self.full_table_name) - self._log(query[:255]) + logger.info("Dropped table %s", self.full_table_name) else: - logger.info("Nothing to drop: table %s is not declared" % self.full_table_name) + logger.info("Nothing to drop: table %s is not declared", self.full_table_name) def drop(self): """ @@ -1097,76 +1083,3 @@ def __init__(self, conn, full_table_name): def __repr__(self): return "FreeTable(`%s`.`%s`)\n" % (self.database, self._table_name) + super().__repr__() - - -class Log(Table): - """ - The log table for each schema. - Instances are callable. Calls log the time and identifying information along with the event. - - :param skip_logging: if True, then log entry is skipped by default. See __call__ - """ - - _table_name = "~log" - - def __init__(self, conn, database, skip_logging=False): - self.database = database - self.skip_logging = skip_logging - self._connection = conn - self._heading = Heading(table_info=dict(conn=conn, database=database, table_name=self.table_name, context=None)) - self._support = [self.full_table_name] - - self._definition = """ # event logging table for `{database}` - id :int unsigned auto_increment # event order id - --- - timestamp = CURRENT_TIMESTAMP : timestamp # event timestamp - version :varchar(12) # datajoint version - user :varchar(255) # user@host - host="" :varchar(255) # system hostname - event="" :varchar(255) # event message - """.format(database=database) - - super().__init__() - - if not self.is_declared: - self.declare() - self.connection.dependencies.clear() - self._user = self.connection.get_user() - - @property - def definition(self): - return self._definition - - def __call__(self, event, skip_logging=None): - """ - - :param event: string to write into the log table - :param skip_logging: If True then do not log. If None, then use self.skip_logging - """ - skip_logging = self.skip_logging if skip_logging is None else skip_logging - if not skip_logging: - try: - self.insert1( - dict( - user=self._user, - version=version + "py", - host=platform.uname().node, - event=event, - ), - skip_duplicates=True, - ignore_extra_fields=True, - ) - except DataJointError: - logger.info("could not log event in table ~log") - - def delete(self): - """ - bypass interactive prompts and cascading dependencies - - :return: number of deleted items - """ - return self.delete_quick(get_count=True) - - def drop(self): - """bypass interactive prompts and cascading dependencies""" - self.drop_quick() diff --git a/src/datajoint/utils.py b/src/datajoint/utils.py index 16927965e..8ff091cb0 100644 --- a/src/datajoint/utils.py +++ b/src/datajoint/utils.py @@ -1,28 +1,55 @@ -"""General-purpose utilities""" +""" +General-purpose utilities for DataJoint. + +This module provides helper functions for common operations including +naming conventions, file operations, and SQL parsing. +""" + +from __future__ import annotations import re import shutil +from collections.abc import Callable, Generator from pathlib import Path +from typing import Any from .errors import DataJointError class ClassProperty: - def __init__(self, f): + """ + Descriptor for defining class-level properties. + + Similar to @property but works on the class itself rather than instances. + """ + + def __init__(self, f: Callable) -> None: self.f = f - def __get__(self, obj, owner): + def __get__(self, obj: Any, owner: type) -> Any: return self.f(owner) -def user_choice(prompt, choices=("yes", "no"), default=None): +def user_choice( + prompt: str, + choices: tuple[str, ...] = ("yes", "no"), + default: str | None = None, +) -> str: """ - Prompts the user for confirmation. The default value, if any, is capitalized. + Prompt the user to select from a list of choices. - :param prompt: Information to display to the user. - :param choices: an iterable of possible choices. - :param default: default choice - :return: the user's choice + The default value, if any, is displayed capitalized. + + Args: + prompt: Message to display to the user. + choices: Tuple of valid response options. + default: Default choice if user presses Enter without input. + + Returns: + The user's selected choice (lowercase). + + Raises: + AssertionError: If default is not None and not in choices. """ assert default is None or default in choices choice_list = ", ".join((choice.title() if choice == default else choice for choice in choices)) @@ -52,46 +79,62 @@ def get_master(full_table_name: str) -> str: return match["master"] + "`" if match else "" -def is_camel_case(s): +def is_camel_case(s: str) -> bool: """ Check if a string is in CamelCase notation. - :param s: string to check - :returns: True if the string is in CamelCase notation, False otherwise - Example: - >>> is_camel_case("TableName") # returns True - >>> is_camel_case("table_name") # returns False + Args: + s: The string to check. + + Returns: + True if the string matches CamelCase pattern (starts with uppercase, + contains only alphanumeric characters). + + Examples: + >>> is_camel_case("TableName") # True + >>> is_camel_case("table_name") # False """ return bool(re.match(r"^[A-Z][A-Za-z0-9]*$", s)) -def to_camel_case(s): +def to_camel_case(s: str) -> str: """ - Convert names with under score (_) separation into camel case names. + Convert underscore-separated names to CamelCase. + + Args: + s: String in underscore_notation. + + Returns: + String in CamelCase notation. - :param s: string in under_score notation - :returns: string in CamelCase notation Example: - >>> to_camel_case("table_name") # returns "TableName" + >>> to_camel_case("table_name") # "TableName" """ - def to_upper(match): + def to_upper(match: re.Match) -> str: return match.group(0)[-1].upper() return re.sub(r"(^|[_\W])+[a-zA-Z]", to_upper, s) -def from_camel_case(s): +def from_camel_case(s: str) -> str: """ - Convert names in camel case into underscore (_) separated names + Convert CamelCase names to underscore-separated lowercase. + + Args: + s: String in CamelCase notation. + + Returns: + String in underscore_notation. + + Raises: + DataJointError: If the input is not valid CamelCase. - :param s: string in CamelCase notation - :returns: string in under_score notation Example: - >>> from_camel_case("TableName") # yields "table_name" + >>> from_camel_case("TableName") # "table_name" """ - def convert(match): + def convert(match: re.Match) -> str: return ("_" if match.groups()[0] else "") + match.group(0).lower() if not is_camel_case(s): @@ -99,12 +142,16 @@ def convert(match): return re.sub(r"(\B[A-Z])|(\b[A-Z])", convert, s) -def safe_write(filepath, blob): +def safe_write(filepath: str | Path, blob: bytes) -> None: """ - A two-step write. + Write binary data to a file atomically using a two-step process. + + Creates a temporary file first, then renames it to the target path. + This prevents partial writes from corrupting the file. - :param filename: full path - :param blob: binary data + Args: + filepath: Destination file path. + blob: Binary data to write. """ filepath = Path(filepath) if not filepath.is_file(): @@ -114,9 +161,21 @@ def safe_write(filepath, blob): temp_file.rename(filepath) -def safe_copy(src, dest, overwrite=False): +def safe_copy( + src: str | Path, + dest: str | Path, + overwrite: bool = False, +) -> None: """ - Copy the contents of src file into dest file as a two-step process. Skip if dest exists already + Copy a file atomically using a two-step process. + + Creates a temporary file first, then renames it. Skips if destination + exists (unless overwrite=True) or if src and dest are the same file. + + Args: + src: Source file path. + dest: Destination file path. + overwrite: If True, overwrite existing destination file. """ src, dest = Path(src), Path(dest) if not (dest.exists() and src.samefile(dest)) and (overwrite or not dest.is_file()): @@ -126,12 +185,20 @@ def safe_copy(src, dest, overwrite=False): temp_file.rename(dest) -def parse_sql(filepath): +def parse_sql(filepath: str | Path) -> Generator[str, None, None]: """ - yield SQL statements from an SQL file + Parse SQL statements from a file. + + Handles custom delimiters and skips SQL comments. + + Args: + filepath: Path to the SQL file. + + Yields: + Individual SQL statements as strings. """ delimiter = ";" - statement = [] + statement: list[str] = [] with Path(filepath).open("rt") as f: for line in f: line = line.strip() diff --git a/tests/test_log.py b/tests/test_log.py deleted file mode 100644 index 87905cd47..000000000 --- a/tests/test_log.py +++ /dev/null @@ -1,3 +0,0 @@ -def test_log(schema_any): - ts, events = (schema_any.log & 'event like "Declared%%"').fetch("timestamp", "event") - assert len(ts) >= 2