Source code for slipstream.utils

"""Slipstream utilities."""

from asyncio import Queue
from inspect import iscoroutinefunction, signature
from typing import (
    Any,
    AsyncIterator,
    Awaitable,
    Callable,
    TypeAlias,
)

[docs] AsyncCallable: TypeAlias = Callable[..., Awaitable[Any]] | Callable[..., Any]
[docs] def iscoroutinecallable(o: Any) -> bool: """Check whether function is coroutine.""" return iscoroutinefunction(o) or ( hasattr(o, '__call__') and iscoroutinefunction(o.__call__) )
[docs] def get_param_names(o: Any): """Return function parameter names.""" params = signature(o).parameters return tuple(params.keys())
[docs] class Singleton(type): """Maintain a single instance of a class.""" _instances: dict['Singleton', Any] = {} def __call__(cls, *args: Any, **kwargs: Any): """Apply metaclass singleton action.""" if cls not in cls._instances: cls._instances[cls] = super( Singleton, cls ).__call__(*args, **kwargs) instance = cls._instances[cls] if hasattr(instance, '__update__'): instance.__update__(*args, **kwargs) return instance
[docs] class PubSub(metaclass=Singleton): """Singleton publish subscribe pattern class.""" _topics: dict[str, list[AsyncCallable]] = {}
[docs] def subscribe(self, topic: str, listener: AsyncCallable) -> None: """Subscribe callable to topic.""" if topic not in self._topics: self._topics[topic] = [] self._topics[topic].append(listener)
[docs] def unsubscribe(self, topic: str, listener: AsyncCallable) -> None: """Unsubscribe callable from topic.""" if topic in self._topics: self._topics[topic].remove(listener) if not self._topics[topic]: del self._topics[topic]
[docs] def publish( self, topic: str, *args: Any, **kwargs: Any ) -> None: """Publish message to subscribers of topic.""" if topic not in self._topics: return for listener in self._topics[topic]: listener(*args, **kwargs)
[docs] async def apublish( self, topic: str, *args: Any, **kwargs: Any ) -> None: """Publish message to subscribers of topic.""" if topic not in self._topics: return for listener in self._topics[topic]: if iscoroutinecallable(listener): await listener(*args, **kwargs) else: listener(*args, **kwargs)
[docs] async def iter_topic(self, topic: str) -> AsyncIterator[Any]: """Asynchronously iterate over messages published to a topic.""" queue: Queue[Any] = Queue() self.subscribe(topic, queue.put_nowait) try: while True: yield await queue.get() finally: self.unsubscribe(topic, queue.put_nowait)