tuitbot_server/
account.rs1use std::sync::Arc;
8
9use axum::extract::FromRequestParts;
10use axum::http::request::Parts;
11use axum::http::StatusCode;
12use axum::response::{IntoResponse, Response};
13use serde_json::json;
14use tuitbot_core::storage::accounts::{self, DEFAULT_ACCOUNT_ID};
15
16use crate::state::AppState;
17
18#[derive(Debug, Clone)]
20pub struct AccountContext {
21 pub account_id: String,
23 pub role: Role,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
29#[serde(rename_all = "lowercase")]
30pub enum Role {
31 Admin,
32 Approver,
33 Viewer,
34}
35
36impl Role {
37 pub fn can_read(self) -> bool {
39 true
40 }
41
42 pub fn can_approve(self) -> bool {
44 matches!(self, Role::Admin | Role::Approver)
45 }
46
47 pub fn can_mutate(self) -> bool {
49 matches!(self, Role::Admin)
50 }
51}
52
53impl std::fmt::Display for Role {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 Role::Admin => write!(f, "admin"),
57 Role::Approver => write!(f, "approver"),
58 Role::Viewer => write!(f, "viewer"),
59 }
60 }
61}
62
63impl std::str::FromStr for Role {
64 type Err = String;
65
66 fn from_str(s: &str) -> Result<Self, Self::Err> {
67 match s {
68 "admin" => Ok(Role::Admin),
69 "approver" => Ok(Role::Approver),
70 "viewer" => Ok(Role::Viewer),
71 other => Err(format!("unknown role: {other}")),
72 }
73 }
74}
75
76pub struct AccountError {
78 pub status: StatusCode,
79 pub message: String,
80}
81
82impl IntoResponse for AccountError {
83 fn into_response(self) -> Response {
84 (self.status, axum::Json(json!({"error": self.message}))).into_response()
85 }
86}
87
88impl FromRequestParts<Arc<AppState>> for AccountContext {
89 type Rejection = AccountError;
90
91 fn from_request_parts(
96 parts: &mut Parts,
97 state: &Arc<AppState>,
98 ) -> impl std::future::Future<Output = Result<Self, Self::Rejection>> + Send {
99 let account_id = parts
100 .headers
101 .get("x-account-id")
102 .and_then(|v| v.to_str().ok())
103 .unwrap_or(DEFAULT_ACCOUNT_ID)
104 .to_string();
105
106 let db = state.db.clone();
107
108 async move {
109 if account_id == DEFAULT_ACCOUNT_ID {
111 return Ok(AccountContext {
112 account_id,
113 role: Role::Admin,
114 });
115 }
116
117 let exists = accounts::account_exists(&db, &account_id)
119 .await
120 .map_err(|e| AccountError {
121 status: StatusCode::INTERNAL_SERVER_ERROR,
122 message: format!("failed to validate account: {e}"),
123 })?;
124
125 if !exists {
126 return Err(AccountError {
127 status: StatusCode::NOT_FOUND,
128 message: format!("account not found: {account_id}"),
129 });
130 }
131
132 let role_str = accounts::get_role(&db, &account_id, "dashboard")
134 .await
135 .map_err(|e| AccountError {
136 status: StatusCode::INTERNAL_SERVER_ERROR,
137 message: format!("failed to resolve role: {e}"),
138 })?;
139
140 let role = role_str
141 .as_deref()
142 .unwrap_or("viewer")
143 .parse::<Role>()
144 .unwrap_or(Role::Viewer);
145
146 Ok(AccountContext { account_id, role })
147 }
148 }
149}
150
151pub fn require_approve(ctx: &AccountContext) -> Result<(), AccountError> {
153 if ctx.role.can_approve() {
154 Ok(())
155 } else {
156 Err(AccountError {
157 status: StatusCode::FORBIDDEN,
158 message: "approver or admin role required".to_string(),
159 })
160 }
161}
162
163pub fn require_mutate(ctx: &AccountContext) -> Result<(), AccountError> {
165 if ctx.role.can_mutate() {
166 Ok(())
167 } else {
168 Err(AccountError {
169 status: StatusCode::FORBIDDEN,
170 message: "admin role required".to_string(),
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use axum::response::IntoResponse;
179 use http_body_util::BodyExt;
180
181 #[test]
184 fn admin_can_read_approve_mutate() {
185 assert!(Role::Admin.can_read());
186 assert!(Role::Admin.can_approve());
187 assert!(Role::Admin.can_mutate());
188 }
189
190 #[test]
191 fn approver_can_read_approve_not_mutate() {
192 assert!(Role::Approver.can_read());
193 assert!(Role::Approver.can_approve());
194 assert!(!Role::Approver.can_mutate());
195 }
196
197 #[test]
198 fn viewer_can_read_only() {
199 assert!(Role::Viewer.can_read());
200 assert!(!Role::Viewer.can_approve());
201 assert!(!Role::Viewer.can_mutate());
202 }
203
204 #[test]
207 fn role_display() {
208 assert_eq!(Role::Admin.to_string(), "admin");
209 assert_eq!(Role::Approver.to_string(), "approver");
210 assert_eq!(Role::Viewer.to_string(), "viewer");
211 }
212
213 #[test]
216 fn role_from_str_valid() {
217 assert_eq!("admin".parse::<Role>().unwrap(), Role::Admin);
218 assert_eq!("approver".parse::<Role>().unwrap(), Role::Approver);
219 assert_eq!("viewer".parse::<Role>().unwrap(), Role::Viewer);
220 }
221
222 #[test]
223 fn role_from_str_invalid() {
224 let err = "superuser".parse::<Role>().unwrap_err();
225 assert!(err.contains("unknown role"));
226 }
227
228 #[test]
231 fn role_serde_roundtrip() {
232 let json = serde_json::to_string(&Role::Admin).unwrap();
233 assert_eq!(json, "\"admin\"");
234 let parsed: Role = serde_json::from_str(&json).unwrap();
235 assert_eq!(parsed, Role::Admin);
236
237 let json = serde_json::to_string(&Role::Approver).unwrap();
238 assert_eq!(json, "\"approver\"");
239
240 let json = serde_json::to_string(&Role::Viewer).unwrap();
241 assert_eq!(json, "\"viewer\"");
242 }
243
244 #[test]
247 fn require_approve_admin_ok() {
248 let ctx = AccountContext {
249 account_id: "test".into(),
250 role: Role::Admin,
251 };
252 assert!(require_approve(&ctx).is_ok());
253 }
254
255 #[test]
256 fn require_approve_approver_ok() {
257 let ctx = AccountContext {
258 account_id: "test".into(),
259 role: Role::Approver,
260 };
261 assert!(require_approve(&ctx).is_ok());
262 }
263
264 #[test]
265 fn require_approve_viewer_rejected() {
266 let ctx = AccountContext {
267 account_id: "test".into(),
268 role: Role::Viewer,
269 };
270 let err = require_approve(&ctx).unwrap_err();
271 assert_eq!(err.status, StatusCode::FORBIDDEN);
272 assert!(err.message.contains("approver"));
273 }
274
275 #[test]
278 fn require_mutate_admin_ok() {
279 let ctx = AccountContext {
280 account_id: "test".into(),
281 role: Role::Admin,
282 };
283 assert!(require_mutate(&ctx).is_ok());
284 }
285
286 #[test]
287 fn require_mutate_approver_rejected() {
288 let ctx = AccountContext {
289 account_id: "test".into(),
290 role: Role::Approver,
291 };
292 let err = require_mutate(&ctx).unwrap_err();
293 assert_eq!(err.status, StatusCode::FORBIDDEN);
294 assert!(err.message.contains("admin"));
295 }
296
297 #[test]
298 fn require_mutate_viewer_rejected() {
299 let ctx = AccountContext {
300 account_id: "test".into(),
301 role: Role::Viewer,
302 };
303 let err = require_mutate(&ctx).unwrap_err();
304 assert_eq!(err.status, StatusCode::FORBIDDEN);
305 }
306
307 #[tokio::test]
310 async fn account_error_into_response() {
311 let err = AccountError {
312 status: StatusCode::NOT_FOUND,
313 message: "not here".into(),
314 };
315 let resp = err.into_response();
316 assert_eq!(resp.status(), StatusCode::NOT_FOUND);
317 let bytes = resp.into_body().collect().await.unwrap().to_bytes();
318 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
319 assert_eq!(json["error"], "not here");
320 }
321
322 #[tokio::test]
323 async fn account_error_forbidden_response() {
324 let err = AccountError {
325 status: StatusCode::FORBIDDEN,
326 message: "denied".into(),
327 };
328 let resp = err.into_response();
329 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
330 }
331}