1use std::collections::HashMap;
14use std::time::{SystemTime, UNIX_EPOCH};
15
16use serde::{Deserialize, Serialize};
17use tonic::{Request, Status};
18
19#[derive(Clone, Debug)]
25pub struct RawToken {
26 pub value: String,
27 pub kind: &'static str,
31}
32
33#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct AuthCtx {
39 pub subject: String,
41 pub issuer: String,
42 pub audience: String,
43 pub scopes: Vec<String>,
44 pub kind: PrincipalKind,
45 pub raw_token: String,
47 pub expires_at: f64,
53 #[serde(default)]
56 pub extra: HashMap<String, serde_json::Value>,
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
60#[serde(rename_all = "lowercase")]
61pub enum PrincipalKind {
62 User,
63 Service,
64 Agent,
65 Anonymous,
67}
68
69impl AuthCtx {
70 pub fn anonymous() -> Self {
72 Self {
73 subject: String::new(),
74 issuer: String::new(),
75 audience: String::new(),
76 scopes: Vec::new(),
77 kind: PrincipalKind::Anonymous,
78 raw_token: String::new(),
79 expires_at: 0.0,
80 extra: HashMap::new(),
81 }
82 }
83
84 pub fn from_bearer(token: impl Into<String>) -> Self {
88 let token = token.into();
89 Self {
90 raw_token: token,
91 kind: PrincipalKind::User,
92 ..Self::anonymous()
93 }
94 }
95
96 pub fn from<T>(req: &Request<T>) -> Self {
100 req.extensions()
101 .get::<AuthCtx>()
102 .cloned()
103 .unwrap_or_else(Self::anonymous)
104 }
105
106 pub fn propagate<T>(&self, req: &mut Request<T>) {
109 if self.raw_token.is_empty() {
110 return;
111 }
112 if let Ok(value) = format!("Bearer {}", self.raw_token).parse() {
113 req.metadata_mut().insert("authorization", value);
114 }
115 }
116
117 #[allow(clippy::result_large_err)] pub fn require_scope(&self, scope: &str) -> Result<(), Status> {
121 if self.scopes.iter().any(|s| s == scope) {
122 Ok(())
123 } else {
124 Err(AuthError::InsufficientScope {
125 required: scope.into(),
126 }
127 .into())
128 }
129 }
130
131 pub fn is_anonymous(&self) -> bool {
132 matches!(self.kind, PrincipalKind::Anonymous)
133 }
134
135 pub fn expires_at_systime(&self) -> SystemTime {
138 if self.expires_at <= 0.0 {
139 UNIX_EPOCH
140 } else {
141 UNIX_EPOCH + std::time::Duration::from_secs_f64(self.expires_at)
142 }
143 }
144
145 pub fn set_expires_at_systime(&mut self, t: SystemTime) {
148 self.expires_at = t
149 .duration_since(UNIX_EPOCH)
150 .map(|d| d.as_secs_f64())
151 .unwrap_or(0.0);
152 }
153}
154
155#[derive(Debug, thiserror::Error)]
156pub enum AuthError {
157 #[error("no token in request")]
158 MissingToken,
159 #[error("token signature invalid")]
160 Signature,
161 #[error("token expired")]
162 Expired,
163 #[error("audience mismatch: expected {expected}, got {got}")]
164 Audience { expected: String, got: String },
165 #[error("issuer mismatch: expected {expected}, got {got}")]
166 Issuer { expected: String, got: String },
167 #[error("token verification failed: {0}")]
168 Verification(String),
169 #[error("insufficient scope: required {required}")]
170 InsufficientScope { required: String },
171 #[error("configuration error: {0}")]
172 Config(String),
173 #[error("transport error contacting auth backend: {0}")]
174 Transport(String),
175}
176
177impl From<AuthError> for Status {
178 fn from(e: AuthError) -> Status {
179 match e {
180 AuthError::InsufficientScope { .. } => Status::permission_denied(e.to_string()),
181 AuthError::Config(_) | AuthError::Transport(_) => Status::internal(e.to_string()),
182 _ => Status::unauthenticated(e.to_string()),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn anonymous_authctx_is_anonymous() {
193 let a = AuthCtx::anonymous();
194 assert!(a.is_anonymous());
195 assert_eq!(a.kind, PrincipalKind::Anonymous);
196 }
197
198 #[test]
199 fn from_bearer_carries_token() {
200 let a = AuthCtx::from_bearer("abc.def.ghi");
201 assert_eq!(a.raw_token, "abc.def.ghi");
202 assert_eq!(a.kind, PrincipalKind::User);
203 }
204
205 #[test]
206 fn propagate_writes_authorization_header() {
207 let a = AuthCtx::from_bearer("abc.def.ghi");
208 let mut req = Request::new(());
209 a.propagate(&mut req);
210 let v = req.metadata().get("authorization").unwrap();
211 assert_eq!(v.to_str().unwrap(), "Bearer abc.def.ghi");
212 }
213
214 #[test]
215 fn propagate_anonymous_is_noop() {
216 let a = AuthCtx::anonymous();
217 let mut req = Request::new(());
218 a.propagate(&mut req);
219 assert!(req.metadata().get("authorization").is_none());
220 }
221
222 #[test]
223 fn require_scope_ok_when_present() {
224 let mut a = AuthCtx::anonymous();
225 a.scopes = vec!["read:billing".into()];
226 assert!(a.require_scope("read:billing").is_ok());
227 }
228
229 #[test]
230 fn require_scope_err_when_missing() {
231 let a = AuthCtx::anonymous();
232 let err = a.require_scope("admin").unwrap_err();
233 assert_eq!(err.code(), tonic::Code::PermissionDenied);
234 }
235
236 #[test]
237 fn auth_error_maps_to_correct_status() {
238 let s: Status = AuthError::Signature.into();
239 assert_eq!(s.code(), tonic::Code::Unauthenticated);
240
241 let s: Status = AuthError::InsufficientScope {
242 required: "admin".into(),
243 }
244 .into();
245 assert_eq!(s.code(), tonic::Code::PermissionDenied);
246
247 let s: Status = AuthError::Config("missing env".into()).into();
248 assert_eq!(s.code(), tonic::Code::Internal);
249 }
250
251 #[test]
257 fn authctx_json_shape_is_stable_for_polyglot_consumers() {
258 let mut ctx = AuthCtx::anonymous();
259 ctx.subject = "alice".into();
260 ctx.issuer = "https://issuer.example".into();
261 ctx.audience = "my-svc".into();
262 ctx.scopes = vec!["read:billing".into(), "write:billing".into()];
263 ctx.kind = PrincipalKind::User;
264 ctx.raw_token = "abc.def.ghi".into();
265 ctx.expires_at = 1_735_689_600.0;
266 ctx.extra
267 .insert("tenant_id".into(), serde_json::json!("acme"));
268
269 let v = serde_json::to_value(&ctx).unwrap();
270 for f in [
272 "subject",
273 "issuer",
274 "audience",
275 "scopes",
276 "kind",
277 "raw_token",
278 "expires_at",
279 "extra",
280 ] {
281 assert!(
282 v.get(f).is_some(),
283 "missing field `{f}` in serialized AuthCtx JSON shape"
284 );
285 }
286 assert!(v["expires_at"].is_number());
288 assert_eq!(v["kind"], serde_json::json!("user"));
290 }
291}