streamduck_core/socket/
mod.rs

1//! Socket related definitions
2
3use std::io::Write;
4use tokio::io::{AsyncWrite, AsyncWriteExt};
5use std::ops::Deref;
6use std::sync::Arc;
7use serde::{Deserialize, Serialize};
8use serde::de::{DeserializeOwned, Error};
9use serde_json::Value;
10use tokio::sync::{Mutex, Notify, RwLock};
11use async_recursion::async_recursion;
12use crate::modules::events::SDGlobalEvent;
13
14/// Type for listener's socket handles
15pub type SocketHandle<'a> = &'a mut (dyn AsyncWrite + Unpin + Send);
16
17/// Boxed socket listener
18pub type UniqueSocketListener = Arc<dyn SocketListener + Send + Sync>;
19
20/// Socket packet
21#[derive(Serialize, Deserialize, Debug, Clone)]
22pub struct SocketPacket {
23    /// Data type
24    pub ty: String,
25    /// Possible requester, for letting client understand what response is for which request
26    pub requester: Option<String>,
27    /// Parse-able data
28    pub data: Option<Value>
29}
30
31/// Socket listener, something that can listen in to socket connections
32#[async_trait]
33pub trait SocketListener {
34    /// Called when message is received, handle can be used to send back a response
35    async fn message(&self, socket: SocketHandle<'_>, packet: SocketPacket);
36}
37
38/// Trait for serialization and deserialization util functions
39pub trait SocketData {
40    /// Name of the request
41    const NAME: &'static str;
42}
43
44/// Attempts to parse socket data into a specified type
45pub fn parse_packet_to_data<T: SocketData + DeserializeOwned>(packet: &SocketPacket) -> Result<T, serde_json::Error> {
46    if packet.ty == T::NAME {
47        if let Some(data) = &packet.data {
48            Ok(serde_json::from_value(data.clone())?)
49        } else {
50            Err(serde_json::Error::custom("Missing data"))
51        }
52    } else {
53        Err(serde_json::Error::custom("Invalid packet"))
54    }
55}
56
57/// Checks if packet is of a certain type, for requests without any parameters
58pub fn check_packet_for_data<T: SocketData>(packet: &SocketPacket) -> bool {
59    packet.ty == T::NAME
60}
61
62/// Writes bytes in chunks
63pub async fn write_in_chunks(handle: SocketHandle<'_>, data: String) -> Result<(), SocketError> {
64    for chunk in data.into_bytes().chunks(250) {
65        handle.write(chunk).await?;
66    }
67
68    Ok(())
69}
70
71/// Writes bytes in chunks with sync IO
72pub fn write_in_chunks_sync(handle: &mut dyn Write, data: String) -> Result<(), SocketError> {
73    for chunk in data.into_bytes().chunks(250) {
74        handle.write(chunk)?;
75    }
76
77    Ok(())
78}
79
80/// Sends a packet with included requester ID from previous package
81pub async fn send_packet<T: SocketData + Serialize>(handle: SocketHandle<'_>, previous_packet: &SocketPacket, data: &T) -> Result<(), SocketError> {
82    let packet = SocketPacket {
83        ty: T::NAME.to_string(),
84        requester: previous_packet.requester.clone(),
85        data: Some(serde_json::to_value(data)?)
86    };
87
88    send_packet_as_is(handle, packet).await?;
89
90    Ok(())
91}
92
93/// Sends a packet with included requester ID from previous package with sync IO
94pub async fn send_packet_sync<T: SocketData + Serialize>(handle: &mut dyn Write, previous_packet: &SocketPacket, data: &T) -> Result<(), SocketError> {
95    let packet = SocketPacket {
96        ty: T::NAME.to_string(),
97        requester: previous_packet.requester.clone(),
98        data: Some(serde_json::to_value(data)?)
99    };
100
101    send_packet_as_is_sync(handle, packet)?;
102
103    Ok(())
104}
105
106/// Sends a packet with included requester ID from previous package
107pub async fn send_packet_with_requester<T: SocketData + Serialize>(handle: SocketHandle<'_>, requester: &str, data: &T) -> Result<(), SocketError> {
108    let packet = SocketPacket {
109        ty: T::NAME.to_string(),
110        requester: Some(requester.to_string()),
111        data: Some(serde_json::to_value(data)?)
112    };
113
114    send_packet_as_is(handle, packet).await?;
115
116    Ok(())
117}
118
119/// Sends a packet with included requester ID from previous package with sync IO
120pub fn send_packet_with_requester_sync<T: SocketData + Serialize>(handle: &mut dyn Write, requester: &str, data: &T) -> Result<(), SocketError> {
121    let packet = SocketPacket {
122        ty: T::NAME.to_string(),
123        requester: Some(requester.to_string()),
124        data: Some(serde_json::to_value(data)?)
125    };
126
127    send_packet_as_is_sync(handle, packet)?;
128
129    Ok(())
130}
131
132/// Sends a packet with included requester ID from previous package, without data
133pub async fn send_no_data_packet_with_requester<T: SocketData>(handle: SocketHandle<'_>, requester: &str) -> Result<(), SocketError> {
134    let packet = SocketPacket {
135        ty: T::NAME.to_string(),
136        requester: Some(requester.to_string()),
137        data: None
138    };
139
140    send_packet_as_is(handle, packet).await?;
141
142    Ok(())
143}
144
145/// Sends a packet with included requester ID from previous package, without data, with sync IO
146pub fn send_no_data_packet_with_requester_sync<T: SocketData>(handle: &mut dyn Write, requester: &str) -> Result<(), SocketError> {
147    let packet = SocketPacket {
148        ty: T::NAME.to_string(),
149        requester: Some(requester.to_string()),
150        data: None
151    };
152
153    send_packet_as_is_sync(handle, packet)?;
154
155    Ok(())
156}
157
158/// Sends a packet as is
159pub async fn send_packet_as_is(handle: SocketHandle<'_>, data: SocketPacket) -> Result<(), SocketError> {
160    write_in_chunks(handle, format!("{}\u{0004}", serde_json::to_string(&data)?)).await?;
161
162    Ok(())
163}
164
165/// Sends a packet as is with sync IO
166pub fn send_packet_as_is_sync(handle: &mut dyn Write, data: SocketPacket) -> Result<(), SocketError> {
167    write_in_chunks_sync(handle, format!("{}\u{0004}", serde_json::to_string(&data)?))?;
168
169    Ok(())
170}
171
172/// Enumeration of various errors during sending and parsing packets
173#[derive(Debug)]
174pub enum SocketError {
175    /// Failed to (de)serialize
176    SerdeError(serde_json::Error),
177    /// Failed to write to the socket
178    WriteError(std::io::Error),
179}
180
181impl From<serde_json::Error> for SocketError {
182    fn from(err: serde_json::Error) -> Self {
183        SocketError::SerdeError(err)
184    }
185}
186
187impl From<std::io::Error> for SocketError {
188    fn from(err: std::io::Error) -> Self {
189        SocketError::WriteError(err)
190    }
191}
192
193/// Manager of socket listeners
194pub struct SocketManager {
195    listeners: RwLock<Vec<UniqueSocketListener>>,
196    pools: RwLock<Vec<Arc<SocketPool>>>
197}
198
199impl SocketManager {
200    /// Creates a new socket manager
201    pub fn new() -> Arc<SocketManager> {
202        Arc::new(SocketManager {
203            listeners: Default::default(),
204            pools: Default::default()
205        })
206    }
207
208    /// Adds socket listener to manager
209    pub async fn add_listener(&self, listener: UniqueSocketListener) {
210        self.listeners.write().await.push(listener);
211    }
212
213    /// Sends a message to all listeners, for socket implementation to trigger all listeners when message is received
214    pub async fn received_message(&self, handle: SocketHandle<'_>, packet: SocketPacket) {
215        for listener in self.listeners.read().await.deref() {
216            listener.message(handle, packet.clone()).await;
217        }
218    }
219
220    /// Creates a new message pool
221    pub async fn get_pool(&self) -> Arc<SocketPool> {
222        let mut pools = self.pools.write().await;
223
224        let new_pool = Arc::new(SocketPool {
225            messages: Mutex::new(vec![]),
226            notification: Default::default(),
227            is_open: RwLock::new(true)
228        });
229
230        pools.push(new_pool.clone());
231
232        new_pool
233    }
234
235    /// For listeners or modules to send messages to all active socket connections, for event purposes
236    pub async fn send_message(&self, packet: SocketPacket) {
237        let mut pools = self.pools.write().await;
238
239        let mut pools_to_delete = vec![];
240
241        for (index, pool) in pools.iter().enumerate() {
242            if *pool.is_open.read().await {
243                pool.add_message(packet.clone()).await
244            } else {
245                pools_to_delete.push(index);
246            }
247        }
248
249        for pool_to_delete in pools_to_delete {
250            pools.remove(pool_to_delete);
251        }
252    }
253}
254
255/// Puts together an event packet and sends it
256pub async fn send_event_to_socket(socket_manager: &Arc<SocketManager>, event: SDGlobalEvent) {
257    socket_manager.send_message(SocketPacket {
258        ty: "event".to_string(),
259        requester: None,
260        data: Some(serde_json::to_value(event).unwrap())
261    }).await
262}
263
264/// Pool of messages for socket implementations
265pub struct SocketPool {
266    messages: Mutex<Vec<SocketPacket>>,
267    notification: Notify,
268    is_open: RwLock<bool>
269}
270
271impl SocketPool {
272    /// Puts message into the pool
273    pub async fn add_message(&self, message: SocketPacket) {
274        let mut messages = self.messages.lock().await;
275        messages.insert(0, message);
276        self.notification.notify_waiters();
277    }
278
279    /// Retrieves a message, will block if pool is currently empty
280    #[async_recursion]
281    pub async fn take_message(&self) -> SocketPacket {
282        // Checking if message exists before waiting
283        {
284            let mut guard = self.messages.lock().await;
285            if !guard.is_empty() {
286                return guard.pop().unwrap();
287            }
288        }
289
290        // Waiting for wake-up if empty pool
291        self.notification.notified().await;
292        let mut guard = self.messages.lock().await;
293
294        if let Some(packet) = guard.pop() {
295            packet
296        } else {
297            drop(guard);
298            self.take_message().await
299        }
300    }
301
302    /// If the pool is still open
303    pub async fn is_open(&self) -> bool {
304        *self.is_open.read().await
305    }
306
307    /// CLoses the pool from receiving any packets
308    pub async fn close(&self) {
309        *self.is_open.write().await = false;
310    }
311}