Source code for openapi.ws.channel

from __future__ import annotations

import asyncio
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable, Dict, Sequence, Set

from .errors import ChannelCallbackError
from .utils import redis_to_py_pattern

logger = logging.getLogger("trading.websocket")


@dataclass
class Event:
    name: str
    pattern: str
    regex: Any
    callbacks: Set[CallbackType] = field(default_factory=set)


CallbackType = Callable[[str, str, Any], Awaitable[Any]]


[docs]@dataclass class Channel: """A websocket channel""" name: str _events: Dict[str, Event] = field(default_factory=dict) @property def events(self): """List of event names this channel is registered with""" return tuple((e.name for e in self._events.values())) def __len__(self) -> int: return len(self._events) def __contains__(self, pattern: str) -> bool: return pattern in self._events def __iter__(self): return iter(self._events) def info(self) -> Dict: return {e.name: len(e.callbacks) for e in self._events.values()} async def __call__(self, message: Dict[str, Any]) -> Sequence[CallbackType]: """Execute callbacks from a new message Return callbacks which have raise WebsocketClosed or have raise an exception """ event_name = message.get("event") or "" data = message.get("data") for event in tuple(self._events.values()): match = event.regex.match(event_name) if match: match = match.group() results = await asyncio.gather( *[ self._execute_callback(callback, event, match, data) for callback in event.callbacks ] ) return tuple(c for c in results if c) return ()
[docs] def register(self, event_name: str, callback: CallbackType): """Register a ``callback`` for ``event_name``""" event_name = event_name or "*" pattern = self.event_pattern(event_name) event = self._events.get(pattern) if not event: event = Event(name=event_name, pattern=pattern, regex=re.compile(pattern)) self._events[event.pattern] = event event.callbacks.add(callback) return event
def get_subscribed(self, callback: CallbackType): events = [] for event in self._events.values(): if callback in event.callbacks: events.append(event.name) return events def unregister(self, event_name: str, callback: CallbackType): pattern = self.event_pattern(event_name) event = self._events.get(pattern) if event: return self.remove_event_callback(event, callback)
[docs] def event_pattern(self, event): """Channel pattern for an event name""" return redis_to_py_pattern(event or "*")
def remove_callback(self, callback: CallbackType) -> None: for key, event in tuple(self._events.items()): self.remove_event_callback(event, callback) def remove_event_callback(self, event: Event, callback: CallbackType) -> None: event.callbacks.discard(callback) if not event.callbacks: self._events.pop(event.pattern) async def _execute_callback( self, callback: CallbackType, event: Event, match: str, data: Any ) -> Any: try: await callback(self.name, match, data) except ChannelCallbackError: return callback except Exception: logger.exception('callback exception: channel "%s" event "%s"', self, event) return callback