snapcast_client/connection/
mod.rs1#[cfg(feature = "websocket")]
4pub mod ws;
5#[cfg(feature = "tls")]
6pub mod wss;
7
8use std::collections::HashMap;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use snapcast_proto::MessageType;
13use snapcast_proto::message::base::BaseMessage;
14use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
15use snapcast_proto::types::Timeval;
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::TcpStream;
18use tokio::sync::oneshot;
19
20async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
22 let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
24 reader
25 .read_exact(&mut header_buf)
26 .await
27 .context("reading base message header")?;
28
29 let mut base = BaseMessage::read_from(&mut &header_buf[..])
30 .map_err(|e| anyhow::anyhow!("parsing header: {e}"))?;
31
32 base.received = steady_time_of_day();
34
35 let mut payload_buf = vec![0u8; base.size as usize];
37 if !payload_buf.is_empty() {
38 reader
39 .read_exact(&mut payload_buf)
40 .await
41 .context("reading payload")?;
42 }
43
44 factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserializing: {e}"))
45}
46
47async fn write_frame<W: AsyncWriteExt + Unpin>(
49 writer: &mut W,
50 base: &mut BaseMessage,
51 payload: &MessagePayload,
52) -> Result<()> {
53 let frame =
54 factory::serialize(base, payload).map_err(|e| anyhow::anyhow!("serializing: {e}"))?;
55 writer.write_all(&frame).await.context("writing frame")?;
56 Ok(())
57}
58
59struct PendingRequest {
61 tx: oneshot::Sender<TypedMessage>,
62}
63
64pub struct TcpConnection {
66 stream: Option<TcpStream>,
67 host: String,
68 port: u16,
69 pending: HashMap<u16, PendingRequest>,
70 next_id: u16,
71}
72
73impl TcpConnection {
74 pub fn new(host: &str, port: u16) -> Self {
76 Self {
77 stream: None,
78 host: host.to_string(),
79 port,
80 pending: HashMap::new(),
81 next_id: 1,
82 }
83 }
84
85 pub async fn connect(&mut self) -> Result<()> {
87 let addr = format!("{}:{}", self.host, self.port);
88 let stream = TcpStream::connect(&addr)
89 .await
90 .with_context(|| format!("connecting to {addr}"))?;
91 self.stream = Some(stream);
92 self.pending.clear();
93 self.next_id = 1;
94 Ok(())
95 }
96
97 pub fn disconnect(&mut self) {
99 self.stream = None;
100 self.pending.clear();
101 }
102
103 fn stream_mut(&mut self) -> Result<&mut TcpStream> {
104 self.stream.as_mut().context("not connected")
105 }
106
107 pub async fn send(&mut self, msg_type: MessageType, payload: &MessagePayload) -> Result<()> {
109 let stream = self.stream_mut()?;
110 let mut base = BaseMessage {
111 msg_type,
112 id: 0,
113 refers_to: 0,
114 sent: Timeval::default(),
115 received: Timeval::default(),
116 size: 0,
117 };
118 stamp_sent(&mut base);
119 write_frame(stream, &mut base, payload).await
120 }
121
122 pub async fn send_request(
124 &mut self,
125 msg_type: MessageType,
126 payload: &MessagePayload,
127 timeout: Duration,
128 ) -> Result<TypedMessage> {
129 let id = self.next_id;
130 self.next_id = self.next_id.wrapping_add(1);
131
132 let (tx, rx) = oneshot::channel();
133 self.pending.insert(id, PendingRequest { tx });
134
135 let stream = self.stream_mut()?;
136 let mut base = BaseMessage {
137 msg_type,
138 id,
139 refers_to: 0,
140 sent: Timeval::default(),
141 received: Timeval::default(),
142 size: 0,
143 };
144 stamp_sent(&mut base);
145 write_frame(stream, &mut base, payload).await?;
146
147 tokio::time::timeout(timeout, rx)
148 .await
149 .context("request timed out")?
150 .context("response channel closed")
151 }
152
153 pub async fn recv(&mut self) -> Result<TypedMessage> {
156 loop {
157 let stream = self.stream_mut()?;
158 let msg = read_frame(stream).await?;
159
160 if msg.base.refers_to != 0
161 && let Some(pending) = self.pending.remove(&msg.base.refers_to)
162 {
163 let _ = pending.tx.send(msg);
164 continue;
165 }
166 return Ok(msg);
167 }
168 }
169}
170
171fn stamp_sent(base: &mut BaseMessage) {
172 let tv = steady_time_of_day();
173 base.sent = tv;
174}
175
176fn steady_time_of_day() -> Timeval {
180 let usec = monotonic_usec();
185 Timeval {
186 sec: (usec / 1_000_000) as i32,
187 usec: (usec % 1_000_000) as i32,
188 }
189}
190
191#[allow(unsafe_code)] fn monotonic_usec() -> i64 {
195 #[cfg(target_os = "macos")]
196 {
197 unsafe extern "C" {
200 fn mach_continuous_time() -> u64;
201 fn mach_timebase_info(info: *mut MachTimebaseInfo) -> i32;
202 }
203 #[repr(C)]
204 struct MachTimebaseInfo {
205 numer: u32,
206 denom: u32,
207 }
208 static TIMEBASE: std::sync::OnceLock<(u32, u32)> = std::sync::OnceLock::new();
209 let (numer, denom) = *TIMEBASE.get_or_init(|| {
210 let mut info = MachTimebaseInfo { numer: 0, denom: 0 };
211 unsafe {
212 mach_timebase_info(&mut info);
213 }
214 (info.numer, info.denom)
215 });
216 let ticks = unsafe { mach_continuous_time() };
217 let nanos = ticks as i128 * numer as i128 / denom as i128;
218 (nanos / 1_000) as i64
219 }
220 #[cfg(all(unix, not(target_os = "macos")))]
221 {
222 let mut ts = libc::timespec {
223 tv_sec: 0,
224 tv_nsec: 0,
225 };
226 unsafe {
228 libc::clock_gettime(libc::CLOCK_MONOTONIC, &mut ts);
229 }
230 ts.tv_sec * 1_000_000 + ts.tv_nsec / 1_000
231 }
232 #[cfg(not(unix))]
233 {
234 let now = std::time::SystemTime::now()
235 .duration_since(std::time::UNIX_EPOCH)
236 .unwrap_or_default();
237 now.as_micros() as i64
238 }
239}
240
241pub fn now_usec() -> i64 {
243 monotonic_usec()
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use snapcast_proto::message::time::Time;
250
251 #[tokio::test]
253 async fn write_and_read_frame() {
254 let payload = MessagePayload::Time(Time {
255 latency: Timeval { sec: 0, usec: 1234 },
256 });
257 let mut base = BaseMessage {
258 msg_type: MessageType::Time,
259 id: 42,
260 refers_to: 0,
261 sent: Timeval { sec: 1, usec: 0 },
262 received: Timeval::default(),
263 size: 0,
264 };
265
266 let mut buf = Vec::new();
268 write_frame(&mut buf, &mut base, &payload).await.unwrap();
269
270 assert_eq!(buf.len(), BaseMessage::HEADER_SIZE + Time::SIZE as usize);
272
273 let mut cursor = std::io::Cursor::new(&buf);
275 let msg = read_frame(&mut cursor).await.unwrap();
276 assert_eq!(msg.base.msg_type, MessageType::Time);
277 assert_eq!(msg.base.id, 42);
278 match msg.payload {
279 MessagePayload::Time(t) => assert_eq!(t.latency.usec, 1234),
280 _ => panic!("expected Time"),
281 }
282 }
283
284 #[tokio::test]
285 async fn write_and_read_error_frame() {
286 use snapcast_proto::message::error::Error;
287
288 let payload = MessagePayload::Error(Error {
289 code: 401,
290 error: "Unauthorized".into(),
291 message: "bad auth".into(),
292 });
293 let mut base = BaseMessage {
294 msg_type: MessageType::Error,
295 id: 0,
296 refers_to: 7,
297 sent: Timeval::default(),
298 received: Timeval::default(),
299 size: 0,
300 };
301
302 let mut buf = Vec::new();
303 write_frame(&mut buf, &mut base, &payload).await.unwrap();
304
305 let mut cursor = std::io::Cursor::new(&buf);
306 let msg = read_frame(&mut cursor).await.unwrap();
307 assert_eq!(msg.base.refers_to, 7);
308 match msg.payload {
309 MessagePayload::Error(e) => {
310 assert_eq!(e.code, 401);
311 assert_eq!(e.error, "Unauthorized");
312 }
313 _ => panic!("expected Error"),
314 }
315 }
316
317 #[tokio::test]
318 async fn write_and_read_multiple_frames() {
319 let frames: Vec<(MessageType, MessagePayload)> = vec![
320 (MessageType::Time, MessagePayload::Time(Time::default())),
321 (
322 MessageType::ClientInfo,
323 MessagePayload::ClientInfo(snapcast_proto::message::client_info::ClientInfo {
324 volume: 80,
325 muted: false,
326 }),
327 ),
328 ];
329
330 let mut buf = Vec::new();
331 for (mt, payload) in &frames {
332 let mut base = BaseMessage {
333 msg_type: *mt,
334 id: 0,
335 refers_to: 0,
336 sent: Timeval::default(),
337 received: Timeval::default(),
338 size: 0,
339 };
340 write_frame(&mut buf, &mut base, payload).await.unwrap();
341 }
342
343 let mut cursor = std::io::Cursor::new(&buf);
345 let msg1 = read_frame(&mut cursor).await.unwrap();
346 assert_eq!(msg1.base.msg_type, MessageType::Time);
347 let msg2 = read_frame(&mut cursor).await.unwrap();
348 assert_eq!(msg2.base.msg_type, MessageType::ClientInfo);
349 }
350
351 #[test]
352 fn tcp_connection_new() {
353 let conn = TcpConnection::new("localhost", 1704);
354 assert!(conn.stream.is_none());
355 assert_eq!(conn.host, "localhost");
356 assert_eq!(conn.port, 1704);
357 }
358}