"""
MIT License
Copyright (c) 2026-present SonoLink Development Team.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import TYPE_CHECKING, Any, Generic, Literal, overload
from typing_extensions import TypeVar
from sonolink import _registry
from sonolink._version import __version__
from sonolink.gateway.player import FrameworkLiteral, PlayerFactory
from sonolink.models.settings import CacheSettings, InactivitySettings
from sonolink.rest.enums import TrackSourceType
from ..node import Node
from ._base import DiscordClient
from ._factory import ClientFactory
if TYPE_CHECKING:
from sonolink.models.responses import SearchResult
from sonolink.models.track import Playable
from sonolink.network import SessionType
from .adapters._disnake import DisnakeClientProto
from .adapters._dpy import DpyClientProto
from .adapters._pycord import PycordClientProto
__all__ = ("Client",)
_log = logging.getLogger(__name__)
N = TypeVar("N", bound=Node, default=Node)
[docs]
class Client(Generic[N]):
"""
Represents a SonoLink client.
A client helps you manage all Node connections and players.
Parameters
----------
client: :class:`discord:discord.Client` (discord.py) | :class:`pycord:discord.Client` (py-cord) | :class:`disnake:disnake.Client`
The Discord client this SonoLink client is attached to.
node_cls: ``type[Node]``
The class to use when creating new nodes. Defaults to :class:`Node`.
framework: :class:`str` | :data:`None`
The Discord framework to use. Accepted values are ``"discord.py"``,
``"pycord"``, and ``"disnake"``. When ``None``, the framework is
detected automatically from whichever library is installed; if multiple
are present, precedence follows ``discord.py`` → ``pycord`` → ``disnake``.
Defaults to ``None``.
.. warning::
If you are using a custom :class:`~sonolink.Player` subclass, ensure it is defined **after**
constructing the :class:`Client`, otherwise the framework adapter may not be resolved correctly.
Alternatively, set the ``SONOLINK_FRAMEWORK`` environment variable before any imports to
force a specific framework ahead of time.
"""
_framework: FrameworkLiteral
_nodes: dict[str, N]
_session: SessionType | None
_node_tasks: dict[str, asyncio.Task[Any]]
@overload
def __init__(
self,
client: DpyClientProto,
*,
node_cls: type[N] = ...,
framework: Literal["discord.py"] = ...,
) -> None: ...
@overload
def __init__(
self,
client: PycordClientProto,
*,
node_cls: type[N] = ...,
framework: Literal["pycord"] = ...,
) -> None: ...
@overload
def __init__(
self,
client: DisnakeClientProto,
*,
node_cls: type[N] = ...,
framework: Literal["disnake"] = ...,
) -> None: ...
@overload
def __init__(
self,
client: Any,
*,
node_cls: type[N] = ...,
framework: None = ...,
) -> None: ...
def __init__(
self,
client: Any,
*,
node_cls: type[N] = Node,
framework: FrameworkLiteral | None = None,
) -> None:
if framework is None:
framework = PlayerFactory().detect_framework() or "discord.py"
os.environ["SONOLINK_FRAMEWORK"] = framework
self._client: DiscordClient[Any] = ClientFactory.create(client, framework)
self._framework = framework
self._nodes = {}
self._session = None
self._node_tasks = {}
self._node_cls: type[N] = node_cls
if self._client._client in _registry.clients:
raise RuntimeError(
f"sonolink.Client already attached to this {framework}.Client"
)
_registry.clients[self._client._client] = self
def __repr__(self) -> str:
return f"<sonolink.Client nodes={len(self._nodes)}>"
@property
def nodes(self) -> list[N]:
"""The active nodes attached to this client."""
return list(self._nodes.values())
@property
def framework(self) -> FrameworkLiteral:
"""
The Discord framework used by this client
(``"discord.py"``, ``"disnake"``, or ``"pycord"``).
"""
return self._framework
[docs]
def create_node(
self,
*,
uri: str,
password: str,
id: str | None = None,
retries: int | None = None,
resume_timeout: float = 60,
cache_settings: CacheSettings | None = None,
inactivity_settings: InactivitySettings | None = None,
session: SessionType | None = None,
) -> N:
"""
Creates a :class:`Node` attached to this client.
Parameters
----------
uri: :class:`str`
The URI the node will connect to. You should only provide the base URI without
any routes, as the library will do it for you.
password: :class:`str`
The password of the node.
id: :class:`str` | :data:`None`
The ID of this node. This is used internally to identify this node. If ``None`` is passed, it is
generated automatically.
retries: :class:`int` | :data:`None`
The amount of retries to attempt when connecting or reconnecting this node. Whenever the limit
is reached, it closes the node automatically. If this is set to ``None``, it retries indefinetely.
Defaults to ``None``.
resume_timeout: :class:`int`
The maximum amount of seconds a resume can take before closing the node. Defaults to ``60``.
cache_settings: :class:`CacheSettings` | :data:`None`
The search result caching configuration.
Defaults to ``CacheSettings.default()``.
inactivity_settings: :class:`InactivitySettings` | :data:`None`
The inactivity configuration for all players connected to this node.
If ``None`` is passed, it uses ``InactivitySettings.default()``.
session: ``aiohttp.ClientSession`` | ``curl_cffi.AsyncSession`` | :data:`None`
The session this node should use. If ``None`` is provided, creates one. Defaults to ``None``.
Returns
-------
:class:`Node`
The node that was created.
"""
i_settings = inactivity_settings or InactivitySettings.default()
c_settings = cache_settings or CacheSettings.default()
node = self._node_cls(
client=self,
uri=uri,
password=password,
id=id,
retries=retries,
resume_timeout=resume_timeout,
cache_settings=c_settings,
inactivity_settings=i_settings,
session=session,
)
self._nodes[node.id] = node
return node
[docs]
def remove_node(self, identifier: str, /) -> None:
"""
Removes a Node from this client.
Parameters
----------
identifier: :class:`str`
The ID of the node to remove.
"""
try:
node = self._nodes.pop(identifier)
except KeyError:
pass
else:
self._cleanup_node(node)
[docs]
def clear_nodes(self) -> None:
"""Clears all Nodes from this Client."""
for node in self.nodes:
self.remove_node(node.id)
[docs]
async def start(self) -> None:
"""
Connects all registered nodes to their respective Lavalink servers.
This method should typically be called after the discord client is logged in,
often within the ``on_ready`` event.
"""
if not self._nodes:
return
for node in self.nodes:
if node.is_connected:
continue
try:
await node.connect()
except Exception as exc:
_log.exception(
"Ignoring exception while connecting Node %r", node, exc_info=exc
)
continue
[docs]
async def close(self) -> None:
"""
Gracefully closes all :class:`Node` connections and cleans up internal resources.
This will stop all active players and close the underlying websocket and HTTP sessions.
"""
for node in self.nodes:
if not node.is_connected:
continue
try:
await node.close()
except Exception as exc:
_log.exception(
"Ignoring exception while closing Node %r", node, exc_info=exc
)
continue
self._nodes.clear()
[docs]
def get_best_node(self) -> N:
"""
Returns the best available :class:`Node` based on current load and connectivity.
Returns
-------
:class:`Node`
The node with the lowest penalty that is currently connected.
Raises
------
RuntimeError
No nodes are currently connected to handle the request.
"""
connected_nodes = [node for node in self.nodes if node.is_connected]
if not connected_nodes:
raise RuntimeError("No nodes are currently connected.")
return min(
connected_nodes,
key=lambda node: node.stats.penalty if node.stats else 0.0,
)
[docs]
async def search_track(
self,
query: str,
*,
source: TrackSourceType | str = TrackSourceType.YOUTUBE,
) -> SearchResult:
"""
Searches for ``query`` in the best Node available, obtained with :meth:`Client.get_best_node`.
Parameters
----------
query: :class:`str`
The query to search. This can be a full URL, or headed by hosts specified by any plugin.
source: :class:`TrackSourceType` | :class:`str`
The source to search from. This is, essentially, providing a host to ``query``. The library
provides default source types under :class:`TrackSourceType`, but custom ones can be passed
with a raw string.
Returns
-------
:class:`SearchResult`
The search result.
"""
node = self.get_best_node()
return await node.search_track(query, source=source)
[docs]
async def decode_track(self, encoded: str) -> Playable:
"""
Decodes a track from its encoded data using the best Node available, obtained with
:meth:`Client.get_best_node`.
When a track is fetched, the encoded data can be found under
:attr:`sonolink.rest.schemas.Track.encoded`.
Parameters
----------
encoded: :class:`str`
The encoded data to resolve the track from.
Returns
-------
:class:`sonolink.models.Playable`
The decoded resolved track.
"""
node = self.get_best_node()
return await node.decode_track(encoded)
[docs]
async def decode_tracks(self, *encoded: str) -> list[Playable]:
"""
Bulk decode encoded tracks using the best Node available, obtained with :meth:`Client.get_best_node`.
Parameters
----------
*encoded: :class:`str`
The encoded data for each track to be decoded.
Returns
-------
``list[Playable]``
The decoded resolved tracks.
"""
node = self.get_best_node()
return await node.decode_tracks(*encoded)
def _cleanup_node(self, node: N) -> asyncio.Task[None]:
if node.id in self._node_tasks:
return self._node_tasks[node.id]
task = asyncio.create_task(node.close(), name=f"sonolink:node-close:{node.id}")
self._node_tasks[node.id] = task
task.add_done_callback(lambda _: self._node_tasks.pop(node.id, None))
return task
def _dispatch(self, event: str, *args: Any, **kwargs: Any) -> None:
self._client.dispatch(f"sonolink_{event}", *args, **kwargs)
def _build_ws_headers(self) -> dict[str, str]:
if self._client.user is None:
raise RuntimeError(
"Cannot connect Nodes without the underlying client running."
)
return {
"User-Id": str(self._client.user.id),
"Client-Name": f"sonolink/{__version__}",
}