Source code for openapi.db.container

import os
from contextlib import asynccontextmanager
from typing import Any, Optional

import sqlalchemy as sa
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from openapi.types import Connection
from openapi.utils import str2bool

from ..exc import ImproperlyConfigured

DBPOOL_MAX_SIZE = int(os.environ.get("DBPOOL_MAX_SIZE") or "10")
DBPOOL_MAX_OVERFLOW = int(os.environ.get("DBPOOL_MAX_OVERFLOW") or "10")
DBECHO = str2bool(os.environ.get("DBECHO") or "no")


[docs]class Database: """A container for tables in a database and a manager of asynchronous connections to a postgresql database :param dsn: Data source name used for database connections :param metadata: :class:`sqlalchemy.schema.MetaData` containing tables """ def __init__(self, dsn: str = "", metadata: sa.MetaData = None) -> None: self._dsn = dsn self._metadata = metadata or sa.MetaData() self._engine = None def __repr__(self) -> str: return self._dsn __str__ = __repr__ @property def dsn(self) -> str: """Data source name used for database connections""" return self._dsn @property def metadata(self) -> sa.MetaData: """The :class:`sqlalchemy.schema.MetaData` containing tables""" return self._metadata @property def engine(self) -> AsyncEngine: """The :class:`sqlalchemy.ext.asyncio.AsyncEngine` creating connection and transactions""" if self._engine is None: if not self._dsn: raise ImproperlyConfigured("DSN not available") self._engine = create_async_engine( self._dsn, echo=DBECHO, pool_size=DBPOOL_MAX_SIZE, max_overflow=DBPOOL_MAX_OVERFLOW, ) return self._engine @property def sync_engine(self) -> Engine: """The :class:`sqlalchemy.engine.Engine` for synchrouns operations""" return create_engine(self._dsn.replace("+asyncpg", ""))
[docs] def __getattr__(self, name: str) -> Any: """Retrive a :class:`sqlalchemy.schema.Table` from metadata tables :param name: if this is a valid table name in the tables of :attr:`.metadata` it returns the table, otherwise it defaults to superclass method """ if name in self._metadata.tables: return self._metadata.tables[name] return super().__getattribute__(name)
[docs] @asynccontextmanager async def connection(self) -> Connection: """Context manager for obtaining an asynchronous connection""" async with self.engine.connect() as conn: yield conn
[docs] @asynccontextmanager async def transaction(self) -> Connection: """Context manager for initializing an asynchronous database transaction""" async with self.engine.begin() as conn: yield conn
[docs] @asynccontextmanager async def ensure_transaction(self, conn: Optional[Connection] = None) -> Connection: """Context manager for ensuring we a connection has initialized a database transaction""" if conn: if not conn.in_transaction(): async with conn.begin(): yield conn else: yield conn else: async with self.transaction() as conn: yield conn
# backward compatibility ensure_connection = ensure_transaction
[docs] async def close(self) -> None: """Close the asynchronous db engine if opened""" if self._engine: engine, self._engine = self._engine, None await engine.dispose()
# SQL Alchemy Sync Operations
[docs] def create_all(self) -> None: """Create all tables defined in :attr:`metadata`""" self.metadata.create_all(self.sync_engine)
[docs] def drop_all(self) -> None: """Drop all tables from :attr:`metadata` in database""" with self.sync_engine.begin() as conn: conn.execute(sa.text(f'truncate {", ".join(self.metadata.tables)}')) try: conn.execute(sa.text("drop table alembic_version")) except Exception: # noqa pass
[docs] def drop_all_schemas(self) -> None: """Drop all schema in database""" with self.sync_engine.begin() as conn: conn.execute(sa.text("DROP SCHEMA IF EXISTS public CASCADE")) conn.execute(sa.text("CREATE SCHEMA IF NOT EXISTS public"))