1use crate::{
21 model::{IncomingEvent, Opcode, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22 player::PlayerManager,
23};
24use futures_util::{
25 lock::BiLock,
26 sink::SinkExt,
27 stream::{Stream, StreamExt},
28};
29use http::{header::HeaderName, Request, Response, StatusCode};
30use std::{
31 error::Error,
32 fmt::{Debug, Display, Formatter, Result as FmtResult},
33 net::SocketAddr,
34 pin::Pin,
35 task::{Context, Poll},
36 time::Duration,
37};
38use tokio::{
39 net::TcpStream,
40 sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
41 time as tokio_time,
42};
43use tokio_tungstenite::{
44 tungstenite::{client::IntoClientRequest, Error as TungsteniteError, Message},
45 MaybeTlsStream, WebSocketStream,
46};
47use twilight_model::id::{marker::UserMarker, Id};
48
49#[derive(Debug)]
52pub struct NodeError {
53 kind: NodeErrorType,
54 source: Option<Box<dyn Error + Send + Sync>>,
55}
56
57impl NodeError {
58 #[must_use = "retrieving the type has no effect if left unused"]
60 pub const fn kind(&self) -> &NodeErrorType {
61 &self.kind
62 }
63
64 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
66 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
67 self.source
68 }
69
70 #[must_use = "consuming the error into its parts has no effect if left unused"]
72 pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
73 (self.kind, self.source)
74 }
75}
76
77impl Display for NodeError {
78 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
79 match &self.kind {
80 NodeErrorType::BuildingConnectionRequest { .. } => {
81 f.write_str("failed to build connection request")
82 }
83 NodeErrorType::Connecting { .. } => f.write_str("Failed to connect to the node"),
84 NodeErrorType::SerializingMessage { .. } => {
85 f.write_str("failed to serialize outgoing message as json")
86 }
87 NodeErrorType::Unauthorized { address, .. } => {
88 f.write_str("the authorization used to connect to node ")?;
89 Display::fmt(address, f)?;
90
91 f.write_str(" is invalid")
92 }
93 }
94 }
95}
96
97impl Error for NodeError {
98 fn source(&self) -> Option<&(dyn Error + 'static)> {
99 self.source
100 .as_ref()
101 .map(|source| &**source as &(dyn Error + 'static))
102 }
103}
104
105#[derive(Debug)]
107#[non_exhaustive]
108pub enum NodeErrorType {
109 BuildingConnectionRequest,
111 Connecting,
113 SerializingMessage {
115 message: OutgoingEvent,
117 },
118 Unauthorized {
120 address: SocketAddr,
122 authorization: String,
124 },
125}
126
127#[derive(Debug)]
129pub struct NodeSenderError {
130 kind: NodeSenderErrorType,
131 source: Option<Box<dyn Error + Send + Sync>>,
132}
133
134impl NodeSenderError {
135 pub const fn kind(&self) -> &NodeSenderErrorType {
137 &self.kind
138 }
139
140 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
142 self.source
143 }
144
145 #[must_use = "consuming the error into its parts has no effect if left unused"]
147 pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
148 (self.kind, self.source)
149 }
150}
151
152impl Display for NodeSenderError {
153 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
154 match &self.kind {
155 NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
156 }
157 }
158}
159
160impl Error for NodeSenderError {
161 fn source(&self) -> Option<&(dyn Error + 'static)> {
162 self.source
163 .as_ref()
164 .map(|source| &**source as &(dyn Error + 'static))
165 }
166}
167
168#[derive(Debug)]
170#[non_exhaustive]
171pub enum NodeSenderErrorType {
172 Sending,
174}
175
176pub struct IncomingEvents {
178 inner: UnboundedReceiver<IncomingEvent>,
179}
180
181impl IncomingEvents {
182 pub fn close(&mut self) {
184 self.inner.close();
185 }
186}
187
188impl Stream for IncomingEvents {
189 type Item = IncomingEvent;
190
191 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
192 self.inner.poll_recv(cx)
193 }
194}
195
196pub struct NodeSender {
198 inner: UnboundedSender<OutgoingEvent>,
199}
200
201impl NodeSender {
202 pub fn is_closed(&self) -> bool {
204 self.inner.is_closed()
205 }
206
207 pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
218 self.inner.send(msg).map_err(|source| NodeSenderError {
219 kind: NodeSenderErrorType::Sending,
220 source: Some(Box::new(source)),
221 })
222 }
223}
224
225#[derive(Clone, Eq, PartialEq)]
227#[non_exhaustive]
228pub struct NodeConfig {
230 pub address: SocketAddr,
232 pub authorization: String,
234 pub resume: Option<Resume>,
238 pub user_id: Id<UserMarker>,
240}
241
242impl Debug for NodeConfig {
243 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
244 struct Redacted;
248
249 impl Debug for Redacted {
250 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
251 f.write_str("<redacted>")
252 }
253 }
254
255 f.debug_struct("NodeConfig")
256 .field("address", &self.address)
257 .field("authorization", &Redacted)
258 .field("resume", &self.resume)
259 .field("user_id", &self.user_id)
260 .finish()
261 }
262}
263
264#[derive(Clone, Debug, Eq, PartialEq)]
266#[non_exhaustive]
267pub struct Resume {
268 pub timeout: u64,
273}
274
275impl Resume {
276 pub const fn new(seconds: u64) -> Self {
279 Self { timeout: seconds }
280 }
281}
282
283impl Default for Resume {
284 fn default() -> Self {
285 Self { timeout: 60 }
286 }
287}
288
289impl NodeConfig {
290 pub fn new(
298 user_id: Id<UserMarker>,
299 address: impl Into<SocketAddr>,
300 authorization: impl Into<String>,
301 resume: impl Into<Option<Resume>>,
302 ) -> Self {
303 Self::_new(user_id, address.into(), authorization.into(), resume.into())
304 }
305
306 const fn _new(
307 user_id: Id<UserMarker>,
308 address: SocketAddr,
309 authorization: String,
310 resume: Option<Resume>,
311 ) -> Self {
312 Self {
313 address,
314 authorization,
315 resume,
316 user_id,
317 }
318 }
319}
320
321#[derive(Debug)]
328pub struct Node {
329 config: NodeConfig,
330 lavalink_tx: UnboundedSender<OutgoingEvent>,
331 players: PlayerManager,
332 stats: BiLock<Stats>,
333}
334
335impl Node {
336 pub async fn connect(
361 config: NodeConfig,
362 players: PlayerManager,
363 ) -> Result<(Self, IncomingEvents), NodeError> {
364 let (bilock_left, bilock_right) = BiLock::new(Stats {
365 cpu: StatsCpu {
366 cores: 0,
367 lavalink_load: 0f64,
368 system_load: 0f64,
369 },
370 frames: None,
371 memory: StatsMemory {
372 allocated: 0,
373 free: 0,
374 used: 0,
375 reservable: 0,
376 },
377 players: 0,
378 playing_players: 0,
379 op: Opcode::Stats,
380 uptime: 0,
381 });
382
383 tracing::debug!("starting connection to {}", config.address);
384
385 let (conn_loop, lavalink_tx, lavalink_rx) =
386 Connection::connect(config.clone(), players.clone(), bilock_right).await?;
387
388 tracing::debug!("started connection to {}", config.address);
389
390 tokio::spawn(conn_loop.run());
391
392 Ok((
393 Self {
394 config,
395 lavalink_tx,
396 players,
397 stats: bilock_left,
398 },
399 IncomingEvents { inner: lavalink_rx },
400 ))
401 }
402
403 pub const fn config(&self) -> &NodeConfig {
405 &self.config
406 }
407
408 pub const fn players(&self) -> &PlayerManager {
410 &self.players
411 }
412
413 pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
423 self.sender().send(event)
424 }
425
426 pub fn sender(&self) -> NodeSender {
431 NodeSender {
432 inner: self.lavalink_tx.clone(),
433 }
434 }
435
436 pub async fn stats(&self) -> Stats {
438 (*self.stats.lock().await).clone()
439 }
440
441 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
446 pub async fn penalty(&self) -> i32 {
447 let stats = self.stats.lock().await;
448 let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
449
450 let (deficit_frame, null_frame) = (
451 1.03f64
452 .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64))
453 * 300f64
454 - 300f64,
455 (1.03f64
456 .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64))
457 * 300f64
458 - 300f64)
459 * 2f64,
460 );
461
462 stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
463 }
464}
465
466struct Connection {
467 config: NodeConfig,
468 connection: WebSocketStream<MaybeTlsStream<TcpStream>>,
469 node_from: UnboundedReceiver<OutgoingEvent>,
470 node_to: UnboundedSender<IncomingEvent>,
471 players: PlayerManager,
472 stats: BiLock<Stats>,
473}
474
475impl Connection {
476 async fn connect(
477 config: NodeConfig,
478 players: PlayerManager,
479 stats: BiLock<Stats>,
480 ) -> Result<
481 (
482 Self,
483 UnboundedSender<OutgoingEvent>,
484 UnboundedReceiver<IncomingEvent>,
485 ),
486 NodeError,
487 > {
488 let connection = reconnect(&config).await?;
489
490 let (to_node, from_lavalink) = mpsc::unbounded_channel();
491 let (to_lavalink, from_node) = mpsc::unbounded_channel();
492
493 Ok((
494 Self {
495 config,
496 connection,
497 node_from: from_node,
498 node_to: to_node,
499 players,
500 stats,
501 },
502 to_lavalink,
503 from_lavalink,
504 ))
505 }
506
507 async fn run(mut self) -> Result<(), NodeError> {
508 loop {
509 tokio::select! {
510 incoming = self.connection.next() => {
511 if let Some(Ok(incoming)) = incoming {
512 self.incoming(incoming).await?;
513 } else {
514 tracing::debug!("connection to {} closed, reconnecting", self.config.address);
515 self.connection = reconnect(&self.config).await?;
516 }
517 }
518 outgoing = self.node_from.recv() => {
519 if let Some(outgoing) = outgoing {
520 tracing::debug!(
521 "forwarding event to {}: {outgoing:?}",
522 self.config.address,
523 );
524
525 let payload = serde_json::to_string(&outgoing).map_err(|source| NodeError {
526 kind: NodeErrorType::SerializingMessage { message: outgoing },
527 source: Some(Box::new(source)),
528 })?;
529 let msg = Message::Text(payload);
530 self.connection.send(msg).await.unwrap();
531 } else {
532 tracing::debug!("node {} closed, ending connection", self.config.address);
533
534 break;
535 }
536 }
537 }
538 }
539
540 Ok(())
541 }
542
543 async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
544 tracing::debug!(
545 "received message from {}: {incoming:?}",
546 self.config.address,
547 );
548
549 let text = match incoming {
550 Message::Close(_) => {
551 tracing::debug!("got close, closing connection");
552 let _result = self.connection.send(Message::Close(None)).await;
553
554 return Ok(false);
555 }
556 Message::Ping(data) => {
557 tracing::debug!("got ping, sending pong");
558 let msg = Message::Pong(data);
559
560 let _result = self.connection.send(msg).await;
562
563 return Ok(true);
564 }
565 Message::Text(text) => text,
566 other => {
567 tracing::debug!("got pong or bytes payload: {other:?}");
568
569 return Ok(true);
570 }
571 };
572
573 let Ok(event) = serde_json::from_str(&text) else {
574 tracing::warn!("unknown message from lavalink node: {text}");
575
576 return Ok(true);
577 };
578
579 match &event {
580 IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
581 IncomingEvent::Stats(stats) => self.stats(stats).await?,
582 _ => {}
583 }
584
585 if !self.node_to.is_closed() {
588 let _result = self.node_to.send(event);
589 }
590
591 Ok(true)
592 }
593
594 fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
595 let Some(player) = self.players.get(&update.guild_id) else {
596 tracing::warn!(
597 "invalid player update for guild {}: {update:?}",
598 update.guild_id,
599 );
600
601 return Ok(());
602 };
603
604 player.set_position(update.state.position.unwrap_or(0));
605 player.set_time(update.state.time);
606
607 Ok(())
608 }
609
610 async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
611 *self.stats.lock().await = stats.clone();
612
613 Ok(())
614 }
615}
616
617impl Drop for Connection {
618 fn drop(&mut self) {
619 self.players
621 .players
622 .retain(|_, v| v.node().config().address != self.config.address);
623 }
624}
625
626fn connect_request(state: &NodeConfig) -> Result<Request<()>, NodeError> {
627 let mut request = format!("ws://{}", state.address)
628 .into_client_request()
629 .map_err(|source| NodeError {
630 kind: NodeErrorType::BuildingConnectionRequest,
631 source: Some(Box::new(source)),
632 })?;
633 let headers = request.headers_mut();
634 headers.insert("Authorization", state.authorization.parse().unwrap());
635 headers.insert("User-Id", state.user_id.get().into());
636
637 if state.resume.is_some() {
638 headers.insert("Resume-Key", state.address.to_string().parse().unwrap());
639 }
640
641 Ok(request)
642}
643
644async fn reconnect(
645 config: &NodeConfig,
646) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
647 let (mut stream, res) = backoff(config).await?;
648
649 let headers = res.headers();
650
651 if let Some(resume) = config.resume.as_ref() {
652 let header = HeaderName::from_static("session-resumed");
653
654 if let Some(value) = headers.get(header) {
655 if value.as_bytes() == b"false" {
656 tracing::debug!("session to node {} didn't resume", config.address);
657
658 let payload = serde_json::json!({
659 "op": "configureResuming",
660 "key": config.address,
661 "timeout": resume.timeout,
662 });
663 let msg = Message::Text(serde_json::to_string(&payload).unwrap());
664
665 stream.send(msg).await.unwrap();
666 } else {
667 tracing::debug!("session to {} resumed", config.address);
668 }
669 }
670 }
671
672 Ok(stream)
673}
674
675async fn backoff(
676 config: &NodeConfig,
677) -> Result<
678 (
679 WebSocketStream<MaybeTlsStream<TcpStream>>,
680 Response<Option<Vec<u8>>>,
681 ),
682 NodeError,
683> {
684 let mut seconds = 1;
685
686 loop {
687 let request = connect_request(config)?;
688
689 match tokio_tungstenite::connect_async(request).await {
690 Ok((stream, response)) => return Ok((stream, response)),
691 Err(source) => {
692 tracing::warn!("failed to connect to node {source}: {:?}", config.address);
693
694 if matches!(&source, TungsteniteError::Http(resp) if resp.status() == StatusCode::UNAUTHORIZED)
695 {
696 return Err(NodeError {
697 kind: NodeErrorType::Unauthorized {
698 address: config.address,
699 authorization: config.authorization.clone(),
700 },
701 source: None,
702 });
703 }
704
705 if seconds > 64 {
706 tracing::debug!("no longer trying to connect to node {}", config.address);
707
708 return Err(NodeError {
709 kind: NodeErrorType::Connecting,
710 source: Some(Box::new(source)),
711 });
712 }
713
714 tracing::debug!(
715 "waiting {seconds} seconds before attempting to connect to node {} again",
716 config.address,
717 );
718 tokio_time::sleep(Duration::from_secs(seconds)).await;
719
720 seconds *= 2;
721
722 continue;
723 }
724 }
725 }
726}
727
728#[cfg(test)]
729mod tests {
730 use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
731 use static_assertions::{assert_fields, assert_impl_all};
732 use std::{
733 error::Error,
734 fmt::Debug,
735 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
736 };
737 use twilight_model::id::Id;
738
739 assert_fields!(NodeConfig: address, authorization, resume, user_id);
740 assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
741 assert_fields!(NodeErrorType::SerializingMessage: message);
742 assert_fields!(NodeErrorType::Unauthorized: address, authorization);
743 assert_impl_all!(NodeErrorType: Debug, Send, Sync);
744 assert_impl_all!(NodeError: Error, Send, Sync);
745 assert_impl_all!(Node: Debug, Send, Sync);
746 assert_fields!(Resume: timeout);
747 assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
748
749 #[test]
750 fn node_config_debug() {
751 let config = NodeConfig {
752 address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1312)),
753 authorization: "some auth".to_owned(),
754 resume: None,
755 user_id: Id::new(123),
756 };
757
758 assert!(format!("{config:?}").contains("authorization: <redacted>"));
759 }
760}