Source code for openapi.data.db

from dataclasses import Field, make_dataclass
from datetime import date, datetime
from decimal import Decimal
from functools import partial
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    Union,
    cast,
)

import sqlalchemy as sa
from sqlalchemy_utils import UUIDType

from . import fields

ConverterType = Callable[[sa.Column, bool, bool, Sequence[str]], Tuple[Type, Field]]
CONVERTERS: Dict[str, ConverterType] = {}


[docs]def dataclass_from_table( name: str, table: sa.Table, *, exclude: Optional[Sequence[str]] = None, include: Optional[Sequence[str]] = None, default: Union[bool, Sequence[str]] = False, required: Union[bool, Sequence[str]] = False, ops: Optional[Dict[str, Sequence[str]]] = None, ) -> Type: """Create a dataclass from an :class:`sqlalchemy.schema.Table` :param name: dataclass name :param table: sqlalchemy table :param exclude: fields to exclude from the dataclass :param include: fields to include in the dataclass :param default: use columns defaults in the dataclass :param required: set non nullable columns without a default as required fields in the dataclass :param ops: additional operation for fields """ columns = [] includes = set(include or table.columns.keys()) - set(exclude or ()) defaults = column_info(includes, default) requireds = column_info(includes, required) column_ops = cast(Dict[str, Sequence[str]], ops or {}) for col in table.columns: if col.name not in includes: continue ctype = type(col.type) converter = CONVERTERS.get(ctype) if not converter: # pragma: no cover raise NotImplementedError(f"Cannot convert column {col.name}: {ctype}") required = col.name in requireds use_default = col.name in defaults field = ( col.name, *converter(col, required, use_default, column_ops.get(col.name, ())), ) columns.append(field) return make_dataclass(name, columns)
def column_info(columns: Set[str], value: Union[bool, Sequence[str]]) -> Set[str]: if value is False: return set() elif value is True: return columns.copy() else: return set(value if value is not None else columns) def converter(*types): def _(f): for type_ in types: CONVERTERS[type_] = f return f return _ @converter(sa.Boolean) def bl( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.bool_field) return (bool, data_field(**info(col, required, use_default, ops))) @converter(sa.Integer) def integer( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.number_field) return (int, data_field(precision=0, **info(col, required, use_default, ops))) @converter(sa.Numeric) def number( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.decimal_field) return ( Decimal, data_field(precision=col.type.scale, **info(col, required, use_default, ops)), ) @converter(sa.String, sa.Text, sa.CHAR, sa.VARCHAR) def string( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.str_field) return ( str, data_field( max_length=col.type.length or 0, **info(col, required, use_default, ops) ), ) @converter(sa.DateTime) def dt_ti( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.date_time_field) return ( datetime, data_field(timezone=col.type.timezone, **info(col, required, use_default, ops)), ) @converter(sa.Date) def dt( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.date_field) return (date, data_field(**info(col, required, use_default, ops))) @converter(sa.Enum) def en( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.enum_field) return ( col.type.enum_class, data_field(col.type.enum_class, **info(col, required, use_default, ops)), ) @converter(sa.JSON) def js( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.json_field) val = None if col.default: arg = col.default.arg val = arg() if col.default.is_callable else arg return ( JsonTypes.get(type(val), Dict), data_field(**info(col, required, use_default, ops)), ) @converter(UUIDType) def uuid( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data_field = col.info.get("data_field", fields.uuid_field) return (str, data_field(**info(col, required, use_default, ops))) def info( col: sa.Column, required: bool, use_default: bool, ops: Sequence[str] ) -> Tuple[Type, Field]: data = dict(ops=ops) if use_default: default = col.default.arg if col.default is not None else None if callable(default): data.update(default_factory=partial(default, None)) required = False elif isinstance(default, (list, dict, set)): data.update(default_factory=lambda: default.copy()) required = False else: data.update(default=default) if required and (col.nullable or default is not None): required = False elif required and col.nullable: required = False data.update(required=required) if col.doc: data.update(description=col.doc) data.update(col.info) data.pop("data_field", None) return data JsonTypes = {list: List, dict: Dict}