Source code for openapi.ws.path

import hashlib
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict

from aiohttp import web

from openapi.ws.channels import Channels

from .. import json
from ..data.validate import ValidationErrors, validated_schema
from ..utils import compact
from .errors import CONNECTION_ERRORS
from .manager import SocketsManager, Websocket

logger = logging.getLogger("openapi.ws")


@dataclass
class RpcProtocol:
    id: str
    method: str
    payload: Dict = field(default_factory=dict)


class ProtocolError(RuntimeError):
    pass


[docs]class WsPathMixin(Websocket): """Api Path mixin for Websocket RPC protocol""" SOCKETS_KEY = "web_sockets" """Key in the app where the Web Sockets manager is located""" @property def sockets(self) -> SocketsManager: """Connected websockets""" return self.request.app[self.SOCKETS_KEY] @property def channels(self) -> Channels: """Channels for pub/sub""" return self.sockets.channels async def get(self): response = web.WebSocketResponse() available = response.can_prepare(self.request) if not available: raise web.HTTPBadRequest( **self.api_response_data( {"message": "Unable to open websocket connection"} ) ) await response.prepare(self.request) self.response = response self.started = time.time() key = "%s - %s" % (self.request.remote, self.started) self.socket_id = hashlib.sha224(key.encode("utf-8")).hexdigest() # # Add to set of sockets if available self.sockets.add(self) # try: async for msg in response: if msg.type == web.WSMsgType.TEXT: await self.on_message(msg) except CONNECTION_ERRORS: logger.info("lost connection with websocket %s", self) finally: self.sockets.remove(self) return response
[docs] def decode_message(self, msg: str) -> Any: """Decode JSON string message, override for different protocol""" try: return json.loads(msg) except json.JSONDecodeError: raise ProtocolError("JSON string expected") from None
[docs] def encode_message(self, msg: Any) -> str: """Encode as JSON string message, override for different protocol""" try: return json.dumps(msg) except TypeError: raise ProtocolError("JSON object expected") from None
async def on_message(self, msg): id_ = None rpc = None try: data = self.decode_message(msg.data) if not isinstance(data, dict): raise ProtocolError( "Malformed message; expected dictionary, " f"got {type(data).__name__}" ) id_ = data.get("id") rpc = validated_schema(RpcProtocol, data) method = getattr(self, f"ws_rpc_{rpc.method}", None) if not method: raise ValidationErrors( dict(method=f"{rpc.method} method not available") ) response = await method(rpc.payload or {}) await self.write(dict(id=rpc.id, method=rpc.method, response=response)) except ProtocolError as exc: logger.error("Protocol error: %s", exc) await self.error_message( str(exc), id=id_, method=rpc.method if rpc else None ) except ValidationErrors as exc: await self.error_message( "Invalid RPC parameters", errors=exc.errors, id=id_, method=rpc.method if rpc else None, ) async def error_message(self, message, *, errors=None, **kw): error = dict(message=message) if errors: error["errors"] = errors await self.write(compact(error=error, **kw)) async def write(self, msg: Dict) -> None: text = self.encode_message(msg) await self.response.send_str(text) async def close(self) -> None: await self.response.close() self.sockets.remove(self)