Source code for openapi.ws.channels
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
from openapi.data.validate import ValidationErrors
from .channel import CallbackType, Channel
from .errors import CannotSubscribe
if TYPE_CHECKING: # pragma: no cover
from .manager import SocketsManager
[docs]class Channels:
"""Manage channels for publish/subscribe"""
def __init__(self, sockets: "SocketsManager") -> None:
self.sockets: "SocketsManager" = sockets
self._channels: Dict[str, Channel] = {}
@property
def registered(self) -> Tuple[str, ...]:
"""Registered channels"""
return tuple(self._channels)
def __len__(self) -> int:
return len(self._channels)
def __contains__(self, channel_name: str) -> bool:
return channel_name in self._channels
def __iter__(self) -> Iterator[Channel]:
return iter(self._channels.values())
def clear(self) -> None:
self._channels.clear()
def get(self, channel_name: str) -> Optional[Channel]:
return self._channels.get(channel_name)
def info(self) -> Dict:
return {channel.name: channel.info() for channel in self}
async def __call__(self, channel_name: str, message: Dict) -> None:
"""Channel callback"""
channel = self.get(channel_name)
if channel:
closed = await channel(message)
for websocket in closed:
for channel_name, channel in tuple(self._channels.items()):
channel.remove_callback(websocket)
await self._maybe_remove_channel(channel)
[docs] async def register(
self, channel_name: str, event_name: str, callback: CallbackType
) -> Channel:
"""Register a callback
:param channel_name: name of the channel
:param event_name: name of the event in the channel or a pattern
:param callback: the callback to invoke when the `event` on `channel` occurs
"""
channel = self.get(channel_name)
if channel is None:
try:
await self.sockets.subscribe(channel_name)
except CannotSubscribe:
raise ValidationErrors(dict(channel="Invalid channel"))
else:
channel = Channel(channel_name)
self._channels[channel_name] = channel
event = channel.register(event_name, callback)
await self.sockets.subscribe_to_event(channel.name, event.name)
return channel
[docs] async def unregister(
self, channel_name: str, event: str, callback: CallbackType
) -> Optional[Channel]:
"""Safely unregister a callback from the list of event
callbacks for channel_name
"""
channel = self.get(channel_name)
if channel is None:
raise ValidationErrors(dict(channel="Invalid channel"))
channel.unregister(event, callback)
return await self._maybe_remove_channel(channel)
async def _maybe_remove_channel(self, channel: Channel) -> Channel:
if not channel:
await self.sockets.unsubscribe(channel.name)
self._channels.pop(channel.name)
return channel
def get_subscribed(self, callback: CallbackType) -> Dict[str, List[str]]:
subscribed = {}
for channel in self:
events = channel.get_subscribed(callback)
if events:
subscribed[channel.name] = events
return subscribed