import os
from collections import OrderedDict
from dataclasses import MISSING, Field, asdict, dataclass, field
from dataclasses import fields as get_fields
from dataclasses import is_dataclass
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union, cast
from aiohttp import hdrs, web
from ..data import fields
from ..data.exc import ErrorMessage, FieldError, ValidationErrors, error_response_schema
from ..exc import InvalidSpecException, InvalidTypeException
from ..utils import TypingInfo, compact, is_subclass
from .path import ApiPath
from .redoc import Redoc
from .server import default_server
from .utils import load_yaml_from_docstring, trim_docstring
OPENAPI = "3.0.3"
METHODS = [method.lower() for method in hdrs.METH_ALL]
SCHEMAS_TO_SCHEMA = ("response_schema", "body_schema")
SCHEMA_BASE_REF = "#/components/schemas/"
EMPTY_DEFAULTS = frozenset((None, MISSING, ""))
SPEC_ROUTE = os.environ.get("SPEC_ROUTE", "/spec")
@dataclass
class Contact:
name: str = "API Support"
url: str = "http://www.example.com/support"
email: str = "support@example.com"
@dataclass
class License:
name: str = "Apache 2.0"
url: str = "https://www.apache.org/licenses/LICENSE-2.0.html"
[docs]@dataclass
class OpenApiInfo:
"""Open API Info object"""
title: str = "Open API"
description: str = ""
version: str = "0.1.0"
termsOfService: str = ""
contact: Contact = field(default_factory=Contact)
license: License = field(default_factory=License)
# for backward compatibility
OpenApi = OpenApiInfo
[docs]@dataclass
class OpenApiSpec:
"""Open API Specification"""
info: OpenApiInfo = field(default_factory=OpenApiInfo)
"""openapi info object, it provides metadata about the API"""
default_content_type: str = "application/json"
default_responses: Dict = field(default_factory=dict)
security: Dict = field(default_factory=dict)
servers: List[Dict] = field(default_factory=list)
validate_docs: bool = False
allowed_tags: Set = field(default_factory=set)
spec_url: str = SPEC_ROUTE
"""the path serving the JSON openpi specification"""
redoc: Optional[Redoc] = None
"""Optional object for rendering the specification as an HTML page via redoc"""
[docs] def routes(self, request: web.Request) -> Iterable:
"""Routes to include in the spec"""
return request.app.router.routes()
def setup_app(self, app: web.Application):
app["spec"] = self
app.router.add_get(self.spec_url, self.spec_route, name="openapi_spec")
if self.redoc:
app.router.add_get(self.redoc.path, self.redoc.handle_doc)
[docs] async def spec_route(self, request: web.Request) -> web.Response:
"""Return the OpenApi spec"""
return web.json_response(self.build(request))
def build(self, request: web.Request) -> Dict:
doc = SpecDoc(request, self)
security = self.security.copy()
servers = self.servers[:] if self.servers else []
return doc(security, servers)
class SchemaParser:
"""Utility class for parsing schemas"""
def __init__(self, validate_docs: bool = False) -> None:
self.validate_docs = validate_docs
self.schemas_to_parse: Dict[str, type] = {}
def get_parameters(self, schema: Any, default_in: str = "path") -> List:
"""Extract parameters list from a dataclass schema"""
params = []
json_schema = self.dataclass2json(schema)
required = set(json_schema.get("required", ()))
for name, entry in json_schema["properties"].items():
entry = compact(
name=name,
description=entry.pop("description", None),
schema=entry,
required=name in required,
)
entry["in"] = default_in
params.append(entry)
return params
def field2json(
self, field_or_type: Union[Type, Field], validate: bool = True
) -> Dict[str, dict]:
"""Convert a dataclass field to Json schema"""
field = fields.as_field(field_or_type)
meta = field.metadata
if meta.get(fields.HIDDEN):
return {}
items = meta.get(fields.ITEMS)
json_property = self.get_schema_info(field.type, items=items)
field_description = meta.get(fields.DESCRIPTION)
if not field_description:
if self.validate_docs and validate:
raise InvalidSpecException(
f'Missing description for field "{field.name}"'
)
else:
json_property["description"] = field_description
fmt = meta.get(fields.FORMAT)
if fmt:
json_property[fields.FORMAT] = fmt
self.add_default(field, json_property)
validator = meta.get(fields.VALIDATOR)
# add additional parameters fields from validators
if isinstance(validator, fields.Validator):
validator.openapi(json_property)
return json_property
def dataclass2json(self, schema: Any) -> Dict[str, Any]:
"""Extract the object representation of a dataclass schema"""
type_info = cast(TypingInfo, TypingInfo.get(schema))
if not type_info or not type_info.is_dataclass:
raise InvalidSpecException(
"Schema must be a dataclass, got "
f"{type_info.element if type_info else None}"
)
properties = {}
required = []
for item in get_fields(type_info.element):
json_property = self.field2json(item)
field_required = json_property.pop("required", True)
if not json_property:
continue
if item.metadata.get(fields.REQUIRED, field_required):
required.append(item.name)
for name in fields.field_ops(item):
properties[name] = json_property
json_schema = {
"type": "object",
"description": trim_docstring(schema.__doc__ or ""),
"properties": properties,
"additionalProperties": False,
}
if required:
json_schema["required"] = required
return json_schema
schema2json = dataclass2json
# for backward compatibility
def get_schema_info(
self, schema: Any, items: Optional[Field] = None
) -> Dict[str, Any]:
type_info = cast(TypingInfo, TypingInfo.get(schema))
if type_info.container is list:
return {
"type": "array",
"items": {"type": "object", "additionalProperties": True}
if type_info.element is Any
else self.field2json(
fields.as_field(type_info.element, field=items), False
),
}
elif type_info.container is dict:
return {
"type": "object",
"additionalProperties": True
if type_info.element is Any
else self.field2json(
fields.as_field(type_info.element, field=items), False
),
}
elif type_info.is_union:
required = True
one_of = []
for e in type_info.element:
if e.is_none:
required = False
else:
one_of.append(self.get_schema_info(e))
info = one_of[0] if len(one_of) == 1 else {"oneOf": one_of}
info["required"] = required
return info
elif type_info.is_dataclass:
name = self.add_schema_to_parse(type_info.element)
return {"$ref": f"{SCHEMA_BASE_REF}{name}"}
else:
return self.get_primitive_info(type_info.element)
def get_primitive_info(self, schema: Type) -> Dict[str, Any]:
mapping = fields.PRIMITIVE_TYPES.get(schema)
if not mapping:
if is_subclass(schema, Enum):
enum_type = cast(Type[Enum], schema)
return {"type": "string", "enum": [e.name for e in enum_type]}
else:
raise InvalidTypeException(f"Cannot understand {schema} while parsing")
return dict(mapping)
def add_schema_to_parse(self, schema: type) -> str:
if not is_dataclass(schema):
raise InvalidTypeException(f"Schema must be a dataclass, got {schema}")
name = schema.__name__
self.schemas_to_parse[name] = schema
return name
def parsed_schemas(self) -> Dict[str, Dict]:
parsed: Dict[str, Dict] = {}
while self.schemas_to_parse:
to_parse = self.schemas_to_parse
self.schemas_to_parse = {}
parsed.update(
(
(name, self.dataclass2json(schema))
for name, schema in to_parse.items()
)
)
return parsed
def add_default(self, field: Field, json_property: Dict):
if field.default in EMPTY_DEFAULTS or field.metadata.get(fields.REQUIRED):
return
if type_info := TypingInfo.get(field.type):
default = field.default
if type_info.element not in fields.PRIMITIVE_TYPES:
if is_subclass(type_info.element, Enum) and isinstance(
default, type_info.element
):
default = default.name
else:
return
json_property["default"] = default
class SpecDoc(SchemaParser):
"""Build the OpenAPI Spec doc"""
def __init__(
self,
request: web.Request,
spec: OpenApiSpec,
public: bool = True,
private: bool = False,
) -> None:
super().__init__(spec.validate_docs)
self.request: web.Request = request
self.spec: OpenApiSpec = spec
self.logger = request.app.logger
self.public: bool = public
self.private: bool = private
self.parameters: Dict = {}
self.responses: Dict = {}
self.tags: Dict = {}
self.plugins: Dict = {}
self.doc: Dict = dict(
openapi=OPENAPI,
info=asdict(self.spec.info),
paths=OrderedDict(),
)
@property
def app(self) -> web.Application:
return self.request.app
def __call__(self, security: Dict, servers: List) -> Dict:
# Add errors schemas
self.add_schema_to_parse(ValidationErrors)
self.add_schema_to_parse(ErrorMessage)
self.add_schema_to_parse(FieldError)
self.doc["security"] = [
{name: value.pop("scopes", [])} for name, value in security.items()
]
# Build paths
self._build_paths()
s = self.parsed_schemas()
p = self.parameters
r = self.responses
doc = self.doc
doc.update(
compact(
tags=[self.tags[name] for name in sorted(self.tags)],
components=compact(
schemas=OrderedDict(((k, s[k]) for k in sorted(s))),
parameters=OrderedDict(((k, p[k]) for k in sorted(p))),
responses=OrderedDict(((k, r[k]) for k in sorted(r))),
securitySchemes=OrderedDict(
(((k, security[k]) for k in sorted(security)))
),
),
servers=servers,
)
)
if not doc.get("servers"):
# build the server info
doc["servers"] = [default_server(self.request)]
return doc
# Internals
def _build_paths(self) -> None:
"""Loop through app paths and add
schemas, parameters and paths objects to the spec
"""
paths = self.doc["paths"]
base_path = self.app["cli"].base_path
for route in self.spec.routes(self.request):
route_info = route.get_info()
path = route_info.get("path", route_info.get("formatter", None))
handler = route.handler
if is_subclass(handler, ApiPath) and self._include(handler.private):
N = len(base_path)
path = path[N:]
paths[path] = self._build_path_object(path, handler)
if self.validate_docs:
self._validate_tags()
def _validate_tags(self) -> None:
for tag_name, tag_obj in self.tags.items():
if self.spec.allowed_tags and tag_name not in self.spec.allowed_tags:
raise InvalidSpecException(f'Tag "{tag_name}" not allowed')
if "description" not in tag_obj:
raise InvalidSpecException(f'Missing tag "{tag_name}" description')
def _build_path_object(self, path: str, handler):
path_obj = load_yaml_from_docstring(handler.__doc__) or {}
doc_tags = path_obj.pop("tags", None)
if not doc_tags and self.validate_docs:
raise InvalidSpecException(f"Missing tags docstring for route '{path}'")
tags = self._extend_tags(doc_tags)
if handler.path_schema:
path_obj["parameters"] = self.get_parameters(handler.path_schema)
for method in METHODS:
try:
method_handler = getattr(handler, method, None)
if method_handler is None:
continue
operation = getattr(method_handler, "op", None)
if operation is None:
self.logger.warning(
"No operation defined for %s.%s", handler.__name__, method
)
continue
method_doc = load_yaml_from_docstring(method_handler.__doc__) or {}
if not self._include(method_doc.pop("private", self.private)):
continue
mtags = tags.copy()
mtags.update(self._extend_tags(method_doc.pop("tags", None)))
self._get_response_object(operation.response_schema, method_doc)
self._get_request_body_object(operation.body_schema, method_doc)
self._get_query_parameters(operation.query_schema, method_doc)
method_info = self._get_method_info(method_handler, method_doc)
method_doc.update(method_info)
method_doc["tags"] = list(mtags)
path_obj[method] = method_doc
except (InvalidSpecException, InvalidTypeException) as exc:
raise InvalidSpecException(
f"Invalid spec in route '{method} {path}': {exc}"
) from None
return path_obj
def _get_method_info(self, method_handler, method_doc):
summary = method_doc.get("summary", "")
description = method_doc.get("description", "")
if self.validate_docs:
if not summary:
raise InvalidSpecException(
f'Missing method summary for "{method_handler}"'
)
if not description:
raise InvalidSpecException(
f'Missing method description for "{method_handler}"'
)
return {"summary": summary, "description": description}
def _get_response_object(
self, type_info: Optional[TypingInfo], doc: Dict[str, str]
) -> None:
if type_info:
schema = self.get_schema_info(type_info)
responses = {}
for response, data in doc.get("responses", {}).items():
rschema = schema
if response >= 400:
rschema = self.get_schema_info(error_response_schema(response))
content = data.get("content", self.spec.default_content_type)
responses[response] = {
"description": data.get("description", ""),
"content": {content: {"schema": rschema}},
}
doc["responses"] = responses
def _get_request_body_object(
self, type_info: Optional[TypingInfo], doc: Dict[str, str]
) -> None:
if type_info:
content = doc.pop("body_content", self.spec.default_content_type)
schema = self.get_schema_info(type_info)
doc["requestBody"] = {"content": {content: {"schema": schema}}}
def _get_query_parameters(
self, type_info: Optional[TypingInfo], doc: Dict[str, str]
) -> None:
if type_info:
doc["parameters"] = self.get_parameters(type_info, "query")
def _extend_tags(self, tags):
names = set()
for tag in tags or ():
if isinstance(tag, str):
tag = {"name": tag}
name = tag.get("name")
if name:
if name not in self.tags:
self.tags[name] = tag
else:
self.tags[name].update(tag)
names.add(name)
return names
def _include(self, is_private: bool) -> bool:
return (is_private and self.private) or (not is_private and self.public)