Skip to main content

snapcast_client/connection/
mod.rs

1//! Connection layer.
2//!
3//! TCP is the supported Snapcast audio transport. The WebSocket modules are
4//! kept feature-gated for future interoperability work, but they are not
5//! selected by [`SnapConnection::new`] until the server and client can speak a
6//! verified binary audio-streaming WebSocket contract.
7
8#[cfg(feature = "websocket")]
9pub mod ws;
10#[cfg(feature = "tls")]
11pub mod wss;
12
13use std::collections::HashMap;
14use std::time::Duration;
15
16use anyhow::{Context, Result};
17use snapcast_proto::MessageType;
18use snapcast_proto::message::base::BaseMessage;
19use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
20use snapcast_proto::types::Timeval;
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::TcpStream;
23use tokio::sync::oneshot;
24
25/// Read a complete frame (header + payload) from an async reader.
26async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
27    // Read 26-byte header
28    let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
29    reader
30        .read_exact(&mut header_buf)
31        .await
32        .context("reading base message header")?;
33
34    let mut base = BaseMessage::read_from(&mut &header_buf[..])
35        .map_err(|e| anyhow::anyhow!("parsing header: {e}"))?;
36
37    // Stamp received time using steady clock (matching C++ steadytimeofday)
38    base.received = steady_time_of_day();
39    ensure_payload_size(base.size)?;
40
41    // Read payload
42    let mut payload_buf = vec![0u8; base.size as usize];
43    if !payload_buf.is_empty() {
44        reader
45            .read_exact(&mut payload_buf)
46            .await
47            .context("reading payload")?;
48    }
49
50    factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserializing: {e}"))
51}
52
53pub(crate) fn ensure_payload_size(size: u32) -> Result<()> {
54    anyhow::ensure!(
55        size <= snapcast_proto::DEFAULT_MAX_PAYLOAD_SIZE,
56        "payload too large: {size} bytes"
57    );
58    Ok(())
59}
60
61/// Write a complete frame (header + payload) to an async writer.
62async fn write_frame<W: AsyncWriteExt + Unpin>(
63    writer: &mut W,
64    base: &mut BaseMessage,
65    payload: &MessagePayload,
66) -> Result<()> {
67    let frame =
68        factory::serialize(base, payload).map_err(|e| anyhow::anyhow!("serializing: {e}"))?;
69    writer.write_all(&frame).await.context("writing frame")?;
70    Ok(())
71}
72
73/// Pending request waiting for a response.
74struct PendingRequest {
75    tx: oneshot::Sender<TypedMessage>,
76}
77
78/// TCP connection to a snapserver.
79pub struct TcpConnection {
80    stream: Option<TcpStream>,
81    host: String,
82    port: u16,
83    pending: HashMap<u16, PendingRequest>,
84    next_id: u16,
85}
86
87/// Unified connection over supported transports.
88pub enum SnapConnection {
89    /// Plain TCP connection.
90    Tcp(TcpConnection),
91    #[cfg(feature = "websocket")]
92    /// WebSocket (non-secure) connection.
93    Ws(ws::WsConnection),
94    #[cfg(feature = "tls")]
95    /// WebSocket over TLS (secure) connection.
96    Wss(wss::WssConnection),
97}
98
99impl SnapConnection {
100    /// Create a new connection based on the scheme.
101    pub fn new(scheme: &str, host: &str, port: u16) -> Result<Self> {
102        match scheme {
103            snapcast_proto::SCHEME_TCP => Ok(Self::Tcp(TcpConnection::new(host, port))),
104            snapcast_proto::SCHEME_WS | snapcast_proto::SCHEME_WSS => anyhow::bail!(
105                "websocket audio transport is not supported yet; use tcp:// for Snapcast audio"
106            ),
107            _ => anyhow::bail!("unsupported scheme: {scheme}"),
108        }
109    }
110
111    /// Establish the connection.
112    pub async fn connect(&mut self) -> Result<()> {
113        match self {
114            Self::Tcp(c) => c.connect().await,
115            #[cfg(feature = "websocket")]
116            Self::Ws(c) => c.connect().await,
117            #[cfg(feature = "tls")]
118            Self::Wss(c) => c.connect().await,
119        }
120    }
121
122    /// Close the connection.
123    pub fn disconnect(&mut self) {
124        match self {
125            Self::Tcp(c) => c.disconnect(),
126            #[cfg(feature = "websocket")]
127            Self::Ws(c) => c.disconnect(),
128            #[cfg(feature = "tls")]
129            Self::Wss(c) => c.disconnect(),
130        }
131    }
132
133    /// Send a message.
134    pub async fn send(&mut self, msg_type: MessageType, payload: &MessagePayload) -> Result<()> {
135        match self {
136            Self::Tcp(c) => c.send(msg_type, payload).await,
137            #[cfg(feature = "websocket")]
138            Self::Ws(c) => c.send(msg_type, payload).await,
139            #[cfg(feature = "tls")]
140            Self::Wss(c) => c.send(msg_type, payload).await,
141        }
142    }
143
144    /// Receive the next message.
145    pub async fn recv(&mut self) -> Result<TypedMessage> {
146        match self {
147            Self::Tcp(c) => c.recv().await,
148            #[cfg(feature = "websocket")]
149            Self::Ws(c) => c.recv().await,
150            #[cfg(feature = "tls")]
151            Self::Wss(c) => c.recv().await,
152        }
153    }
154}
155
156impl TcpConnection {
157    /// Create a new connection to the given host and port.
158    pub fn new(host: &str, port: u16) -> Self {
159        Self {
160            stream: None,
161            host: host.to_string(),
162            port,
163            pending: HashMap::new(),
164            next_id: 1,
165        }
166    }
167
168    /// Establish the TCP connection.
169    pub async fn connect(&mut self) -> Result<()> {
170        let addr = format!("{}:{}", self.host, self.port);
171        let stream = TcpStream::connect(&addr)
172            .await
173            .with_context(|| format!("connecting to {addr}"))?;
174        self.stream = Some(stream);
175        self.pending.clear();
176        self.next_id = 1;
177        Ok(())
178    }
179
180    /// Close the connection.
181    pub fn disconnect(&mut self) {
182        self.stream = None;
183        self.pending.clear();
184    }
185
186    fn stream_mut(&mut self) -> Result<&mut TcpStream> {
187        self.stream.as_mut().context("not connected")
188    }
189
190    /// Send a message without waiting for a response.
191    pub async fn send(&mut self, msg_type: MessageType, payload: &MessagePayload) -> Result<()> {
192        let stream = self.stream_mut()?;
193        let mut base = BaseMessage {
194            msg_type,
195            id: 0,
196            refers_to: 0,
197            sent: Timeval::default(),
198            received: Timeval::default(),
199            size: 0,
200        };
201        stamp_sent(&mut base);
202        write_frame(stream, &mut base, payload).await
203    }
204
205    /// Send a request and wait for the response (matched by `refersTo`).
206    pub async fn send_request(
207        &mut self,
208        msg_type: MessageType,
209        payload: &MessagePayload,
210        timeout: Duration,
211    ) -> Result<TypedMessage> {
212        let id = self.next_id;
213        self.next_id = self.next_id.wrapping_add(1);
214
215        let (tx, rx) = oneshot::channel();
216        self.pending.insert(id, PendingRequest { tx });
217
218        let stream = self.stream_mut()?;
219        let mut base = BaseMessage {
220            msg_type,
221            id,
222            refers_to: 0,
223            sent: Timeval::default(),
224            received: Timeval::default(),
225            size: 0,
226        };
227        stamp_sent(&mut base);
228        write_frame(stream, &mut base, payload).await?;
229
230        tokio::time::timeout(timeout, rx)
231            .await
232            .context("request timed out")?
233            .context("response channel closed")
234    }
235
236    /// Receive the next message. If it's a response to a pending request,
237    /// deliver it to the waiting caller and receive again.
238    pub async fn recv(&mut self) -> Result<TypedMessage> {
239        loop {
240            let stream = self.stream_mut()?;
241            let msg = read_frame(stream).await?;
242
243            if msg.base.refers_to != 0
244                && let Some(pending) = self.pending.remove(&msg.base.refers_to)
245            {
246                let _ = pending.tx.send(msg);
247                continue;
248            }
249            return Ok(msg);
250        }
251    }
252}
253
254pub(super) fn stamp_sent(base: &mut BaseMessage) {
255    let tv = steady_time_of_day();
256    base.sent = tv;
257}
258
259/// Matches the C++ `chronos::steadytimeofday` — monotonic clock time.
260/// On macOS/Linux, `Instant` is based on `CLOCK_MONOTONIC` which counts
261/// seconds since boot, matching the C++ snapserver's clock domain.
262pub(super) fn steady_time_of_day() -> Timeval {
263    // Instant::now().duration_since(EPOCH) gives time since first call.
264    // We need time since boot. On Unix, Instant uses CLOCK_MONOTONIC
265    // which starts at boot. We can get this via the elapsed time from
266    // a known-early Instant.
267    let usec = monotonic_usec();
268    Timeval {
269        sec: (usec / 1_000_000) as i32,
270        usec: (usec % 1_000_000) as i32,
271    }
272}
273
274/// Microseconds since boot (monotonic clock).
275/// Uses the same clock source as C++ std::chrono::steady_clock.
276#[allow(unsafe_code)] // FFI: mach_continuous_time (macOS), clock_gettime (Linux)
277fn monotonic_usec() -> i64 {
278    #[cfg(target_os = "macos")]
279    {
280        // macOS: C++ steady_clock uses mach_continuous_time, not CLOCK_MONOTONIC.
281        // These differ by ~2s on macOS. We must match the server's clock exactly.
282        unsafe extern "C" {
283            fn mach_continuous_time() -> u64;
284            fn mach_timebase_info(info: *mut MachTimebaseInfo) -> i32;
285        }
286        #[repr(C)]
287        struct MachTimebaseInfo {
288            numer: u32,
289            denom: u32,
290        }
291        static TIMEBASE: std::sync::OnceLock<(u32, u32)> = std::sync::OnceLock::new();
292        let (numer, denom) = *TIMEBASE.get_or_init(|| {
293            let mut info = MachTimebaseInfo { numer: 0, denom: 0 };
294            unsafe {
295                mach_timebase_info(&mut info);
296            }
297            (info.numer, info.denom)
298        });
299        let ticks = unsafe { mach_continuous_time() };
300        let nanos = ticks as i128 * numer as i128 / denom as i128;
301        (nanos / 1_000) as i64
302    }
303    #[cfg(all(unix, not(target_os = "macos")))]
304    {
305        let mut ts = libc::timespec {
306            tv_sec: 0,
307            tv_nsec: 0,
308        };
309        // SAFETY: clock_gettime with CLOCK_MONOTONIC is always safe
310        unsafe {
311            libc::clock_gettime(libc::CLOCK_MONOTONIC, &mut ts);
312        }
313        ts.tv_sec * 1_000_000 + ts.tv_nsec / 1_000
314    }
315    #[cfg(not(unix))]
316    {
317        let now = std::time::SystemTime::now()
318            .duration_since(std::time::UNIX_EPOCH)
319            .unwrap_or_default();
320        now.as_micros() as i64
321    }
322}
323
324/// Current time in microseconds using the steady clock.
325pub fn now_usec() -> i64 {
326    monotonic_usec()
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use snapcast_proto::message::time::Time;
333
334    /// Test frame read/write with in-memory buffers (no network needed).
335    #[tokio::test]
336    async fn write_and_read_frame() {
337        let payload = MessagePayload::Time(Time {
338            latency: Timeval { sec: 0, usec: 1234 },
339        });
340        let mut base = BaseMessage {
341            msg_type: MessageType::Time,
342            id: 42,
343            refers_to: 0,
344            sent: Timeval { sec: 1, usec: 0 },
345            received: Timeval::default(),
346            size: 0,
347        };
348
349        // Write to buffer
350        let mut buf = Vec::new();
351        write_frame(&mut buf, &mut base, &payload).await.unwrap();
352
353        // Size should be header + payload
354        assert_eq!(buf.len(), BaseMessage::HEADER_SIZE + Time::SIZE as usize);
355
356        // Read back
357        let mut cursor = std::io::Cursor::new(&buf);
358        let msg = read_frame(&mut cursor).await.unwrap();
359        assert_eq!(msg.base.msg_type, MessageType::Time);
360        assert_eq!(msg.base.id, 42);
361        match msg.payload {
362            MessagePayload::Time(t) => assert_eq!(t.latency.usec, 1234),
363            _ => panic!("expected Time"),
364        }
365    }
366
367    #[tokio::test]
368    async fn write_and_read_error_frame() {
369        use snapcast_proto::message::error::Error;
370
371        let payload = MessagePayload::Error(Error {
372            code: 401,
373            error: "Unauthorized".into(),
374            message: "bad auth".into(),
375        });
376        let mut base = BaseMessage {
377            msg_type: MessageType::Error,
378            id: 0,
379            refers_to: 7,
380            sent: Timeval::default(),
381            received: Timeval::default(),
382            size: 0,
383        };
384
385        let mut buf = Vec::new();
386        write_frame(&mut buf, &mut base, &payload).await.unwrap();
387
388        let mut cursor = std::io::Cursor::new(&buf);
389        let msg = read_frame(&mut cursor).await.unwrap();
390        assert_eq!(msg.base.refers_to, 7);
391        match msg.payload {
392            MessagePayload::Error(e) => {
393                assert_eq!(e.code, 401);
394                assert_eq!(e.error, "Unauthorized");
395            }
396            _ => panic!("expected Error"),
397        }
398    }
399
400    #[tokio::test]
401    async fn write_and_read_multiple_frames() {
402        let frames: Vec<(MessageType, MessagePayload)> = vec![
403            (MessageType::Time, MessagePayload::Time(Time::default())),
404            (
405                MessageType::ClientInfo,
406                MessagePayload::ClientInfo(snapcast_proto::message::client_info::ClientInfo {
407                    volume: 80,
408                    muted: false,
409                }),
410            ),
411        ];
412
413        let mut buf = Vec::new();
414        for (mt, payload) in &frames {
415            let mut base = BaseMessage {
416                msg_type: *mt,
417                id: 0,
418                refers_to: 0,
419                sent: Timeval::default(),
420                received: Timeval::default(),
421                size: 0,
422            };
423            write_frame(&mut buf, &mut base, payload).await.unwrap();
424        }
425
426        // Read both back
427        let mut cursor = std::io::Cursor::new(&buf);
428        let msg1 = read_frame(&mut cursor).await.unwrap();
429        assert_eq!(msg1.base.msg_type, MessageType::Time);
430        let msg2 = read_frame(&mut cursor).await.unwrap();
431        assert_eq!(msg2.base.msg_type, MessageType::ClientInfo);
432    }
433
434    #[test]
435    fn tcp_connection_new() {
436        let conn = TcpConnection::new("localhost", 1704);
437        assert!(conn.stream.is_none());
438        assert_eq!(conn.host, "localhost");
439        assert_eq!(conn.port, 1704);
440    }
441
442    #[test]
443    fn rejects_oversized_payload() {
444        let too_large = snapcast_proto::DEFAULT_MAX_PAYLOAD_SIZE + 1;
445        assert!(ensure_payload_size(too_large).is_err());
446    }
447
448    #[test]
449    fn rejects_websocket_audio_scheme() {
450        assert!(SnapConnection::new("ws", "localhost", 1780).is_err());
451        assert!(SnapConnection::new("wss", "localhost", 1788).is_err());
452    }
453}