1use std::collections::HashSet;
53use std::time::{Duration, SystemTime, UNIX_EPOCH};
54
55use serde::{Deserialize, Serialize};
56
57use crate::filter_ir::{AuthCapabilities, AuthScope};
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CapabilityToken {
68 pub version: u8,
70
71 pub token_id: String,
73
74 pub allowed_namespaces: Vec<String>,
76
77 pub tenant_id: Option<String>,
79
80 pub project_id: Option<String>,
82
83 pub capabilities: TokenCapabilities,
85
86 pub issued_at: u64,
88
89 pub expires_at: u64,
91
92 pub acl_tags: Vec<String>,
94
95 pub signature: Vec<u8>,
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct TokenCapabilities {
102 pub can_read: bool,
104 pub can_write: bool,
106 pub can_delete: bool,
108 pub can_admin: bool,
110 pub can_delegate: bool,
112}
113
114impl CapabilityToken {
115 pub const CURRENT_VERSION: u8 = 1;
117
118 pub fn is_expired(&self) -> bool {
120 let now = SystemTime::now()
121 .duration_since(UNIX_EPOCH)
122 .map(|d| d.as_secs())
123 .unwrap_or(0);
124 now > self.expires_at
125 }
126
127 pub fn is_namespace_allowed(&self, namespace: &str) -> bool {
129 self.allowed_namespaces.iter().any(|ns| ns == namespace)
130 }
131
132 pub fn to_auth_scope(&self) -> AuthScope {
134 AuthScope {
135 allowed_namespaces: self.allowed_namespaces.clone(),
136 tenant_id: self.tenant_id.clone(),
137 project_id: self.project_id.clone(),
138 expires_at: Some(self.expires_at),
139 capabilities: AuthCapabilities {
140 can_read: self.capabilities.can_read,
141 can_write: self.capabilities.can_write,
142 can_delete: self.capabilities.can_delete,
143 can_admin: self.capabilities.can_admin,
144 },
145 acl_tags: self.acl_tags.clone(),
146 }
147 }
148
149 pub fn remaining_validity(&self) -> Option<Duration> {
151 let now = SystemTime::now()
152 .duration_since(UNIX_EPOCH)
153 .map(|d| d.as_secs())
154 .unwrap_or(0);
155
156 if now >= self.expires_at {
157 None
158 } else {
159 Some(Duration::from_secs(self.expires_at - now))
160 }
161 }
162}
163
164pub struct TokenBuilder {
170 namespaces: Vec<String>,
171 tenant_id: Option<String>,
172 project_id: Option<String>,
173 capabilities: TokenCapabilities,
174 validity: Duration,
175 acl_tags: Vec<String>,
176}
177
178impl TokenBuilder {
179 pub fn new(namespace: impl Into<String>) -> Self {
181 Self {
182 namespaces: vec![namespace.into()],
183 tenant_id: None,
184 project_id: None,
185 capabilities: TokenCapabilities {
186 can_read: true,
187 ..Default::default()
188 },
189 validity: Duration::from_secs(3600), acl_tags: Vec::new(),
191 }
192 }
193
194 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
196 self.namespaces.push(namespace.into());
197 self
198 }
199
200 pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
202 self.tenant_id = Some(tenant_id.into());
203 self
204 }
205
206 pub fn with_project(mut self, project_id: impl Into<String>) -> Self {
208 self.project_id = Some(project_id.into());
209 self
210 }
211
212 pub fn can_read(mut self) -> Self {
214 self.capabilities.can_read = true;
215 self
216 }
217
218 pub fn can_write(mut self) -> Self {
220 self.capabilities.can_write = true;
221 self
222 }
223
224 pub fn can_delete(mut self) -> Self {
226 self.capabilities.can_delete = true;
227 self
228 }
229
230 pub fn can_admin(mut self) -> Self {
232 self.capabilities.can_admin = true;
233 self
234 }
235
236 pub fn full_access(mut self) -> Self {
238 self.capabilities = TokenCapabilities {
239 can_read: true,
240 can_write: true,
241 can_delete: true,
242 can_admin: true,
243 can_delegate: false,
244 };
245 self
246 }
247
248 pub fn valid_for(mut self, duration: Duration) -> Self {
250 self.validity = duration;
251 self
252 }
253
254 pub fn with_acl_tags(mut self, tags: Vec<String>) -> Self {
256 self.acl_tags = tags;
257 self
258 }
259
260 pub fn build_unsigned(self) -> CapabilityToken {
262 let now = SystemTime::now()
263 .duration_since(UNIX_EPOCH)
264 .map(|d| d.as_secs())
265 .unwrap_or(0);
266
267 CapabilityToken {
268 version: CapabilityToken::CURRENT_VERSION,
269 token_id: generate_token_id(),
270 allowed_namespaces: self.namespaces,
271 tenant_id: self.tenant_id,
272 project_id: self.project_id,
273 capabilities: self.capabilities,
274 issued_at: now,
275 expires_at: now + self.validity.as_secs(),
276 acl_tags: self.acl_tags,
277 signature: Vec::new(),
278 }
279 }
280}
281
282fn generate_token_id() -> String {
284
285 format!("tok_{:x}",
287 std::time::SystemTime::now()
288 .duration_since(UNIX_EPOCH)
289 .unwrap_or_default()
290 .as_nanos()
291 )
292}
293
294pub struct TokenSigner {
300 secret: Vec<u8>,
302}
303
304impl TokenSigner {
305 pub fn new(secret: impl AsRef<[u8]>) -> Self {
307 Self {
308 secret: secret.as_ref().to_vec(),
309 }
310 }
311
312 pub fn sign(&self, token: &mut CapabilityToken) {
314 let payload = self.compute_payload(token);
315 token.signature = self.hmac_sha256(&payload);
316 }
317
318 pub fn verify(&self, token: &CapabilityToken) -> Result<(), TokenError> {
320 if token.version != CapabilityToken::CURRENT_VERSION {
322 return Err(TokenError::UnsupportedVersion(token.version));
323 }
324
325 if token.is_expired() {
327 return Err(TokenError::Expired);
328 }
329
330 let payload = self.compute_payload(token);
332 let expected = self.hmac_sha256(&payload);
333
334 if !constant_time_eq(&token.signature, &expected) {
335 return Err(TokenError::InvalidSignature);
336 }
337
338 Ok(())
339 }
340
341 fn compute_payload(&self, token: &CapabilityToken) -> Vec<u8> {
343 let mut payload = Vec::new();
345
346 payload.push(token.version);
347 payload.extend(token.token_id.as_bytes());
348
349 for ns in &token.allowed_namespaces {
350 payload.extend(ns.as_bytes());
351 payload.push(0); }
353
354 if let Some(ref tenant) = token.tenant_id {
355 payload.extend(tenant.as_bytes());
356 }
357 payload.push(0);
358
359 if let Some(ref project) = token.project_id {
360 payload.extend(project.as_bytes());
361 }
362 payload.push(0);
363
364 let caps = (token.capabilities.can_read as u8)
366 | ((token.capabilities.can_write as u8) << 1)
367 | ((token.capabilities.can_delete as u8) << 2)
368 | ((token.capabilities.can_admin as u8) << 3)
369 | ((token.capabilities.can_delegate as u8) << 4);
370 payload.push(caps);
371
372 payload.extend(&token.issued_at.to_le_bytes());
373 payload.extend(&token.expires_at.to_le_bytes());
374
375 for tag in &token.acl_tags {
376 payload.extend(tag.as_bytes());
377 payload.push(0);
378 }
379
380 payload
381 }
382
383 fn hmac_sha256(&self, data: &[u8]) -> Vec<u8> {
385 use std::collections::hash_map::DefaultHasher;
388 use std::hash::{Hash, Hasher};
389
390 let mut hasher = DefaultHasher::new();
393 self.secret.hash(&mut hasher);
394 data.hash(&mut hasher);
395 let h1 = hasher.finish();
396
397 let mut hasher2 = DefaultHasher::new();
398 h1.hash(&mut hasher2);
399 self.secret.hash(&mut hasher2);
400 let h2 = hasher2.finish();
401
402 let mut result = Vec::with_capacity(16);
403 result.extend(&h1.to_le_bytes());
404 result.extend(&h2.to_le_bytes());
405 result
406 }
407}
408
409fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
411 if a.len() != b.len() {
412 return false;
413 }
414
415 let mut diff = 0u8;
416 for (x, y) in a.iter().zip(b.iter()) {
417 diff |= x ^ y;
418 }
419 diff == 0
420}
421
422#[derive(Debug, Clone, thiserror::Error)]
424pub enum TokenError {
425 #[error("token has expired")]
426 Expired,
427
428 #[error("invalid signature")]
429 InvalidSignature,
430
431 #[error("unsupported token version: {0}")]
432 UnsupportedVersion(u8),
433
434 #[error("token revoked")]
435 Revoked,
436
437 #[error("namespace not allowed: {0}")]
438 NamespaceNotAllowed(String),
439
440 #[error("insufficient capabilities")]
441 InsufficientCapabilities,
442}
443
444pub struct RevocationList {
450 revoked: std::sync::RwLock<HashSet<String>>,
452}
453
454impl RevocationList {
455 pub fn new() -> Self {
457 Self {
458 revoked: std::sync::RwLock::new(HashSet::new()),
459 }
460 }
461
462 pub fn revoke(&self, token_id: &str) {
464 self.revoked.write().unwrap().insert(token_id.to_string());
465 }
466
467 pub fn is_revoked(&self, token_id: &str) -> bool {
469 self.revoked.read().unwrap().contains(token_id)
470 }
471
472 pub fn count(&self) -> usize {
474 self.revoked.read().unwrap().len()
475 }
476}
477
478impl Default for RevocationList {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484pub struct TokenValidator {
490 signer: TokenSigner,
491 revocation_list: RevocationList,
492}
493
494impl TokenValidator {
495 pub fn new(secret: impl AsRef<[u8]>) -> Self {
497 Self {
498 signer: TokenSigner::new(secret),
499 revocation_list: RevocationList::new(),
500 }
501 }
502
503 pub fn issue(&self, builder: TokenBuilder) -> CapabilityToken {
505 let mut token = builder.build_unsigned();
506 self.signer.sign(&mut token);
507 token
508 }
509
510 pub fn validate(&self, token: &CapabilityToken) -> Result<AuthScope, TokenError> {
512 if self.revocation_list.is_revoked(&token.token_id) {
514 return Err(TokenError::Revoked);
515 }
516
517 self.signer.verify(token)?;
519
520 Ok(token.to_auth_scope())
522 }
523
524 pub fn revoke(&self, token_id: &str) {
526 self.revocation_list.revoke(token_id);
527 }
528}
529
530#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
543pub struct AclTag(String);
544
545impl AclTag {
546 pub fn new(tag: impl Into<String>) -> Self {
548 Self(tag.into())
549 }
550
551 pub fn name(&self) -> &str {
553 &self.0
554 }
555}
556
557#[derive(Debug, Default)]
562pub struct AclTagIndex {
563 tag_to_docs: std::collections::HashMap<String, Vec<u64>>,
565}
566
567impl AclTagIndex {
568 pub fn new() -> Self {
570 Self::default()
571 }
572
573 pub fn add_tag(&mut self, doc_id: u64, tag: &str) {
575 self.tag_to_docs
576 .entry(tag.to_string())
577 .or_default()
578 .push(doc_id);
579 }
580
581 pub fn docs_with_tag(&self, tag: &str) -> &[u64] {
583 self.tag_to_docs.get(tag).map(|v| v.as_slice()).unwrap_or(&[])
584 }
585
586 pub fn accessible_docs(&self, allowed_tags: &[String]) -> Vec<u64> {
588 let mut result = HashSet::new();
589 for tag in allowed_tags {
590 if let Some(docs) = self.tag_to_docs.get(tag) {
591 result.extend(docs.iter().copied());
592 }
593 }
594 result.into_iter().collect()
595 }
596}
597
598#[cfg(test)]
603mod tests {
604 use super::*;
605
606 #[test]
607 fn test_token_builder() {
608 let token = TokenBuilder::new("production")
609 .with_namespace("staging")
610 .with_tenant("acme")
611 .can_read()
612 .can_write()
613 .valid_for(Duration::from_secs(3600))
614 .build_unsigned();
615
616 assert_eq!(token.allowed_namespaces.len(), 2);
617 assert_eq!(token.tenant_id, Some("acme".to_string()));
618 assert!(token.capabilities.can_read);
619 assert!(token.capabilities.can_write);
620 assert!(!token.capabilities.can_delete);
621 }
622
623 #[test]
624 fn test_token_signing_and_verification() {
625 let signer = TokenSigner::new("super_secret_key");
626
627 let mut token = TokenBuilder::new("production")
628 .can_read()
629 .valid_for(Duration::from_secs(3600))
630 .build_unsigned();
631
632 signer.sign(&mut token);
633 assert!(!token.signature.is_empty());
634
635 assert!(signer.verify(&token).is_ok());
637
638 token.allowed_namespaces.push("hacked".to_string());
640 assert!(signer.verify(&token).is_err());
641 }
642
643 #[test]
644 fn test_token_expiry() {
645 let mut token = TokenBuilder::new("production")
647 .valid_for(Duration::from_secs(3600))
648 .build_unsigned();
649
650 token.expires_at = 0;
652
653 assert!(token.is_expired());
654 }
655
656 #[test]
657 fn test_token_to_auth_scope() {
658 let token = TokenBuilder::new("production")
659 .with_tenant("acme")
660 .can_read()
661 .can_write()
662 .with_acl_tags(vec!["public".to_string(), "internal".to_string()])
663 .build_unsigned();
664
665 let scope = token.to_auth_scope();
666 assert!(scope.is_namespace_allowed("production"));
667 assert!(!scope.is_namespace_allowed("staging"));
668 assert_eq!(scope.tenant_id, Some("acme".to_string()));
669 assert!(scope.capabilities.can_read);
670 assert!(scope.capabilities.can_write);
671 assert_eq!(scope.acl_tags.len(), 2);
672 }
673
674 #[test]
675 fn test_revocation() {
676 let validator = TokenValidator::new("secret");
677
678 let token = validator.issue(
679 TokenBuilder::new("production")
680 .can_read()
681 .valid_for(Duration::from_secs(3600))
682 );
683
684 assert!(validator.validate(&token).is_ok());
686
687 validator.revoke(&token.token_id);
689
690 assert!(matches!(
692 validator.validate(&token),
693 Err(TokenError::Revoked)
694 ));
695 }
696
697 #[test]
698 fn test_acl_tag_index() {
699 let mut index = AclTagIndex::new();
700
701 index.add_tag(1, "public");
702 index.add_tag(2, "public");
703 index.add_tag(3, "internal");
704 index.add_tag(4, "confidential");
705
706 assert_eq!(index.docs_with_tag("public").len(), 2);
707 assert_eq!(index.docs_with_tag("internal").len(), 1);
708
709 let accessible = index.accessible_docs(&["public".to_string(), "internal".to_string()]);
710 assert_eq!(accessible.len(), 3);
711 }
712}