1use std::collections::HashSet;
56use std::time::{Duration, SystemTime, UNIX_EPOCH};
57
58use serde::{Deserialize, Serialize};
59
60use crate::filter_ir::{AuthCapabilities, AuthScope};
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct CapabilityToken {
71 pub version: u8,
73
74 pub token_id: String,
76
77 pub allowed_namespaces: Vec<String>,
79
80 pub tenant_id: Option<String>,
82
83 pub project_id: Option<String>,
85
86 pub capabilities: TokenCapabilities,
88
89 pub issued_at: u64,
91
92 pub expires_at: u64,
94
95 pub acl_tags: Vec<String>,
97
98 pub signature: Vec<u8>,
100}
101
102#[derive(Debug, Clone, Default, Serialize, Deserialize)]
104pub struct TokenCapabilities {
105 pub can_read: bool,
107 pub can_write: bool,
109 pub can_delete: bool,
111 pub can_admin: bool,
113 pub can_delegate: bool,
115}
116
117impl CapabilityToken {
118 pub const CURRENT_VERSION: u8 = 1;
120
121 pub fn is_expired(&self) -> bool {
123 let now = SystemTime::now()
124 .duration_since(UNIX_EPOCH)
125 .map(|d| d.as_secs())
126 .unwrap_or(0);
127 now > self.expires_at
128 }
129
130 pub fn is_namespace_allowed(&self, namespace: &str) -> bool {
132 self.allowed_namespaces.iter().any(|ns| ns == namespace)
133 }
134
135 pub fn to_auth_scope(&self) -> AuthScope {
137 AuthScope {
138 allowed_namespaces: self.allowed_namespaces.clone(),
139 tenant_id: self.tenant_id.clone(),
140 project_id: self.project_id.clone(),
141 expires_at: Some(self.expires_at),
142 capabilities: AuthCapabilities {
143 can_read: self.capabilities.can_read,
144 can_write: self.capabilities.can_write,
145 can_delete: self.capabilities.can_delete,
146 can_admin: self.capabilities.can_admin,
147 },
148 acl_tags: self.acl_tags.clone(),
149 }
150 }
151
152 pub fn remaining_validity(&self) -> Option<Duration> {
154 let now = SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .map(|d| d.as_secs())
157 .unwrap_or(0);
158
159 if now >= self.expires_at {
160 None
161 } else {
162 Some(Duration::from_secs(self.expires_at - now))
163 }
164 }
165}
166
167pub struct TokenBuilder {
173 namespaces: Vec<String>,
174 tenant_id: Option<String>,
175 project_id: Option<String>,
176 capabilities: TokenCapabilities,
177 validity: Duration,
178 acl_tags: Vec<String>,
179}
180
181impl TokenBuilder {
182 pub fn new(namespace: impl Into<String>) -> Self {
184 Self {
185 namespaces: vec![namespace.into()],
186 tenant_id: None,
187 project_id: None,
188 capabilities: TokenCapabilities {
189 can_read: true,
190 ..Default::default()
191 },
192 validity: Duration::from_secs(3600), acl_tags: Vec::new(),
194 }
195 }
196
197 pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self {
199 self.namespaces.push(namespace.into());
200 self
201 }
202
203 pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
205 self.tenant_id = Some(tenant_id.into());
206 self
207 }
208
209 pub fn with_project(mut self, project_id: impl Into<String>) -> Self {
211 self.project_id = Some(project_id.into());
212 self
213 }
214
215 pub fn can_read(mut self) -> Self {
217 self.capabilities.can_read = true;
218 self
219 }
220
221 pub fn can_write(mut self) -> Self {
223 self.capabilities.can_write = true;
224 self
225 }
226
227 pub fn can_delete(mut self) -> Self {
229 self.capabilities.can_delete = true;
230 self
231 }
232
233 pub fn can_admin(mut self) -> Self {
235 self.capabilities.can_admin = true;
236 self
237 }
238
239 pub fn full_access(mut self) -> Self {
241 self.capabilities = TokenCapabilities {
242 can_read: true,
243 can_write: true,
244 can_delete: true,
245 can_admin: true,
246 can_delegate: false,
247 };
248 self
249 }
250
251 pub fn valid_for(mut self, duration: Duration) -> Self {
253 self.validity = duration;
254 self
255 }
256
257 pub fn with_acl_tags(mut self, tags: Vec<String>) -> Self {
259 self.acl_tags = tags;
260 self
261 }
262
263 pub fn build_unsigned(self) -> CapabilityToken {
265 let now = SystemTime::now()
266 .duration_since(UNIX_EPOCH)
267 .map(|d| d.as_secs())
268 .unwrap_or(0);
269
270 CapabilityToken {
271 version: CapabilityToken::CURRENT_VERSION,
272 token_id: generate_token_id(),
273 allowed_namespaces: self.namespaces,
274 tenant_id: self.tenant_id,
275 project_id: self.project_id,
276 capabilities: self.capabilities,
277 issued_at: now,
278 expires_at: now + self.validity.as_secs(),
279 acl_tags: self.acl_tags,
280 signature: Vec::new(),
281 }
282 }
283}
284
285fn generate_token_id() -> String {
287
288 format!("tok_{:x}",
290 std::time::SystemTime::now()
291 .duration_since(UNIX_EPOCH)
292 .unwrap_or_default()
293 .as_nanos()
294 )
295}
296
297pub struct TokenSigner {
303 secret: Vec<u8>,
305}
306
307impl TokenSigner {
308 pub fn new(secret: impl AsRef<[u8]>) -> Self {
310 Self {
311 secret: secret.as_ref().to_vec(),
312 }
313 }
314
315 pub fn sign(&self, token: &mut CapabilityToken) {
317 let payload = self.compute_payload(token);
318 token.signature = self.hmac_sha256(&payload);
319 }
320
321 pub fn verify(&self, token: &CapabilityToken) -> Result<(), TokenError> {
323 if token.version != CapabilityToken::CURRENT_VERSION {
325 return Err(TokenError::UnsupportedVersion(token.version));
326 }
327
328 if token.is_expired() {
330 return Err(TokenError::Expired);
331 }
332
333 let payload = self.compute_payload(token);
335 let expected = self.hmac_sha256(&payload);
336
337 if !constant_time_eq(&token.signature, &expected) {
338 return Err(TokenError::InvalidSignature);
339 }
340
341 Ok(())
342 }
343
344 fn compute_payload(&self, token: &CapabilityToken) -> Vec<u8> {
346 let mut payload = Vec::new();
348
349 payload.push(token.version);
350 payload.extend(token.token_id.as_bytes());
351
352 for ns in &token.allowed_namespaces {
353 payload.extend(ns.as_bytes());
354 payload.push(0); }
356
357 if let Some(ref tenant) = token.tenant_id {
358 payload.extend(tenant.as_bytes());
359 }
360 payload.push(0);
361
362 if let Some(ref project) = token.project_id {
363 payload.extend(project.as_bytes());
364 }
365 payload.push(0);
366
367 let caps = (token.capabilities.can_read as u8)
369 | ((token.capabilities.can_write as u8) << 1)
370 | ((token.capabilities.can_delete as u8) << 2)
371 | ((token.capabilities.can_admin as u8) << 3)
372 | ((token.capabilities.can_delegate as u8) << 4);
373 payload.push(caps);
374
375 payload.extend(&token.issued_at.to_le_bytes());
376 payload.extend(&token.expires_at.to_le_bytes());
377
378 for tag in &token.acl_tags {
379 payload.extend(tag.as_bytes());
380 payload.push(0);
381 }
382
383 payload
384 }
385
386 fn hmac_sha256(&self, data: &[u8]) -> Vec<u8> {
388 use std::collections::hash_map::DefaultHasher;
391 use std::hash::{Hash, Hasher};
392
393 let mut hasher = DefaultHasher::new();
396 self.secret.hash(&mut hasher);
397 data.hash(&mut hasher);
398 let h1 = hasher.finish();
399
400 let mut hasher2 = DefaultHasher::new();
401 h1.hash(&mut hasher2);
402 self.secret.hash(&mut hasher2);
403 let h2 = hasher2.finish();
404
405 let mut result = Vec::with_capacity(16);
406 result.extend(&h1.to_le_bytes());
407 result.extend(&h2.to_le_bytes());
408 result
409 }
410}
411
412fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
414 if a.len() != b.len() {
415 return false;
416 }
417
418 let mut diff = 0u8;
419 for (x, y) in a.iter().zip(b.iter()) {
420 diff |= x ^ y;
421 }
422 diff == 0
423}
424
425#[derive(Debug, Clone, thiserror::Error)]
427pub enum TokenError {
428 #[error("token has expired")]
429 Expired,
430
431 #[error("invalid signature")]
432 InvalidSignature,
433
434 #[error("unsupported token version: {0}")]
435 UnsupportedVersion(u8),
436
437 #[error("token revoked")]
438 Revoked,
439
440 #[error("namespace not allowed: {0}")]
441 NamespaceNotAllowed(String),
442
443 #[error("insufficient capabilities")]
444 InsufficientCapabilities,
445}
446
447pub struct RevocationList {
453 revoked: std::sync::RwLock<HashSet<String>>,
455}
456
457impl RevocationList {
458 pub fn new() -> Self {
460 Self {
461 revoked: std::sync::RwLock::new(HashSet::new()),
462 }
463 }
464
465 pub fn revoke(&self, token_id: &str) {
467 self.revoked.write().unwrap().insert(token_id.to_string());
468 }
469
470 pub fn is_revoked(&self, token_id: &str) -> bool {
472 self.revoked.read().unwrap().contains(token_id)
473 }
474
475 pub fn count(&self) -> usize {
477 self.revoked.read().unwrap().len()
478 }
479}
480
481impl Default for RevocationList {
482 fn default() -> Self {
483 Self::new()
484 }
485}
486
487pub struct TokenValidator {
493 signer: TokenSigner,
494 revocation_list: RevocationList,
495}
496
497impl TokenValidator {
498 pub fn new(secret: impl AsRef<[u8]>) -> Self {
500 Self {
501 signer: TokenSigner::new(secret),
502 revocation_list: RevocationList::new(),
503 }
504 }
505
506 pub fn issue(&self, builder: TokenBuilder) -> CapabilityToken {
508 let mut token = builder.build_unsigned();
509 self.signer.sign(&mut token);
510 token
511 }
512
513 pub fn validate(&self, token: &CapabilityToken) -> Result<AuthScope, TokenError> {
515 if self.revocation_list.is_revoked(&token.token_id) {
517 return Err(TokenError::Revoked);
518 }
519
520 self.signer.verify(token)?;
522
523 Ok(token.to_auth_scope())
525 }
526
527 pub fn revoke(&self, token_id: &str) {
529 self.revocation_list.revoke(token_id);
530 }
531}
532
533#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
546pub struct AclTag(String);
547
548impl AclTag {
549 pub fn new(tag: impl Into<String>) -> Self {
551 Self(tag.into())
552 }
553
554 pub fn name(&self) -> &str {
556 &self.0
557 }
558}
559
560#[derive(Debug, Default)]
565pub struct AclTagIndex {
566 tag_to_docs: std::collections::HashMap<String, Vec<u64>>,
568}
569
570impl AclTagIndex {
571 pub fn new() -> Self {
573 Self::default()
574 }
575
576 pub fn add_tag(&mut self, doc_id: u64, tag: &str) {
578 self.tag_to_docs
579 .entry(tag.to_string())
580 .or_default()
581 .push(doc_id);
582 }
583
584 pub fn docs_with_tag(&self, tag: &str) -> &[u64] {
586 self.tag_to_docs.get(tag).map(|v| v.as_slice()).unwrap_or(&[])
587 }
588
589 pub fn accessible_docs(&self, allowed_tags: &[String]) -> Vec<u64> {
591 let mut result = HashSet::new();
592 for tag in allowed_tags {
593 if let Some(docs) = self.tag_to_docs.get(tag) {
594 result.extend(docs.iter().copied());
595 }
596 }
597 result.into_iter().collect()
598 }
599}
600
601#[cfg(test)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_token_builder() {
611 let token = TokenBuilder::new("production")
612 .with_namespace("staging")
613 .with_tenant("acme")
614 .can_read()
615 .can_write()
616 .valid_for(Duration::from_secs(3600))
617 .build_unsigned();
618
619 assert_eq!(token.allowed_namespaces.len(), 2);
620 assert_eq!(token.tenant_id, Some("acme".to_string()));
621 assert!(token.capabilities.can_read);
622 assert!(token.capabilities.can_write);
623 assert!(!token.capabilities.can_delete);
624 }
625
626 #[test]
627 fn test_token_signing_and_verification() {
628 let signer = TokenSigner::new("super_secret_key");
629
630 let mut token = TokenBuilder::new("production")
631 .can_read()
632 .valid_for(Duration::from_secs(3600))
633 .build_unsigned();
634
635 signer.sign(&mut token);
636 assert!(!token.signature.is_empty());
637
638 assert!(signer.verify(&token).is_ok());
640
641 token.allowed_namespaces.push("hacked".to_string());
643 assert!(signer.verify(&token).is_err());
644 }
645
646 #[test]
647 fn test_token_expiry() {
648 let mut token = TokenBuilder::new("production")
650 .valid_for(Duration::from_secs(3600))
651 .build_unsigned();
652
653 token.expires_at = 0;
655
656 assert!(token.is_expired());
657 }
658
659 #[test]
660 fn test_token_to_auth_scope() {
661 let token = TokenBuilder::new("production")
662 .with_tenant("acme")
663 .can_read()
664 .can_write()
665 .with_acl_tags(vec!["public".to_string(), "internal".to_string()])
666 .build_unsigned();
667
668 let scope = token.to_auth_scope();
669 assert!(scope.is_namespace_allowed("production"));
670 assert!(!scope.is_namespace_allowed("staging"));
671 assert_eq!(scope.tenant_id, Some("acme".to_string()));
672 assert!(scope.capabilities.can_read);
673 assert!(scope.capabilities.can_write);
674 assert_eq!(scope.acl_tags.len(), 2);
675 }
676
677 #[test]
678 fn test_revocation() {
679 let validator = TokenValidator::new("secret");
680
681 let token = validator.issue(
682 TokenBuilder::new("production")
683 .can_read()
684 .valid_for(Duration::from_secs(3600))
685 );
686
687 assert!(validator.validate(&token).is_ok());
689
690 validator.revoke(&token.token_id);
692
693 assert!(matches!(
695 validator.validate(&token),
696 Err(TokenError::Revoked)
697 ));
698 }
699
700 #[test]
701 fn test_acl_tag_index() {
702 let mut index = AclTagIndex::new();
703
704 index.add_tag(1, "public");
705 index.add_tag(2, "public");
706 index.add_tag(3, "internal");
707 index.add_tag(4, "confidential");
708
709 assert_eq!(index.docs_with_tag("public").len(), 2);
710 assert_eq!(index.docs_with_tag("internal").len(), 1);
711
712 let accessible = index.accessible_docs(&["public".to_string(), "internal".to_string()]);
713 assert_eq!(accessible.len(), 3);
714 }
715}