Skip to main content

snapcast_server/
auth.rs

1//! Streaming client authentication.
2//!
3//! Implement [`AuthValidator`] for custom authentication (database, LDAP, etc.)
4//! or use [`StaticAuthValidator`] for config-file-based users/roles.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use subtle::ConstantTimeEq;
10
11/// Constant-time byte comparison to prevent timing attacks.
12fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
13    a.ct_eq(b).into()
14}
15
16/// Result of successful authentication.
17#[derive(Debug, Clone)]
18pub struct AuthResult {
19    /// Authenticated username.
20    pub username: String,
21    /// Granted permissions (e.g. "Streaming", "Control").
22    pub permissions: Vec<String>,
23}
24
25/// Authentication error.
26#[derive(Debug, Clone)]
27pub enum AuthError {
28    /// 401 — invalid or missing credentials.
29    Unauthorized(String),
30    /// 403 — authenticated but lacking required permission.
31    Forbidden(String),
32}
33
34impl std::fmt::Display for AuthError {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        match self {
37            Self::Unauthorized(msg) => write!(f, "Unauthorized: {msg}"),
38            Self::Forbidden(msg) => write!(f, "Forbidden: {msg}"),
39        }
40    }
41}
42
43impl std::error::Error for AuthError {}
44
45impl AuthError {
46    /// HTTP-style error code.
47    pub fn code(&self) -> i32 {
48        match self {
49            Self::Unauthorized(_) => 401,
50            Self::Forbidden(_) => 403,
51        }
52    }
53
54    /// Error message.
55    pub fn message(&self) -> &str {
56        match self {
57            Self::Unauthorized(msg) | Self::Forbidden(msg) => msg,
58        }
59    }
60}
61
62/// Trait for validating streaming client credentials.
63///
64/// The server calls [`validate`](AuthValidator::validate) after receiving a Hello message.
65/// Return [`AuthResult`] on success or [`AuthError`] on failure.
66///
67/// # Example: Custom validator
68///
69/// ```
70/// use snapcast_server::auth::{AuthValidator, AuthResult, AuthError};
71///
72/// struct MyValidator;
73///
74/// impl AuthValidator for MyValidator {
75///     fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
76///         // Look up in database, LDAP, etc.
77///         Ok(AuthResult {
78///             username: "user".into(),
79///             permissions: vec!["Streaming".into()],
80///         })
81///     }
82/// }
83/// ```
84/// Trait for validating streaming client credentials.
85pub trait AuthValidator: Send + Sync {
86    /// Validate credentials from the Hello message's auth field.
87    fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError>;
88}
89
90/// Permission required for streaming clients.
91pub const PERM_STREAMING: &str = "Streaming";
92
93/// A role with named permissions.
94#[derive(Debug, Clone)]
95pub struct Role {
96    /// Role name.
97    pub name: String,
98    /// Granted permissions.
99    pub permissions: Vec<String>,
100}
101
102/// A user with credentials and role assignment.
103#[derive(Debug, Clone)]
104pub struct User {
105    /// Username.
106    pub name: String,
107    /// Password (plaintext — hashing is the deployer's responsibility).
108    pub password: String,
109    /// Role name.
110    pub role: String,
111}
112
113/// Config-file-based authentication matching the C++ implementation.
114///
115/// Validates Basic auth (`base64(user:password)`) against a static user/role list.
116#[derive(Debug, Clone)]
117pub struct StaticAuthValidator {
118    users: HashMap<String, (String, Arc<Role>)>, // name → (password, role)
119}
120
121impl StaticAuthValidator {
122    /// Create from user and role lists.
123    pub fn new(users: Vec<User>, roles: Vec<Role>) -> Self {
124        let role_map: HashMap<String, Arc<Role>> = roles
125            .into_iter()
126            .map(|r| (r.name.clone(), Arc::new(r)))
127            .collect();
128        let empty_role = Arc::new(Role {
129            name: String::new(),
130            permissions: vec![],
131        });
132        let user_map = users
133            .into_iter()
134            .map(|u| {
135                let role = role_map
136                    .get(&u.role)
137                    .cloned()
138                    .unwrap_or_else(|| empty_role.clone());
139                (u.name, (u.password, role))
140            })
141            .collect();
142        Self { users: user_map }
143    }
144}
145
146impl AuthValidator for StaticAuthValidator {
147    fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
148        if !scheme.eq_ignore_ascii_case("basic") {
149            return Err(AuthError::Unauthorized(format!(
150                "Unsupported auth scheme: {scheme}"
151            )));
152        }
153
154        // Decode base64(user:password)
155        use base64::Engine;
156        let decoded = base64::engine::general_purpose::STANDARD
157            .decode(param)
158            .map_err(|_| AuthError::Unauthorized("Invalid base64".into()))?;
159        let decoded = String::from_utf8(decoded)
160            .map_err(|_| AuthError::Unauthorized("Invalid UTF-8".into()))?;
161        let (username, password) = decoded
162            .split_once(':')
163            .ok_or_else(|| AuthError::Unauthorized("Expected user:password".into()))?;
164
165        let (stored_pw, role) = self
166            .users
167            .get(username)
168            .ok_or_else(|| AuthError::Unauthorized("Unknown user".into()))?;
169
170        if !constant_time_eq(stored_pw.as_bytes(), password.as_bytes()) {
171            return Err(AuthError::Unauthorized("Wrong password".into()));
172        }
173
174        Ok(AuthResult {
175            username: username.to_string(),
176            permissions: role.permissions.clone(),
177        })
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    fn test_validator() -> StaticAuthValidator {
186        StaticAuthValidator::new(
187            vec![
188                User {
189                    name: "admin".into(),
190                    password: "secret".into(),
191                    role: "full".into(),
192                },
193                User {
194                    name: "player".into(),
195                    password: "play".into(),
196                    role: "streaming".into(),
197                },
198            ],
199            vec![
200                Role {
201                    name: "full".into(),
202                    permissions: vec!["Streaming".into(), "Control".into()],
203                },
204                Role {
205                    name: "streaming".into(),
206                    permissions: vec!["Streaming".into()],
207                },
208            ],
209        )
210    }
211
212    fn basic(user: &str, pass: &str) -> String {
213        use base64::Engine;
214        base64::engine::general_purpose::STANDARD.encode(format!("{user}:{pass}"))
215    }
216
217    #[test]
218    fn valid_credentials() {
219        let v = test_validator();
220        let result = v.validate("Basic", &basic("admin", "secret")).unwrap();
221        assert_eq!(result.username, "admin");
222        assert!(result.permissions.contains(&"Streaming".into()));
223        assert!(result.permissions.contains(&"Control".into()));
224    }
225
226    #[test]
227    fn wrong_password() {
228        let v = test_validator();
229        let err = v.validate("Basic", &basic("admin", "wrong")).unwrap_err();
230        assert_eq!(err.code(), 401);
231    }
232
233    #[test]
234    fn unknown_user() {
235        let v = test_validator();
236        let err = v.validate("Basic", &basic("nobody", "x")).unwrap_err();
237        assert_eq!(err.code(), 401);
238    }
239
240    #[test]
241    fn unsupported_scheme() {
242        let v = test_validator();
243        let err = v.validate("Bearer", "token123").unwrap_err();
244        assert_eq!(err.code(), 401);
245    }
246
247    #[test]
248    fn streaming_only_user() {
249        let v = test_validator();
250        let result = v.validate("Basic", &basic("player", "play")).unwrap();
251        assert_eq!(result.username, "player");
252        assert!(result.permissions.contains(&"Streaming".into()));
253        assert!(!result.permissions.contains(&"Control".into()));
254    }
255}