1use std::collections::HashMap;
7use std::sync::Arc;
8
9use subtle::ConstantTimeEq;
10
11fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
13 a.ct_eq(b).into()
14}
15
16#[derive(Debug, Clone)]
18pub struct AuthResult {
19 pub username: String,
21 pub permissions: Vec<String>,
23}
24
25#[derive(Debug, Clone)]
27pub enum AuthError {
28 Unauthorized(String),
30 Forbidden(String),
32}
33
34impl std::fmt::Display for AuthError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::Unauthorized(msg) => write!(f, "Unauthorized: {msg}"),
38 Self::Forbidden(msg) => write!(f, "Forbidden: {msg}"),
39 }
40 }
41}
42
43impl std::error::Error for AuthError {}
44
45impl AuthError {
46 pub fn code(&self) -> i32 {
48 match self {
49 Self::Unauthorized(_) => 401,
50 Self::Forbidden(_) => 403,
51 }
52 }
53
54 pub fn message(&self) -> &str {
56 match self {
57 Self::Unauthorized(msg) | Self::Forbidden(msg) => msg,
58 }
59 }
60}
61
62pub trait AuthValidator: Send + Sync {
86 fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError>;
88}
89
90pub const PERM_STREAMING: &str = "Streaming";
92
93#[derive(Debug, Clone)]
95pub struct Role {
96 pub name: String,
98 pub permissions: Vec<String>,
100}
101
102#[derive(Debug, Clone)]
104pub struct User {
105 pub name: String,
107 pub password: String,
109 pub role: String,
111}
112
113#[derive(Debug, Clone)]
117pub struct StaticAuthValidator {
118 users: HashMap<String, (String, Arc<Role>)>, }
120
121impl StaticAuthValidator {
122 pub fn new(users: Vec<User>, roles: Vec<Role>) -> Self {
124 let role_map: HashMap<String, Arc<Role>> = roles
125 .into_iter()
126 .map(|r| (r.name.clone(), Arc::new(r)))
127 .collect();
128 let empty_role = Arc::new(Role {
129 name: String::new(),
130 permissions: vec![],
131 });
132 let user_map = users
133 .into_iter()
134 .map(|u| {
135 let role = role_map
136 .get(&u.role)
137 .cloned()
138 .unwrap_or_else(|| empty_role.clone());
139 (u.name, (u.password, role))
140 })
141 .collect();
142 Self { users: user_map }
143 }
144}
145
146impl AuthValidator for StaticAuthValidator {
147 fn validate(&self, scheme: &str, param: &str) -> Result<AuthResult, AuthError> {
148 if !scheme.eq_ignore_ascii_case("basic") {
149 return Err(AuthError::Unauthorized(format!(
150 "Unsupported auth scheme: {scheme}"
151 )));
152 }
153
154 use base64::Engine;
156 let decoded = base64::engine::general_purpose::STANDARD
157 .decode(param)
158 .map_err(|_| AuthError::Unauthorized("Invalid base64".into()))?;
159 let decoded = String::from_utf8(decoded)
160 .map_err(|_| AuthError::Unauthorized("Invalid UTF-8".into()))?;
161 let (username, password) = decoded
162 .split_once(':')
163 .ok_or_else(|| AuthError::Unauthorized("Expected user:password".into()))?;
164
165 let (stored_pw, role) = self
166 .users
167 .get(username)
168 .ok_or_else(|| AuthError::Unauthorized("Unknown user".into()))?;
169
170 if !constant_time_eq(stored_pw.as_bytes(), password.as_bytes()) {
171 return Err(AuthError::Unauthorized("Wrong password".into()));
172 }
173
174 Ok(AuthResult {
175 username: username.to_string(),
176 permissions: role.permissions.clone(),
177 })
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 fn test_validator() -> StaticAuthValidator {
186 StaticAuthValidator::new(
187 vec![
188 User {
189 name: "admin".into(),
190 password: "secret".into(),
191 role: "full".into(),
192 },
193 User {
194 name: "player".into(),
195 password: "play".into(),
196 role: "streaming".into(),
197 },
198 ],
199 vec![
200 Role {
201 name: "full".into(),
202 permissions: vec!["Streaming".into(), "Control".into()],
203 },
204 Role {
205 name: "streaming".into(),
206 permissions: vec!["Streaming".into()],
207 },
208 ],
209 )
210 }
211
212 fn basic(user: &str, pass: &str) -> String {
213 use base64::Engine;
214 base64::engine::general_purpose::STANDARD.encode(format!("{user}:{pass}"))
215 }
216
217 #[test]
218 fn valid_credentials() {
219 let v = test_validator();
220 let result = v.validate("Basic", &basic("admin", "secret")).unwrap();
221 assert_eq!(result.username, "admin");
222 assert!(result.permissions.contains(&"Streaming".into()));
223 assert!(result.permissions.contains(&"Control".into()));
224 }
225
226 #[test]
227 fn wrong_password() {
228 let v = test_validator();
229 let err = v.validate("Basic", &basic("admin", "wrong")).unwrap_err();
230 assert_eq!(err.code(), 401);
231 }
232
233 #[test]
234 fn unknown_user() {
235 let v = test_validator();
236 let err = v.validate("Basic", &basic("nobody", "x")).unwrap_err();
237 assert_eq!(err.code(), 401);
238 }
239
240 #[test]
241 fn unsupported_scheme() {
242 let v = test_validator();
243 let err = v.validate("Bearer", "token123").unwrap_err();
244 assert_eq!(err.code(), 401);
245 }
246
247 #[test]
248 fn streaming_only_user() {
249 let v = test_validator();
250 let result = v.validate("Basic", &basic("player", "play")).unwrap();
251 assert_eq!(result.username, "player");
252 assert!(result.permissions.contains(&"Streaming".into()));
253 assert!(!result.permissions.contains(&"Control".into()));
254 }
255}