1use crate::{
2 ChunkInput, Result, RuntimeConfig, SearchRequest, SearchResult, SqlRite, SqlRiteError,
3};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{HashMap, HashSet};
7use std::fs::{self, OpenOptions};
8use std::io::Write;
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11
12const MINIMUM_TENANT_KEY_BYTES: usize = 16;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum AccessOperation {
17 Ingest,
18 Query,
19 SqlAdmin,
20 DeleteTenant,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AccessContext {
25 pub actor_id: String,
26 pub tenant_id: String,
27 pub roles: Vec<String>,
28}
29
30impl AccessContext {
31 pub fn new(actor_id: impl Into<String>, tenant_id: impl Into<String>) -> Self {
32 Self {
33 actor_id: actor_id.into(),
34 tenant_id: tenant_id.into(),
35 roles: Vec::new(),
36 }
37 }
38
39 pub fn with_roles(mut self, roles: Vec<String>) -> Self {
40 self.roles = roles;
41 self
42 }
43
44 fn is_admin(&self) -> bool {
45 self.roles.iter().any(|role| role == "admin")
46 }
47}
48
49pub trait AccessPolicy: Send + Sync {
50 fn authorize(
51 &self,
52 context: &AccessContext,
53 operation: AccessOperation,
54 target_tenant: &str,
55 ) -> Result<()>;
56}
57
58#[derive(Debug, Default, Clone)]
59pub struct AllowAllPolicy;
60
61impl AccessPolicy for AllowAllPolicy {
62 fn authorize(
63 &self,
64 context: &AccessContext,
65 _operation: AccessOperation,
66 target_tenant: &str,
67 ) -> Result<()> {
68 if context.tenant_id.trim().is_empty() || target_tenant.trim().is_empty() {
69 return Err(SqlRiteError::InvalidTenantId);
70 }
71 if context.tenant_id != target_tenant && !context.is_admin() {
72 return Err(SqlRiteError::AuthorizationDenied(format!(
73 "tenant `{}` cannot access tenant `{}`",
74 context.tenant_id, target_tenant
75 )));
76 }
77 Ok(())
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct RbacPolicyConfig {
83 pub role_permissions: HashMap<String, Vec<AccessOperation>>,
84 pub cross_tenant_roles: Vec<String>,
85}
86
87impl Default for RbacPolicyConfig {
88 fn default() -> Self {
89 Self {
90 role_permissions: HashMap::from([
91 ("reader".to_string(), vec![AccessOperation::Query]),
92 (
93 "writer".to_string(),
94 vec![AccessOperation::Query, AccessOperation::Ingest],
95 ),
96 (
97 "tenant_admin".to_string(),
98 vec![
99 AccessOperation::Query,
100 AccessOperation::Ingest,
101 AccessOperation::DeleteTenant,
102 ],
103 ),
104 (
105 "admin".to_string(),
106 vec![
107 AccessOperation::Query,
108 AccessOperation::Ingest,
109 AccessOperation::SqlAdmin,
110 AccessOperation::DeleteTenant,
111 ],
112 ),
113 ]),
114 cross_tenant_roles: vec!["admin".to_string()],
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
120pub struct RbacPolicy {
121 role_permissions: HashMap<String, HashSet<AccessOperation>>,
122 cross_tenant_roles: HashSet<String>,
123}
124
125impl Default for RbacPolicy {
126 fn default() -> Self {
127 Self::from_config(RbacPolicyConfig::default())
128 }
129}
130
131impl RbacPolicy {
132 pub fn from_config(config: RbacPolicyConfig) -> Self {
133 let role_permissions = config
134 .role_permissions
135 .into_iter()
136 .map(|(role, operations)| (role, operations.into_iter().collect()))
137 .collect();
138 let cross_tenant_roles = config.cross_tenant_roles.into_iter().collect();
139 Self {
140 role_permissions,
141 cross_tenant_roles,
142 }
143 }
144
145 pub fn to_config(&self) -> RbacPolicyConfig {
146 let mut role_permissions = HashMap::new();
147 for (role, operations) in &self.role_permissions {
148 let mut values = operations.iter().copied().collect::<Vec<_>>();
149 values.sort_by_key(|operation| match operation {
150 AccessOperation::Query => 0,
151 AccessOperation::Ingest => 1,
152 AccessOperation::SqlAdmin => 2,
153 AccessOperation::DeleteTenant => 3,
154 });
155 role_permissions.insert(role.clone(), values);
156 }
157
158 let mut cross_tenant_roles = self.cross_tenant_roles.iter().cloned().collect::<Vec<_>>();
159 cross_tenant_roles.sort();
160 RbacPolicyConfig {
161 role_permissions,
162 cross_tenant_roles,
163 }
164 }
165
166 pub fn load_from_json_file(path: impl AsRef<Path>) -> Result<Self> {
167 let path = path.as_ref();
168 if !path.exists() {
169 return Ok(Self::default());
170 }
171 let payload = fs::read_to_string(path)?;
172 let config = serde_json::from_str::<RbacPolicyConfig>(&payload)
173 .map_err(|e| SqlRiteError::UnsupportedOperation(e.to_string()))?;
174 Ok(Self::from_config(config))
175 }
176
177 pub fn save_to_json_file(&self, path: impl AsRef<Path>) -> Result<()> {
178 let path = path.as_ref();
179 if let Some(parent) = path.parent()
180 && !parent.as_os_str().is_empty()
181 {
182 fs::create_dir_all(parent)?;
183 }
184
185 let payload = serde_json::to_string_pretty(&self.to_config())?;
186 let temp = path.with_extension("tmp");
187 fs::write(&temp, payload)?;
188 fs::rename(temp, path)?;
189 Ok(())
190 }
191
192 pub fn role_names(&self) -> Vec<String> {
193 let mut roles = self.role_permissions.keys().cloned().collect::<Vec<_>>();
194 roles.sort();
195 roles
196 }
197}
198
199impl AccessPolicy for RbacPolicy {
200 fn authorize(
201 &self,
202 context: &AccessContext,
203 operation: AccessOperation,
204 target_tenant: &str,
205 ) -> Result<()> {
206 if context.tenant_id.trim().is_empty() || target_tenant.trim().is_empty() {
207 return Err(SqlRiteError::InvalidTenantId);
208 }
209
210 let allowed = context.roles.iter().any(|role| {
211 self.role_permissions
212 .get(role)
213 .is_some_and(|operations| operations.contains(&operation))
214 });
215 if !allowed {
216 return Err(SqlRiteError::AuthorizationDenied(format!(
217 "actor `{}` lacks role permission for {:?}",
218 context.actor_id, operation
219 )));
220 }
221
222 if context.tenant_id != target_tenant
223 && !context
224 .roles
225 .iter()
226 .any(|role| self.cross_tenant_roles.contains(role))
227 {
228 return Err(SqlRiteError::AuthorizationDenied(format!(
229 "tenant `{}` cannot access tenant `{}`",
230 context.tenant_id, target_tenant
231 )));
232 }
233
234 Ok(())
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct TenantKey {
240 pub key_id: String,
241 pub material: Vec<u8>,
242}
243
244impl TenantKey {
245 pub fn new(key_id: impl Into<String>, material: impl AsRef<[u8]>) -> Result<Self> {
246 let key_id = key_id.into();
247 let material = material.as_ref().to_vec();
248 if key_id.trim().is_empty() || material.len() < MINIMUM_TENANT_KEY_BYTES {
249 return Err(SqlRiteError::UnsupportedOperation(format!(
250 "tenant key_id/material are required and key material must be at least {MINIMUM_TENANT_KEY_BYTES} bytes"
251 )));
252 }
253 Ok(Self { key_id, material })
254 }
255}
256
257pub trait TenantKeyRegistry: Send + Sync {
258 fn active_key(&self, tenant_id: &str) -> Option<TenantKey>;
259 fn key_by_id(&self, tenant_id: &str, key_id: &str) -> Option<TenantKey>;
260}
261
262#[derive(Debug, Default)]
263pub struct InMemoryTenantKeyRegistry {
264 keys: Mutex<HashMap<String, TenantKeyState>>,
265}
266
267#[derive(Debug, Default, Clone)]
268struct TenantKeyState {
269 active_key_id: Option<String>,
270 keys: HashMap<String, TenantKey>,
271}
272
273impl InMemoryTenantKeyRegistry {
274 pub fn new() -> Self {
275 Self::default()
276 }
277
278 pub fn set_active_key(&self, tenant_id: &str, key: TenantKey) -> Result<()> {
279 let mut guard = self.keys.lock().map_err(|_| {
280 SqlRiteError::UnsupportedOperation("tenant key registry mutex poisoned".to_string())
281 })?;
282 let state = guard.entry(tenant_id.to_string()).or_default();
283 state.active_key_id = Some(key.key_id.clone());
284 state.keys.insert(key.key_id.clone(), key);
285 Ok(())
286 }
287
288 pub fn set_key(&self, tenant_id: &str, key: TenantKey, make_active: bool) -> Result<()> {
289 let mut guard = self.keys.lock().map_err(|_| {
290 SqlRiteError::UnsupportedOperation("tenant key registry mutex poisoned".to_string())
291 })?;
292 let state = guard.entry(tenant_id.to_string()).or_default();
293 if make_active {
294 state.active_key_id = Some(key.key_id.clone());
295 }
296 state.keys.insert(key.key_id.clone(), key);
297 Ok(())
298 }
299
300 pub fn load_from_json_file(path: impl AsRef<Path>) -> Result<Self> {
301 let path = path.as_ref();
302 if !path.exists() {
303 return Ok(Self::new());
304 }
305 let payload = fs::read_to_string(path)?;
306 let serializable = serde_json::from_str::<SerializableTenantKeyRegistry>(&payload)
307 .map_err(|e| SqlRiteError::UnsupportedOperation(e.to_string()))?;
308
309 let mut tenants = HashMap::new();
310 for (tenant_id, state) in serializable.tenants {
311 let mut keys = HashMap::new();
312 for key in state.keys {
313 keys.insert(key.key_id.clone(), key);
314 }
315 tenants.insert(
316 tenant_id,
317 TenantKeyState {
318 active_key_id: state.active_key_id,
319 keys,
320 },
321 );
322 }
323 Ok(Self {
324 keys: Mutex::new(tenants),
325 })
326 }
327
328 pub fn save_to_json_file(&self, path: impl AsRef<Path>) -> Result<()> {
329 let path = path.as_ref();
330 if let Some(parent) = path.parent()
331 && !parent.as_os_str().is_empty()
332 {
333 fs::create_dir_all(parent)?;
334 }
335
336 let guard = self.keys.lock().map_err(|_| {
337 SqlRiteError::UnsupportedOperation("tenant key registry mutex poisoned".to_string())
338 })?;
339 let tenants = guard
340 .iter()
341 .map(|(tenant_id, state)| {
342 let keys = state.keys.values().cloned().collect::<Vec<_>>();
343 (
344 tenant_id.clone(),
345 SerializableTenantKeyState {
346 active_key_id: state.active_key_id.clone(),
347 keys,
348 },
349 )
350 })
351 .collect::<HashMap<_, _>>();
352
353 let payload = serde_json::to_string_pretty(&SerializableTenantKeyRegistry { tenants })?;
354 let temp = path.with_extension("tmp");
355 fs::write(&temp, payload)?;
356 fs::rename(temp, path)?;
357 Ok(())
358 }
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
362struct SerializableTenantKeyState {
363 active_key_id: Option<String>,
364 keys: Vec<TenantKey>,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
368struct SerializableTenantKeyRegistry {
369 tenants: HashMap<String, SerializableTenantKeyState>,
370}
371
372impl TenantKeyRegistry for InMemoryTenantKeyRegistry {
373 fn active_key(&self, tenant_id: &str) -> Option<TenantKey> {
374 let guard = self.keys.lock().ok()?;
375 let state = guard.get(tenant_id)?;
376 let active = state.active_key_id.as_ref()?;
377 state.keys.get(active).cloned()
378 }
379
380 fn key_by_id(&self, tenant_id: &str, key_id: &str) -> Option<TenantKey> {
381 let guard = self.keys.lock().ok()?;
382 guard.get(tenant_id)?.keys.get(key_id).cloned()
383 }
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct AuditEvent {
388 pub unix_ms: u64,
389 pub actor_id: String,
390 pub tenant_id: String,
391 pub operation: AccessOperation,
392 pub allowed: bool,
393 pub detail: Value,
394}
395
396pub trait AuditLogger: Send + Sync {
397 fn log(&self, event: &AuditEvent) -> Result<()>;
398}
399
400#[derive(Debug, Clone, Default, Serialize, Deserialize)]
401pub struct AuditQuery {
402 pub actor_id: Option<String>,
403 pub tenant_id: Option<String>,
404 pub operation: Option<AccessOperation>,
405 pub allowed: Option<bool>,
406 pub from_unix_ms: Option<u64>,
407 pub to_unix_ms: Option<u64>,
408 pub limit: Option<usize>,
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
412pub enum AuditExportFormat {
413 Json,
414 Jsonl,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct AuditExportReport {
419 pub source_path: PathBuf,
420 pub output_path: Option<PathBuf>,
421 pub matched_events: usize,
422 pub exported_events: usize,
423 pub format: String,
424 pub filters: AuditQuery,
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct KeyRotationReport {
429 pub tenant_id: String,
430 pub metadata_field: String,
431 pub target_key_id: String,
432 pub scanned_chunks: usize,
433 pub encrypted_chunks: usize,
434 pub rotated_chunks: usize,
435 pub target_key_matches: usize,
436 pub stale_key_ids: Vec<String>,
437 pub verified_all_target_key: bool,
438}
439
440#[derive(Debug)]
441pub struct JsonlAuditLogger {
442 path: PathBuf,
443 redacted_fields: HashSet<String>,
444 lock: Mutex<()>,
445}
446
447impl JsonlAuditLogger {
448 pub fn new(
449 path: impl AsRef<Path>,
450 redacted_fields: impl IntoIterator<Item = String>,
451 ) -> Result<Self> {
452 let path = path.as_ref().to_path_buf();
453 if let Some(parent) = path.parent()
454 && !parent.as_os_str().is_empty()
455 {
456 fs::create_dir_all(parent)?;
457 }
458 Ok(Self {
459 path,
460 redacted_fields: redacted_fields.into_iter().collect(),
461 lock: Mutex::new(()),
462 })
463 }
464
465 fn redact(&self, detail: &Value) -> Value {
466 match detail {
467 Value::Object(map) => {
468 let mut copy = map.clone();
469 for key in &self.redacted_fields {
470 if copy.contains_key(key) {
471 copy.insert(key.clone(), Value::String("[REDACTED]".to_string()));
472 }
473 }
474 Value::Object(copy)
475 }
476 _ => detail.clone(),
477 }
478 }
479}
480
481impl AuditQuery {
482 fn matches(&self, event: &AuditEvent) -> bool {
483 if let Some(actor_id) = &self.actor_id
484 && &event.actor_id != actor_id
485 {
486 return false;
487 }
488 if let Some(tenant_id) = &self.tenant_id
489 && &event.tenant_id != tenant_id
490 {
491 return false;
492 }
493 if let Some(operation) = self.operation
494 && event.operation != operation
495 {
496 return false;
497 }
498 if let Some(allowed) = self.allowed
499 && event.allowed != allowed
500 {
501 return false;
502 }
503 if let Some(from_unix_ms) = self.from_unix_ms
504 && event.unix_ms < from_unix_ms
505 {
506 return false;
507 }
508 if let Some(to_unix_ms) = self.to_unix_ms
509 && event.unix_ms > to_unix_ms
510 {
511 return false;
512 }
513 true
514 }
515}
516
517impl AuditLogger for JsonlAuditLogger {
518 fn log(&self, event: &AuditEvent) -> Result<()> {
519 let _guard = self.lock.lock().map_err(|_| {
520 SqlRiteError::UnsupportedOperation("audit logger mutex poisoned".to_string())
521 })?;
522 let mut file = OpenOptions::new()
523 .create(true)
524 .append(true)
525 .open(&self.path)?;
526
527 let serialized = serde_json::to_string(&AuditEvent {
528 detail: self.redact(&event.detail),
529 ..event.clone()
530 })?;
531 file.write_all(serialized.as_bytes())?;
532 file.write_all(b"\n")?;
533 Ok(())
534 }
535}
536
537pub struct SecureSqlRite<P: AccessPolicy, A: AuditLogger> {
538 db: SqlRite,
539 policy: P,
540 audit_logger: A,
541}
542
543impl<P: AccessPolicy, A: AuditLogger> SecureSqlRite<P, A> {
544 pub fn open_with_config(
545 path: impl AsRef<Path>,
546 runtime: RuntimeConfig,
547 policy: P,
548 audit_logger: A,
549 ) -> Result<Self> {
550 Ok(Self {
551 db: SqlRite::open_with_config(path, runtime)?,
552 policy,
553 audit_logger,
554 })
555 }
556
557 pub fn from_db(db: SqlRite, policy: P, audit_logger: A) -> Self {
558 Self {
559 db,
560 policy,
561 audit_logger,
562 }
563 }
564
565 pub fn ingest_chunks(&self, context: &AccessContext, chunks: &[ChunkInput]) -> Result<()> {
566 self.policy
567 .authorize(context, AccessOperation::Ingest, &context.tenant_id)?;
568
569 let mut enriched = Vec::with_capacity(chunks.len());
570 for chunk in chunks {
571 let metadata = merge_tenant_metadata(&chunk.metadata, &context.tenant_id);
572 enriched.push(ChunkInput {
573 id: chunk.id.clone(),
574 doc_id: chunk.doc_id.clone(),
575 content: chunk.content.clone(),
576 embedding: chunk.embedding.clone(),
577 metadata,
578 source: chunk.source.clone(),
579 });
580 }
581
582 let result = self.db.ingest_chunks(&enriched);
583 self.audit(
584 context,
585 AccessOperation::Ingest,
586 result.is_ok(),
587 serde_json::json!({
588 "chunk_count": chunks.len(),
589 }),
590 )?;
591 result
592 }
593
594 pub fn ingest_chunks_with_encryption<R: TenantKeyRegistry>(
595 &self,
596 context: &AccessContext,
597 chunks: &[ChunkInput],
598 key_registry: &R,
599 sensitive_metadata_fields: &[&str],
600 ) -> Result<()> {
601 self.policy
602 .authorize(context, AccessOperation::Ingest, &context.tenant_id)?;
603 let active_key = key_registry.active_key(&context.tenant_id).ok_or_else(|| {
604 SqlRiteError::UnsupportedOperation(format!(
605 "no active key configured for tenant `{}`",
606 context.tenant_id
607 ))
608 })?;
609
610 let mut encrypted_chunks = Vec::with_capacity(chunks.len());
611 for chunk in chunks {
612 let mut metadata = merge_tenant_metadata(&chunk.metadata, &context.tenant_id);
613 encrypt_metadata_fields(
614 &mut metadata,
615 &active_key,
616 &context.tenant_id,
617 sensitive_metadata_fields,
618 )?;
619
620 encrypted_chunks.push(ChunkInput {
621 id: chunk.id.clone(),
622 doc_id: chunk.doc_id.clone(),
623 content: chunk.content.clone(),
624 embedding: chunk.embedding.clone(),
625 metadata,
626 source: chunk.source.clone(),
627 });
628 }
629
630 self.db.ingest_chunks(&encrypted_chunks)
631 }
632
633 pub fn search(
634 &self,
635 context: &AccessContext,
636 mut request: SearchRequest,
637 ) -> Result<Vec<SearchResult>> {
638 self.policy
639 .authorize(context, AccessOperation::Query, &context.tenant_id)?;
640
641 if let Some(existing) = request.metadata_filters.get("tenant")
642 && existing != &context.tenant_id
643 {
644 self.audit(
645 context,
646 AccessOperation::Query,
647 false,
648 serde_json::json!({"reason": "tenant filter mismatch"}),
649 )?;
650 return Err(SqlRiteError::AuthorizationDenied(
651 "tenant filter mismatch".to_string(),
652 ));
653 }
654
655 request
656 .metadata_filters
657 .insert("tenant".to_string(), context.tenant_id.clone());
658 let result = self.db.search(request);
659
660 self.audit(
661 context,
662 AccessOperation::Query,
663 result.is_ok(),
664 serde_json::json!({
665 "result_count": result.as_ref().map(|items| items.len()).unwrap_or(0),
666 }),
667 )?;
668 result
669 }
670
671 pub fn delete_tenant_data(&self, context: &AccessContext, tenant_id: &str) -> Result<usize> {
672 self.policy
673 .authorize(context, AccessOperation::DeleteTenant, tenant_id)?;
674
675 let result = self.db.delete_chunks_by_metadata("tenant", tenant_id);
676 self.audit(
677 context,
678 AccessOperation::DeleteTenant,
679 result.is_ok(),
680 serde_json::json!({
681 "target_tenant": tenant_id,
682 "deleted": result.as_ref().copied().unwrap_or(0),
683 }),
684 )?;
685 result
686 }
687
688 pub fn db(&self) -> &SqlRite {
689 &self.db
690 }
691
692 pub fn into_inner(self) -> SqlRite {
693 self.db
694 }
695
696 fn audit(
697 &self,
698 context: &AccessContext,
699 operation: AccessOperation,
700 allowed: bool,
701 detail: Value,
702 ) -> Result<()> {
703 self.audit_logger.log(&AuditEvent {
704 unix_ms: now_unix_ms(),
705 actor_id: context.actor_id.clone(),
706 tenant_id: context.tenant_id.clone(),
707 operation,
708 allowed,
709 detail,
710 })
711 }
712}
713
714pub fn rotate_tenant_encryption_key<R: TenantKeyRegistry>(
715 db: &SqlRite,
716 tenant_id: &str,
717 metadata_field: &str,
718 key_registry: &R,
719 new_key_id: &str,
720) -> Result<usize> {
721 Ok(rotate_tenant_encryption_key_with_report(
722 db,
723 tenant_id,
724 metadata_field,
725 key_registry,
726 new_key_id,
727 )?
728 .rotated_chunks)
729}
730
731pub fn rotate_tenant_encryption_key_with_report<R: TenantKeyRegistry>(
732 db: &SqlRite,
733 tenant_id: &str,
734 metadata_field: &str,
735 key_registry: &R,
736 new_key_id: &str,
737) -> Result<KeyRotationReport> {
738 let new_key = key_registry
739 .key_by_id(tenant_id, new_key_id)
740 .ok_or_else(|| SqlRiteError::UnsupportedOperation("new key not found".to_string()))?;
741
742 let mut updated = 0usize;
743 let mut offset = 0usize;
744 const PAGE_SIZE: usize = 256;
745 let mut scanned_chunks = 0usize;
746 let mut encrypted_chunks = 0usize;
747 let mut target_key_matches = 0usize;
748 let mut stale_key_ids = HashSet::new();
749
750 loop {
751 let page = db.list_chunks_page(offset, PAGE_SIZE, Some(tenant_id))?;
752 if page.is_empty() {
753 break;
754 }
755
756 for chunk in &page {
757 scanned_chunks = scanned_chunks.saturating_add(1);
758 let mut metadata = chunk.metadata.clone();
759 let Some(encrypted_value) = metadata.get(metadata_field).and_then(Value::as_str) else {
760 continue;
761 };
762 let Some((old_key_id, cipher_hex)) = parse_encrypted_value(encrypted_value) else {
763 continue;
764 };
765 encrypted_chunks = encrypted_chunks.saturating_add(1);
766 if old_key_id == new_key_id {
767 target_key_matches = target_key_matches.saturating_add(1);
768 continue;
769 }
770 stale_key_ids.insert(old_key_id.to_string());
771
772 let old_key = key_registry
773 .key_by_id(tenant_id, old_key_id)
774 .ok_or_else(|| {
775 SqlRiteError::UnsupportedOperation(format!(
776 "old key `{old_key_id}` not found for tenant `{tenant_id}`"
777 ))
778 })?;
779
780 let plaintext = decrypt_with_key(cipher_hex, &old_key.material)?;
781 let rotated = encrypt_with_key(&plaintext, tenant_id, &new_key);
782 if let Value::Object(ref mut map) = metadata {
783 map.insert(metadata_field.to_string(), Value::String(rotated));
784 map.insert(
785 "tenant_key_id".to_string(),
786 Value::String(new_key.key_id.clone()),
787 );
788 }
789
790 db.update_chunk_metadata(&chunk.id, &metadata)?;
791 updated += 1;
792 }
793
794 offset += page.len();
795 }
796
797 let stale_key_ids = stale_key_ids.into_iter().collect::<Vec<_>>();
798 Ok(KeyRotationReport {
799 tenant_id: tenant_id.to_string(),
800 metadata_field: metadata_field.to_string(),
801 target_key_id: new_key.key_id,
802 scanned_chunks,
803 encrypted_chunks,
804 rotated_chunks: updated,
805 target_key_matches: target_key_matches.saturating_add(updated),
806 stale_key_ids,
807 verified_all_target_key: encrypted_chunks == target_key_matches.saturating_add(updated),
808 })
809}
810
811pub fn inspect_tenant_key_rotation<R: TenantKeyRegistry>(
812 db: &SqlRite,
813 tenant_id: &str,
814 metadata_field: &str,
815 key_registry: &R,
816 target_key_id: &str,
817) -> Result<KeyRotationReport> {
818 let target_key = key_registry
819 .key_by_id(tenant_id, target_key_id)
820 .ok_or_else(|| SqlRiteError::UnsupportedOperation("target key not found".to_string()))?;
821
822 let mut offset = 0usize;
823 const PAGE_SIZE: usize = 256;
824 let mut scanned_chunks = 0usize;
825 let mut encrypted_chunks = 0usize;
826 let mut target_key_matches = 0usize;
827 let mut stale_key_ids = HashSet::new();
828
829 loop {
830 let page = db.list_chunks_page(offset, PAGE_SIZE, Some(tenant_id))?;
831 if page.is_empty() {
832 break;
833 }
834
835 for chunk in &page {
836 scanned_chunks = scanned_chunks.saturating_add(1);
837 let Some(encrypted_value) = chunk.metadata.get(metadata_field).and_then(Value::as_str)
838 else {
839 continue;
840 };
841 let Some((key_id, _)) = parse_encrypted_value(encrypted_value) else {
842 continue;
843 };
844 encrypted_chunks = encrypted_chunks.saturating_add(1);
845 if key_id == target_key.key_id {
846 target_key_matches = target_key_matches.saturating_add(1);
847 } else {
848 stale_key_ids.insert(key_id.to_string());
849 }
850 }
851
852 offset += page.len();
853 }
854
855 Ok(KeyRotationReport {
856 tenant_id: tenant_id.to_string(),
857 metadata_field: metadata_field.to_string(),
858 target_key_id: target_key.key_id,
859 scanned_chunks,
860 encrypted_chunks,
861 rotated_chunks: 0,
862 target_key_matches,
863 stale_key_ids: stale_key_ids.into_iter().collect(),
864 verified_all_target_key: encrypted_chunks == target_key_matches,
865 })
866}
867
868pub fn read_audit_events(path: impl AsRef<Path>) -> Result<Vec<AuditEvent>> {
869 let path = path.as_ref();
870 if !path.exists() {
871 return Ok(Vec::new());
872 }
873 let payload = fs::read_to_string(path)?;
874 payload
875 .lines()
876 .filter(|line| !line.trim().is_empty())
877 .map(|line| {
878 serde_json::from_str::<AuditEvent>(line)
879 .map_err(|e| SqlRiteError::UnsupportedOperation(e.to_string()))
880 })
881 .collect()
882}
883
884pub fn export_audit_events(
885 source_path: impl AsRef<Path>,
886 query: &AuditQuery,
887 output_path: Option<&Path>,
888 format: AuditExportFormat,
889) -> Result<AuditExportReport> {
890 let source_path = source_path.as_ref().to_path_buf();
891 let mut events = read_audit_events(&source_path)?
892 .into_iter()
893 .filter(|event| query.matches(event))
894 .collect::<Vec<_>>();
895
896 if let Some(limit) = query.limit {
897 events.truncate(limit);
898 }
899
900 if let Some(path) = output_path {
901 if let Some(parent) = path.parent()
902 && !parent.as_os_str().is_empty()
903 {
904 fs::create_dir_all(parent)?;
905 }
906 let payload = match format {
907 AuditExportFormat::Json => serde_json::to_string_pretty(&events)?,
908 AuditExportFormat::Jsonl => events
909 .iter()
910 .map(serde_json::to_string)
911 .collect::<std::result::Result<Vec<_>, _>>()?
912 .join("\n"),
913 };
914 fs::write(
915 path,
916 if payload.is_empty() {
917 payload
918 } else {
919 format!("{payload}\n")
920 },
921 )?;
922 }
923
924 Ok(AuditExportReport {
925 source_path,
926 output_path: output_path.map(Path::to_path_buf),
927 matched_events: events.len(),
928 exported_events: events.len(),
929 format: match format {
930 AuditExportFormat::Json => "json".to_string(),
931 AuditExportFormat::Jsonl => "jsonl".to_string(),
932 },
933 filters: query.clone(),
934 })
935}
936
937fn merge_tenant_metadata(metadata: &Value, tenant_id: &str) -> Value {
938 match metadata {
939 Value::Object(map) => {
940 let mut merged = map.clone();
941 merged.insert("tenant".to_string(), Value::String(tenant_id.to_string()));
942 Value::Object(merged)
943 }
944 _ => {
945 let mut merged = serde_json::Map::new();
946 merged.insert("tenant".to_string(), Value::String(tenant_id.to_string()));
947 Value::Object(merged)
948 }
949 }
950}
951
952fn encrypt_metadata_fields(
953 metadata: &mut Value,
954 key: &TenantKey,
955 tenant_id: &str,
956 sensitive_metadata_fields: &[&str],
957) -> Result<()> {
958 let Some(map) = metadata.as_object_mut() else {
959 return Err(SqlRiteError::UnsupportedOperation(
960 "metadata must be a json object for encryption".to_string(),
961 ));
962 };
963
964 for field in sensitive_metadata_fields {
965 let Some(raw) = map.get(*field).and_then(Value::as_str) else {
966 continue;
967 };
968 map.insert(
969 (*field).to_string(),
970 Value::String(encrypt_with_key(raw, tenant_id, key)),
971 );
972 }
973 map.insert(
974 "tenant_key_id".to_string(),
975 Value::String(key.key_id.clone()),
976 );
977 Ok(())
978}
979
980fn encrypt_with_key(plaintext: &str, tenant_id: &str, key: &TenantKey) -> String {
981 let mut scoped = Vec::new();
982 scoped.extend_from_slice(tenant_id.as_bytes());
983 scoped.push(0);
984 scoped.extend_from_slice(plaintext.as_bytes());
985 let cipher = xor_with_key(&scoped, &key.material);
986 format!("enc:v1:{}:{}", key.key_id, hex_encode(&cipher))
987}
988
989fn decrypt_with_key(cipher_hex: &str, key_material: &[u8]) -> Result<String> {
990 let cipher = hex_decode(cipher_hex)?;
991 let plain_scoped = xor_with_key(&cipher, key_material);
992 let Some(separator_idx) = plain_scoped.iter().position(|byte| *byte == 0) else {
993 return Err(SqlRiteError::UnsupportedOperation(
994 "invalid encrypted payload format".to_string(),
995 ));
996 };
997 let plaintext = &plain_scoped[(separator_idx + 1)..];
998 String::from_utf8(plaintext.to_vec()).map_err(|_| {
999 SqlRiteError::UnsupportedOperation("invalid utf8 in decrypted payload".to_string())
1000 })
1001}
1002
1003fn xor_with_key(input: &[u8], key: &[u8]) -> Vec<u8> {
1004 input
1005 .iter()
1006 .enumerate()
1007 .map(|(idx, byte)| byte ^ key[idx % key.len()])
1008 .collect()
1009}
1010
1011fn parse_encrypted_value(value: &str) -> Option<(&str, &str)> {
1012 let mut parts = value.splitn(4, ':');
1013 let marker = parts.next()?;
1014 let version = parts.next()?;
1015 let key_id = parts.next()?;
1016 let payload = parts.next()?;
1017 if marker == "enc" && version == "v1" && !key_id.is_empty() && !payload.is_empty() {
1018 Some((key_id, payload))
1019 } else {
1020 None
1021 }
1022}
1023
1024fn hex_encode(bytes: &[u8]) -> String {
1025 let mut out = String::with_capacity(bytes.len() * 2);
1026 for byte in bytes {
1027 out.push_str(&format!("{byte:02x}"));
1028 }
1029 out
1030}
1031
1032fn hex_decode(value: &str) -> Result<Vec<u8>> {
1033 if !value.len().is_multiple_of(2) {
1034 return Err(SqlRiteError::UnsupportedOperation(
1035 "invalid hex payload length".to_string(),
1036 ));
1037 }
1038 let mut out = Vec::with_capacity(value.len() / 2);
1039 for idx in (0..value.len()).step_by(2) {
1040 let byte = u8::from_str_radix(&value[idx..idx + 2], 16)
1041 .map_err(|_| SqlRiteError::UnsupportedOperation("invalid hex payload".to_string()))?;
1042 out.push(byte);
1043 }
1044 Ok(out)
1045}
1046
1047fn now_unix_ms() -> u64 {
1048 std::time::SystemTime::now()
1049 .duration_since(std::time::UNIX_EPOCH)
1050 .map(|d| d.as_millis() as u64)
1051 .unwrap_or(0)
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use super::*;
1057 use crate::{ChunkInput, RuntimeConfig, SearchRequest, SqlRite};
1058 use serde_json::json;
1059 use tempfile::tempdir;
1060
1061 #[test]
1062 fn secure_wrapper_enforces_tenant_filter() -> Result<()> {
1063 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1064 let tmp = tempdir()?;
1065 let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1066 let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1067
1068 let ctx_acme = AccessContext::new("user-1", "acme");
1069 secure.ingest_chunks(
1070 &ctx_acme,
1071 &[ChunkInput {
1072 id: "c1".to_string(),
1073 doc_id: "d1".to_string(),
1074 content: "tenant scoped".to_string(),
1075 embedding: vec![1.0, 0.0],
1076 metadata: json!({}),
1077 source: None,
1078 }],
1079 )?;
1080
1081 let ctx_beta = AccessContext::new("user-2", "beta");
1082 let beta_results = secure.search(
1083 &ctx_beta,
1084 SearchRequest {
1085 query_text: Some("tenant".to_string()),
1086 top_k: 5,
1087 ..Default::default()
1088 },
1089 )?;
1090 assert!(beta_results.is_empty());
1091
1092 let acme_results = secure.search(
1093 &ctx_acme,
1094 SearchRequest {
1095 query_text: Some("tenant".to_string()),
1096 top_k: 5,
1097 ..Default::default()
1098 },
1099 )?;
1100 assert_eq!(acme_results.len(), 1);
1101 Ok(())
1102 }
1103
1104 #[test]
1105 fn non_admin_cannot_delete_other_tenant() -> Result<()> {
1106 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1107 let tmp = tempdir()?;
1108 let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1109 let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1110
1111 let err = secure
1112 .delete_tenant_data(&AccessContext::new("u1", "acme"), "beta")
1113 .expect_err("cross tenant delete should fail");
1114 assert!(matches!(err, SqlRiteError::AuthorizationDenied(_)));
1115 Ok(())
1116 }
1117
1118 #[test]
1119 fn encrypted_ingest_and_key_rotation_workflow() -> Result<()> {
1120 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1121 let tmp = tempdir()?;
1122 let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1123 let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1124
1125 let key_registry = InMemoryTenantKeyRegistry::new();
1126 key_registry.set_active_key("acme", TenantKey::new("k1", b"secret-key-00001")?)?;
1127 key_registry.set_active_key("acme", TenantKey::new("k2", b"secret-key-00002")?)?;
1128
1129 let ctx = AccessContext::new("user-enc", "acme");
1130 secure.ingest_chunks_with_encryption(
1131 &ctx,
1132 &[ChunkInput {
1133 id: "c-sec".to_string(),
1134 doc_id: "d-sec".to_string(),
1135 content: "sensitive chunk".to_string(),
1136 embedding: vec![1.0, 0.0],
1137 metadata: json!({"secret_payload": "highly-sensitive"}),
1138 source: None,
1139 }],
1140 &key_registry,
1141 &["secret_payload"],
1142 )?;
1143
1144 let before = secure
1145 .db()
1146 .list_chunks_page(0, 10, Some("acme"))?
1147 .into_iter()
1148 .next()
1149 .expect("chunk exists");
1150 let before_payload = before
1151 .metadata
1152 .get("secret_payload")
1153 .and_then(Value::as_str)
1154 .unwrap_or_default()
1155 .to_string();
1156 assert!(before_payload.starts_with("enc:v1:"));
1157
1158 let rotated = rotate_tenant_encryption_key(
1159 secure.db(),
1160 "acme",
1161 "secret_payload",
1162 &key_registry,
1163 "k1",
1164 )?;
1165 assert_eq!(rotated, 1);
1166
1167 let after = secure
1168 .db()
1169 .list_chunks_page(0, 10, Some("acme"))?
1170 .into_iter()
1171 .next()
1172 .expect("chunk exists");
1173 let after_payload = after
1174 .metadata
1175 .get("secret_payload")
1176 .and_then(Value::as_str)
1177 .unwrap_or_default()
1178 .to_string();
1179 assert!(after_payload.starts_with("enc:v1:k1:"));
1180 Ok(())
1181 }
1182
1183 #[test]
1184 fn key_registry_persists_to_disk() -> Result<()> {
1185 let tmp = tempdir()?;
1186 let path = tmp.path().join("tenant_keys.json");
1187 let registry = InMemoryTenantKeyRegistry::new();
1188 registry.set_active_key("acme", TenantKey::new("k1", b"material-0000001")?)?;
1189 registry.set_key("acme", TenantKey::new("k2", b"material-0000002")?, false)?;
1190 registry.save_to_json_file(&path)?;
1191
1192 let restored = InMemoryTenantKeyRegistry::load_from_json_file(&path)?;
1193 assert!(restored.active_key("acme").is_some());
1194 assert!(restored.key_by_id("acme", "k2").is_some());
1195 Ok(())
1196 }
1197
1198 #[test]
1199 fn rbac_policy_enforces_role_permissions_and_cross_tenant_rules() -> Result<()> {
1200 let policy = RbacPolicy::default();
1201
1202 let reader = AccessContext::new("reader-1", "acme").with_roles(vec!["reader".to_string()]);
1203 policy.authorize(&reader, AccessOperation::Query, "acme")?;
1204
1205 let ingest_err = policy
1206 .authorize(&reader, AccessOperation::Ingest, "acme")
1207 .expect_err("reader ingest should be denied");
1208 assert!(matches!(ingest_err, SqlRiteError::AuthorizationDenied(_)));
1209
1210 let cross_tenant_err = policy
1211 .authorize(&reader, AccessOperation::Query, "beta")
1212 .expect_err("reader cross-tenant query should be denied");
1213 assert!(matches!(
1214 cross_tenant_err,
1215 SqlRiteError::AuthorizationDenied(_)
1216 ));
1217
1218 let admin = AccessContext::new("admin-1", "acme").with_roles(vec!["admin".to_string()]);
1219 policy.authorize(&admin, AccessOperation::SqlAdmin, "beta")?;
1220 Ok(())
1221 }
1222
1223 #[test]
1224 fn rbac_policy_round_trips_to_disk() -> Result<()> {
1225 let tmp = tempdir()?;
1226 let path = tmp.path().join("rbac-policy.json");
1227 let policy = RbacPolicy::default();
1228 policy.save_to_json_file(&path)?;
1229
1230 let restored = RbacPolicy::load_from_json_file(&path)?;
1231 assert_eq!(restored.role_names(), policy.role_names());
1232 Ok(())
1233 }
1234
1235 #[test]
1236 fn tenant_key_requires_minimum_length() {
1237 let error = TenantKey::new("k-short", b"too-short")
1238 .expect_err("short key material should be rejected");
1239 assert!(matches!(error, SqlRiteError::UnsupportedOperation(_)));
1240 }
1241
1242 #[test]
1243 fn audit_export_filters_and_writes_jsonl() -> Result<()> {
1244 let tmp = tempdir()?;
1245 let path = tmp.path().join("audit.jsonl");
1246 let logger = JsonlAuditLogger::new(&path, Vec::<String>::new())?;
1247 logger.log(&AuditEvent {
1248 unix_ms: 10,
1249 actor_id: "reader-1".to_string(),
1250 tenant_id: "acme".to_string(),
1251 operation: AccessOperation::Query,
1252 allowed: true,
1253 detail: json!({"path":"/v1/query"}),
1254 })?;
1255 logger.log(&AuditEvent {
1256 unix_ms: 20,
1257 actor_id: "admin-1".to_string(),
1258 tenant_id: "acme".to_string(),
1259 operation: AccessOperation::SqlAdmin,
1260 allowed: false,
1261 detail: json!({"path":"/v1/sql"}),
1262 })?;
1263
1264 let output = tmp.path().join("export.jsonl");
1265 let report = export_audit_events(
1266 &path,
1267 &AuditQuery {
1268 actor_id: Some("reader-1".to_string()),
1269 ..AuditQuery::default()
1270 },
1271 Some(&output),
1272 AuditExportFormat::Jsonl,
1273 )?;
1274 assert_eq!(report.matched_events, 1);
1275 let exported = fs::read_to_string(output)?;
1276 assert!(exported.contains("reader-1"));
1277 assert!(!exported.contains("admin-1"));
1278 Ok(())
1279 }
1280
1281 #[test]
1282 fn inspect_rotation_reports_stale_keys() -> Result<()> {
1283 let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1284 let tmp = tempdir()?;
1285 let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1286 let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1287
1288 let key_registry = InMemoryTenantKeyRegistry::new();
1289 key_registry.set_active_key("acme", TenantKey::new("k1", b"secret-key-00001")?)?;
1290 key_registry.set_key("acme", TenantKey::new("k2", b"secret-key-00002")?, false)?;
1291
1292 let ctx = AccessContext::new("user-enc", "acme");
1293 secure.ingest_chunks_with_encryption(
1294 &ctx,
1295 &[ChunkInput {
1296 id: "c-sec".to_string(),
1297 doc_id: "d-sec".to_string(),
1298 content: "sensitive chunk".to_string(),
1299 embedding: vec![1.0, 0.0],
1300 metadata: json!({"secret_payload": "highly-sensitive"}),
1301 source: None,
1302 }],
1303 &key_registry,
1304 &["secret_payload"],
1305 )?;
1306
1307 let report = inspect_tenant_key_rotation(
1308 secure.db(),
1309 "acme",
1310 "secret_payload",
1311 &key_registry,
1312 "k2",
1313 )?;
1314 assert_eq!(report.encrypted_chunks, 1);
1315 assert_eq!(report.target_key_matches, 0);
1316 assert_eq!(report.stale_key_ids, vec!["k1".to_string()]);
1317 assert!(!report.verified_all_target_key);
1318 Ok(())
1319 }
1320}