Skip to main content

snapcast_server/
session.rs

1//! Client session management — binary protocol server for snapclients.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use snapcast_proto::MessageType;
8use snapcast_proto::message::base::BaseMessage;
9use snapcast_proto::message::codec_header::CodecHeader;
10use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
11use snapcast_proto::message::server_settings::ServerSettings;
12use snapcast_proto::message::time::Time;
13use snapcast_proto::message::wire_chunk::WireChunk;
14use snapcast_proto::types::Timeval;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::{Mutex, broadcast, mpsc};
18
19use crate::ClientSettingsUpdate;
20use crate::ServerEvent;
21use crate::stream::manager::WireChunkData;
22use crate::time::now_usec;
23
24/// Info about a connected streaming client.
25#[derive(Debug, Clone)]
26pub struct ClientInfo {
27    /// Unique client ID (from Hello).
28    pub id: String,
29    /// Client hostname.
30    pub host_name: String,
31    /// MAC address.
32    pub mac: String,
33    /// Whether the client is currently connected.
34    pub connected: bool,
35}
36
37/// Manages all streaming client sessions.
38pub struct SessionServer {
39    port: u16,
40    buffer_ms: i32,
41    auth: Option<Arc<dyn crate::auth::AuthValidator>>,
42    clients: Arc<Mutex<HashMap<String, ClientInfo>>>,
43    settings_senders: Arc<Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>>,
44    #[cfg(feature = "custom-protocol")]
45    custom_senders: Arc<Mutex<HashMap<String, mpsc::Sender<CustomOutbound>>>>,
46}
47
48/// Outbound custom message to a specific client.
49#[cfg(feature = "custom-protocol")]
50#[derive(Debug, Clone)]
51pub struct CustomOutbound {
52    /// Message type ID (9+).
53    pub type_id: u16,
54    /// Raw payload.
55    pub payload: Vec<u8>,
56}
57
58impl SessionServer {
59    /// Create a new session server.
60    pub fn new(
61        port: u16,
62        buffer_ms: i32,
63        auth: Option<Arc<dyn crate::auth::AuthValidator>>,
64    ) -> Self {
65        Self {
66            port,
67            buffer_ms,
68            auth,
69            clients: Arc::new(Mutex::new(HashMap::new())),
70            settings_senders: Arc::new(Mutex::new(HashMap::new())),
71            #[cfg(feature = "custom-protocol")]
72            custom_senders: Arc::new(Mutex::new(HashMap::new())),
73        }
74    }
75
76    /// Push a settings update to a specific streaming client.
77    pub async fn push_settings(&self, update: ClientSettingsUpdate) {
78        let senders = self.settings_senders.lock().await;
79        if let Some(tx) = senders.get(&update.client_id) {
80            let _ = tx.send(update).await;
81        }
82    }
83
84    /// Run the session server — accepts connections and spawns per-client tasks.
85    pub async fn run(
86        &self,
87        chunk_rx: broadcast::Sender<WireChunkData>,
88        codec: String,
89        codec_header: Vec<u8>,
90        event_tx: mpsc::Sender<ServerEvent>,
91    ) -> Result<()> {
92        let listener = TcpListener::bind(format!("0.0.0.0:{}", self.port)).await?;
93        tracing::info!(port = self.port, "Stream server listening");
94
95        loop {
96            let (stream, peer) = listener.accept().await?;
97            tracing::info!(%peer, "Client connecting");
98
99            let chunk_sub = chunk_rx.subscribe();
100            let clients = Arc::clone(&self.clients);
101            let settings_senders = Arc::clone(&self.settings_senders);
102            #[cfg(feature = "custom-protocol")]
103            let custom_senders = Arc::clone(&self.custom_senders);
104            let event_tx = event_tx.clone();
105            let buffer_ms = self.buffer_ms;
106            let auth = self.auth.clone();
107            let codec = codec.clone();
108            let codec_header = codec_header.clone();
109
110            tokio::spawn(async move {
111                let (settings_tx, settings_rx) = mpsc::channel(16);
112                #[cfg(feature = "custom-protocol")]
113                let (custom_tx, custom_rx) = mpsc::channel(64);
114                let result = handle_client(
115                    stream,
116                    chunk_sub,
117                    settings_rx,
118                    #[cfg(feature = "custom-protocol")]
119                    custom_rx,
120                    &clients,
121                    &settings_senders,
122                    #[cfg(feature = "custom-protocol")]
123                    &custom_senders,
124                    settings_tx,
125                    #[cfg(feature = "custom-protocol")]
126                    custom_tx,
127                    event_tx,
128                    auth.as_deref(),
129                    buffer_ms,
130                    &codec,
131                    &codec_header,
132                )
133                .await;
134                if let Err(e) = result {
135                    tracing::debug!(%peer, error = %e, "Client session ended");
136                }
137            });
138        }
139    }
140
141    /// Get list of connected clients.
142    pub async fn connected_clients(&self) -> Vec<ClientInfo> {
143        self.clients
144            .lock()
145            .await
146            .values()
147            .filter(|c| c.connected)
148            .cloned()
149            .collect()
150    }
151
152    /// Send a custom binary protocol message to a specific client.
153    #[cfg(feature = "custom-protocol")]
154    pub async fn send_custom(&self, client_id: &str, type_id: u16, payload: Vec<u8>) {
155        let senders = self.custom_senders.lock().await;
156        if let Some(tx) = senders.get(client_id) {
157            let _ = tx.send(CustomOutbound { type_id, payload }).await;
158        }
159    }
160}
161
162#[allow(clippy::too_many_arguments)]
163async fn handle_client(
164    mut stream: TcpStream,
165    chunk_rx: broadcast::Receiver<WireChunkData>,
166    settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
167    #[cfg(feature = "custom-protocol")] custom_rx: mpsc::Receiver<CustomOutbound>,
168    clients: &Mutex<HashMap<String, ClientInfo>>,
169    settings_senders: &Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>,
170    #[cfg(feature = "custom-protocol")] custom_senders: &Mutex<
171        HashMap<String, mpsc::Sender<CustomOutbound>>,
172    >,
173    settings_tx: mpsc::Sender<ClientSettingsUpdate>,
174    #[cfg(feature = "custom-protocol")] custom_tx: mpsc::Sender<CustomOutbound>,
175    event_tx: mpsc::Sender<ServerEvent>,
176    auth: Option<&dyn crate::auth::AuthValidator>,
177    buffer_ms: i32,
178    codec: &str,
179    codec_header: &[u8],
180) -> Result<()> {
181    // 1. Read Hello
182    let hello_msg = read_frame_from(&mut stream).await?;
183    let hello = match hello_msg.payload {
184        MessagePayload::Hello(h) => h,
185        _ => anyhow::bail!("expected Hello, got {:?}", hello_msg.base.msg_type),
186    };
187
188    let client_id = hello.id.clone();
189    tracing::info!(id = %client_id, name = %hello.host_name, mac = %hello.mac, "Client hello");
190
191    // 1b. Validate auth if required
192    if let Some(validator) = auth {
193        let auth_result = match &hello.auth {
194            Some(a) => validator.validate(&a.scheme, &a.param),
195            None => Err(crate::auth::AuthError::Unauthorized(
196                "Authentication required".into(),
197            )),
198        };
199        match auth_result {
200            Ok(result) => {
201                if !result
202                    .permissions
203                    .iter()
204                    .any(|p| p == crate::auth::PERM_STREAMING)
205                {
206                    let err = snapcast_proto::message::error::Error {
207                        code: 403,
208                        message: "Forbidden".into(),
209                        error: "Permission 'Streaming' missing".into(),
210                    };
211                    send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
212                    anyhow::bail!("Client {client_id}: missing Streaming permission");
213                }
214                tracing::info!(id = %client_id, user = %result.username, "Authenticated");
215            }
216            Err(e) => {
217                let err = snapcast_proto::message::error::Error {
218                    code: e.code() as u32,
219                    message: e.message().to_string(),
220                    error: e.message().to_string(),
221                };
222                send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
223                anyhow::bail!("Client {client_id}: {e}");
224            }
225        }
226    }
227
228    // Register client + settings channel
229    {
230        clients.lock().await.insert(
231            client_id.clone(),
232            ClientInfo {
233                id: client_id.clone(),
234                host_name: hello.host_name.clone(),
235                mac: hello.mac.clone(),
236                connected: true,
237            },
238        );
239        settings_senders
240            .lock()
241            .await
242            .insert(client_id.clone(), settings_tx);
243        #[cfg(feature = "custom-protocol")]
244        custom_senders
245            .lock()
246            .await
247            .insert(client_id.clone(), custom_tx);
248    }
249
250    let _ = event_tx
251        .send(ServerEvent::ClientConnected {
252            id: client_id.clone(),
253            name: hello.host_name.clone(),
254        })
255        .await;
256
257    // 2. Send ServerSettings
258    let ss = ServerSettings {
259        buffer_ms,
260        latency: 0,
261        volume: 100,
262        muted: false,
263    };
264    send_msg(
265        &mut stream,
266        MessageType::ServerSettings,
267        &MessagePayload::ServerSettings(ss),
268    )
269    .await?;
270
271    // 3. Send CodecHeader
272    let ch = CodecHeader {
273        codec: codec.to_string(),
274        payload: codec_header.to_vec(),
275    };
276    send_msg(
277        &mut stream,
278        MessageType::CodecHeader,
279        &MessagePayload::CodecHeader(ch),
280    )
281    .await?;
282
283    // 4. Main loop
284    let result = session_loop(
285        &mut stream,
286        chunk_rx,
287        settings_rx,
288        #[cfg(feature = "custom-protocol")]
289        custom_rx,
290        #[cfg(feature = "custom-protocol")]
291        event_tx.clone(),
292        #[cfg(feature = "custom-protocol")]
293        client_id.clone(),
294    )
295    .await;
296
297    // Cleanup
298    {
299        let mut map = clients.lock().await;
300        if let Some(c) = map.get_mut(&client_id) {
301            c.connected = false;
302        }
303    }
304    settings_senders.lock().await.remove(&client_id);
305    #[cfg(feature = "custom-protocol")]
306    custom_senders.lock().await.remove(&client_id);
307    let _ = event_tx
308        .send(ServerEvent::ClientDisconnected { id: client_id })
309        .await;
310
311    result
312}
313
314async fn session_loop(
315    stream: &mut TcpStream,
316    mut chunk_rx: broadcast::Receiver<WireChunkData>,
317    mut settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
318    #[cfg(feature = "custom-protocol")] mut custom_rx: mpsc::Receiver<CustomOutbound>,
319    #[cfg(feature = "custom-protocol")] event_tx: mpsc::Sender<ServerEvent>,
320    #[cfg(feature = "custom-protocol")] client_id: String,
321) -> Result<()> {
322    let (mut reader, mut writer) = stream.split();
323
324    #[cfg(not(feature = "custom-protocol"))]
325    let (mut custom_rx, _event_tx, _client_id): (mpsc::Receiver<()>, Option<()>, String) = {
326        let (_tx, rx) = mpsc::channel(1);
327        (rx, None, String::new())
328    };
329
330    loop {
331        tokio::select! {
332            chunk = chunk_rx.recv() => {
333                let chunk = chunk.context("broadcast closed")?;
334                let ts_usec = chunk.timestamp_usec;
335                let wc = WireChunk {
336                    timestamp: Timeval::from_usec(ts_usec),
337                    payload: chunk.data,
338                };
339                let frame = serialize_msg(MessageType::WireChunk, &MessagePayload::WireChunk(wc), 0)?;
340                writer.write_all(&frame).await.context("write chunk")?;
341            }
342            msg = read_frame_from(&mut reader) => {
343                let msg = msg?;
344                match msg.payload {
345                    MessagePayload::Time(t) => {
346                        let response = Time { latency: t.latency };
347                        let frame = serialize_msg(MessageType::Time, &MessagePayload::Time(response), msg.base.id)?;
348                        writer.write_all(&frame).await.context("write time")?;
349                    }
350                    #[cfg(feature = "custom-protocol")]
351                    MessagePayload::Custom(payload) => {
352                        if let MessageType::Custom(type_id) = msg.base.msg_type {
353                            let _ = event_tx.send(ServerEvent::CustomMessage {
354                                client_id: client_id.clone(),
355                                message: snapcast_proto::CustomMessage::new(type_id, payload),
356                            }).await;
357                        }
358                    }
359                    _ => {}
360                }
361            }
362            update = settings_rx.recv() => {
363                let Some(update) = update else { continue };
364                let ss = ServerSettings {
365                    buffer_ms: update.buffer_ms,
366                    latency: update.latency,
367                    volume: update.volume,
368                    muted: update.muted,
369                };
370                let frame = serialize_msg(
371                    MessageType::ServerSettings,
372                    &MessagePayload::ServerSettings(ss),
373                    0,
374                )?;
375                writer.write_all(&frame).await.context("write settings")?;
376                tracing::debug!(volume = update.volume, latency = update.latency, "Pushed settings to client");
377            }
378            outbound = custom_rx.recv() => {
379                #[cfg(feature = "custom-protocol")]
380                if let Some(msg) = outbound {
381                    let frame = serialize_msg(
382                        MessageType::Custom(msg.type_id),
383                        &MessagePayload::Custom(msg.payload),
384                        0,
385                    )?;
386                    writer.write_all(&frame).await.context("write custom")?;
387                }
388                #[cfg(not(feature = "custom-protocol"))]
389                let _ = outbound;
390            }
391        }
392    }
393}
394
395fn serialize_msg(
396    msg_type: MessageType,
397    payload: &MessagePayload,
398    refers_to: u16,
399) -> Result<Vec<u8>> {
400    let mut base = BaseMessage {
401        msg_type,
402        id: 0,
403        refers_to,
404        sent: now_timeval(),
405        received: Timeval::default(),
406        size: 0,
407    };
408    factory::serialize(&mut base, payload).map_err(|e| anyhow::anyhow!("serialize: {e}"))
409}
410
411async fn send_msg(
412    stream: &mut TcpStream,
413    msg_type: MessageType,
414    payload: &MessagePayload,
415) -> Result<()> {
416    let frame = serialize_msg(msg_type, payload, 0)?;
417    stream.write_all(&frame).await.context("write message")
418}
419
420async fn read_frame_from<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
421    let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
422    reader
423        .read_exact(&mut header_buf)
424        .await
425        .context("read header")?;
426    let mut base =
427        BaseMessage::read_from(&mut &header_buf[..]).map_err(|e| anyhow::anyhow!("parse: {e}"))?;
428    base.received = now_timeval();
429    let mut payload_buf = vec![0u8; base.size as usize];
430    if !payload_buf.is_empty() {
431        reader
432            .read_exact(&mut payload_buf)
433            .await
434            .context("read payload")?;
435    }
436    factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserialize: {e}"))
437}
438
439fn now_timeval() -> Timeval {
440    Timeval::from_usec(now_usec())
441}