1use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use snapcast_proto::MessageType;
8use snapcast_proto::message::base::BaseMessage;
9use snapcast_proto::message::codec_header::CodecHeader;
10use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
11use snapcast_proto::message::server_settings::ServerSettings;
12use snapcast_proto::message::time::Time;
13use snapcast_proto::message::wire_chunk::WireChunk;
14use snapcast_proto::types::Timeval;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::{Mutex, broadcast, mpsc};
18
19use crate::ClientSettingsUpdate;
20use crate::ServerEvent;
21use crate::stream::manager::WireChunkData;
22use crate::time::now_usec;
23
24#[derive(Debug, Clone)]
26pub struct ClientInfo {
27 pub id: String,
29 pub host_name: String,
31 pub mac: String,
33 pub connected: bool,
35}
36
37pub struct SessionServer {
39 port: u16,
40 buffer_ms: i32,
41 auth: Option<Arc<dyn crate::auth::AuthValidator>>,
42 clients: Arc<Mutex<HashMap<String, ClientInfo>>>,
43 settings_senders: Arc<Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>>,
44 #[cfg(feature = "custom-protocol")]
45 custom_senders: Arc<Mutex<HashMap<String, mpsc::Sender<CustomOutbound>>>>,
46}
47
48#[cfg(feature = "custom-protocol")]
50#[derive(Debug, Clone)]
51pub struct CustomOutbound {
52 pub type_id: u16,
54 pub payload: Vec<u8>,
56}
57
58impl SessionServer {
59 pub fn new(
61 port: u16,
62 buffer_ms: i32,
63 auth: Option<Arc<dyn crate::auth::AuthValidator>>,
64 ) -> Self {
65 Self {
66 port,
67 buffer_ms,
68 auth,
69 clients: Arc::new(Mutex::new(HashMap::new())),
70 settings_senders: Arc::new(Mutex::new(HashMap::new())),
71 #[cfg(feature = "custom-protocol")]
72 custom_senders: Arc::new(Mutex::new(HashMap::new())),
73 }
74 }
75
76 pub async fn push_settings(&self, update: ClientSettingsUpdate) {
78 let senders = self.settings_senders.lock().await;
79 if let Some(tx) = senders.get(&update.client_id) {
80 let _ = tx.send(update).await;
81 }
82 }
83
84 pub async fn run(
86 &self,
87 chunk_rx: broadcast::Sender<WireChunkData>,
88 codec: String,
89 codec_header: Vec<u8>,
90 event_tx: mpsc::Sender<ServerEvent>,
91 ) -> Result<()> {
92 let listener = TcpListener::bind(format!("0.0.0.0:{}", self.port)).await?;
93 tracing::info!(port = self.port, "Stream server listening");
94
95 loop {
96 let (stream, peer) = listener.accept().await?;
97 tracing::info!(%peer, "Client connecting");
98
99 let chunk_sub = chunk_rx.subscribe();
100 let clients = Arc::clone(&self.clients);
101 let settings_senders = Arc::clone(&self.settings_senders);
102 #[cfg(feature = "custom-protocol")]
103 let custom_senders = Arc::clone(&self.custom_senders);
104 let event_tx = event_tx.clone();
105 let buffer_ms = self.buffer_ms;
106 let auth = self.auth.clone();
107 let codec = codec.clone();
108 let codec_header = codec_header.clone();
109
110 tokio::spawn(async move {
111 let (settings_tx, settings_rx) = mpsc::channel(16);
112 #[cfg(feature = "custom-protocol")]
113 let (custom_tx, custom_rx) = mpsc::channel(64);
114 let result = handle_client(
115 stream,
116 chunk_sub,
117 settings_rx,
118 #[cfg(feature = "custom-protocol")]
119 custom_rx,
120 &clients,
121 &settings_senders,
122 #[cfg(feature = "custom-protocol")]
123 &custom_senders,
124 settings_tx,
125 #[cfg(feature = "custom-protocol")]
126 custom_tx,
127 event_tx,
128 auth.as_deref(),
129 buffer_ms,
130 &codec,
131 &codec_header,
132 )
133 .await;
134 if let Err(e) = result {
135 tracing::debug!(%peer, error = %e, "Client session ended");
136 }
137 });
138 }
139 }
140
141 pub async fn connected_clients(&self) -> Vec<ClientInfo> {
143 self.clients
144 .lock()
145 .await
146 .values()
147 .filter(|c| c.connected)
148 .cloned()
149 .collect()
150 }
151
152 #[cfg(feature = "custom-protocol")]
154 pub async fn send_custom(&self, client_id: &str, type_id: u16, payload: Vec<u8>) {
155 let senders = self.custom_senders.lock().await;
156 if let Some(tx) = senders.get(client_id) {
157 let _ = tx.send(CustomOutbound { type_id, payload }).await;
158 }
159 }
160}
161
162#[allow(clippy::too_many_arguments)]
163async fn handle_client(
164 mut stream: TcpStream,
165 chunk_rx: broadcast::Receiver<WireChunkData>,
166 settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
167 #[cfg(feature = "custom-protocol")] custom_rx: mpsc::Receiver<CustomOutbound>,
168 clients: &Mutex<HashMap<String, ClientInfo>>,
169 settings_senders: &Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>,
170 #[cfg(feature = "custom-protocol")] custom_senders: &Mutex<
171 HashMap<String, mpsc::Sender<CustomOutbound>>,
172 >,
173 settings_tx: mpsc::Sender<ClientSettingsUpdate>,
174 #[cfg(feature = "custom-protocol")] custom_tx: mpsc::Sender<CustomOutbound>,
175 event_tx: mpsc::Sender<ServerEvent>,
176 auth: Option<&dyn crate::auth::AuthValidator>,
177 buffer_ms: i32,
178 codec: &str,
179 codec_header: &[u8],
180) -> Result<()> {
181 let hello_msg = read_frame_from(&mut stream).await?;
183 let hello = match hello_msg.payload {
184 MessagePayload::Hello(h) => h,
185 _ => anyhow::bail!("expected Hello, got {:?}", hello_msg.base.msg_type),
186 };
187
188 let client_id = hello.id.clone();
189 tracing::info!(id = %client_id, name = %hello.host_name, mac = %hello.mac, "Client hello");
190
191 if let Some(validator) = auth {
193 let auth_result = match &hello.auth {
194 Some(a) => validator.validate(&a.scheme, &a.param),
195 None => Err(crate::auth::AuthError::Unauthorized(
196 "Authentication required".into(),
197 )),
198 };
199 match auth_result {
200 Ok(result) => {
201 if !result
202 .permissions
203 .iter()
204 .any(|p| p == crate::auth::PERM_STREAMING)
205 {
206 let err = snapcast_proto::message::error::Error {
207 code: 403,
208 message: "Forbidden".into(),
209 error: "Permission 'Streaming' missing".into(),
210 };
211 send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
212 anyhow::bail!("Client {client_id}: missing Streaming permission");
213 }
214 tracing::info!(id = %client_id, user = %result.username, "Authenticated");
215 }
216 Err(e) => {
217 let err = snapcast_proto::message::error::Error {
218 code: e.code() as u32,
219 message: e.message().to_string(),
220 error: e.message().to_string(),
221 };
222 send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
223 anyhow::bail!("Client {client_id}: {e}");
224 }
225 }
226 }
227
228 {
230 clients.lock().await.insert(
231 client_id.clone(),
232 ClientInfo {
233 id: client_id.clone(),
234 host_name: hello.host_name.clone(),
235 mac: hello.mac.clone(),
236 connected: true,
237 },
238 );
239 settings_senders
240 .lock()
241 .await
242 .insert(client_id.clone(), settings_tx);
243 #[cfg(feature = "custom-protocol")]
244 custom_senders
245 .lock()
246 .await
247 .insert(client_id.clone(), custom_tx);
248 }
249
250 let _ = event_tx
251 .send(ServerEvent::ClientConnected {
252 id: client_id.clone(),
253 name: hello.host_name.clone(),
254 })
255 .await;
256
257 let ss = ServerSettings {
259 buffer_ms,
260 latency: 0,
261 volume: 100,
262 muted: false,
263 };
264 send_msg(
265 &mut stream,
266 MessageType::ServerSettings,
267 &MessagePayload::ServerSettings(ss),
268 )
269 .await?;
270
271 let ch = CodecHeader {
273 codec: codec.to_string(),
274 payload: codec_header.to_vec(),
275 };
276 send_msg(
277 &mut stream,
278 MessageType::CodecHeader,
279 &MessagePayload::CodecHeader(ch),
280 )
281 .await?;
282
283 let result = session_loop(
285 &mut stream,
286 chunk_rx,
287 settings_rx,
288 #[cfg(feature = "custom-protocol")]
289 custom_rx,
290 #[cfg(feature = "custom-protocol")]
291 event_tx.clone(),
292 #[cfg(feature = "custom-protocol")]
293 client_id.clone(),
294 )
295 .await;
296
297 {
299 let mut map = clients.lock().await;
300 if let Some(c) = map.get_mut(&client_id) {
301 c.connected = false;
302 }
303 }
304 settings_senders.lock().await.remove(&client_id);
305 #[cfg(feature = "custom-protocol")]
306 custom_senders.lock().await.remove(&client_id);
307 let _ = event_tx
308 .send(ServerEvent::ClientDisconnected { id: client_id })
309 .await;
310
311 result
312}
313
314async fn session_loop(
315 stream: &mut TcpStream,
316 mut chunk_rx: broadcast::Receiver<WireChunkData>,
317 mut settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
318 #[cfg(feature = "custom-protocol")] mut custom_rx: mpsc::Receiver<CustomOutbound>,
319 #[cfg(feature = "custom-protocol")] event_tx: mpsc::Sender<ServerEvent>,
320 #[cfg(feature = "custom-protocol")] client_id: String,
321) -> Result<()> {
322 let (mut reader, mut writer) = stream.split();
323
324 #[cfg(not(feature = "custom-protocol"))]
325 let (mut custom_rx, _event_tx, _client_id): (mpsc::Receiver<()>, Option<()>, String) = {
326 let (_tx, rx) = mpsc::channel(1);
327 (rx, None, String::new())
328 };
329
330 loop {
331 tokio::select! {
332 chunk = chunk_rx.recv() => {
333 let chunk = chunk.context("broadcast closed")?;
334 let ts_usec = chunk.timestamp_usec;
335 let wc = WireChunk {
336 timestamp: Timeval::from_usec(ts_usec),
337 payload: chunk.data,
338 };
339 let frame = serialize_msg(MessageType::WireChunk, &MessagePayload::WireChunk(wc), 0)?;
340 writer.write_all(&frame).await.context("write chunk")?;
341 }
342 msg = read_frame_from(&mut reader) => {
343 let msg = msg?;
344 match msg.payload {
345 MessagePayload::Time(t) => {
346 let response = Time { latency: t.latency };
347 let frame = serialize_msg(MessageType::Time, &MessagePayload::Time(response), msg.base.id)?;
348 writer.write_all(&frame).await.context("write time")?;
349 }
350 #[cfg(feature = "custom-protocol")]
351 MessagePayload::Custom(payload) => {
352 if let MessageType::Custom(type_id) = msg.base.msg_type {
353 let _ = event_tx.send(ServerEvent::CustomMessage {
354 client_id: client_id.clone(),
355 message: snapcast_proto::CustomMessage::new(type_id, payload),
356 }).await;
357 }
358 }
359 _ => {}
360 }
361 }
362 update = settings_rx.recv() => {
363 let Some(update) = update else { continue };
364 let ss = ServerSettings {
365 buffer_ms: update.buffer_ms,
366 latency: update.latency,
367 volume: update.volume,
368 muted: update.muted,
369 };
370 let frame = serialize_msg(
371 MessageType::ServerSettings,
372 &MessagePayload::ServerSettings(ss),
373 0,
374 )?;
375 writer.write_all(&frame).await.context("write settings")?;
376 tracing::debug!(volume = update.volume, latency = update.latency, "Pushed settings to client");
377 }
378 outbound = custom_rx.recv() => {
379 #[cfg(feature = "custom-protocol")]
380 if let Some(msg) = outbound {
381 let frame = serialize_msg(
382 MessageType::Custom(msg.type_id),
383 &MessagePayload::Custom(msg.payload),
384 0,
385 )?;
386 writer.write_all(&frame).await.context("write custom")?;
387 }
388 #[cfg(not(feature = "custom-protocol"))]
389 let _ = outbound;
390 }
391 }
392 }
393}
394
395fn serialize_msg(
396 msg_type: MessageType,
397 payload: &MessagePayload,
398 refers_to: u16,
399) -> Result<Vec<u8>> {
400 let mut base = BaseMessage {
401 msg_type,
402 id: 0,
403 refers_to,
404 sent: now_timeval(),
405 received: Timeval::default(),
406 size: 0,
407 };
408 factory::serialize(&mut base, payload).map_err(|e| anyhow::anyhow!("serialize: {e}"))
409}
410
411async fn send_msg(
412 stream: &mut TcpStream,
413 msg_type: MessageType,
414 payload: &MessagePayload,
415) -> Result<()> {
416 let frame = serialize_msg(msg_type, payload, 0)?;
417 stream.write_all(&frame).await.context("write message")
418}
419
420async fn read_frame_from<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
421 let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
422 reader
423 .read_exact(&mut header_buf)
424 .await
425 .context("read header")?;
426 let mut base =
427 BaseMessage::read_from(&mut &header_buf[..]).map_err(|e| anyhow::anyhow!("parse: {e}"))?;
428 base.received = now_timeval();
429 let mut payload_buf = vec![0u8; base.size as usize];
430 if !payload_buf.is_empty() {
431 reader
432 .read_exact(&mut payload_buf)
433 .await
434 .context("read payload")?;
435 }
436 factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserialize: {e}"))
437}
438
439fn now_timeval() -> Timeval {
440 Timeval::from_usec(now_usec())
441}