Source code for openapi.db.dbmodel

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast

from sqlalchemy import Column, Table, func, insert, select
from sqlalchemy.sql import Select, and_, or_
from sqlalchemy.sql.dml import Delete, Insert, Update

from ..db.container import Database
from ..pagination import (
    Pagination,
    PaginationVisitor,
    Search,
    SearchVisitor,
    fields_flip_sign,
)
from ..pagination.cursor import cursor_to_python
from ..types import Connection, Record, Records

QueryType = Union[Delete, Select, Update]
SelectUpdate = Union[Select, Update]


[docs]class CrudDB(Database): """A :class:`.Database` with additional methods for CRUD operations"""
[docs] async def db_select( self, table: Table, filters: Dict, *, conn: Optional[Connection] = None, consumer: Any = None, ) -> Records: """Select rows from a given table :param table: sqlalchemy Table :param filters: key-value pairs for filtering rows :param conn: optional db connection :param consumer: optional consumer (see :meth:`.get_query`) """ sql_query = self.get_query( table, table.select(), consumer=consumer, params=filters ) async with self.ensure_connection(conn) as conn: return await conn.execute(sql_query)
[docs] async def db_delete( self, table: Table, filters: Dict, *, conn: Optional[Connection] = None, consumer: Any = None, ) -> Records: """Delete rows from a given table :param table: sqlalchemy Table :param filters: key-value pairs for filtering rows :param conn: optional db connection :param consumer: optional consumer (see :meth:`.get_query`) """ sql_query = self.get_query( table, table.delete().returning(*table.columns), consumer=consumer, params=filters, ) async with self.ensure_connection(conn) as conn: return await conn.execute(sql_query)
[docs] async def db_count( self, table: Table, filters: Optional[Dict] = None, *, conn: Optional[Connection] = None, consumer: Any = None, ) -> int: """Count rows in a table :param table: sqlalchemy Table :param filters: key-value pairs for filtering rows :param conn: optional db connection :param consumer: optional consumer (see :meth:`.get_query`) """ sql_query = self.get_query( table, table.select(), consumer=consumer, params=filters ) return await self.db_count_query(sql_query, conn=conn)
async def db_count_query( self, sql_query: Select, *, conn: Optional[Connection] = None, ) -> int: count_query = select(func.count()).select_from(sql_query.alias("inner")) async with self.ensure_connection(conn) as conn: result = await conn.execute(count_query) return result.scalar()
[docs] async def db_insert( self, table: Table, data: Union[List[Dict], Dict], *, conn: Optional[Connection] = None, ) -> Records: """Perform an insert into a table :param table: sqlalchemy Table :param data: key-value pairs for columns values :param conn: optional db connection """ async with self.ensure_connection(conn) as conn: sql_query = self.insert_query(table, data) return await conn.execute(sql_query)
[docs] async def db_update( self, table: Table, filters: Dict, data: Dict, *, conn: Optional[Connection] = None, consumer: Any = None, ) -> Records: """Perform an update of rows :param table: sqlalchemy Table :param filters: key-value pairs for filtering rows to update :param data: key-value pairs for updating columns values of selected rows :param conn: optional db connection :param consumer: optional consumer (see :meth:`.get_query`) """ update = ( cast( Update, self.get_query( table, table.update(), consumer=consumer, params=filters ), ) .values(**data) .returning(*table.columns) ) async with self.ensure_connection(conn) as conn: return await conn.execute(update)
[docs] async def db_upsert( self, table: Table, filters: Dict, data: Optional[Dict] = None, *, conn: Optional[Connection] = None, consumer: Any = None, ) -> Record: """Perform an upsert for a single record :param table: sqlalchemy Table :param filters: key-value pairs for filtering rows to update :param data: key-value pairs for updating columns values of selected rows :param conn: optional db connection :param consumer: optional consumer (see :meth:`.get_query`) """ if data: result = await self.db_update( table, filters, data, conn=conn, consumer=consumer ) else: result = await self.db_select(table, filters, conn=conn, consumer=consumer) record = result.one_or_none() if record is None: insert_data = data.copy() if data else {} insert_data.update(filters) result = await self.db_insert(table, insert_data, conn=conn) record = result.one() return record
async def db_paginate( self, table: Table, sql_query: Select, pagination: Pagination, *, conn: Optional[Connection] = None, ) -> Tuple[Records, Optional[int]]: pagination_visitor = DbPaginationVisitor( db=self, table=table, sql_query=sql_query ) pagination.apply(pagination_visitor) async with self.ensure_connection(conn) as conn: return await pagination_visitor.execute(conn) # Query methods def insert_query(self, table: Table, records: Union[List[Dict], Dict]) -> Insert: if isinstance(records, dict): records = [records] else: cols: Set[str] = set() for record in records: cols.update(record) new_records = [] for record in records: if len(record) < len(cols): record = record.copy() missing = cols.difference(record) for col in missing: record[col] = None new_records.append(record) records = new_records return insert(table).values(records).returning(*table.columns) # backward compatibility get_insert = insert_query
[docs] def get_query( self, table: Table, sql_query: QueryType, *, params: Optional[Dict] = None, consumer: Any = None, ) -> QueryType: """Build an SqlAlchemy query :param table: sqlalchemy Table :param sql_query: sqlalchemy query type :param params: key-value pairs for the query :param consumer: optional consumer for manipulating parameters """ filters: List = [] columns = table.c params = params or {} for key, value in params.items(): bits = key.split(":") field = bits[0] op = bits[1] if len(bits) == 2 else "eq" filter_field = getattr(consumer, f"filter_{field}", None) if filter_field: result = filter_field(op, value) else: field = getattr(columns, field) result = self.default_filter_field(field, op, value) if result is not None: if not isinstance(result, (list, tuple)): result = (result,) filters.extend(result) if filters: whereclause = and_(*filters) if len(filters) > 1 else filters[0] sql_query = cast(Select, sql_query).where(whereclause) return sql_query
[docs] def search_query( self, table: Table, sql_query: SelectUpdate, search: Search ) -> SelectUpdate: """Build an SqlAlchemy query for a search :param table: sqlalchemy Table :param sql_query: sqlalchemy query type :param search: the search dataclass """ search_visitor = DbSearchVisitor( db=self, table=table, sql_query=cast(SelectUpdate, sql_query) ) search.apply(search_visitor) return search_visitor.sql_query
[docs] def order_by_query( self, table: Table, sql_query: Select, order_by: Optional[Union[str, Sequence[str]]], ) -> Select: """Apply ordering to a sql_query""" if isinstance(order_by, str): order_by = (order_by,) for name in order_by or (): if name.startswith("-"): order_by_column = getattr(table.c, name[1:], None) if order_by_column is not None: order_by_column = order_by_column.desc() else: order_by_column = getattr(table.c, name, None) if order_by_column is not None: sql_query = sql_query.order_by(order_by_column) return sql_query
# backward compatibility order_by = order_by_query
[docs] def default_filter_field(self, field: Column, op: str, value: Any): """ Applies a filter on a field. Notes on 'ne' op: Example data: [None, 'john', 'roger'] ne:john would return only roger (i.e. nulls excluded) ne: would return john and roger Notes on 'search' op: For some reason, SQLAlchemy uses to_tsquery rather than plainto_tsquery for the match operator to_tsquery uses operators (&, |, ! etc.) while plainto_tsquery tokenises the input string and uses AND between tokens, hence plainto_tsquery is what we want here For other database back ends, the behaviour of the match operator is completely different - see: http://docs.sqlalchemy.org/en/rel_1_0/core/sqlelement.html :param field: field name :param op: 'eq', 'ne', 'gt', 'lt', 'ge', 'le' or 'search' :param value: comparison value, string or list/tuple :return: """ multiple = isinstance(value, (list, tuple)) if multiple and op in ("eq", "ne"): if op == "eq": return field.in_(value) elif op == "ne": return ~field.in_(value) else: if multiple: assert len(value) > 0 value = value[0] if op == "eq": return field == value elif op == "ne": return field != value elif op == "gt": return field > value elif op == "ge": return field >= value elif op == "lt": return field < value elif op == "le": return field <= value
@dataclass class DbSearchVisitor(SearchVisitor): db: CrudDB table: Table sql_query: SelectUpdate def apply_search(self, search: str, search_fields: Sequence[str]) -> None: if search: columns = [getattr(self.table.c, col) for col in search_fields] self.sql_query = self.sql_query.where( or_(*(col.ilike(f"%{search}%") for col in columns)) ) @dataclass class DbPaginationVisitor(PaginationVisitor): db: CrudDB table: Table sql_query: Select initial_sql: Optional[QueryType] = None def apply_offset_pagination( self, limit: int, offset: int, order_by: Optional[Union[str, Sequence[str]]], ) -> None: self.initial_sql = self.sql_query sql_query = self.db.order_by_query(self.table, self.sql_query, order_by) if offset: sql_query = sql_query.offset(offset) if limit: sql_query = sql_query.limit(limit) self.sql_query = sql_query def apply_cursor_pagination( self, cursor: Sequence[Tuple[str, str]], limit: int, order_by: Sequence[str], previous: bool = False, ) -> None: sql_query = self.sql_query for key, value in cursor: sql_query = sql_query.where(self.filter(key, value, previous)) extra = 1 if previous: extra += 1 order_by = fields_flip_sign(order_by) self.sql_query = self.db.order_by_query(self.table, sql_query, order_by).limit( limit + extra ) async def execute(self, conn: Connection) -> Tuple[Records, Optional[int]]: total = None if self.initial_sql is not None: total = await self.db.db_count_query(self.initial_sql, conn=conn) values = await conn.execute(self.sql_query) return values, total def filter(self, field: str, value: str, previous: bool) -> Column: if field.startswith("-"): field = field[1:] op = "ge" if previous else "le" else: op = "le" if previous else "ge" column = getattr(self.table.c, field) py_value = cursor_to_python(column.type.python_type, value) return self.db.default_filter_field(column, op, py_value)