1#![allow(clippy::cast_possible_truncation)]
7
8use base64::Engine;
9use ring::hmac;
10use ring::rand::{SecureRandom, SystemRandom};
11use std::collections::HashMap;
12use std::num::NonZeroU32;
13use std::time::{SystemTime, UNIX_EPOCH};
14use subtle::ConstantTimeEq;
15
16use crate::error::{Result, ShieldError};
17
18fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
20 let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
21 let num_blocks = length.div_ceil(32);
22
23 for i in 0..num_blocks {
24 let counter = (i as u32).to_le_bytes();
25 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
26 data.extend_from_slice(key);
27 data.extend_from_slice(nonce);
28 data.extend_from_slice(&counter);
29
30 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
31 keystream.extend_from_slice(hash.as_ref());
32 }
33
34 keystream.truncate(length);
35 keystream
36}
37
38#[derive(Clone)]
40pub struct Identity {
41 pub user_id: String,
42 pub display_name: String,
43 pub verification_key: [u8; 32],
44 pub attributes: HashMap<String, String>,
45 pub created_at: u64,
46}
47
48#[derive(Clone)]
50pub struct Session {
51 pub user_id: String,
52 pub permissions: Vec<String>,
53 pub expires_at: Option<u64>,
54 pub attributes: HashMap<String, String>,
55}
56
57impl Session {
58 #[must_use]
60 pub fn is_expired(&self) -> bool {
61 match self.expires_at {
62 None => false,
63 Some(expires) => {
64 let now = SystemTime::now()
65 .duration_since(UNIX_EPOCH)
66 .unwrap()
67 .as_secs();
68 now > expires
69 }
70 }
71 }
72
73 #[must_use]
75 pub fn has_permission(&self, permission: &str) -> bool {
76 self.permissions.contains(&permission.to_string())
77 }
78}
79
80struct UserData {
82 password_hash: [u8; 32],
83 salt: [u8; 16],
84 identity: Identity,
85}
86
87pub struct IdentityProvider {
89 master_key: [u8; 32],
90 token_ttl: u64,
91 users: HashMap<String, UserData>,
92}
93
94impl IdentityProvider {
95 const ITERATIONS: u32 = 100_000;
96
97 #[must_use]
99 pub fn new(master_key: [u8; 32], token_ttl: u64) -> Self {
100 Self {
101 master_key,
102 token_ttl: if token_ttl == 0 { 3600 } else { token_ttl },
103 users: HashMap::new(),
104 }
105 }
106
107 fn derive_key(&self, purpose: &str) -> [u8; 32] {
109 let mut data = Vec::with_capacity(32 + purpose.len());
110 data.extend_from_slice(&self.master_key);
111 data.extend_from_slice(purpose.as_bytes());
112 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
113 let mut key = [0u8; 32];
114 key.copy_from_slice(hash.as_ref());
115 key
116 }
117
118 pub fn register(
120 &mut self,
121 user_id: &str,
122 password: &str,
123 display_name: Option<&str>,
124 attributes: HashMap<String, String>,
125 ) -> Result<Identity> {
126 if self.users.contains_key(user_id) {
127 return Err(ShieldError::UserExists(user_id.to_string()));
128 }
129
130 let rng = SystemRandom::new();
131 let mut salt = [0u8; 16];
132 rng.fill(&mut salt).map_err(|_| ShieldError::RandomFailed)?;
133
134 let mut password_hash = [0u8; 32];
135 ring::pbkdf2::derive(
136 ring::pbkdf2::PBKDF2_HMAC_SHA256,
137 NonZeroU32::new(Self::ITERATIONS).unwrap(),
138 &salt,
139 password.as_bytes(),
140 &mut password_hash,
141 );
142
143 let verify_key = self.derive_key("verify");
145 let mut vk_data = Vec::with_capacity(32 + user_id.len());
146 vk_data.extend_from_slice(&verify_key);
147 vk_data.extend_from_slice(user_id.as_bytes());
148 let vk_hash = ring::digest::digest(&ring::digest::SHA256, &vk_data);
149 let mut verification_key = [0u8; 32];
150 verification_key.copy_from_slice(vk_hash.as_ref());
151
152 let now = SystemTime::now()
153 .duration_since(UNIX_EPOCH)
154 .unwrap()
155 .as_secs();
156
157 let identity = Identity {
158 user_id: user_id.to_string(),
159 display_name: display_name.unwrap_or(user_id).to_string(),
160 verification_key,
161 attributes, created_at: now,
163 };
164
165 self.users.insert(
166 user_id.to_string(),
167 UserData {
168 password_hash,
169 salt,
170 identity: identity.clone(),
171 },
172 );
173
174 Ok(identity)
175 }
176
177 #[must_use]
179 pub fn authenticate(
180 &self,
181 user_id: &str,
182 password: &str,
183 permissions: &[String],
184 ttl: Option<u64>,
185 ) -> Option<String> {
186 let user = self.users.get(user_id)?;
187
188 let mut password_hash = [0u8; 32];
189 ring::pbkdf2::derive(
190 ring::pbkdf2::PBKDF2_HMAC_SHA256,
191 NonZeroU32::new(Self::ITERATIONS).unwrap(),
192 &user.salt,
193 password.as_bytes(),
194 &mut password_hash,
195 );
196
197 if password_hash.ct_eq(&user.password_hash).unwrap_u8() != 1 {
198 return None;
199 }
200
201 Some(self.create_token(user_id, permissions, ttl.unwrap_or(self.token_ttl)))
202 }
203
204 fn create_token(&self, user_id: &str, permissions: &[String], ttl: u64) -> String {
206 let rng = SystemRandom::new();
207 let mut nonce = [0u8; 16];
208 rng.fill(&mut nonce).unwrap();
209
210 let now = SystemTime::now()
211 .duration_since(UNIX_EPOCH)
212 .unwrap()
213 .as_secs();
214 let expires_at = now + ttl;
215
216 let user_id_bytes = user_id.as_bytes();
218 let perms_json = serde_json::to_string(permissions).unwrap();
219 let perms_bytes = perms_json.as_bytes();
220
221 let mut token_data = Vec::new();
222 token_data.extend_from_slice(&(user_id_bytes.len() as u16).to_le_bytes());
223 token_data.extend_from_slice(user_id_bytes);
224 token_data.extend_from_slice(&(perms_bytes.len() as u16).to_le_bytes());
225 token_data.extend_from_slice(perms_bytes);
226 token_data.extend_from_slice(&expires_at.to_le_bytes());
227
228 let key = self.derive_key("session");
230 let keystream = generate_keystream(&key, &nonce, token_data.len());
231 let encrypted: Vec<u8> = token_data
232 .iter()
233 .zip(keystream.iter())
234 .map(|(p, k)| p ^ k)
235 .collect();
236
237 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &key);
239 let mut hmac_data = Vec::with_capacity(16 + encrypted.len());
240 hmac_data.extend_from_slice(&nonce);
241 hmac_data.extend_from_slice(&encrypted);
242 let tag = hmac::sign(&hmac_key, &hmac_data);
243
244 let mut result = Vec::with_capacity(16 + encrypted.len() + 16);
245 result.extend_from_slice(&nonce);
246 result.extend_from_slice(&encrypted);
247 result.extend_from_slice(&tag.as_ref()[..16]);
248
249 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&result)
250 }
251
252 #[must_use]
254 pub fn validate_token(&self, token: &str) -> Option<Session> {
255 let data = base64::engine::general_purpose::URL_SAFE_NO_PAD
256 .decode(token)
257 .ok()?;
258
259 if data.len() < 34 {
260 return None;
261 }
262
263 let nonce = &data[..16];
264 let encrypted = &data[16..data.len() - 16];
265 let mac = &data[data.len() - 16..];
266
267 let key = self.derive_key("session");
268
269 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &key);
271 let mut hmac_data = Vec::with_capacity(16 + encrypted.len());
272 hmac_data.extend_from_slice(nonce);
273 hmac_data.extend_from_slice(encrypted);
274 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
275
276 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
277 return None;
278 }
279
280 let keystream = generate_keystream(&key, nonce, encrypted.len());
282 let token_data: Vec<u8> = encrypted
283 .iter()
284 .zip(keystream.iter())
285 .map(|(c, k)| c ^ k)
286 .collect();
287
288 let user_id_len = u16::from_le_bytes([token_data[0], token_data[1]]) as usize;
290 let user_id = String::from_utf8(token_data[2..2 + user_id_len].to_vec()).ok()?;
291
292 let offset = 2 + user_id_len;
293 let perms_len = u16::from_le_bytes([token_data[offset], token_data[offset + 1]]) as usize;
294 let perms_json =
295 String::from_utf8(token_data[offset + 2..offset + 2 + perms_len].to_vec()).ok()?;
296 let permissions: Vec<String> = serde_json::from_str(&perms_json).ok()?;
297
298 let exp_offset = offset + 2 + perms_len;
299 let expires_at =
300 u64::from_le_bytes(token_data[exp_offset..exp_offset + 8].try_into().ok()?);
301
302 let session = Session {
303 user_id,
304 permissions,
305 expires_at: Some(expires_at),
306 attributes: HashMap::new(),
307 };
308
309 if session.is_expired() {
310 return None;
311 }
312
313 Some(session)
314 }
315
316 #[must_use]
318 pub fn create_service_token(
319 &self,
320 session_token: &str,
321 service: &str,
322 permissions: &[String],
323 ttl: u64,
324 ) -> Option<String> {
325 let session = self.validate_token(session_token)?;
326
327 let rng = SystemRandom::new();
328 let mut nonce = [0u8; 16];
329 rng.fill(&mut nonce).ok()?;
330
331 let now = SystemTime::now()
332 .duration_since(UNIX_EPOCH)
333 .unwrap()
334 .as_secs();
335 let expires_at = now + ttl;
336
337 let user_id_bytes = session.user_id.as_bytes();
339 let service_bytes = service.as_bytes();
340 let perms_json = serde_json::to_string(permissions).unwrap();
341 let perms_bytes = perms_json.as_bytes();
342
343 let mut token_data = Vec::new();
344 token_data.extend_from_slice(&(user_id_bytes.len() as u16).to_le_bytes());
345 token_data.extend_from_slice(user_id_bytes);
346 token_data.extend_from_slice(&(service_bytes.len() as u16).to_le_bytes());
347 token_data.extend_from_slice(service_bytes);
348 token_data.extend_from_slice(&(perms_bytes.len() as u16).to_le_bytes());
349 token_data.extend_from_slice(perms_bytes);
350 token_data.extend_from_slice(&expires_at.to_le_bytes());
351
352 let key = self.derive_key(&format!("service:{service}"));
354 let keystream = generate_keystream(&key, &nonce, token_data.len());
355 let encrypted: Vec<u8> = token_data
356 .iter()
357 .zip(keystream.iter())
358 .map(|(p, k)| p ^ k)
359 .collect();
360
361 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &key);
362 let mut hmac_data = Vec::with_capacity(16 + encrypted.len());
363 hmac_data.extend_from_slice(&nonce);
364 hmac_data.extend_from_slice(&encrypted);
365 let tag = hmac::sign(&hmac_key, &hmac_data);
366
367 let mut result = Vec::with_capacity(16 + encrypted.len() + 16);
368 result.extend_from_slice(&nonce);
369 result.extend_from_slice(&encrypted);
370 result.extend_from_slice(&tag.as_ref()[..16]);
371
372 Some(base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&result))
373 }
374
375 #[must_use]
377 pub fn validate_service_token(&self, token: &str, service: &str) -> Option<Session> {
378 let data = base64::engine::general_purpose::URL_SAFE_NO_PAD
379 .decode(token)
380 .ok()?;
381
382 if data.len() < 34 {
383 return None;
384 }
385
386 let nonce = &data[..16];
387 let encrypted = &data[16..data.len() - 16];
388 let mac = &data[data.len() - 16..];
389
390 let key = self.derive_key(&format!("service:{service}"));
391
392 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &key);
394 let mut hmac_data = Vec::with_capacity(16 + encrypted.len());
395 hmac_data.extend_from_slice(nonce);
396 hmac_data.extend_from_slice(encrypted);
397 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
398
399 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
400 return None;
401 }
402
403 let keystream = generate_keystream(&key, nonce, encrypted.len());
405 let token_data: Vec<u8> = encrypted
406 .iter()
407 .zip(keystream.iter())
408 .map(|(c, k)| c ^ k)
409 .collect();
410
411 let mut offset = 0;
413 let user_id_len = u16::from_le_bytes([token_data[offset], token_data[offset + 1]]) as usize;
414 offset += 2;
415 let user_id = String::from_utf8(token_data[offset..offset + user_id_len].to_vec()).ok()?;
416 offset += user_id_len;
417
418 let service_len = u16::from_le_bytes([token_data[offset], token_data[offset + 1]]) as usize;
419 offset += 2;
420 let token_service =
421 String::from_utf8(token_data[offset..offset + service_len].to_vec()).ok()?;
422 offset += service_len;
423
424 if token_service != service {
425 return None;
426 }
427
428 let perms_len = u16::from_le_bytes([token_data[offset], token_data[offset + 1]]) as usize;
429 offset += 2;
430 let perms_json = String::from_utf8(token_data[offset..offset + perms_len].to_vec()).ok()?;
431 let permissions: Vec<String> = serde_json::from_str(&perms_json).ok()?;
432 offset += perms_len;
433
434 let expires_at = u64::from_le_bytes(token_data[offset..offset + 8].try_into().ok()?);
435
436 let session = Session {
437 user_id,
438 permissions,
439 expires_at: Some(expires_at),
440 attributes: HashMap::new(),
441 };
442
443 if session.is_expired() {
444 return None;
445 }
446
447 Some(session)
448 }
449
450 #[must_use]
452 pub fn refresh_token(&self, token: &str) -> Option<String> {
453 let session = self.validate_token(token)?;
454 Some(self.create_token(&session.user_id, &session.permissions, self.token_ttl))
455 }
456
457 #[must_use]
459 pub fn get_identity(&self, user_id: &str) -> Option<&Identity> {
460 self.users.get(user_id).map(|u| &u.identity)
461 }
462
463 pub fn revoke_user(&mut self, user_id: &str) {
465 self.users.remove(user_id);
466 }
467}
468
469pub struct SecureSession {
471 master_key: [u8; 32],
472 rotation_interval: u64,
473 max_old_keys: usize,
474 key_version: u32,
475 keys: HashMap<u32, [u8; 32]>,
476 last_rotation: u64,
477}
478
479impl SecureSession {
480 #[must_use]
482 pub fn new(master_key: [u8; 32], rotation_interval: u64, max_old_keys: usize) -> Self {
483 let now = SystemTime::now()
484 .duration_since(UNIX_EPOCH)
485 .unwrap()
486 .as_secs();
487
488 let key = Self::derive_session_key(&master_key, 1);
489 let mut keys = HashMap::new();
490 keys.insert(1, key);
491
492 Self {
493 master_key,
494 rotation_interval: if rotation_interval == 0 {
495 3600
496 } else {
497 rotation_interval
498 },
499 max_old_keys: if max_old_keys == 0 { 3 } else { max_old_keys },
500 key_version: 1,
501 keys,
502 last_rotation: now,
503 }
504 }
505
506 fn derive_session_key(master_key: &[u8; 32], version: u32) -> [u8; 32] {
507 let mut data = Vec::with_capacity(32 + 16);
508 data.extend_from_slice(master_key);
509 data.extend_from_slice(format!("session:{version}").as_bytes());
510 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
511 let mut key = [0u8; 32];
512 key.copy_from_slice(hash.as_ref());
513 key
514 }
515
516 fn maybe_rotate(&mut self) {
517 let now = SystemTime::now()
518 .duration_since(UNIX_EPOCH)
519 .unwrap()
520 .as_secs();
521
522 if now - self.last_rotation >= self.rotation_interval {
523 self.key_version += 1;
524 let new_key = Self::derive_session_key(&self.master_key, self.key_version);
525 self.keys.insert(self.key_version, new_key);
526 self.last_rotation = now;
527
528 let mut versions: Vec<u32> = self.keys.keys().copied().collect();
530 versions.sort_by(|a, b| b.cmp(a));
531 for v in versions.into_iter().skip(self.max_old_keys + 1) {
532 self.keys.remove(&v);
533 }
534 }
535 }
536
537 pub fn encrypt(&mut self, data: &[u8]) -> Result<Vec<u8>> {
539 self.maybe_rotate();
540
541 let key = self.keys.get(&self.key_version).unwrap();
542 let rng = SystemRandom::new();
543 let mut nonce = [0u8; 16];
544 rng.fill(&mut nonce)
545 .map_err(|_| ShieldError::RandomFailed)?;
546
547 let keystream = generate_keystream(key, &nonce, data.len());
548 let ciphertext: Vec<u8> = data
549 .iter()
550 .zip(keystream.iter())
551 .map(|(p, k)| p ^ k)
552 .collect();
553
554 let version_bytes = self.key_version.to_le_bytes();
555
556 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
557 let mut hmac_data = Vec::with_capacity(4 + 16 + ciphertext.len());
558 hmac_data.extend_from_slice(&version_bytes);
559 hmac_data.extend_from_slice(&nonce);
560 hmac_data.extend_from_slice(&ciphertext);
561 let tag = hmac::sign(&hmac_key, &hmac_data);
562
563 let mut result = Vec::with_capacity(4 + 16 + ciphertext.len() + 16);
564 result.extend_from_slice(&version_bytes);
565 result.extend_from_slice(&nonce);
566 result.extend_from_slice(&ciphertext);
567 result.extend_from_slice(&tag.as_ref()[..16]);
568
569 Ok(result)
570 }
571
572 pub fn decrypt(&mut self, encrypted: &[u8]) -> Option<Vec<u8>> {
574 self.maybe_rotate();
575
576 if encrypted.len() < 36 {
577 return None;
578 }
579
580 let version = u32::from_le_bytes(encrypted[..4].try_into().ok()?);
581 let nonce = &encrypted[4..20];
582 let ciphertext = &encrypted[20..encrypted.len() - 16];
583 let mac = &encrypted[encrypted.len() - 16..];
584
585 let key = self.keys.get(&version)?;
586
587 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
589 let expected_tag = hmac::sign(&hmac_key, &encrypted[..encrypted.len() - 16]);
590
591 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
592 return None;
593 }
594
595 let keystream = generate_keystream(key, nonce, ciphertext.len());
596 let plaintext: Vec<u8> = ciphertext
597 .iter()
598 .zip(keystream.iter())
599 .map(|(c, k)| c ^ k)
600 .collect();
601
602 Some(plaintext)
603 }
604
605 #[must_use]
607 pub fn key_version(&self) -> u32 {
608 self.key_version
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_register_user() {
618 let mut provider = IdentityProvider::new([0u8; 32], 3600);
619 let identity = provider
620 .register("alice", "password123", Some("Alice Smith"), HashMap::new())
621 .unwrap();
622
623 assert_eq!(identity.user_id, "alice");
624 assert_eq!(identity.display_name, "Alice Smith");
625 }
626
627 #[test]
628 fn test_register_duplicate() {
629 let mut provider = IdentityProvider::new([0u8; 32], 3600);
630 provider
631 .register("alice", "password", None, HashMap::new())
632 .unwrap();
633 assert!(provider
634 .register("alice", "password2", None, HashMap::new())
635 .is_err());
636 }
637
638 #[test]
639 fn test_authenticate() {
640 let mut provider = IdentityProvider::new([0u8; 32], 3600);
641 provider
642 .register("alice", "password123", None, HashMap::new())
643 .unwrap();
644
645 let token = provider.authenticate("alice", "password123", &[], None);
646 assert!(token.is_some());
647 }
648
649 #[test]
650 fn test_authenticate_wrong_password() {
651 let mut provider = IdentityProvider::new([0u8; 32], 3600);
652 provider
653 .register("alice", "password123", None, HashMap::new())
654 .unwrap();
655
656 let token = provider.authenticate("alice", "wrongpassword", &[], None);
657 assert!(token.is_none());
658 }
659
660 #[test]
661 fn test_validate_token() {
662 let mut provider = IdentityProvider::new([0u8; 32], 3600);
663 provider
664 .register("alice", "password", None, HashMap::new())
665 .unwrap();
666 let token = provider
667 .authenticate("alice", "password", &[], None)
668 .unwrap();
669
670 let session = provider.validate_token(&token);
671 assert!(session.is_some());
672 assert_eq!(session.unwrap().user_id, "alice");
673 }
674
675 #[test]
676 fn test_service_token() {
677 let mut provider = IdentityProvider::new([0u8; 32], 3600);
678 provider
679 .register("alice", "password", None, HashMap::new())
680 .unwrap();
681 let session_token = provider
682 .authenticate("alice", "password", &[], None)
683 .unwrap();
684
685 let service_token = provider
686 .create_service_token(
687 &session_token,
688 "api.example.com",
689 &["read".to_string()],
690 300,
691 )
692 .unwrap();
693
694 let session = provider.validate_service_token(&service_token, "api.example.com");
695 assert!(session.is_some());
696 assert_eq!(session.as_ref().unwrap().user_id, "alice");
697 assert!(session.unwrap().has_permission("read"));
698 }
699
700 #[test]
701 fn test_service_token_wrong_service() {
702 let mut provider = IdentityProvider::new([0u8; 32], 3600);
703 provider
704 .register("alice", "password", None, HashMap::new())
705 .unwrap();
706 let session_token = provider
707 .authenticate("alice", "password", &[], None)
708 .unwrap();
709 let service_token = provider
710 .create_service_token(&session_token, "api.example.com", &[], 300)
711 .unwrap();
712
713 let session = provider.validate_service_token(&service_token, "other.example.com");
714 assert!(session.is_none());
715 }
716
717 #[test]
718 fn test_secure_session() {
719 let mut session = SecureSession::new([0u8; 32], 3600, 3);
720 let plaintext = b"session data";
721 let encrypted = session.encrypt(plaintext).unwrap();
722 let decrypted = session.decrypt(&encrypted).unwrap();
723 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
724 }
725
726 #[test]
727 fn test_secure_session_tampered() {
728 let mut session = SecureSession::new([0u8; 32], 3600, 3);
729 let mut encrypted = session.encrypt(b"data").unwrap();
730 encrypted[20] ^= 0xFF;
731 assert!(session.decrypt(&encrypted).is_none());
732 }
733}