Skip to main content

snapcast_server/
session.rs

1//! Client session management — binary protocol server for snapclients.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use snapcast_proto::MessageType;
8use snapcast_proto::message::base::BaseMessage;
9use snapcast_proto::message::codec_header::CodecHeader;
10use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
11use snapcast_proto::message::server_settings::ServerSettings;
12use snapcast_proto::message::time::Time;
13use snapcast_proto::message::wire_chunk::WireChunk;
14use snapcast_proto::types::Timeval;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::{Mutex, broadcast, mpsc, watch};
18
19use crate::ClientSettingsUpdate;
20use crate::ServerEvent;
21use crate::WireChunkData;
22use crate::time::now_usec;
23
24// ── Routing ───────────────────────────────────────────────────
25
26/// Per-session routing state — updated when groups/streams/mute change.
27#[derive(Debug, Clone, PartialEq)]
28pub struct SessionRouting {
29    /// Stream ID this client should receive audio from.
30    pub stream_id: String,
31    /// Client is muted.
32    pub client_muted: bool,
33    /// Client's group is muted.
34    pub group_muted: bool,
35}
36
37/// Per-stream codec header, stored at stream registration time.
38#[derive(Debug, Clone)]
39pub struct StreamCodecInfo {
40    /// Codec name (e.g. "flac", "pcm", "f32lz4").
41    pub codec: String,
42    /// Encoded codec header bytes.
43    pub header: Vec<u8>,
44}
45
46// ── Session context (private fields, method API) ──────────────
47
48/// Shared context for all sessions.
49///
50/// # Lock ordering
51///
52/// When acquiring multiple locks, always follow this order to prevent deadlocks:
53///
54/// 1. `shared_state` (server state — groups, clients, streams)
55/// 2. `routing_senders` (per-client watch channels)
56/// 3. `settings_senders` / `custom_senders` (per-client mpsc channels)
57/// 4. `codec_headers` (per-stream codec info)
58///
59/// Never hold a lower-numbered lock while acquiring a higher-numbered one.
60/// In practice, most paths only need one or two locks:
61/// - Routing updates: `shared_state` → `routing_senders`
62/// - Settings push: `settings_senders` only
63/// - Codec lookup: `codec_headers` only
64/// - Client registration: `shared_state`, then separately `routing_senders`
65struct SessionContext {
66    buffer_ms: i32,
67    auth: Option<Arc<dyn crate::auth::AuthValidator>>,
68    send_audio_to_muted: bool,
69    settings_senders: Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>,
70    #[cfg(feature = "custom-protocol")]
71    custom_senders: Mutex<HashMap<String, mpsc::Sender<CustomOutbound>>>,
72    routing_senders: Mutex<HashMap<String, watch::Sender<SessionRouting>>>,
73    codec_headers: Mutex<HashMap<String, StreamCodecInfo>>,
74    shared_state: Arc<tokio::sync::Mutex<crate::state::ServerState>>,
75    default_stream: String,
76}
77
78impl SessionContext {
79    /// Build routing for a single client from server state.
80    /// State lock must be held by caller.
81    fn build_routing(state: &crate::state::ServerState, client_id: &str) -> Option<SessionRouting> {
82        let group = state
83            .groups
84            .iter()
85            .find(|g| g.clients.contains(&client_id.to_string()))?;
86        let client_muted = state
87            .clients
88            .get(client_id)
89            .map(|c| c.config.volume.muted)
90            .unwrap_or(false);
91        Some(SessionRouting {
92            stream_id: group.stream_id.clone(),
93            client_muted,
94            group_muted: group.muted,
95        })
96    }
97
98    /// Send routing update to a single client's watch channel.
99    async fn push_routing(&self, client_id: &str) {
100        let s = self.shared_state.lock().await;
101        if let Some(routing) = Self::build_routing(&s, client_id) {
102            let senders = self.routing_senders.lock().await;
103            if let Some(tx) = senders.get(client_id) {
104                let _ = tx.send(routing);
105            }
106        }
107    }
108
109    /// Send routing updates to all clients in a group.
110    async fn push_routing_for_group(&self, group_id: &str) {
111        let s = self.shared_state.lock().await;
112        let senders = self.routing_senders.lock().await;
113        let Some(group) = s.groups.iter().find(|g| g.id == group_id) else {
114            return;
115        };
116        for client_id in &group.clients {
117            if let Some(routing) = Self::build_routing(&s, client_id)
118                && let Some(tx) = senders.get(client_id)
119            {
120                let _ = tx.send(routing);
121            }
122        }
123    }
124
125    /// Send routing updates to all connected sessions.
126    async fn push_routing_all(&self) {
127        let s = self.shared_state.lock().await;
128        let senders = self.routing_senders.lock().await;
129        for group in &s.groups {
130            for client_id in &group.clients {
131                if let Some(routing) = Self::build_routing(&s, client_id)
132                    && let Some(tx) = senders.get(client_id)
133                {
134                    let _ = tx.send(routing);
135                }
136            }
137        }
138    }
139
140    /// Get codec header for a stream. Returns None if not registered.
141    async fn codec_header_for(&self, stream_id: &str) -> Option<StreamCodecInfo> {
142        self.codec_headers.lock().await.get(stream_id).cloned()
143    }
144}
145
146// ── Session server (public API) ───────────────────────────────
147
148/// Manages all streaming client sessions.
149pub struct SessionServer {
150    port: u16,
151    ctx: Arc<SessionContext>,
152}
153
154/// Outbound custom message to a specific client.
155#[cfg(feature = "custom-protocol")]
156#[derive(Debug, Clone)]
157pub struct CustomOutbound {
158    /// Message type ID (9+).
159    pub type_id: u16,
160    /// Raw payload.
161    pub payload: Vec<u8>,
162}
163
164impl SessionServer {
165    /// Create a new session server.
166    pub fn new(
167        port: u16,
168        buffer_ms: i32,
169        auth: Option<Arc<dyn crate::auth::AuthValidator>>,
170        shared_state: Arc<tokio::sync::Mutex<crate::state::ServerState>>,
171        default_stream: String,
172        send_audio_to_muted: bool,
173    ) -> Self {
174        Self {
175            port,
176            ctx: Arc::new(SessionContext {
177                buffer_ms,
178                auth,
179                send_audio_to_muted,
180                settings_senders: Mutex::new(HashMap::new()),
181                #[cfg(feature = "custom-protocol")]
182                custom_senders: Mutex::new(HashMap::new()),
183                routing_senders: Mutex::new(HashMap::new()),
184                codec_headers: Mutex::new(HashMap::new()),
185                shared_state,
186                default_stream,
187            }),
188        }
189    }
190
191    /// Register a stream's codec header (called during stream setup).
192    pub async fn register_stream_codec(&self, stream_id: &str, codec: &str, header: &[u8]) {
193        self.ctx.codec_headers.lock().await.insert(
194            stream_id.to_string(),
195            StreamCodecInfo {
196                codec: codec.to_string(),
197                header: header.to_vec(),
198            },
199        );
200    }
201
202    /// Push a settings update to a specific streaming client.
203    pub async fn push_settings(&self, update: ClientSettingsUpdate) {
204        let senders = self.ctx.settings_senders.lock().await;
205        if let Some(tx) = senders.get(&update.client_id) {
206            let _ = tx.send(update).await;
207        }
208    }
209
210    /// Update routing for a single client.
211    pub async fn update_routing_for_client(&self, client_id: &str) {
212        self.ctx.push_routing(client_id).await;
213    }
214
215    /// Update routing for all clients in a group.
216    pub async fn update_routing_for_group(&self, group_id: &str) {
217        self.ctx.push_routing_for_group(group_id).await;
218    }
219
220    /// Update routing for all connected sessions (after structural changes).
221    pub async fn update_routing_all(&self) {
222        self.ctx.push_routing_all().await;
223    }
224
225    /// Run the session server — accepts connections and spawns per-client tasks.
226    pub async fn run(
227        &self,
228        chunk_rx: broadcast::Sender<WireChunkData>,
229        event_tx: mpsc::Sender<ServerEvent>,
230    ) -> Result<()> {
231        let listener = TcpListener::bind(format!("0.0.0.0:{}", self.port)).await?;
232        tracing::info!(port = self.port, "Stream server listening");
233
234        loop {
235            let (stream, peer) = listener.accept().await?;
236            stream.set_nodelay(true).ok();
237            let ka = socket2::TcpKeepalive::new().with_time(std::time::Duration::from_secs(10));
238            let sock = socket2::SockRef::from(&stream);
239            sock.set_tcp_keepalive(&ka).ok();
240            tracing::info!(%peer, "Client connecting");
241
242            let chunk_sub = chunk_rx.subscribe();
243            let ctx = Arc::clone(&self.ctx);
244            let event_tx = event_tx.clone();
245
246            tokio::spawn(async move {
247                let result = handle_client(stream, chunk_sub, &ctx, event_tx).await;
248                if let Err(e) = result {
249                    tracing::debug!(%peer, error = %e, "Client session ended");
250                }
251            });
252        }
253    }
254
255    /// Send a custom binary protocol message to a specific client.
256    #[cfg(feature = "custom-protocol")]
257    pub async fn send_custom(&self, client_id: &str, type_id: u16, payload: Vec<u8>) {
258        let senders = self.ctx.custom_senders.lock().await;
259        if let Some(tx) = senders.get(client_id) {
260            let _ = tx.send(CustomOutbound { type_id, payload }).await;
261        }
262    }
263}
264
265// ── Client handler ────────────────────────────────────────────
266
267async fn handle_client(
268    mut stream: TcpStream,
269    chunk_rx: broadcast::Receiver<WireChunkData>,
270    ctx: &SessionContext,
271    event_tx: mpsc::Sender<ServerEvent>,
272) -> Result<()> {
273    let hello_msg = read_frame_from(&mut stream).await?;
274    let hello_id = hello_msg.base.id;
275    let hello = match hello_msg.payload {
276        MessagePayload::Hello(h) => h,
277        _ => anyhow::bail!("expected Hello, got {:?}", hello_msg.base.msg_type),
278    };
279
280    let client_id = hello.id.clone();
281    tracing::info!(id = %client_id, name = %hello.host_name, mac = %hello.mac, "Client hello");
282
283    if let Some(validator) = &ctx.auth {
284        validate_auth(validator.as_ref(), &hello, &mut stream, &client_id).await?;
285    }
286
287    // Register channels
288    let (settings_tx, settings_rx) = mpsc::channel(16);
289    #[cfg(feature = "custom-protocol")]
290    let (custom_tx, custom_rx) = mpsc::channel(64);
291
292    ctx.settings_senders
293        .lock()
294        .await
295        .insert(client_id.clone(), settings_tx);
296    #[cfg(feature = "custom-protocol")]
297    ctx.custom_senders
298        .lock()
299        .await
300        .insert(client_id.clone(), custom_tx);
301
302    // Register in state + build initial routing
303    let initial_stream_id;
304    let initial_routing;
305    let client_settings;
306    {
307        let mut s = ctx.shared_state.lock().await;
308        let c = s.get_or_create_client(&client_id, &hello.host_name, &hello.mac);
309        c.connected = true;
310        client_settings = ServerSettings {
311            buffer_ms: ctx.buffer_ms,
312            latency: c.config.latency,
313            volume: c.config.volume.percent,
314            muted: c.config.volume.muted,
315        };
316        s.group_for_client(&client_id, &ctx.default_stream);
317
318        initial_routing =
319            SessionContext::build_routing(&s, &client_id).unwrap_or_else(|| SessionRouting {
320                stream_id: ctx.default_stream.clone(),
321                client_muted: false,
322                group_muted: false,
323            });
324        initial_stream_id = initial_routing.stream_id.clone();
325    }
326
327    let (routing_tx, routing_rx) = watch::channel(initial_routing);
328    ctx.routing_senders
329        .lock()
330        .await
331        .insert(client_id.clone(), routing_tx);
332
333    let _ = event_tx
334        .send(ServerEvent::ClientConnected {
335            id: client_id.clone(),
336            hello: hello.clone(),
337        })
338        .await;
339
340    // ServerSettings (refers_to must match Hello id for client's pending request)
341    let ss_frame = serialize_msg(
342        MessageType::ServerSettings,
343        &MessagePayload::ServerSettings(client_settings),
344        hello_id,
345    )?;
346    stream
347        .write_all(&ss_frame)
348        .await
349        .context("write server settings")?;
350
351    // CodecHeader for client's stream
352    match ctx.codec_header_for(&initial_stream_id).await {
353        Some(info) => {
354            send_msg(
355                &mut stream,
356                MessageType::CodecHeader,
357                &MessagePayload::CodecHeader(CodecHeader {
358                    codec: info.codec,
359                    payload: info.header,
360                }),
361            )
362            .await?;
363        }
364        None => {
365            tracing::warn!(stream = %initial_stream_id, client = %client_id, "No codec header registered for stream");
366        }
367    }
368
369    // Main loop
370    let result = session_loop(
371        &mut stream,
372        chunk_rx,
373        settings_rx,
374        routing_rx,
375        #[cfg(feature = "custom-protocol")]
376        custom_rx,
377        event_tx.clone(),
378        client_id.clone(),
379        ctx,
380    )
381    .await;
382
383    // Cleanup
384    ctx.settings_senders.lock().await.remove(&client_id);
385    ctx.routing_senders.lock().await.remove(&client_id);
386    #[cfg(feature = "custom-protocol")]
387    ctx.custom_senders.lock().await.remove(&client_id);
388    {
389        let mut s = ctx.shared_state.lock().await;
390        if let Some(c) = s.clients.get_mut(&client_id) {
391            c.connected = false;
392        }
393    }
394    let _ = event_tx
395        .send(ServerEvent::ClientDisconnected { id: client_id })
396        .await;
397
398    result
399}
400
401// ── Session loop (single function, cfg on custom-protocol arms) ──
402//
403// Custom-protocol outbound messages are drained via `try_recv` before each
404// `select!` iteration because `tokio::select!` doesn't support `#[cfg]` on arms.
405// This adds up to one select cycle of latency (~20ms at 48kHz) for custom
406// messages, which is acceptable for low-frequency control traffic.
407
408#[allow(clippy::too_many_arguments)]
409async fn session_loop(
410    stream: &mut TcpStream,
411    mut chunk_rx: broadcast::Receiver<WireChunkData>,
412    mut settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
413    mut routing_rx: watch::Receiver<SessionRouting>,
414    #[cfg(feature = "custom-protocol")] mut custom_rx: mpsc::Receiver<CustomOutbound>,
415    event_tx: mpsc::Sender<ServerEvent>,
416    client_id: String,
417    ctx: &SessionContext,
418) -> Result<()> {
419    let (mut reader, mut writer) = stream.split();
420    let mut routing = routing_rx.borrow().clone();
421
422    loop {
423        // Drain pending custom outbound before blocking on select.
424        // tokio::select! doesn't support #[cfg] on arms, so custom messages
425        // are drained via try_recv. This adds up to one select cycle of latency
426        // (~20ms at 48kHz) which is fine for low-frequency control messages.
427        #[cfg(feature = "custom-protocol")]
428        while let Ok(msg) = custom_rx.try_recv() {
429            let frame = serialize_msg(
430                MessageType::Custom(msg.type_id),
431                &MessagePayload::Custom(msg.payload),
432                0,
433            )?;
434            writer.write_all(&frame).await.context("write custom")?;
435        }
436
437        tokio::select! {
438            chunk = chunk_rx.recv() => {
439                let chunk = match chunk {
440                    Ok(c) => c,
441                    Err(broadcast::error::RecvError::Lagged(n)) => {
442                        tracing::warn!(skipped = n, "Broadcast lagged");
443                        continue;
444                    }
445                    Err(broadcast::error::RecvError::Closed) => {
446                        tracing::warn!("Broadcast closed");
447                        anyhow::bail!("broadcast closed");
448                    }
449                };
450                if !should_send_chunk(&chunk, &routing, ctx.send_audio_to_muted) {
451                    continue;
452                }
453                write_chunk(&mut writer, chunk).await?;
454            }
455            Ok(()) = routing_rx.changed() => {
456                let new = routing_rx.borrow().clone();
457                if new.stream_id != routing.stream_id {
458                    tracing::debug!(old = %routing.stream_id, new = %new.stream_id, "Stream switch");
459                    if let Some(info) = ctx.codec_header_for(&new.stream_id).await {
460                        let frame = serialize_msg(
461                            MessageType::CodecHeader,
462                            &MessagePayload::CodecHeader(CodecHeader {
463                                codec: info.codec,
464                                payload: info.header,
465                            }),
466                            0,
467                        )?;
468                        writer.write_all(&frame).await.context("write codec header")?;
469                    }
470                }
471                routing = new;
472            }
473            msg = read_frame_from(&mut reader) => {
474                let msg = msg?;
475                match msg.payload {
476                    MessagePayload::Time(_t) => {
477                        // latency = server_received - client_sent (c2s one-way estimate)
478                        let latency = msg.base.received - msg.base.sent;
479                        let frame = serialize_msg(
480                            MessageType::Time,
481                            &MessagePayload::Time(Time { latency }),
482                            msg.base.id,
483                        )?;
484                        writer.write_all(&frame).await.context("write time")?;
485                    }
486                    MessagePayload::ClientInfo(info) => {
487                        {
488                            let mut s = ctx.shared_state.lock().await;
489                            if let Some(c) = s.clients.get_mut(&client_id) {
490                                c.config.volume.percent = info.volume;
491                                c.config.volume.muted = info.muted;
492                            }
493                        }
494                        let _ = event_tx.send(ServerEvent::ClientVolumeChanged {
495                            client_id: client_id.clone(),
496                            volume: info.volume,
497                            muted: info.muted,
498                        }).await;
499                    }
500                    #[cfg(feature = "custom-protocol")]
501                    MessagePayload::Custom(payload) => {
502                        if let MessageType::Custom(type_id) = msg.base.msg_type {
503                            let _ = event_tx.send(ServerEvent::CustomMessage {
504                                client_id: client_id.clone(),
505                                message: snapcast_proto::CustomMessage::new(type_id, payload),
506                            }).await;
507                        }
508                    }
509                    _ => {}
510                }
511            }
512            update = settings_rx.recv() => {
513                let Some(update) = update else { continue };
514                write_settings(&mut writer, update).await?;
515            }
516        }
517    }
518}
519
520// ── Helpers ───────────────────────────────────────────────────
521
522/// Decide whether to send a chunk to this session.
523#[inline]
524fn should_send_chunk(
525    chunk: &WireChunkData,
526    routing: &SessionRouting,
527    send_audio_to_muted: bool,
528) -> bool {
529    if chunk.stream_id != routing.stream_id {
530        return false;
531    }
532    if !send_audio_to_muted && (routing.client_muted || routing.group_muted) {
533        return false;
534    }
535    true
536}
537
538async fn write_chunk<W: AsyncWriteExt + Unpin>(writer: &mut W, chunk: WireChunkData) -> Result<()> {
539    let wc = WireChunk {
540        timestamp: Timeval::from_usec(chunk.timestamp_usec),
541        payload: chunk.data,
542    };
543    let frame = serialize_msg(MessageType::WireChunk, &MessagePayload::WireChunk(wc), 0)?;
544    writer.write_all(&frame).await.context("write chunk")
545}
546
547async fn write_settings<W: AsyncWriteExt + Unpin>(
548    writer: &mut W,
549    update: ClientSettingsUpdate,
550) -> Result<()> {
551    let ss = ServerSettings {
552        buffer_ms: update.buffer_ms,
553        latency: update.latency,
554        volume: update.volume,
555        muted: update.muted,
556    };
557    let frame = serialize_msg(
558        MessageType::ServerSettings,
559        &MessagePayload::ServerSettings(ss),
560        0,
561    )?;
562    writer.write_all(&frame).await.context("write settings")?;
563    tracing::debug!(
564        volume = update.volume,
565        latency = update.latency,
566        "Pushed settings"
567    );
568    Ok(())
569}
570
571async fn validate_auth(
572    validator: &dyn crate::auth::AuthValidator,
573    hello: &snapcast_proto::message::hello::Hello,
574    stream: &mut TcpStream,
575    client_id: &str,
576) -> Result<()> {
577    let auth_result = match &hello.auth {
578        Some(a) => validator.validate(&a.scheme, &a.param),
579        None => Err(crate::auth::AuthError::Unauthorized(
580            "Authentication required".into(),
581        )),
582    };
583    match auth_result {
584        Ok(result) => {
585            if !result
586                .permissions
587                .iter()
588                .any(|p| p == crate::auth::PERM_STREAMING)
589            {
590                let err = snapcast_proto::message::error::Error {
591                    code: 403,
592                    message: "Forbidden".into(),
593                    error: "Permission 'Streaming' missing".into(),
594                };
595                send_msg(stream, MessageType::Error, &MessagePayload::Error(err)).await?;
596                anyhow::bail!("Client {client_id}: missing Streaming permission");
597            }
598            tracing::info!(id = %client_id, user = %result.username, "Authenticated");
599            Ok(())
600        }
601        Err(e) => {
602            let err = snapcast_proto::message::error::Error {
603                code: e.code() as u32,
604                message: e.message().to_string(),
605                error: e.message().to_string(),
606            };
607            send_msg(stream, MessageType::Error, &MessagePayload::Error(err)).await?;
608            anyhow::bail!("Client {client_id}: {e}");
609        }
610    }
611}
612
613fn serialize_msg(
614    msg_type: MessageType,
615    payload: &MessagePayload,
616    refers_to: u16,
617) -> Result<Vec<u8>> {
618    let mut base = BaseMessage {
619        msg_type,
620        id: 0,
621        refers_to,
622        sent: now_timeval(),
623        received: Timeval::default(),
624        size: 0,
625    };
626    factory::serialize(&mut base, payload).map_err(|e| anyhow::anyhow!("serialize: {e}"))
627}
628
629async fn send_msg(
630    stream: &mut TcpStream,
631    msg_type: MessageType,
632    payload: &MessagePayload,
633) -> Result<()> {
634    let frame = serialize_msg(msg_type, payload, 0)?;
635    stream.write_all(&frame).await.context("write message")
636}
637
638async fn read_frame_from<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
639    const MAX_PAYLOAD_SIZE: u32 = 2 * 1024 * 1024; // 2 MiB
640
641    let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
642    reader
643        .read_exact(&mut header_buf)
644        .await
645        .context("read header")?;
646    let mut base =
647        BaseMessage::read_from(&mut &header_buf[..]).map_err(|e| anyhow::anyhow!("parse: {e}"))?;
648    base.received = now_timeval();
649    anyhow::ensure!(
650        base.size <= MAX_PAYLOAD_SIZE,
651        "payload too large: {} bytes",
652        base.size
653    );
654    let mut payload_buf = vec![0u8; base.size as usize];
655    if !payload_buf.is_empty() {
656        reader
657            .read_exact(&mut payload_buf)
658            .await
659            .context("read payload")?;
660    }
661    factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserialize: {e}"))
662}
663
664fn now_timeval() -> Timeval {
665    Timeval::from_usec(now_usec())
666}
667
668// ── Tests ─────────────────────────────────────────────────────
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    fn chunk(stream_id: &str) -> WireChunkData {
675        WireChunkData {
676            stream_id: stream_id.to_string(),
677            timestamp_usec: 0,
678            data: vec![0u8; 64],
679        }
680    }
681
682    fn routing(stream_id: &str, client_muted: bool, group_muted: bool) -> SessionRouting {
683        SessionRouting {
684            stream_id: stream_id.to_string(),
685            client_muted,
686            group_muted,
687        }
688    }
689
690    // ── should_send_chunk ─────────────────────────────────────
691
692    #[test]
693    fn matching_stream_unmuted_sends() {
694        assert!(should_send_chunk(
695            &chunk("z1"),
696            &routing("z1", false, false),
697            false
698        ));
699    }
700
701    #[test]
702    fn wrong_stream_skips() {
703        assert!(!should_send_chunk(
704            &chunk("z2"),
705            &routing("z1", false, false),
706            false
707        ));
708    }
709
710    #[test]
711    fn client_muted_skips() {
712        assert!(!should_send_chunk(
713            &chunk("z1"),
714            &routing("z1", true, false),
715            false
716        ));
717    }
718
719    #[test]
720    fn group_muted_skips() {
721        assert!(!should_send_chunk(
722            &chunk("z1"),
723            &routing("z1", false, true),
724            false
725        ));
726    }
727
728    #[test]
729    fn send_audio_to_muted_overrides() {
730        assert!(should_send_chunk(
731            &chunk("z1"),
732            &routing("z1", true, true),
733            true
734        ));
735    }
736
737    #[test]
738    fn wrong_stream_ignores_send_audio_to_muted() {
739        assert!(!should_send_chunk(
740            &chunk("z2"),
741            &routing("z1", false, false),
742            true
743        ));
744    }
745
746    // ── build_routing ─────────────────────────────────────────
747
748    #[test]
749    fn build_routing_finds_client_in_group() {
750        let mut state = crate::state::ServerState::default();
751        state.get_or_create_client("c1", "host", "mac");
752        state.group_for_client("c1", "stream1");
753        let r = SessionContext::build_routing(&state, "c1").unwrap();
754        assert_eq!(r.stream_id, "stream1");
755        assert!(!r.client_muted);
756        assert!(!r.group_muted);
757    }
758
759    #[test]
760    fn build_routing_reflects_mute() {
761        let mut state = crate::state::ServerState::default();
762        let c = state.get_or_create_client("c1", "host", "mac");
763        c.config.volume.muted = true;
764        state.group_for_client("c1", "stream1");
765        if let Some(g) = state
766            .groups
767            .iter_mut()
768            .find(|g| g.clients.contains(&"c1".to_string()))
769        {
770            g.muted = true;
771        }
772        let r = SessionContext::build_routing(&state, "c1").unwrap();
773        assert!(r.client_muted);
774        assert!(r.group_muted);
775    }
776
777    #[test]
778    fn build_routing_returns_none_for_unknown_client() {
779        let state = crate::state::ServerState::default();
780        assert!(SessionContext::build_routing(&state, "unknown").is_none());
781    }
782
783    // ── watch integration ─────────────────────────────────────
784
785    #[test]
786    fn routing_watch_delivers_updates() {
787        let (tx, rx) = watch::channel(routing("z1", false, false));
788        assert_eq!(rx.borrow().stream_id, "z1");
789        tx.send(routing("z2", true, false)).unwrap();
790        assert_eq!(rx.borrow().stream_id, "z2");
791        assert!(rx.borrow().client_muted);
792    }
793
794    #[test]
795    fn unmute_cycle() {
796        let r_muted = routing("z1", true, false);
797        let r_unmuted = routing("z1", false, false);
798        assert!(!should_send_chunk(&chunk("z1"), &r_muted, false));
799        assert!(should_send_chunk(&chunk("z1"), &r_unmuted, false));
800    }
801
802    #[test]
803    fn stream_switch_changes_filter() {
804        let r1 = routing("z1", false, false);
805        let r2 = routing("z2", false, false);
806        assert!(should_send_chunk(&chunk("z1"), &r1, false));
807        assert!(!should_send_chunk(&chunk("z1"), &r2, false));
808        assert!(should_send_chunk(&chunk("z2"), &r2, false));
809    }
810}