Skip to main content

snapcast_server/
session.rs

1//! Client session management — binary protocol server for snapclients.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use snapcast_proto::MessageType;
8use snapcast_proto::message::base::BaseMessage;
9use snapcast_proto::message::codec_header::CodecHeader;
10use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
11use snapcast_proto::message::server_settings::ServerSettings;
12use snapcast_proto::message::time::Time;
13use snapcast_proto::message::wire_chunk::WireChunk;
14use snapcast_proto::types::Timeval;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::{Mutex, broadcast, mpsc};
18
19use crate::ClientSettingsUpdate;
20use crate::ServerEvent;
21use crate::WireChunkData;
22use crate::time::now_usec;
23
24/// Info about a connected streaming client.
25#[derive(Debug, Clone)]
26pub struct ClientInfo {
27    /// Unique client ID (from Hello).
28    pub id: String,
29    /// Client hostname.
30    pub host_name: String,
31    /// MAC address.
32    pub mac: String,
33    /// Whether the client is currently connected.
34    pub connected: bool,
35}
36
37/// Manages all streaming client sessions.
38pub struct SessionServer {
39    port: u16,
40    buffer_ms: i32,
41    auth: Option<Arc<dyn crate::auth::AuthValidator>>,
42    clients: Arc<Mutex<HashMap<String, ClientInfo>>>,
43    settings_senders: Arc<Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>>,
44    #[cfg(feature = "custom-protocol")]
45    custom_senders: Arc<Mutex<HashMap<String, mpsc::Sender<CustomOutbound>>>>,
46}
47
48/// Outbound custom message to a specific client.
49#[cfg(feature = "custom-protocol")]
50#[derive(Debug, Clone)]
51pub struct CustomOutbound {
52    /// Message type ID (9+).
53    pub type_id: u16,
54    /// Raw payload.
55    pub payload: Vec<u8>,
56}
57
58impl SessionServer {
59    /// Create a new session server.
60    pub fn new(
61        port: u16,
62        buffer_ms: i32,
63        auth: Option<Arc<dyn crate::auth::AuthValidator>>,
64    ) -> Self {
65        Self {
66            port,
67            buffer_ms,
68            auth,
69            clients: Arc::new(Mutex::new(HashMap::new())),
70            settings_senders: Arc::new(Mutex::new(HashMap::new())),
71            #[cfg(feature = "custom-protocol")]
72            custom_senders: Arc::new(Mutex::new(HashMap::new())),
73        }
74    }
75
76    /// Push a settings update to a specific streaming client.
77    pub async fn push_settings(&self, update: ClientSettingsUpdate) {
78        let senders = self.settings_senders.lock().await;
79        if let Some(tx) = senders.get(&update.client_id) {
80            let _ = tx.send(update).await;
81        }
82    }
83
84    /// Run the session server — accepts connections and spawns per-client tasks.
85    pub async fn run(
86        &self,
87        chunk_rx: broadcast::Sender<WireChunkData>,
88        codec: String,
89        codec_header: Vec<u8>,
90        event_tx: mpsc::Sender<ServerEvent>,
91    ) -> Result<()> {
92        let listener = TcpListener::bind(format!("0.0.0.0:{}", self.port)).await?;
93        tracing::info!(port = self.port, "Stream server listening");
94
95        loop {
96            let (stream, peer) = listener.accept().await?;
97            tracing::info!(%peer, "Client connecting");
98
99            let chunk_sub = chunk_rx.subscribe();
100            let clients = Arc::clone(&self.clients);
101            let settings_senders = Arc::clone(&self.settings_senders);
102            #[cfg(feature = "custom-protocol")]
103            let custom_senders = Arc::clone(&self.custom_senders);
104            let event_tx = event_tx.clone();
105            let buffer_ms = self.buffer_ms;
106            let auth = self.auth.clone();
107            let codec = codec.clone();
108            let codec_header = codec_header.clone();
109
110            tokio::spawn(async move {
111                let (settings_tx, settings_rx) = mpsc::channel(16);
112                #[cfg(feature = "custom-protocol")]
113                let (custom_tx, custom_rx) = mpsc::channel(64);
114                let result = handle_client(
115                    stream,
116                    chunk_sub,
117                    settings_rx,
118                    #[cfg(feature = "custom-protocol")]
119                    custom_rx,
120                    &clients,
121                    &settings_senders,
122                    #[cfg(feature = "custom-protocol")]
123                    &custom_senders,
124                    settings_tx,
125                    #[cfg(feature = "custom-protocol")]
126                    custom_tx,
127                    event_tx,
128                    auth.as_deref(),
129                    buffer_ms,
130                    &codec,
131                    &codec_header,
132                )
133                .await;
134                if let Err(e) = result {
135                    tracing::debug!(%peer, error = %e, "Client session ended");
136                }
137            });
138        }
139    }
140
141    /// Get list of connected clients.
142    pub async fn connected_clients(&self) -> Vec<ClientInfo> {
143        self.clients
144            .lock()
145            .await
146            .values()
147            .filter(|c| c.connected)
148            .cloned()
149            .collect()
150    }
151
152    /// Send a custom binary protocol message to a specific client.
153    #[cfg(feature = "custom-protocol")]
154    pub async fn send_custom(&self, client_id: &str, type_id: u16, payload: Vec<u8>) {
155        let senders = self.custom_senders.lock().await;
156        if let Some(tx) = senders.get(client_id) {
157            let _ = tx.send(CustomOutbound { type_id, payload }).await;
158        }
159    }
160}
161
162#[allow(clippy::too_many_arguments)]
163async fn handle_client(
164    mut stream: TcpStream,
165    chunk_rx: broadcast::Receiver<WireChunkData>,
166    settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
167    #[cfg(feature = "custom-protocol")] custom_rx: mpsc::Receiver<CustomOutbound>,
168    clients: &Mutex<HashMap<String, ClientInfo>>,
169    settings_senders: &Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>,
170    #[cfg(feature = "custom-protocol")] custom_senders: &Mutex<
171        HashMap<String, mpsc::Sender<CustomOutbound>>,
172    >,
173    settings_tx: mpsc::Sender<ClientSettingsUpdate>,
174    #[cfg(feature = "custom-protocol")] custom_tx: mpsc::Sender<CustomOutbound>,
175    event_tx: mpsc::Sender<ServerEvent>,
176    auth: Option<&dyn crate::auth::AuthValidator>,
177    buffer_ms: i32,
178    codec: &str,
179    codec_header: &[u8],
180) -> Result<()> {
181    // 1. Read Hello
182    let hello_msg = read_frame_from(&mut stream).await?;
183    let hello = match hello_msg.payload {
184        MessagePayload::Hello(h) => h,
185        _ => anyhow::bail!("expected Hello, got {:?}", hello_msg.base.msg_type),
186    };
187
188    let client_id = hello.id.clone();
189    tracing::info!(id = %client_id, name = %hello.host_name, mac = %hello.mac, "Client hello");
190
191    // 1b. Validate auth if required
192    if let Some(validator) = auth {
193        let auth_result = match &hello.auth {
194            Some(a) => validator.validate(&a.scheme, &a.param),
195            None => Err(crate::auth::AuthError::Unauthorized(
196                "Authentication required".into(),
197            )),
198        };
199        match auth_result {
200            Ok(result) => {
201                if !result
202                    .permissions
203                    .iter()
204                    .any(|p| p == crate::auth::PERM_STREAMING)
205                {
206                    let err = snapcast_proto::message::error::Error {
207                        code: 403,
208                        message: "Forbidden".into(),
209                        error: "Permission 'Streaming' missing".into(),
210                    };
211                    send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
212                    anyhow::bail!("Client {client_id}: missing Streaming permission");
213                }
214                tracing::info!(id = %client_id, user = %result.username, "Authenticated");
215            }
216            Err(e) => {
217                let err = snapcast_proto::message::error::Error {
218                    code: e.code() as u32,
219                    message: e.message().to_string(),
220                    error: e.message().to_string(),
221                };
222                send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
223                anyhow::bail!("Client {client_id}: {e}");
224            }
225        }
226    }
227
228    // Register client + settings channel
229    {
230        clients.lock().await.insert(
231            client_id.clone(),
232            ClientInfo {
233                id: client_id.clone(),
234                host_name: hello.host_name.clone(),
235                mac: hello.mac.clone(),
236                connected: true,
237            },
238        );
239        settings_senders
240            .lock()
241            .await
242            .insert(client_id.clone(), settings_tx);
243        #[cfg(feature = "custom-protocol")]
244        custom_senders
245            .lock()
246            .await
247            .insert(client_id.clone(), custom_tx);
248    }
249
250    let _ = event_tx
251        .send(ServerEvent::ClientConnected {
252            id: client_id.clone(),
253            name: hello.host_name.clone(),
254            mac: hello.mac.clone(),
255        })
256        .await;
257
258    // 2. Send ServerSettings
259    let ss = ServerSettings {
260        buffer_ms,
261        latency: 0,
262        volume: 100,
263        muted: false,
264    };
265    send_msg(
266        &mut stream,
267        MessageType::ServerSettings,
268        &MessagePayload::ServerSettings(ss),
269    )
270    .await?;
271
272    // 3. Send CodecHeader
273    let ch = CodecHeader {
274        codec: codec.to_string(),
275        payload: codec_header.to_vec(),
276    };
277    send_msg(
278        &mut stream,
279        MessageType::CodecHeader,
280        &MessagePayload::CodecHeader(ch),
281    )
282    .await?;
283
284    // 4. Main loop
285    let result = session_loop(
286        &mut stream,
287        chunk_rx,
288        settings_rx,
289        #[cfg(feature = "custom-protocol")]
290        custom_rx,
291        #[cfg(feature = "custom-protocol")]
292        event_tx.clone(),
293        #[cfg(feature = "custom-protocol")]
294        client_id.clone(),
295    )
296    .await;
297
298    // Cleanup
299    {
300        let mut map = clients.lock().await;
301        if let Some(c) = map.get_mut(&client_id) {
302            c.connected = false;
303        }
304    }
305    settings_senders.lock().await.remove(&client_id);
306    #[cfg(feature = "custom-protocol")]
307    custom_senders.lock().await.remove(&client_id);
308    let _ = event_tx
309        .send(ServerEvent::ClientDisconnected { id: client_id })
310        .await;
311
312    result
313}
314
315async fn session_loop(
316    stream: &mut TcpStream,
317    mut chunk_rx: broadcast::Receiver<WireChunkData>,
318    mut settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
319    #[cfg(feature = "custom-protocol")] mut custom_rx: mpsc::Receiver<CustomOutbound>,
320    #[cfg(feature = "custom-protocol")] event_tx: mpsc::Sender<ServerEvent>,
321    #[cfg(feature = "custom-protocol")] client_id: String,
322) -> Result<()> {
323    let (mut reader, mut writer) = stream.split();
324
325    #[cfg(not(feature = "custom-protocol"))]
326    let (mut custom_rx, _event_tx, _client_id): (mpsc::Receiver<()>, Option<()>, String) = {
327        let (_tx, rx) = mpsc::channel(1);
328        (rx, None, String::new())
329    };
330
331    loop {
332        tokio::select! {
333            chunk = chunk_rx.recv() => {
334                let chunk = chunk.context("broadcast closed")?;
335                let ts_usec = chunk.timestamp_usec;
336                let wc = WireChunk {
337                    timestamp: Timeval::from_usec(ts_usec),
338                    payload: chunk.data,
339                };
340                let frame = serialize_msg(MessageType::WireChunk, &MessagePayload::WireChunk(wc), 0)?;
341                writer.write_all(&frame).await.context("write chunk")?;
342            }
343            msg = read_frame_from(&mut reader) => {
344                let msg = msg?;
345                match msg.payload {
346                    MessagePayload::Time(t) => {
347                        let response = Time { latency: t.latency };
348                        let frame = serialize_msg(MessageType::Time, &MessagePayload::Time(response), msg.base.id)?;
349                        writer.write_all(&frame).await.context("write time")?;
350                    }
351                    #[cfg(feature = "custom-protocol")]
352                    MessagePayload::Custom(payload) => {
353                        if let MessageType::Custom(type_id) = msg.base.msg_type {
354                            let _ = event_tx.send(ServerEvent::CustomMessage {
355                                client_id: client_id.clone(),
356                                message: snapcast_proto::CustomMessage::new(type_id, payload),
357                            }).await;
358                        }
359                    }
360                    _ => {}
361                }
362            }
363            update = settings_rx.recv() => {
364                let Some(update) = update else { continue };
365                let ss = ServerSettings {
366                    buffer_ms: update.buffer_ms,
367                    latency: update.latency,
368                    volume: update.volume,
369                    muted: update.muted,
370                };
371                let frame = serialize_msg(
372                    MessageType::ServerSettings,
373                    &MessagePayload::ServerSettings(ss),
374                    0,
375                )?;
376                writer.write_all(&frame).await.context("write settings")?;
377                tracing::debug!(volume = update.volume, latency = update.latency, "Pushed settings to client");
378            }
379            outbound = custom_rx.recv() => {
380                #[cfg(feature = "custom-protocol")]
381                if let Some(msg) = outbound {
382                    let frame = serialize_msg(
383                        MessageType::Custom(msg.type_id),
384                        &MessagePayload::Custom(msg.payload),
385                        0,
386                    )?;
387                    writer.write_all(&frame).await.context("write custom")?;
388                }
389                #[cfg(not(feature = "custom-protocol"))]
390                let _ = outbound;
391            }
392        }
393    }
394}
395
396fn serialize_msg(
397    msg_type: MessageType,
398    payload: &MessagePayload,
399    refers_to: u16,
400) -> Result<Vec<u8>> {
401    let mut base = BaseMessage {
402        msg_type,
403        id: 0,
404        refers_to,
405        sent: now_timeval(),
406        received: Timeval::default(),
407        size: 0,
408    };
409    factory::serialize(&mut base, payload).map_err(|e| anyhow::anyhow!("serialize: {e}"))
410}
411
412async fn send_msg(
413    stream: &mut TcpStream,
414    msg_type: MessageType,
415    payload: &MessagePayload,
416) -> Result<()> {
417    let frame = serialize_msg(msg_type, payload, 0)?;
418    stream.write_all(&frame).await.context("write message")
419}
420
421async fn read_frame_from<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
422    let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
423    reader
424        .read_exact(&mut header_buf)
425        .await
426        .context("read header")?;
427    let mut base =
428        BaseMessage::read_from(&mut &header_buf[..]).map_err(|e| anyhow::anyhow!("parse: {e}"))?;
429    base.received = now_timeval();
430    let mut payload_buf = vec![0u8; base.size as usize];
431    if !payload_buf.is_empty() {
432        reader
433            .read_exact(&mut payload_buf)
434            .await
435            .context("read payload")?;
436    }
437    factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserialize: {e}"))
438}
439
440fn now_timeval() -> Timeval {
441    Timeval::from_usec(now_usec())
442}