1use std::sync::atomic::{AtomicBool, Ordering};
16
17use crate::context::Context;
18use crate::error::Error;
19use crate::http::{Request, Response};
20use crate::middleware::Next;
21
22#[derive(Debug, Clone)]
23pub struct Identity {
24 pub user_id: String,
25 pub is_admin: bool,
26}
27
28static PRODUCTION_WARNED: AtomicBool = AtomicBool::new(false);
31
32pub fn in_production() -> bool {
37 std::env::var("RUSTIO_ENV")
38 .map(|v| {
39 let v = v.to_ascii_lowercase();
40 v == "production" || v == "prod"
41 })
42 .unwrap_or(false)
43}
44
45pub async fn authenticate(mut req: Request, next: Next) -> Result<Response, Error> {
46 if in_production() {
47 if !PRODUCTION_WARNED.swap(true, Ordering::Relaxed) {
51 eprintln!(
52 "rustio_core::auth: RUSTIO_ENV={} — built-in dev tokens are disabled. \
53 Replace `authenticate` with your own middleware before accepting traffic.",
54 std::env::var("RUSTIO_ENV").unwrap_or_default()
55 );
56 }
57 return next.run(req).await;
60 }
61
62 if let Some(token) = bearer_token(&req) {
63 if let Some(identity) = dev_identity(token) {
64 req.ctx_mut().insert(identity);
65 }
66 }
67 next.run(req).await
68}
69
70pub fn bearer_token(req: &Request) -> Option<&str> {
71 req.headers()
72 .get("authorization")
73 .and_then(|v| v.to_str().ok())
74 .and_then(|s| s.strip_prefix("Bearer "))
75}
76
77fn dev_identity(token: &str) -> Option<Identity> {
78 match token {
79 "dev-admin" => Some(Identity {
80 user_id: String::from("admin"),
81 is_admin: true,
82 }),
83 "dev-user" => Some(Identity {
84 user_id: String::from("user"),
85 is_admin: false,
86 }),
87 _ => None,
88 }
89}
90
91pub fn identity(ctx: &Context) -> Option<&Identity> {
92 ctx.get::<Identity>()
93}
94
95pub fn require_auth(ctx: &Context) -> Result<&Identity, Error> {
96 identity(ctx).ok_or(Error::Unauthorized)
97}
98
99pub fn require_admin(ctx: &Context) -> Result<&Identity, Error> {
100 let id = require_auth(ctx)?;
101 if !id.is_admin {
102 return Err(Error::Forbidden);
103 }
104 Ok(id)
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 fn user(is_admin: bool) -> Identity {
112 Identity {
113 user_id: String::from(if is_admin { "admin" } else { "user" }),
114 is_admin,
115 }
116 }
117
118 #[test]
119 fn identity_returns_none_when_absent() {
120 let ctx = Context::new();
121 assert!(identity(&ctx).is_none());
122 }
123
124 #[test]
125 fn identity_returns_reference_when_attached() {
126 let mut ctx = Context::new();
127 ctx.insert(user(false));
128 assert_eq!(identity(&ctx).map(|i| i.user_id.as_str()), Some("user"));
129 }
130
131 #[test]
132 fn require_auth_missing_returns_unauthorized() {
133 let ctx = Context::new();
134 assert!(matches!(require_auth(&ctx), Err(Error::Unauthorized)));
135 }
136
137 #[test]
138 fn require_auth_present_returns_identity() {
139 let mut ctx = Context::new();
140 ctx.insert(user(false));
141 let id = require_auth(&ctx).unwrap();
142 assert_eq!(id.user_id, "user");
143 assert!(!id.is_admin);
144 }
145
146 #[test]
147 fn require_admin_without_identity_returns_unauthorized() {
148 let ctx = Context::new();
149 assert!(matches!(require_admin(&ctx), Err(Error::Unauthorized)));
150 }
151
152 #[test]
153 fn require_admin_with_non_admin_returns_forbidden() {
154 let mut ctx = Context::new();
155 ctx.insert(user(false));
156 assert!(matches!(require_admin(&ctx), Err(Error::Forbidden)));
157 }
158
159 #[test]
160 fn require_admin_with_admin_returns_identity() {
161 let mut ctx = Context::new();
162 ctx.insert(user(true));
163 let id = require_admin(&ctx).unwrap();
164 assert_eq!(id.user_id, "admin");
165 assert!(id.is_admin);
166 }
167
168 #[test]
169 fn dev_identity_rejects_unknown_tokens() {
170 assert!(dev_identity("garbage").is_none());
171 assert!(dev_identity("").is_none());
172 }
173
174 #[test]
175 fn dev_identity_maps_known_tokens() {
176 let admin = dev_identity("dev-admin").unwrap();
177 assert!(admin.is_admin);
178 let user = dev_identity("dev-user").unwrap();
179 assert!(!user.is_admin);
180 }
181
182 #[test]
183 fn in_production_detects_known_values() {
184 fn detect(v: Option<&str>) -> bool {
187 v.map(|s| {
188 let s = s.to_ascii_lowercase();
189 s == "production" || s == "prod"
190 })
191 .unwrap_or(false)
192 }
193 assert!(detect(Some("production")));
194 assert!(detect(Some("PRODUCTION")));
195 assert!(detect(Some("prod")));
196 assert!(detect(Some("Prod")));
197 assert!(!detect(Some("dev")));
198 assert!(!detect(Some("staging")));
199 assert!(!detect(Some("")));
200 assert!(!detect(None));
201 }
202}