Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions docs/tutorial/sharing_pool.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
---
title: Sharing a Connection Pool
---

By default, each component (broker, result backend, schedule source) creates and manages its own connection pool. If you are integrating taskiq-postgres into an application that already maintains a pool — or if you simply want to reduce the total number of database connections — you can pass a single pool to all three components.

## How it works

`PsycopgBroker`, `PsycopgResultBackend`, and `PsycopgScheduleSource` each accept an optional `pool` (or `write_pool`) keyword argument. When a pool is provided:

- The component sets `_owns_pool = False` and uses the pool as-is.
- `startup()` opens the pool if it is not yet open, but will **not** close it on `shutdown()`.
- Lifecycle management (opening, closing) is your responsibility.
Comment on lines +11 to +13

## Basic example

```python
import asyncio

from psycopg import AsyncConnection, AsyncRawCursor
from psycopg_pool import AsyncConnectionPool
from taskiq import TaskiqScheduler

from taskiq_pg.psycopg import PsycopgBroker, PsycopgResultBackend, PsycopgScheduleSource

DSN = "postgres://user:password@localhost:5432/mydb"

async def main() -> None:
# Create one pool shared by all components.
pool = AsyncConnectionPool(conninfo=DSN, open=False)
# A dedicated connection is required by the broker for LISTEN/NOTIFY.
read_conn = await AsyncConnection.connect(
conninfo=DSN, autocommit=True, cursor_factory=AsyncRawCursor
)

broker = PsycopgBroker(
write_pool=pool,
read_connection=read_conn,
).with_result_backend(
PsycopgResultBackend(pool=pool)
)

schedule_source = PsycopgScheduleSource(broker=broker, pool=pool)
scheduler = TaskiqScheduler(broker=broker, sources=[schedule_source])

await broker.startup()
# ... run your application ...
await broker.shutdown()

# Close shared resources after all components have shut down.
await read_conn.close()
await pool.close()


if __name__ == "__main__":
asyncio.run(main())
```

You can see fully working example inside repository in `examples/example_with_shared_pool.py`.
70 changes: 70 additions & 0 deletions examples/example_with_shared_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
How to run:

1) Run worker in one terminal:
uv run taskiq worker examples.example_with_shared_pool:get_broker --workers 1

2) Run this script in another terminal:
uv run python -m examples.example_with_shared_pool
"""

import asyncio

from psycopg import AsyncConnection, AsyncRawCursor
from psycopg_pool import AsyncConnectionPool
from taskiq import async_shared_broker

from taskiq_pg.psycopg import PsycopgBroker, PsycopgResultBackend


DSN = "postgres://taskiq_postgres:look_in_vault@localhost:5432/taskiq_postgres"


def create_pool() -> AsyncConnectionPool:
return AsyncConnectionPool(conninfo=DSN, open=False, timeout=5)


async def create_connection() -> AsyncConnection:
return await AsyncConnection.connect(
conninfo=DSN,
autocommit=True,
cursor_factory=AsyncRawCursor,
)


def make_broker(pool: AsyncConnectionPool, connection: AsyncConnection) -> PsycopgBroker:
broker = PsycopgBroker(
write_pool=pool,
read_connection=connection,
).with_result_backend(PsycopgResultBackend(pool=pool))
async_shared_broker.default_broker(broker)
return broker


@async_shared_broker.task("solve_all_problems")
async def best_task_ever() -> None:
"""Solve all problems in the world."""
await asyncio.sleep(2)
print("All problems are solved!")


def get_broker() -> PsycopgBroker:
"""Sync factory used by the taskiq worker CLI."""
pool = create_pool()
connection = asyncio.run(create_connection())
return make_broker(pool, connection)


async def main() -> None:
pool = create_pool()
connection = await create_connection()
broker = make_broker(pool, connection)

await broker.startup()
task = await best_task_ever.kiq()
print(await task.wait_result())
await broker.shutdown()


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ nav:
- Tutorial:
- tutorial/result_backend.md
- tutorial/schedule_source.md
- tutorial/sharing_pool.md
- tutorial/common_issues.md
- API:
- reference.md
Expand Down
14 changes: 10 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Information Technology",
"Framework :: AsyncIO",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
Expand Down Expand Up @@ -61,6 +60,7 @@ dev = [
{include-group = "test"},
{include-group = "docs"},
"prek>=0.2.19",
"ty>=0.0.34",
]
lint = [
"ruff>=0.14.8",
Expand All @@ -73,11 +73,11 @@ lint = [
]
test = [
"polyfactory>=3.1.0",
"pytest>=9.0.1",
"pytest>=9.0.2",
"pytest-asyncio>=1.3.0",
"pytest-cov>=7.0.0",
# for database in tests
"sqlalchemy-utils>=0.42.0",
"sqlalchemy-utils>=0.42.1",
# for faster asyncio loop in tests
"uvloop>=0.22.1",
]
Expand All @@ -87,7 +87,7 @@ docs = [
]

[build-system]
requires = ["uv_build>=0.9,<0.10"]
requires = ["uv_build>=0.11,<0.12"]
build-backend = "uv_build"

[tool.uv.build-backend]
Expand Down Expand Up @@ -145,6 +145,9 @@ ignore = [
"D203", # with D211
"D212", # with D213
"COM812", # with formatter

"EM101",
"TRY003",
]

[tool.ruff.lint.per-file-ignores]
Expand Down Expand Up @@ -183,6 +186,9 @@ convention = "google"
known-local-folder = ["taskiq_pg"]
lines-after-imports = 2

[tool.ruff.lint.pylint]
max-args = 9

[tool.mypy]
python_version = "3.10"
modules = "taskiq_pg"
Expand Down
5 changes: 1 addition & 4 deletions src/taskiq_pg/_internal/broker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import abc
import typing as tp

Expand All @@ -16,7 +14,7 @@
class BasePostgresBroker(AsyncBroker, abc.ABC):
"""Base class for Postgres brokers."""

def __init__( # noqa: PLR0913
def __init__(
self,
dsn: str | tp.Callable[[], str] = "postgresql://postgres:postgres@localhost:5432/postgres",
result_backend: AsyncResultBackend[_T] | None = None,
Expand All @@ -39,7 +37,6 @@ def __init__( # noqa: PLR0913
max_retry_attempts: Maximum number of message processing attempts.
read_kwargs: Additional arguments for read connection creation.
write_kwargs: Additional arguments for write pool creation.

"""
super().__init__(
result_backend=result_backend,
Expand Down
1 change: 0 additions & 1 deletion src/taskiq_pg/_internal/result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(
field_for_task_id: type of the field to store task_id.
serializer: serializer class to serialize/deserialize result from task.
connect_kwargs: additional arguments for creating connection pool.

"""
self._dsn: tp.Final = dsn
self.keep_results: tp.Final = keep_results
Expand Down
8 changes: 1 addition & 7 deletions src/taskiq_pg/_internal/schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
from __future__ import annotations

import typing as tp
import uuid
from logging import getLogger

from pydantic import ValidationError
from taskiq import ScheduleSource
from taskiq.abc.broker import AsyncBroker
from taskiq.scheduler.scheduled_task import ScheduledTask


if tp.TYPE_CHECKING:
from taskiq.abc.broker import AsyncBroker


logger = getLogger("taskiq_pg")


Expand All @@ -37,7 +32,6 @@ def __init__(
broker: The TaskIQ broker instance to use for finding and managing tasks.
Required if startup_schedule is provided.
**connect_kwargs: Additional keyword arguments passed to the database connection pool.

"""
self._broker: tp.Final = broker
self._dsn: tp.Final = dsn
Expand Down
80 changes: 72 additions & 8 deletions src/taskiq_pg/aiopg/result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from aiopg import Pool, create_pool
from taskiq import TaskiqResult
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.depends.progress_tracker import TaskProgress

from taskiq_pg import exceptions
Expand All @@ -13,19 +14,82 @@ class AiopgResultBackend(BasePostgresResultBackend):
"""Result backend for TaskIQ based on Aiopg."""

_database_pool: Pool
_owns_pool: bool

@tp.overload
def __init__(
self,
dsn: tp.Callable[[], str] | str | None = ...,
keep_results: bool = ...,
table_name: str = ...,
field_for_task_id: tp.Literal["VarChar", "Text", "Uuid"] = ...,
serializer: TaskiqSerializer | None = ...,
*,
pool: None = ...,
**connect_kwargs: tp.Any,
) -> None: ...

@tp.overload
def __init__(
self,
dsn: tp.Callable[[], str] | str | None = ...,
keep_results: bool = ...,
table_name: str = ...,
field_for_task_id: tp.Literal["VarChar", "Text", "Uuid"] = ...,
serializer: TaskiqSerializer | None = ...,
*,
pool: Pool,
) -> None: ...

def __init__(
self,
dsn: tp.Callable[[], str] | str | None = "postgres://postgres:postgres@localhost:5432/postgres",
keep_results: bool = True,
table_name: str = "taskiq_results",
field_for_task_id: tp.Literal["VarChar", "Text", "Uuid"] = "VarChar",
serializer: TaskiqSerializer | None = None,
*,
pool: Pool | None = None,
**connect_kwargs: tp.Any,
) -> None:
"""
Construct a new AiopgResultBackend.

Args:
dsn: PostgreSQL connection string or callable. Can be None if pool is provided.
keep_results: Whether to keep results after reading.
table_name: Table to store results in.
field_for_task_id: Column type for task_id.
serializer: Serializer for task results.
pool: An existing connection pool to reuse.
**connect_kwargs: Extra kwargs for connection pool creation.
"""
self._owns_pool = True
if pool is not None:
self._owns_pool = False
self._database_pool = pool

super().__init__(
dsn=dsn,
keep_results=keep_results,
table_name=table_name,
field_for_task_id=field_for_task_id,
serializer=serializer,
**connect_kwargs,
)

async def startup(self) -> None:
"""
Initialize the result backend.

Construct new connection pool
and create new table for results if not exists.
Construct new connection pool (if not provided externally) and create new table for results if not exists.
"""
try:
self._database_pool = await create_pool(
self.dsn,
**self.connect_kwargs,
)
if self._owns_pool:
self._database_pool = await create_pool(
self.dsn,
**self.connect_kwargs,
)

async with self._database_pool.acquire() as connection, connection.cursor() as cursor:
await cursor.execute(
Expand All @@ -50,8 +114,8 @@ async def startup(self) -> None:
raise exceptions.DatabaseConnectionError(str(error)) from error

async def shutdown(self) -> None:
"""Close the connection pool."""
if getattr(self, "_database_pool", None) is not None:
"""Close the connection pool if created by this backend."""
if self._owns_pool and getattr(self, "_database_pool", None) is not None:
self._database_pool.close()

async def set_result(
Expand Down
Loading