1use std::collections::HashMap;
4use std::sync::Arc;
5
6use anyhow::{Context, Result};
7use snapcast_proto::MessageType;
8use snapcast_proto::message::base::BaseMessage;
9use snapcast_proto::message::codec_header::CodecHeader;
10use snapcast_proto::message::factory::{self, MessagePayload, TypedMessage};
11use snapcast_proto::message::server_settings::ServerSettings;
12use snapcast_proto::message::time::Time;
13use snapcast_proto::message::wire_chunk::WireChunk;
14use snapcast_proto::types::Timeval;
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::{Mutex, broadcast, mpsc};
18
19use crate::ClientSettingsUpdate;
20use crate::ServerEvent;
21use crate::WireChunkData;
22use crate::time::now_usec;
23
24#[derive(Debug, Clone)]
26pub struct ClientInfo {
27 pub id: String,
29 pub host_name: String,
31 pub mac: String,
33 pub connected: bool,
35}
36
37pub struct SessionServer {
39 port: u16,
40 buffer_ms: i32,
41 auth: Option<Arc<dyn crate::auth::AuthValidator>>,
42 clients: Arc<Mutex<HashMap<String, ClientInfo>>>,
43 settings_senders: Arc<Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>>,
44 #[cfg(feature = "custom-protocol")]
45 custom_senders: Arc<Mutex<HashMap<String, mpsc::Sender<CustomOutbound>>>>,
46}
47
48#[cfg(feature = "custom-protocol")]
50#[derive(Debug, Clone)]
51pub struct CustomOutbound {
52 pub type_id: u16,
54 pub payload: Vec<u8>,
56}
57
58impl SessionServer {
59 pub fn new(
61 port: u16,
62 buffer_ms: i32,
63 auth: Option<Arc<dyn crate::auth::AuthValidator>>,
64 ) -> Self {
65 Self {
66 port,
67 buffer_ms,
68 auth,
69 clients: Arc::new(Mutex::new(HashMap::new())),
70 settings_senders: Arc::new(Mutex::new(HashMap::new())),
71 #[cfg(feature = "custom-protocol")]
72 custom_senders: Arc::new(Mutex::new(HashMap::new())),
73 }
74 }
75
76 pub async fn push_settings(&self, update: ClientSettingsUpdate) {
78 let senders = self.settings_senders.lock().await;
79 if let Some(tx) = senders.get(&update.client_id) {
80 let _ = tx.send(update).await;
81 }
82 }
83
84 pub async fn run(
86 &self,
87 chunk_rx: broadcast::Sender<WireChunkData>,
88 codec: String,
89 codec_header: Vec<u8>,
90 event_tx: mpsc::Sender<ServerEvent>,
91 ) -> Result<()> {
92 let listener = TcpListener::bind(format!("0.0.0.0:{}", self.port)).await?;
93 tracing::info!(port = self.port, "Stream server listening");
94
95 loop {
96 let (stream, peer) = listener.accept().await?;
97 tracing::info!(%peer, "Client connecting");
98
99 let chunk_sub = chunk_rx.subscribe();
100 let clients = Arc::clone(&self.clients);
101 let settings_senders = Arc::clone(&self.settings_senders);
102 #[cfg(feature = "custom-protocol")]
103 let custom_senders = Arc::clone(&self.custom_senders);
104 let event_tx = event_tx.clone();
105 let buffer_ms = self.buffer_ms;
106 let auth = self.auth.clone();
107 let codec = codec.clone();
108 let codec_header = codec_header.clone();
109
110 tokio::spawn(async move {
111 let (settings_tx, settings_rx) = mpsc::channel(16);
112 #[cfg(feature = "custom-protocol")]
113 let (custom_tx, custom_rx) = mpsc::channel(64);
114 let result = handle_client(
115 stream,
116 chunk_sub,
117 settings_rx,
118 #[cfg(feature = "custom-protocol")]
119 custom_rx,
120 &clients,
121 &settings_senders,
122 #[cfg(feature = "custom-protocol")]
123 &custom_senders,
124 settings_tx,
125 #[cfg(feature = "custom-protocol")]
126 custom_tx,
127 event_tx,
128 auth.as_deref(),
129 buffer_ms,
130 &codec,
131 &codec_header,
132 )
133 .await;
134 if let Err(e) = result {
135 tracing::debug!(%peer, error = %e, "Client session ended");
136 }
137 });
138 }
139 }
140
141 pub async fn connected_clients(&self) -> Vec<ClientInfo> {
143 self.clients
144 .lock()
145 .await
146 .values()
147 .filter(|c| c.connected)
148 .cloned()
149 .collect()
150 }
151
152 #[cfg(feature = "custom-protocol")]
154 pub async fn send_custom(&self, client_id: &str, type_id: u16, payload: Vec<u8>) {
155 let senders = self.custom_senders.lock().await;
156 if let Some(tx) = senders.get(client_id) {
157 let _ = tx.send(CustomOutbound { type_id, payload }).await;
158 }
159 }
160}
161
162#[allow(clippy::too_many_arguments)]
163async fn handle_client(
164 mut stream: TcpStream,
165 chunk_rx: broadcast::Receiver<WireChunkData>,
166 settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
167 #[cfg(feature = "custom-protocol")] custom_rx: mpsc::Receiver<CustomOutbound>,
168 clients: &Mutex<HashMap<String, ClientInfo>>,
169 settings_senders: &Mutex<HashMap<String, mpsc::Sender<ClientSettingsUpdate>>>,
170 #[cfg(feature = "custom-protocol")] custom_senders: &Mutex<
171 HashMap<String, mpsc::Sender<CustomOutbound>>,
172 >,
173 settings_tx: mpsc::Sender<ClientSettingsUpdate>,
174 #[cfg(feature = "custom-protocol")] custom_tx: mpsc::Sender<CustomOutbound>,
175 event_tx: mpsc::Sender<ServerEvent>,
176 auth: Option<&dyn crate::auth::AuthValidator>,
177 buffer_ms: i32,
178 codec: &str,
179 codec_header: &[u8],
180) -> Result<()> {
181 let hello_msg = read_frame_from(&mut stream).await?;
183 let hello = match hello_msg.payload {
184 MessagePayload::Hello(h) => h,
185 _ => anyhow::bail!("expected Hello, got {:?}", hello_msg.base.msg_type),
186 };
187
188 let client_id = hello.id.clone();
189 tracing::info!(id = %client_id, name = %hello.host_name, mac = %hello.mac, "Client hello");
190
191 if let Some(validator) = auth {
193 let auth_result = match &hello.auth {
194 Some(a) => validator.validate(&a.scheme, &a.param),
195 None => Err(crate::auth::AuthError::Unauthorized(
196 "Authentication required".into(),
197 )),
198 };
199 match auth_result {
200 Ok(result) => {
201 if !result
202 .permissions
203 .iter()
204 .any(|p| p == crate::auth::PERM_STREAMING)
205 {
206 let err = snapcast_proto::message::error::Error {
207 code: 403,
208 message: "Forbidden".into(),
209 error: "Permission 'Streaming' missing".into(),
210 };
211 send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
212 anyhow::bail!("Client {client_id}: missing Streaming permission");
213 }
214 tracing::info!(id = %client_id, user = %result.username, "Authenticated");
215 }
216 Err(e) => {
217 let err = snapcast_proto::message::error::Error {
218 code: e.code() as u32,
219 message: e.message().to_string(),
220 error: e.message().to_string(),
221 };
222 send_msg(&mut stream, MessageType::Error, &MessagePayload::Error(err)).await?;
223 anyhow::bail!("Client {client_id}: {e}");
224 }
225 }
226 }
227
228 {
230 clients.lock().await.insert(
231 client_id.clone(),
232 ClientInfo {
233 id: client_id.clone(),
234 host_name: hello.host_name.clone(),
235 mac: hello.mac.clone(),
236 connected: true,
237 },
238 );
239 settings_senders
240 .lock()
241 .await
242 .insert(client_id.clone(), settings_tx);
243 #[cfg(feature = "custom-protocol")]
244 custom_senders
245 .lock()
246 .await
247 .insert(client_id.clone(), custom_tx);
248 }
249
250 let _ = event_tx
251 .send(ServerEvent::ClientConnected {
252 id: client_id.clone(),
253 name: hello.host_name.clone(),
254 mac: hello.mac.clone(),
255 })
256 .await;
257
258 let ss = ServerSettings {
260 buffer_ms,
261 latency: 0,
262 volume: 100,
263 muted: false,
264 };
265 send_msg(
266 &mut stream,
267 MessageType::ServerSettings,
268 &MessagePayload::ServerSettings(ss),
269 )
270 .await?;
271
272 let ch = CodecHeader {
274 codec: codec.to_string(),
275 payload: codec_header.to_vec(),
276 };
277 send_msg(
278 &mut stream,
279 MessageType::CodecHeader,
280 &MessagePayload::CodecHeader(ch),
281 )
282 .await?;
283
284 let result = session_loop(
286 &mut stream,
287 chunk_rx,
288 settings_rx,
289 #[cfg(feature = "custom-protocol")]
290 custom_rx,
291 #[cfg(feature = "custom-protocol")]
292 event_tx.clone(),
293 #[cfg(feature = "custom-protocol")]
294 client_id.clone(),
295 )
296 .await;
297
298 {
300 let mut map = clients.lock().await;
301 if let Some(c) = map.get_mut(&client_id) {
302 c.connected = false;
303 }
304 }
305 settings_senders.lock().await.remove(&client_id);
306 #[cfg(feature = "custom-protocol")]
307 custom_senders.lock().await.remove(&client_id);
308 let _ = event_tx
309 .send(ServerEvent::ClientDisconnected { id: client_id })
310 .await;
311
312 result
313}
314
315async fn session_loop(
316 stream: &mut TcpStream,
317 mut chunk_rx: broadcast::Receiver<WireChunkData>,
318 mut settings_rx: mpsc::Receiver<ClientSettingsUpdate>,
319 #[cfg(feature = "custom-protocol")] mut custom_rx: mpsc::Receiver<CustomOutbound>,
320 #[cfg(feature = "custom-protocol")] event_tx: mpsc::Sender<ServerEvent>,
321 #[cfg(feature = "custom-protocol")] client_id: String,
322) -> Result<()> {
323 let (mut reader, mut writer) = stream.split();
324
325 #[cfg(not(feature = "custom-protocol"))]
326 let (mut custom_rx, _event_tx, _client_id): (mpsc::Receiver<()>, Option<()>, String) = {
327 let (_tx, rx) = mpsc::channel(1);
328 (rx, None, String::new())
329 };
330
331 loop {
332 tokio::select! {
333 chunk = chunk_rx.recv() => {
334 let chunk = chunk.context("broadcast closed")?;
335 let ts_usec = chunk.timestamp_usec;
336 let wc = WireChunk {
337 timestamp: Timeval::from_usec(ts_usec),
338 payload: chunk.data,
339 };
340 let frame = serialize_msg(MessageType::WireChunk, &MessagePayload::WireChunk(wc), 0)?;
341 writer.write_all(&frame).await.context("write chunk")?;
342 }
343 msg = read_frame_from(&mut reader) => {
344 let msg = msg?;
345 match msg.payload {
346 MessagePayload::Time(t) => {
347 let response = Time { latency: t.latency };
348 let frame = serialize_msg(MessageType::Time, &MessagePayload::Time(response), msg.base.id)?;
349 writer.write_all(&frame).await.context("write time")?;
350 }
351 #[cfg(feature = "custom-protocol")]
352 MessagePayload::Custom(payload) => {
353 if let MessageType::Custom(type_id) = msg.base.msg_type {
354 let _ = event_tx.send(ServerEvent::CustomMessage {
355 client_id: client_id.clone(),
356 message: snapcast_proto::CustomMessage::new(type_id, payload),
357 }).await;
358 }
359 }
360 _ => {}
361 }
362 }
363 update = settings_rx.recv() => {
364 let Some(update) = update else { continue };
365 let ss = ServerSettings {
366 buffer_ms: update.buffer_ms,
367 latency: update.latency,
368 volume: update.volume,
369 muted: update.muted,
370 };
371 let frame = serialize_msg(
372 MessageType::ServerSettings,
373 &MessagePayload::ServerSettings(ss),
374 0,
375 )?;
376 writer.write_all(&frame).await.context("write settings")?;
377 tracing::debug!(volume = update.volume, latency = update.latency, "Pushed settings to client");
378 }
379 outbound = custom_rx.recv() => {
380 #[cfg(feature = "custom-protocol")]
381 if let Some(msg) = outbound {
382 let frame = serialize_msg(
383 MessageType::Custom(msg.type_id),
384 &MessagePayload::Custom(msg.payload),
385 0,
386 )?;
387 writer.write_all(&frame).await.context("write custom")?;
388 }
389 #[cfg(not(feature = "custom-protocol"))]
390 let _ = outbound;
391 }
392 }
393 }
394}
395
396fn serialize_msg(
397 msg_type: MessageType,
398 payload: &MessagePayload,
399 refers_to: u16,
400) -> Result<Vec<u8>> {
401 let mut base = BaseMessage {
402 msg_type,
403 id: 0,
404 refers_to,
405 sent: now_timeval(),
406 received: Timeval::default(),
407 size: 0,
408 };
409 factory::serialize(&mut base, payload).map_err(|e| anyhow::anyhow!("serialize: {e}"))
410}
411
412async fn send_msg(
413 stream: &mut TcpStream,
414 msg_type: MessageType,
415 payload: &MessagePayload,
416) -> Result<()> {
417 let frame = serialize_msg(msg_type, payload, 0)?;
418 stream.write_all(&frame).await.context("write message")
419}
420
421async fn read_frame_from<R: AsyncReadExt + Unpin>(reader: &mut R) -> Result<TypedMessage> {
422 let mut header_buf = [0u8; BaseMessage::HEADER_SIZE];
423 reader
424 .read_exact(&mut header_buf)
425 .await
426 .context("read header")?;
427 let mut base =
428 BaseMessage::read_from(&mut &header_buf[..]).map_err(|e| anyhow::anyhow!("parse: {e}"))?;
429 base.received = now_timeval();
430 let mut payload_buf = vec![0u8; base.size as usize];
431 if !payload_buf.is_empty() {
432 reader
433 .read_exact(&mut payload_buf)
434 .await
435 .context("read payload")?;
436 }
437 factory::deserialize(base, &payload_buf).map_err(|e| anyhow::anyhow!("deserialize: {e}"))
438}
439
440fn now_timeval() -> Timeval {
441 Timeval::from_usec(now_usec())
442}