twilight_lavalink/
node.rs

1//! Nodes for communicating with a Lavalink server.
2//!
3//! Using nodes, you can send events to a server and receive events.
4//!
5//! This is a bit more low level than using the [`Lavalink`] client because you
6//! will need to provide your own `VoiceUpdate` events when your bot joins
7//! channels, meaning you will have to accumulate and combine voice state update
8//! and voice server update events from the Discord gateway to send them to
9//! a node.
10//!
11//! Additionally, you will have to create and manage your own [`PlayerManager`]
12//! and make your own players for guilds when your bot joins voice channels.
13//!
14//! This can be a lot of work, and there's not really much reason to do it
15//! yourself. For that reason, you should almost always use the `Lavalink`
16//! client which does all of this for you.
17//!
18//! [`Lavalink`]: crate::client::Lavalink
19
20use crate::{
21    model::{IncomingEvent, Opcode, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22    player::PlayerManager,
23};
24use futures_util::{
25    lock::BiLock,
26    sink::SinkExt,
27    stream::{Stream, StreamExt},
28};
29use http::{header::HeaderName, Request, Response, StatusCode};
30use std::{
31    error::Error,
32    fmt::{Debug, Display, Formatter, Result as FmtResult},
33    net::SocketAddr,
34    pin::Pin,
35    task::{Context, Poll},
36    time::Duration,
37};
38use tokio::{
39    net::TcpStream,
40    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
41    time as tokio_time,
42};
43use tokio_tungstenite::{
44    tungstenite::{client::IntoClientRequest, Error as TungsteniteError, Message},
45    MaybeTlsStream, WebSocketStream,
46};
47use twilight_model::id::{marker::UserMarker, Id};
48
49/// An error occurred while either initializing a connection or while running
50/// its event loop.
51#[derive(Debug)]
52pub struct NodeError {
53    kind: NodeErrorType,
54    source: Option<Box<dyn Error + Send + Sync>>,
55}
56
57impl NodeError {
58    /// Immutable reference to the type of error that occurred.
59    #[must_use = "retrieving the type has no effect if left unused"]
60    pub const fn kind(&self) -> &NodeErrorType {
61        &self.kind
62    }
63
64    /// Consume the error, returning the source error if there is any.
65    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
66    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
67        self.source
68    }
69
70    /// Consume the error, returning the owned error type and the source error.
71    #[must_use = "consuming the error into its parts has no effect if left unused"]
72    pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
73        (self.kind, self.source)
74    }
75}
76
77impl Display for NodeError {
78    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
79        match &self.kind {
80            NodeErrorType::BuildingConnectionRequest { .. } => {
81                f.write_str("failed to build connection request")
82            }
83            NodeErrorType::Connecting { .. } => f.write_str("Failed to connect to the node"),
84            NodeErrorType::SerializingMessage { .. } => {
85                f.write_str("failed to serialize outgoing message as json")
86            }
87            NodeErrorType::Unauthorized { address, .. } => {
88                f.write_str("the authorization used to connect to node ")?;
89                Display::fmt(address, f)?;
90
91                f.write_str(" is invalid")
92            }
93        }
94    }
95}
96
97impl Error for NodeError {
98    fn source(&self) -> Option<&(dyn Error + 'static)> {
99        self.source
100            .as_ref()
101            .map(|source| &**source as &(dyn Error + 'static))
102    }
103}
104
105/// Type of [`NodeError`] that occurred.
106#[derive(Debug)]
107#[non_exhaustive]
108pub enum NodeErrorType {
109    /// Building the HTTP request to initialize a connection failed.
110    BuildingConnectionRequest,
111    /// Connecting to the Lavalink server failed after several backoff attempts.
112    Connecting,
113    /// Serializing a JSON message to be sent to a Lavalink node failed.
114    SerializingMessage {
115        /// The message that couldn't be serialized.
116        message: OutgoingEvent,
117    },
118    /// The given authorization for the node is incorrect.
119    Unauthorized {
120        /// The address of the node that failed to authorize.
121        address: SocketAddr,
122        /// The authorization used to connect to the node.
123        authorization: String,
124    },
125}
126
127/// An error that can occur while sending an event over a node.
128#[derive(Debug)]
129pub struct NodeSenderError {
130    kind: NodeSenderErrorType,
131    source: Option<Box<dyn Error + Send + Sync>>,
132}
133
134impl NodeSenderError {
135    /// Immutable reference to the type of error that occurred.
136    pub const fn kind(&self) -> &NodeSenderErrorType {
137        &self.kind
138    }
139
140    /// Consume the error, returning the source error if there is any.
141    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
142        self.source
143    }
144
145    /// Consume the error, returning the owned error type and the source error.
146    #[must_use = "consuming the error into its parts has no effect if left unused"]
147    pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
148        (self.kind, self.source)
149    }
150}
151
152impl Display for NodeSenderError {
153    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
154        match &self.kind {
155            NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
156        }
157    }
158}
159
160impl Error for NodeSenderError {
161    fn source(&self) -> Option<&(dyn Error + 'static)> {
162        self.source
163            .as_ref()
164            .map(|source| &**source as &(dyn Error + 'static))
165    }
166}
167
168/// Type of [`NodeSenderError`] that occurred.
169#[derive(Debug)]
170#[non_exhaustive]
171pub enum NodeSenderErrorType {
172    /// Error occurred while sending over the channel.
173    Sending,
174}
175
176/// Stream of incoming events from a node.
177pub struct IncomingEvents {
178    inner: UnboundedReceiver<IncomingEvent>,
179}
180
181impl IncomingEvents {
182    /// Closes the receiving half of a channel without dropping it.
183    pub fn close(&mut self) {
184        self.inner.close();
185    }
186}
187
188impl Stream for IncomingEvents {
189    type Item = IncomingEvent;
190
191    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192        self.inner.poll_recv(cx)
193    }
194}
195
196/// Send outgoing events to the associated node.
197pub struct NodeSender {
198    inner: UnboundedSender<OutgoingEvent>,
199}
200
201impl NodeSender {
202    /// Returns whether this channel is closed without needing a context.
203    pub fn is_closed(&self) -> bool {
204        self.inner.is_closed()
205    }
206
207    /// Sends a message along this channel.
208    ///
209    /// This is an unbounded sender, so this function differs from `Sink::send`
210    /// by ensuring the return type reflects that the channel is always ready to
211    /// receive messages.
212    ///
213    /// # Errors
214    ///
215    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
216    /// longer connected.
217    pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
218        self.inner.send(msg).map_err(|source| NodeSenderError {
219            kind: NodeSenderErrorType::Sending,
220            source: Some(Box::new(source)),
221        })
222    }
223}
224
225/// The configuration that a [`Node`] uses to connect to a Lavalink server.
226#[derive(Clone, Eq, PartialEq)]
227#[non_exhaustive]
228// Keep fields in sync with its Debug implementation.
229pub struct NodeConfig {
230    /// The address of the node.
231    pub address: SocketAddr,
232    /// The password to use when authenticating.
233    pub authorization: String,
234    /// The details for resuming a Lavalink session, if any.
235    ///
236    /// Set this to `None` to disable resume capability.
237    pub resume: Option<Resume>,
238    /// The user ID of the bot.
239    pub user_id: Id<UserMarker>,
240}
241
242impl Debug for NodeConfig {
243    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
244        /// Debug as `<redacted>`. Necessary because debugging a struct field
245        /// with a value of of `"<redacted>"` will insert quotations in the
246        /// string, which doesn't align with other token debugs.
247        struct Redacted;
248
249        impl Debug for Redacted {
250            fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
251                f.write_str("<redacted>")
252            }
253        }
254
255        f.debug_struct("NodeConfig")
256            .field("address", &self.address)
257            .field("authorization", &Redacted)
258            .field("resume", &self.resume)
259            .field("user_id", &self.user_id)
260            .finish()
261    }
262}
263
264/// Configuration for a session which can be resumed.
265#[derive(Clone, Debug, Eq, PartialEq)]
266#[non_exhaustive]
267pub struct Resume {
268    /// The number of seconds that the Lavalink server will allow the session to
269    /// be resumed for after a disconnect.
270    ///
271    /// The default is 60.
272    pub timeout: u64,
273}
274
275impl Resume {
276    /// Configure resume capability, providing the number of seconds that the
277    /// Lavalink server should queue events for when the connection is resumed.
278    pub const fn new(seconds: u64) -> Self {
279        Self { timeout: seconds }
280    }
281}
282
283impl Default for Resume {
284    fn default() -> Self {
285        Self { timeout: 60 }
286    }
287}
288
289impl NodeConfig {
290    /// Create a new configuration for connecting to a node via
291    /// [`Node::connect`].
292    ///
293    /// If adding a node through the [`Lavalink`] client then you don't need to
294    /// do this yourself.
295    ///
296    /// [`Lavalink`]: crate::client::Lavalink
297    pub fn new(
298        user_id: Id<UserMarker>,
299        address: impl Into<SocketAddr>,
300        authorization: impl Into<String>,
301        resume: impl Into<Option<Resume>>,
302    ) -> Self {
303        Self::_new(user_id, address.into(), authorization.into(), resume.into())
304    }
305
306    const fn _new(
307        user_id: Id<UserMarker>,
308        address: SocketAddr,
309        authorization: String,
310        resume: Option<Resume>,
311    ) -> Self {
312        Self {
313            address,
314            authorization,
315            resume,
316            user_id,
317        }
318    }
319}
320
321/// A connection to a single Lavalink server. It receives events and forwards
322/// events from players to the server.
323///
324/// Please refer to the [module] documentation.
325///
326/// [module]: crate
327#[derive(Debug)]
328pub struct Node {
329    config: NodeConfig,
330    lavalink_tx: UnboundedSender<OutgoingEvent>,
331    players: PlayerManager,
332    stats: BiLock<Stats>,
333}
334
335impl Node {
336    /// Connect to a node, providing a player manager so that the node can
337    /// update player details.
338    ///
339    /// Please refer to the [module] documentation for some additional
340    /// information about directly creating and using nodes. You are encouraged
341    /// to use the [`Lavalink`] client instead.
342    ///
343    /// [`Lavalink`]: crate::client::Lavalink
344    /// [module]: crate
345    ///
346    /// # Errors
347    ///
348    /// Returns an error of type [`Connecting`] if the connection fails after
349    /// several backoff attempts.
350    ///
351    /// Returns an error of type [`BuildingConnectionRequest`] if the request
352    /// failed to build.
353    ///
354    /// Returns an error of type [`Unauthorized`] if the supplied authorization
355    /// is rejected by the node.
356    ///
357    /// [`Connecting`]: crate::node::NodeErrorType::Connecting
358    /// [`BuildingConnectionRequest`]: crate::node::NodeErrorType::BuildingConnectionRequest
359    /// [`Unauthorized`]: crate::node::NodeErrorType::Unauthorized
360    pub async fn connect(
361        config: NodeConfig,
362        players: PlayerManager,
363    ) -> Result<(Self, IncomingEvents), NodeError> {
364        let (bilock_left, bilock_right) = BiLock::new(Stats {
365            cpu: StatsCpu {
366                cores: 0,
367                lavalink_load: 0f64,
368                system_load: 0f64,
369            },
370            frames: None,
371            memory: StatsMemory {
372                allocated: 0,
373                free: 0,
374                used: 0,
375                reservable: 0,
376            },
377            players: 0,
378            playing_players: 0,
379            op: Opcode::Stats,
380            uptime: 0,
381        });
382
383        tracing::debug!("starting connection to {}", config.address);
384
385        let (conn_loop, lavalink_tx, lavalink_rx) =
386            Connection::connect(config.clone(), players.clone(), bilock_right).await?;
387
388        tracing::debug!("started connection to {}", config.address);
389
390        tokio::spawn(conn_loop.run());
391
392        Ok((
393            Self {
394                config,
395                lavalink_tx,
396                players,
397                stats: bilock_left,
398            },
399            IncomingEvents { inner: lavalink_rx },
400        ))
401    }
402
403    /// Retrieve an immutable reference to the node's configuration.
404    pub const fn config(&self) -> &NodeConfig {
405        &self.config
406    }
407
408    /// Retrieve an immutable reference to the player manager used by the node.
409    pub const fn players(&self) -> &PlayerManager {
410        &self.players
411    }
412
413    /// Retrieve an immutable reference to the node's configuration.
414    ///
415    /// Note that sending player events through the node's sender won't update
416    /// player states, such as whether it's paused.
417    ///
418    /// # Errors
419    ///
420    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
421    /// longer connected.
422    pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
423        self.sender().send(event)
424    }
425
426    /// Retrieve a unique sender to send events to the Lavalink server.
427    ///
428    /// Note that sending player events through the node's sender won't update
429    /// player states, such as whether it's paused.
430    pub fn sender(&self) -> NodeSender {
431        NodeSender {
432            inner: self.lavalink_tx.clone(),
433        }
434    }
435
436    /// Retrieve a copy of the node's stats.
437    pub async fn stats(&self) -> Stats {
438        (*self.stats.lock().await).clone()
439    }
440
441    /// Retrieve the calculated penalty score of the node.
442    ///
443    /// This score can be used to calculate how loaded the server is. A higher
444    /// number means it is more heavily loaded.
445    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
446    pub async fn penalty(&self) -> i32 {
447        let stats = self.stats.lock().await;
448        let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
449
450        let (deficit_frame, null_frame) = (
451            1.03f64
452                .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64))
453                * 300f64
454                - 300f64,
455            (1.03f64
456                .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64))
457                * 300f64
458                - 300f64)
459                * 2f64,
460        );
461
462        stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
463    }
464}
465
466struct Connection {
467    config: NodeConfig,
468    connection: WebSocketStream<MaybeTlsStream<TcpStream>>,
469    node_from: UnboundedReceiver<OutgoingEvent>,
470    node_to: UnboundedSender<IncomingEvent>,
471    players: PlayerManager,
472    stats: BiLock<Stats>,
473}
474
475impl Connection {
476    async fn connect(
477        config: NodeConfig,
478        players: PlayerManager,
479        stats: BiLock<Stats>,
480    ) -> Result<
481        (
482            Self,
483            UnboundedSender<OutgoingEvent>,
484            UnboundedReceiver<IncomingEvent>,
485        ),
486        NodeError,
487    > {
488        let connection = reconnect(&config).await?;
489
490        let (to_node, from_lavalink) = mpsc::unbounded_channel();
491        let (to_lavalink, from_node) = mpsc::unbounded_channel();
492
493        Ok((
494            Self {
495                config,
496                connection,
497                node_from: from_node,
498                node_to: to_node,
499                players,
500                stats,
501            },
502            to_lavalink,
503            from_lavalink,
504        ))
505    }
506
507    async fn run(mut self) -> Result<(), NodeError> {
508        loop {
509            tokio::select! {
510                incoming = self.connection.next() => {
511                    if let Some(Ok(incoming)) = incoming {
512                        self.incoming(incoming).await?;
513                    } else {
514                        tracing::debug!("connection to {} closed, reconnecting", self.config.address);
515                        self.connection = reconnect(&self.config).await?;
516                    }
517                }
518                outgoing = self.node_from.recv() => {
519                    if let Some(outgoing) = outgoing {
520                        tracing::debug!(
521                            "forwarding event to {}: {outgoing:?}",
522                            self.config.address,
523                        );
524
525                        let payload = serde_json::to_string(&outgoing).map_err(|source| NodeError {
526                            kind: NodeErrorType::SerializingMessage { message: outgoing },
527                            source: Some(Box::new(source)),
528                        })?;
529                        let msg = Message::Text(payload);
530                        self.connection.send(msg).await.unwrap();
531                    } else {
532                        tracing::debug!("node {} closed, ending connection", self.config.address);
533
534                        break;
535                    }
536                }
537            }
538        }
539
540        Ok(())
541    }
542
543    async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
544        tracing::debug!(
545            "received message from {}: {incoming:?}",
546            self.config.address,
547        );
548
549        let text = match incoming {
550            Message::Close(_) => {
551                tracing::debug!("got close, closing connection");
552                let _result = self.connection.send(Message::Close(None)).await;
553
554                return Ok(false);
555            }
556            Message::Ping(data) => {
557                tracing::debug!("got ping, sending pong");
558                let msg = Message::Pong(data);
559
560                // We don't need to immediately care if a pong fails.
561                let _result = self.connection.send(msg).await;
562
563                return Ok(true);
564            }
565            Message::Text(text) => text,
566            other => {
567                tracing::debug!("got pong or bytes payload: {other:?}");
568
569                return Ok(true);
570            }
571        };
572
573        let Ok(event) = serde_json::from_str(&text) else {
574            tracing::warn!("unknown message from lavalink node: {text}");
575
576            return Ok(true);
577        };
578
579        match &event {
580            IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
581            IncomingEvent::Stats(stats) => self.stats(stats).await?,
582            _ => {}
583        }
584
585        // It's fine if the rx end dropped, often users don't need to care about
586        // these events.
587        if !self.node_to.is_closed() {
588            let _result = self.node_to.send(event);
589        }
590
591        Ok(true)
592    }
593
594    fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
595        let Some(player) = self.players.get(&update.guild_id) else {
596            tracing::warn!(
597                "invalid player update for guild {}: {update:?}",
598                update.guild_id,
599            );
600
601            return Ok(());
602        };
603
604        player.set_position(update.state.position.unwrap_or(0));
605        player.set_time(update.state.time);
606
607        Ok(())
608    }
609
610    async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
611        *self.stats.lock().await = stats.clone();
612
613        Ok(())
614    }
615}
616
617impl Drop for Connection {
618    fn drop(&mut self) {
619        // Cleanup local players associated with the node
620        self.players
621            .players
622            .retain(|_, v| v.node().config().address != self.config.address);
623    }
624}
625
626fn connect_request(state: &NodeConfig) -> Result<Request<()>, NodeError> {
627    let mut request = format!("ws://{}", state.address)
628        .into_client_request()
629        .map_err(|source| NodeError {
630            kind: NodeErrorType::BuildingConnectionRequest,
631            source: Some(Box::new(source)),
632        })?;
633    let headers = request.headers_mut();
634    headers.insert("Authorization", state.authorization.parse().unwrap());
635    headers.insert("User-Id", state.user_id.get().into());
636
637    if state.resume.is_some() {
638        headers.insert("Resume-Key", state.address.to_string().parse().unwrap());
639    }
640
641    Ok(request)
642}
643
644async fn reconnect(
645    config: &NodeConfig,
646) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
647    let (mut stream, res) = backoff(config).await?;
648
649    let headers = res.headers();
650
651    if let Some(resume) = config.resume.as_ref() {
652        let header = HeaderName::from_static("session-resumed");
653
654        if let Some(value) = headers.get(header) {
655            if value.as_bytes() == b"false" {
656                tracing::debug!("session to node {} didn't resume", config.address);
657
658                let payload = serde_json::json!({
659                    "op": "configureResuming",
660                    "key": config.address,
661                    "timeout": resume.timeout,
662                });
663                let msg = Message::Text(serde_json::to_string(&payload).unwrap());
664
665                stream.send(msg).await.unwrap();
666            } else {
667                tracing::debug!("session to {} resumed", config.address);
668            }
669        }
670    }
671
672    Ok(stream)
673}
674
675async fn backoff(
676    config: &NodeConfig,
677) -> Result<
678    (
679        WebSocketStream<MaybeTlsStream<TcpStream>>,
680        Response<Option<Vec<u8>>>,
681    ),
682    NodeError,
683> {
684    let mut seconds = 1;
685
686    loop {
687        let request = connect_request(config)?;
688
689        match tokio_tungstenite::connect_async(request).await {
690            Ok((stream, response)) => return Ok((stream, response)),
691            Err(source) => {
692                tracing::warn!("failed to connect to node {source}: {:?}", config.address);
693
694                if matches!(&source, TungsteniteError::Http(resp) if resp.status() == StatusCode::UNAUTHORIZED)
695                {
696                    return Err(NodeError {
697                        kind: NodeErrorType::Unauthorized {
698                            address: config.address,
699                            authorization: config.authorization.clone(),
700                        },
701                        source: None,
702                    });
703                }
704
705                if seconds > 64 {
706                    tracing::debug!("no longer trying to connect to node {}", config.address);
707
708                    return Err(NodeError {
709                        kind: NodeErrorType::Connecting,
710                        source: Some(Box::new(source)),
711                    });
712                }
713
714                tracing::debug!(
715                    "waiting {seconds} seconds before attempting to connect to node {} again",
716                    config.address,
717                );
718                tokio_time::sleep(Duration::from_secs(seconds)).await;
719
720                seconds *= 2;
721
722                continue;
723            }
724        }
725    }
726}
727
728#[cfg(test)]
729mod tests {
730    use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
731    use static_assertions::{assert_fields, assert_impl_all};
732    use std::{
733        error::Error,
734        fmt::Debug,
735        net::{Ipv4Addr, SocketAddr, SocketAddrV4},
736    };
737    use twilight_model::id::Id;
738
739    assert_fields!(NodeConfig: address, authorization, resume, user_id);
740    assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
741    assert_fields!(NodeErrorType::SerializingMessage: message);
742    assert_fields!(NodeErrorType::Unauthorized: address, authorization);
743    assert_impl_all!(NodeErrorType: Debug, Send, Sync);
744    assert_impl_all!(NodeError: Error, Send, Sync);
745    assert_impl_all!(Node: Debug, Send, Sync);
746    assert_fields!(Resume: timeout);
747    assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
748
749    #[test]
750    fn node_config_debug() {
751        let config = NodeConfig {
752            address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1312)),
753            authorization: "some auth".to_owned(),
754            resume: None,
755            user_id: Id::new(123),
756        };
757
758        assert!(format!("{config:?}").contains("authorization: <redacted>"));
759    }
760}