import enum
from dataclasses import Field, dataclass, field, fields
from datetime import date, datetime, time
from decimal import Decimal, InvalidOperation
from numbers import Number
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
from uuid import UUID
from dateutil.parser import parse as parse_date
from email_validator import EmailNotValidError, validate_email
from .. import json, tz
from ..utils import compact_dict, str2bool
REQUIRED = "required"
VALIDATOR = "OPENAPI_VALIDATOR"
DESCRIPTION = "description"
POST_PROCESS = "post_process"
DUMP = "dump"
FORMAT = "format"
OPS = "ops"
ITEMS = "items"
HIDDEN = "hidden"
PRIMITIVE_TYPES: Dict[Any, Dict] = {
str: {"type": "string"},
bytes: {"type": "string", FORMAT: "binary"},
int: {"type": "integer", FORMAT: "int32"},
float: {"type": "number", FORMAT: "float"},
bool: {"type": "boolean"},
date: {"type": "string", FORMAT: "date"},
datetime: {"type": "string", FORMAT: "date-time"},
Decimal: {"type": "number"},
}
class Ops(enum.Enum):
eq = enum.auto()
ne = enum.auto()
gt = enum.auto()
ge = enum.auto()
lt = enum.auto()
le = enum.auto()
class ValidationError(ValueError):
def __init__(self, field: str, message: str) -> None:
self.field = field
self.message = message
def field_dict(dc: type) -> Dict[str, Field]:
return {f.name: f for f in fields(dc)}
[docs]def data_field(
required: bool = False,
validator: Optional[Callable[[Field, Any], Any]] = None,
dump: Optional[Callable[[Any], Any]] = None,
format: Optional[str] = None,
description: Optional[str] = None,
items: Optional[Field] = None,
post_process: Callable[[Any], Any] = None,
ops: Tuple = (),
hidden: bool = False,
meta: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Field:
"""Extend a dataclass field with the following metadata
:param validator: optional callable which accept field and raw value
as inputs and return the validated value
:param required: boolean specifying if field is required
:param dump: optional callable which receive the field value and convert to
the desired value to serve in requests
:param format: optional string which represents the JSON schema format
:param description: optional field description
:param items: field for items of the current field
(only used for `List` and `Dict` fields)
:param post_process: post processor function executed after validation
:param ops: optional tuple of strings specifying available operations
:param hidden: when `True` the field is not added to the Openapi documentation
:param meta: optional dictionary with additional metadata
"""
if isinstance(validator, Validator) and not dump:
dump = validator.dump
# Add default None otherwisenon-default fields can follow default ones
if "default_factory" not in kwargs:
kwargs.setdefault("default", None)
meta = meta or {}
f = field(
metadata=compact_dict(
{
VALIDATOR: validator,
REQUIRED: required,
DUMP: dump,
DESCRIPTION: description,
ITEMS: items,
POST_PROCESS: post_process,
FORMAT: format,
OPS: ops,
HIDDEN: hidden,
**meta,
}
),
**kwargs,
)
return f
[docs]def str_field(min_length: int = 0, max_length: int = 0, **kw) -> Field:
"""A specialized :func:`.data_field` for strings
:param min_length: minim length of string
:param max_length: maximum length of string
"""
kw.setdefault(
"validator", StrValidator(min_length=min_length, max_length=max_length)
)
return data_field(**kw)
[docs]def bool_field(**kw) -> Field:
"""Specialized :func:`.data_field` for bool types"""
kw.setdefault("validator", BoolValidator())
return data_field(**kw)
[docs]def uuid_field(format: str = "uuid", **kw) -> Field:
"""A UUID field with validation"""
kw.setdefault("validator", UUIDValidator())
return data_field(format=format, **kw)
[docs]def number_field(
min_value: Optional[Number] = None,
max_value: Optional[Number] = None,
precision: Optional[int] = None,
**kw,
) -> Field:
"""A specialized :func:`.data_field` for numeric values
:param min_value: minimum value
:param max_value: maximum value
:param precision: decimal precision
"""
kw.setdefault("validator", NumberValidator(min_value, max_value, precision))
return data_field(**kw)
[docs]def integer_field(
min_value: Optional[Number] = None,
max_value: Optional[Number] = None,
**kw,
) -> Field:
"""A specialized :func:`.data_field` for integer values
:param min_value: minimum value
:param max_value: maximum value
"""
kw.setdefault("validator", IntegerValidator(min_value, max_value))
return data_field(**kw)
def decimal_field(min_value=None, max_value=None, precision=None, **kw) -> Field:
kw.setdefault("validator", DecimalValidator(min_value, max_value, precision))
return data_field(**kw)
[docs]def email_field(min_length: int = 0, max_length: int = 0, **kw) -> Field:
"""A specialized :func:`.data_field` for emails, validation via the
`email_validator` third party library
:param min_length: minimum length of email
:param max_length: maximum length of email
"""
kw.setdefault(
"validator", EmailValidator(min_length=min_length, max_length=max_length)
)
return data_field(**kw)
[docs]def enum_field(EnumClass, **kw) -> Field:
"""A specialized :func:`.data_field` for enums
:param EnumClass: enum for validation
"""
kw.setdefault("validator", EnumValidator(EnumClass))
return data_field(**kw)
[docs]def date_field(**kw) -> Field:
"""A specialized :func:`.data_field` for dates"""
kw.setdefault("validator", DateValidator())
return data_field(**kw)
[docs]def date_time_field(timezone=False, **kw) -> Field:
"""A specialized :func:`.data_field` for datetimes
:param timezone: timezone for validation
"""
kw.setdefault("validator", DateTimeValidator(timezone=timezone))
return data_field(**kw)
def as_field(item: Any, *, field: Optional[Field] = None, **kw) -> Field:
if isinstance(item, Field):
return item
field = field or data_field(**kw)
if field.type and field.type is not item:
raise RuntimeError("Cannot override field type")
field.type = item
return field
[docs]def json_field(**kw) -> Field:
"""A specialized :func:`.data_field` for JSON data"""
kw.setdefault("validator", JSONValidator())
return data_field(**kw)
# Utilities
def field_ops(field: Field) -> Iterator[str]:
yield field.name
for op in field.metadata.get(OPS, ()):
yield f"{field.name}:{op}"
# VALIDATORS
class Validator:
def __call__(self, field: Field, value: Any) -> Any:
raise ValidationError(field.name, "invalid")
def openapi(self, prop: Dict) -> None:
pass
def dump(self, value: Any) -> Any:
return value
@dataclass
class StrValidator(Validator):
min_length: int = 0
max_length: int = 0
def __call__(self, field: Field, value: Any) -> Any:
if not isinstance(value, str):
raise ValidationError(field.name, "Must be a string")
if self.min_length and len(value) < self.min_length:
raise ValidationError(field.name, "Too short")
if self.max_length and len(value) > self.max_length:
raise ValidationError(field.name, "Too long")
return value
def openapi(self, prop: Dict) -> None:
if self.min_length:
prop["minLength"] = self.min_length
if self.max_length:
prop["maxLength"] = self.max_length
@dataclass
class EmailValidator(StrValidator):
def __call__(self, field: Field, value: Any) -> Any:
value = super().__call__(field, value)
try:
validate_email(value, check_deliverability=False)
except EmailNotValidError:
raise ValidationError(field.name, "%s not a valid email" % value) from None
return value
class ListValidator(Validator):
def __init__(self, validators) -> None:
self.validators = validators
def __call__(self, field: Field, value: Any) -> Any:
for validator in self.validators:
value = validator(field, value)
return value
def dump(self, value: Any) -> Any:
for validator in self.validators:
dump = getattr(validator, "dump", None)
if hasattr(dump, "__call__"):
value = dump(value)
return value
def openapi(self, prop: Dict) -> None:
for validator in self.validators:
if isinstance(validator, Validator):
validator.openapi(prop)
class UUIDValidator(Validator):
def __call__(self, field: Field, value: Any) -> Any:
try:
if not isinstance(value, UUID):
value = UUID(str(value))
return value.hex
except ValueError:
raise ValidationError(field.name, "%s not a valid uuid" % value)
def dump(self, value: Any) -> Any:
if isinstance(value, UUID):
return value.hex
return value
class EnumValidator(Validator):
"""Enum validator to and from name (str) and value (int)"""
def __init__(self, EnumClass) -> None:
self.EnumClass = EnumClass
def __call__(self, field: Field, value: Any) -> Any:
try:
e = value
if isinstance(e, str):
e = getattr(self.EnumClass, value)
if isinstance(e, self.EnumClass):
return e if field.type == self.EnumClass else e.name
raise AttributeError
except AttributeError:
raise ValidationError(field.name, "%s not valid" % value)
def dump(self, value: Any) -> Any:
if isinstance(value, self.EnumClass):
return value.name
return value
class Choice(Validator):
def __init__(self, choices) -> None:
self.choices = choices
def __call__(self, field: Field, value: Any) -> Any:
if value not in self.choices:
raise ValidationError(field.name, "%s not valid" % value)
return value
def openapi(self, prop: Dict) -> None:
prop["enum"] = sorted(self.choices)
class DateValidator(Validator):
def dump(self, value: Any) -> Any:
if isinstance(value, datetime):
return value.date().isoformat()
elif isinstance(value, date):
return value.isoformat()
return value
def __call__(self, field: Field, value: Any) -> Any:
if isinstance(value, str):
try:
value = parse_date(value).date()
except ValueError:
pass
if not isinstance(value, date):
raise ValidationError(field.name, "%s not valid format" % value)
return value
class DateTimeValidator(Validator):
def __init__(self, timezone=False) -> None:
self.timezone = timezone
def dump(self, value: Any) -> Any:
if isinstance(value, datetime):
return value.isoformat()
return value
def __call__(self, field: Field, value: Any) -> Any:
if isinstance(value, str):
try:
value = parse_date(value)
except ValueError:
pass
if not isinstance(value, datetime):
raise ValidationError(field.name, "%s not valid format" % value)
if self.timezone and not value.tzinfo:
if value.time() == time():
value = tz.as_utc(value)
else:
raise ValidationError(field.name, "Timezone information required")
return value
NumericErrors = (TypeError, ValueError, InvalidOperation)
class BoundedNumberValidator(Validator):
def __init__(
self, min_value: Optional[Number] = None, max_value: Optional[Number] = None
) -> None:
self.min_value = min_value
self.max_value = max_value
def __call__(self, field: Field, value: Any) -> Any:
if self.min_value is not None and value < self.min_value:
raise ValidationError(
field.name, "%s less than %s" % (value, self.min_value)
)
if self.max_value is not None and value > self.max_value:
raise ValidationError(
field.name, "%s greater than %s" % (value, self.max_value)
)
return value
def dump(self, value: Any) -> Number:
return self.to_number(value)
def openapi(self, prop: Dict) -> None:
if self.min_value is not None:
prop["minimum"] = self.min_value
if self.max_value is not None:
prop["maximum"] = self.max_value
def to_number(self, value: Any) -> Number:
if isinstance(value, str):
try:
return int(value)
except ValueError:
return Decimal(value)
else:
return value
class NumberValidator(BoundedNumberValidator):
def __init__(
self,
min_value: Optional[Number] = None,
max_value: Optional[Number] = None,
precision: Optional[int] = None,
) -> None:
super().__init__(min_value=min_value, max_value=max_value)
self.precision = precision
def __call__(self, field: Field, value: Any) -> Any:
try:
value = self.to_number(value)
if self.precision is not None:
value = round(value, self.precision)
except NumericErrors:
raise ValidationError(field.name, "%s not valid number" % value)
return super().__call__(field, value)
def dump(self, value: Any) -> Any:
value = self.to_number(value)
if self.precision is not None:
return round(value, self.precision)
return value
class IntegerValidator(BoundedNumberValidator):
def __call__(self, field: Field, value: Any) -> Any:
try:
value = self.to_number(value)
if not isinstance(value, int):
raise ValueError
except NumericErrors:
raise ValidationError(field.name, "%s not valid integer" % value)
return super().__call__(field, value)
class DecimalValidator(NumberValidator):
def __call__(self, field: Field, value: Any) -> Any:
try:
value = self.to_number(value)
if not isinstance(value, Decimal):
value = Decimal(str(value))
except NumericErrors:
raise ValidationError(field.name, "%s not valid Decimal" % value)
return super().__call__(field, value)
class BoolValidator(Validator):
def __call__(self, field: Field, value: Any) -> bool:
return str2bool(value)
def dump(self, value: Any) -> bool:
return str2bool(value)
class JSONValidator(Validator):
def __call__(self, field: Field, value: Any) -> Any:
try:
return self.dump(value)
except (json.JSONDecodeError, TypeError):
raise ValidationError(field.name, "%s not valid" % value)
def dump(self, value: Any) -> Any:
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
pass
return json.loads(json.dumps(value))