Skip to main content

sqlrite/
security.rs

1use crate::{
2    ChunkInput, Result, RuntimeConfig, SearchRequest, SearchResult, SqlRite, SqlRiteError,
3};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::{HashMap, HashSet};
7use std::fs::{self, OpenOptions};
8use std::io::Write;
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11
12const MINIMUM_TENANT_KEY_BYTES: usize = 16;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16pub enum AccessOperation {
17    Ingest,
18    Query,
19    SqlAdmin,
20    DeleteTenant,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct AccessContext {
25    pub actor_id: String,
26    pub tenant_id: String,
27    pub roles: Vec<String>,
28}
29
30impl AccessContext {
31    pub fn new(actor_id: impl Into<String>, tenant_id: impl Into<String>) -> Self {
32        Self {
33            actor_id: actor_id.into(),
34            tenant_id: tenant_id.into(),
35            roles: Vec::new(),
36        }
37    }
38
39    pub fn with_roles(mut self, roles: Vec<String>) -> Self {
40        self.roles = roles;
41        self
42    }
43
44    fn is_admin(&self) -> bool {
45        self.roles.iter().any(|role| role == "admin")
46    }
47}
48
49pub trait AccessPolicy: Send + Sync {
50    fn authorize(
51        &self,
52        context: &AccessContext,
53        operation: AccessOperation,
54        target_tenant: &str,
55    ) -> Result<()>;
56}
57
58#[derive(Debug, Default, Clone)]
59pub struct AllowAllPolicy;
60
61impl AccessPolicy for AllowAllPolicy {
62    fn authorize(
63        &self,
64        context: &AccessContext,
65        _operation: AccessOperation,
66        target_tenant: &str,
67    ) -> Result<()> {
68        if context.tenant_id.trim().is_empty() || target_tenant.trim().is_empty() {
69            return Err(SqlRiteError::InvalidTenantId);
70        }
71        if context.tenant_id != target_tenant && !context.is_admin() {
72            return Err(SqlRiteError::AuthorizationDenied(format!(
73                "tenant `{}` cannot access tenant `{}`",
74                context.tenant_id, target_tenant
75            )));
76        }
77        Ok(())
78    }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct RbacPolicyConfig {
83    pub role_permissions: HashMap<String, Vec<AccessOperation>>,
84    pub cross_tenant_roles: Vec<String>,
85}
86
87impl Default for RbacPolicyConfig {
88    fn default() -> Self {
89        Self {
90            role_permissions: HashMap::from([
91                ("reader".to_string(), vec![AccessOperation::Query]),
92                (
93                    "writer".to_string(),
94                    vec![AccessOperation::Query, AccessOperation::Ingest],
95                ),
96                (
97                    "tenant_admin".to_string(),
98                    vec![
99                        AccessOperation::Query,
100                        AccessOperation::Ingest,
101                        AccessOperation::DeleteTenant,
102                    ],
103                ),
104                (
105                    "admin".to_string(),
106                    vec![
107                        AccessOperation::Query,
108                        AccessOperation::Ingest,
109                        AccessOperation::SqlAdmin,
110                        AccessOperation::DeleteTenant,
111                    ],
112                ),
113            ]),
114            cross_tenant_roles: vec!["admin".to_string()],
115        }
116    }
117}
118
119#[derive(Debug, Clone)]
120pub struct RbacPolicy {
121    role_permissions: HashMap<String, HashSet<AccessOperation>>,
122    cross_tenant_roles: HashSet<String>,
123}
124
125impl Default for RbacPolicy {
126    fn default() -> Self {
127        Self::from_config(RbacPolicyConfig::default())
128    }
129}
130
131impl RbacPolicy {
132    pub fn from_config(config: RbacPolicyConfig) -> Self {
133        let role_permissions = config
134            .role_permissions
135            .into_iter()
136            .map(|(role, operations)| (role, operations.into_iter().collect()))
137            .collect();
138        let cross_tenant_roles = config.cross_tenant_roles.into_iter().collect();
139        Self {
140            role_permissions,
141            cross_tenant_roles,
142        }
143    }
144
145    pub fn to_config(&self) -> RbacPolicyConfig {
146        let mut role_permissions = HashMap::new();
147        for (role, operations) in &self.role_permissions {
148            let mut values = operations.iter().copied().collect::<Vec<_>>();
149            values.sort_by_key(|operation| match operation {
150                AccessOperation::Query => 0,
151                AccessOperation::Ingest => 1,
152                AccessOperation::SqlAdmin => 2,
153                AccessOperation::DeleteTenant => 3,
154            });
155            role_permissions.insert(role.clone(), values);
156        }
157
158        let mut cross_tenant_roles = self.cross_tenant_roles.iter().cloned().collect::<Vec<_>>();
159        cross_tenant_roles.sort();
160        RbacPolicyConfig {
161            role_permissions,
162            cross_tenant_roles,
163        }
164    }
165
166    pub fn load_from_json_file(path: impl AsRef<Path>) -> Result<Self> {
167        let path = path.as_ref();
168        if !path.exists() {
169            return Ok(Self::default());
170        }
171        let payload = fs::read_to_string(path)?;
172        let config = serde_json::from_str::<RbacPolicyConfig>(&payload)
173            .map_err(|e| SqlRiteError::UnsupportedOperation(e.to_string()))?;
174        Ok(Self::from_config(config))
175    }
176
177    pub fn save_to_json_file(&self, path: impl AsRef<Path>) -> Result<()> {
178        let path = path.as_ref();
179        if let Some(parent) = path.parent()
180            && !parent.as_os_str().is_empty()
181        {
182            fs::create_dir_all(parent)?;
183        }
184
185        let payload = serde_json::to_string_pretty(&self.to_config())?;
186        let temp = path.with_extension("tmp");
187        fs::write(&temp, payload)?;
188        fs::rename(temp, path)?;
189        Ok(())
190    }
191
192    pub fn role_names(&self) -> Vec<String> {
193        let mut roles = self.role_permissions.keys().cloned().collect::<Vec<_>>();
194        roles.sort();
195        roles
196    }
197}
198
199impl AccessPolicy for RbacPolicy {
200    fn authorize(
201        &self,
202        context: &AccessContext,
203        operation: AccessOperation,
204        target_tenant: &str,
205    ) -> Result<()> {
206        if context.tenant_id.trim().is_empty() || target_tenant.trim().is_empty() {
207            return Err(SqlRiteError::InvalidTenantId);
208        }
209
210        let allowed = context.roles.iter().any(|role| {
211            self.role_permissions
212                .get(role)
213                .is_some_and(|operations| operations.contains(&operation))
214        });
215        if !allowed {
216            return Err(SqlRiteError::AuthorizationDenied(format!(
217                "actor `{}` lacks role permission for {:?}",
218                context.actor_id, operation
219            )));
220        }
221
222        if context.tenant_id != target_tenant
223            && !context
224                .roles
225                .iter()
226                .any(|role| self.cross_tenant_roles.contains(role))
227        {
228            return Err(SqlRiteError::AuthorizationDenied(format!(
229                "tenant `{}` cannot access tenant `{}`",
230                context.tenant_id, target_tenant
231            )));
232        }
233
234        Ok(())
235    }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub struct TenantKey {
240    pub key_id: String,
241    pub material: Vec<u8>,
242}
243
244impl TenantKey {
245    pub fn new(key_id: impl Into<String>, material: impl AsRef<[u8]>) -> Result<Self> {
246        let key_id = key_id.into();
247        let material = material.as_ref().to_vec();
248        if key_id.trim().is_empty() || material.len() < MINIMUM_TENANT_KEY_BYTES {
249            return Err(SqlRiteError::UnsupportedOperation(format!(
250                "tenant key_id/material are required and key material must be at least {MINIMUM_TENANT_KEY_BYTES} bytes"
251            )));
252        }
253        Ok(Self { key_id, material })
254    }
255}
256
257pub trait TenantKeyRegistry: Send + Sync {
258    fn active_key(&self, tenant_id: &str) -> Option<TenantKey>;
259    fn key_by_id(&self, tenant_id: &str, key_id: &str) -> Option<TenantKey>;
260}
261
262#[derive(Debug, Default)]
263pub struct InMemoryTenantKeyRegistry {
264    keys: Mutex<HashMap<String, TenantKeyState>>,
265}
266
267#[derive(Debug, Default, Clone)]
268struct TenantKeyState {
269    active_key_id: Option<String>,
270    keys: HashMap<String, TenantKey>,
271}
272
273impl InMemoryTenantKeyRegistry {
274    pub fn new() -> Self {
275        Self::default()
276    }
277
278    pub fn set_active_key(&self, tenant_id: &str, key: TenantKey) -> Result<()> {
279        let mut guard = self.keys.lock().map_err(|_| {
280            SqlRiteError::UnsupportedOperation("tenant key registry mutex poisoned".to_string())
281        })?;
282        let state = guard.entry(tenant_id.to_string()).or_default();
283        state.active_key_id = Some(key.key_id.clone());
284        state.keys.insert(key.key_id.clone(), key);
285        Ok(())
286    }
287
288    pub fn set_key(&self, tenant_id: &str, key: TenantKey, make_active: bool) -> Result<()> {
289        let mut guard = self.keys.lock().map_err(|_| {
290            SqlRiteError::UnsupportedOperation("tenant key registry mutex poisoned".to_string())
291        })?;
292        let state = guard.entry(tenant_id.to_string()).or_default();
293        if make_active {
294            state.active_key_id = Some(key.key_id.clone());
295        }
296        state.keys.insert(key.key_id.clone(), key);
297        Ok(())
298    }
299
300    pub fn load_from_json_file(path: impl AsRef<Path>) -> Result<Self> {
301        let path = path.as_ref();
302        if !path.exists() {
303            return Ok(Self::new());
304        }
305        let payload = fs::read_to_string(path)?;
306        let serializable = serde_json::from_str::<SerializableTenantKeyRegistry>(&payload)
307            .map_err(|e| SqlRiteError::UnsupportedOperation(e.to_string()))?;
308
309        let mut tenants = HashMap::new();
310        for (tenant_id, state) in serializable.tenants {
311            let mut keys = HashMap::new();
312            for key in state.keys {
313                keys.insert(key.key_id.clone(), key);
314            }
315            tenants.insert(
316                tenant_id,
317                TenantKeyState {
318                    active_key_id: state.active_key_id,
319                    keys,
320                },
321            );
322        }
323        Ok(Self {
324            keys: Mutex::new(tenants),
325        })
326    }
327
328    pub fn save_to_json_file(&self, path: impl AsRef<Path>) -> Result<()> {
329        let path = path.as_ref();
330        if let Some(parent) = path.parent()
331            && !parent.as_os_str().is_empty()
332        {
333            fs::create_dir_all(parent)?;
334        }
335
336        let guard = self.keys.lock().map_err(|_| {
337            SqlRiteError::UnsupportedOperation("tenant key registry mutex poisoned".to_string())
338        })?;
339        let tenants = guard
340            .iter()
341            .map(|(tenant_id, state)| {
342                let keys = state.keys.values().cloned().collect::<Vec<_>>();
343                (
344                    tenant_id.clone(),
345                    SerializableTenantKeyState {
346                        active_key_id: state.active_key_id.clone(),
347                        keys,
348                    },
349                )
350            })
351            .collect::<HashMap<_, _>>();
352
353        let payload = serde_json::to_string_pretty(&SerializableTenantKeyRegistry { tenants })?;
354        let temp = path.with_extension("tmp");
355        fs::write(&temp, payload)?;
356        fs::rename(temp, path)?;
357        Ok(())
358    }
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
362struct SerializableTenantKeyState {
363    active_key_id: Option<String>,
364    keys: Vec<TenantKey>,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
368struct SerializableTenantKeyRegistry {
369    tenants: HashMap<String, SerializableTenantKeyState>,
370}
371
372impl TenantKeyRegistry for InMemoryTenantKeyRegistry {
373    fn active_key(&self, tenant_id: &str) -> Option<TenantKey> {
374        let guard = self.keys.lock().ok()?;
375        let state = guard.get(tenant_id)?;
376        let active = state.active_key_id.as_ref()?;
377        state.keys.get(active).cloned()
378    }
379
380    fn key_by_id(&self, tenant_id: &str, key_id: &str) -> Option<TenantKey> {
381        let guard = self.keys.lock().ok()?;
382        guard.get(tenant_id)?.keys.get(key_id).cloned()
383    }
384}
385
386#[derive(Debug, Clone, Serialize, Deserialize)]
387pub struct AuditEvent {
388    pub unix_ms: u64,
389    pub actor_id: String,
390    pub tenant_id: String,
391    pub operation: AccessOperation,
392    pub allowed: bool,
393    pub detail: Value,
394}
395
396pub trait AuditLogger: Send + Sync {
397    fn log(&self, event: &AuditEvent) -> Result<()>;
398}
399
400#[derive(Debug, Clone, Default, Serialize, Deserialize)]
401pub struct AuditQuery {
402    pub actor_id: Option<String>,
403    pub tenant_id: Option<String>,
404    pub operation: Option<AccessOperation>,
405    pub allowed: Option<bool>,
406    pub from_unix_ms: Option<u64>,
407    pub to_unix_ms: Option<u64>,
408    pub limit: Option<usize>,
409}
410
411#[derive(Debug, Clone, Copy, PartialEq, Eq)]
412pub enum AuditExportFormat {
413    Json,
414    Jsonl,
415}
416
417#[derive(Debug, Clone, Serialize, Deserialize)]
418pub struct AuditExportReport {
419    pub source_path: PathBuf,
420    pub output_path: Option<PathBuf>,
421    pub matched_events: usize,
422    pub exported_events: usize,
423    pub format: String,
424    pub filters: AuditQuery,
425}
426
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct KeyRotationReport {
429    pub tenant_id: String,
430    pub metadata_field: String,
431    pub target_key_id: String,
432    pub scanned_chunks: usize,
433    pub encrypted_chunks: usize,
434    pub rotated_chunks: usize,
435    pub target_key_matches: usize,
436    pub stale_key_ids: Vec<String>,
437    pub verified_all_target_key: bool,
438}
439
440#[derive(Debug)]
441pub struct JsonlAuditLogger {
442    path: PathBuf,
443    redacted_fields: HashSet<String>,
444    lock: Mutex<()>,
445}
446
447impl JsonlAuditLogger {
448    pub fn new(
449        path: impl AsRef<Path>,
450        redacted_fields: impl IntoIterator<Item = String>,
451    ) -> Result<Self> {
452        let path = path.as_ref().to_path_buf();
453        if let Some(parent) = path.parent()
454            && !parent.as_os_str().is_empty()
455        {
456            fs::create_dir_all(parent)?;
457        }
458        Ok(Self {
459            path,
460            redacted_fields: redacted_fields.into_iter().collect(),
461            lock: Mutex::new(()),
462        })
463    }
464
465    fn redact(&self, detail: &Value) -> Value {
466        match detail {
467            Value::Object(map) => {
468                let mut copy = map.clone();
469                for key in &self.redacted_fields {
470                    if copy.contains_key(key) {
471                        copy.insert(key.clone(), Value::String("[REDACTED]".to_string()));
472                    }
473                }
474                Value::Object(copy)
475            }
476            _ => detail.clone(),
477        }
478    }
479}
480
481impl AuditQuery {
482    fn matches(&self, event: &AuditEvent) -> bool {
483        if let Some(actor_id) = &self.actor_id
484            && &event.actor_id != actor_id
485        {
486            return false;
487        }
488        if let Some(tenant_id) = &self.tenant_id
489            && &event.tenant_id != tenant_id
490        {
491            return false;
492        }
493        if let Some(operation) = self.operation
494            && event.operation != operation
495        {
496            return false;
497        }
498        if let Some(allowed) = self.allowed
499            && event.allowed != allowed
500        {
501            return false;
502        }
503        if let Some(from_unix_ms) = self.from_unix_ms
504            && event.unix_ms < from_unix_ms
505        {
506            return false;
507        }
508        if let Some(to_unix_ms) = self.to_unix_ms
509            && event.unix_ms > to_unix_ms
510        {
511            return false;
512        }
513        true
514    }
515}
516
517impl AuditLogger for JsonlAuditLogger {
518    fn log(&self, event: &AuditEvent) -> Result<()> {
519        let _guard = self.lock.lock().map_err(|_| {
520            SqlRiteError::UnsupportedOperation("audit logger mutex poisoned".to_string())
521        })?;
522        let mut file = OpenOptions::new()
523            .create(true)
524            .append(true)
525            .open(&self.path)?;
526
527        let serialized = serde_json::to_string(&AuditEvent {
528            detail: self.redact(&event.detail),
529            ..event.clone()
530        })?;
531        file.write_all(serialized.as_bytes())?;
532        file.write_all(b"\n")?;
533        Ok(())
534    }
535}
536
537pub struct SecureSqlRite<P: AccessPolicy, A: AuditLogger> {
538    db: SqlRite,
539    policy: P,
540    audit_logger: A,
541}
542
543impl<P: AccessPolicy, A: AuditLogger> SecureSqlRite<P, A> {
544    pub fn open_with_config(
545        path: impl AsRef<Path>,
546        runtime: RuntimeConfig,
547        policy: P,
548        audit_logger: A,
549    ) -> Result<Self> {
550        Ok(Self {
551            db: SqlRite::open_with_config(path, runtime)?,
552            policy,
553            audit_logger,
554        })
555    }
556
557    pub fn from_db(db: SqlRite, policy: P, audit_logger: A) -> Self {
558        Self {
559            db,
560            policy,
561            audit_logger,
562        }
563    }
564
565    pub fn ingest_chunks(&self, context: &AccessContext, chunks: &[ChunkInput]) -> Result<()> {
566        self.policy
567            .authorize(context, AccessOperation::Ingest, &context.tenant_id)?;
568
569        let mut enriched = Vec::with_capacity(chunks.len());
570        for chunk in chunks {
571            let metadata = merge_tenant_metadata(&chunk.metadata, &context.tenant_id);
572            enriched.push(ChunkInput {
573                id: chunk.id.clone(),
574                doc_id: chunk.doc_id.clone(),
575                content: chunk.content.clone(),
576                embedding: chunk.embedding.clone(),
577                metadata,
578                source: chunk.source.clone(),
579            });
580        }
581
582        let result = self.db.ingest_chunks(&enriched);
583        self.audit(
584            context,
585            AccessOperation::Ingest,
586            result.is_ok(),
587            serde_json::json!({
588                "chunk_count": chunks.len(),
589            }),
590        )?;
591        result
592    }
593
594    pub fn ingest_chunks_with_encryption<R: TenantKeyRegistry>(
595        &self,
596        context: &AccessContext,
597        chunks: &[ChunkInput],
598        key_registry: &R,
599        sensitive_metadata_fields: &[&str],
600    ) -> Result<()> {
601        self.policy
602            .authorize(context, AccessOperation::Ingest, &context.tenant_id)?;
603        let active_key = key_registry.active_key(&context.tenant_id).ok_or_else(|| {
604            SqlRiteError::UnsupportedOperation(format!(
605                "no active key configured for tenant `{}`",
606                context.tenant_id
607            ))
608        })?;
609
610        let mut encrypted_chunks = Vec::with_capacity(chunks.len());
611        for chunk in chunks {
612            let mut metadata = merge_tenant_metadata(&chunk.metadata, &context.tenant_id);
613            encrypt_metadata_fields(
614                &mut metadata,
615                &active_key,
616                &context.tenant_id,
617                sensitive_metadata_fields,
618            )?;
619
620            encrypted_chunks.push(ChunkInput {
621                id: chunk.id.clone(),
622                doc_id: chunk.doc_id.clone(),
623                content: chunk.content.clone(),
624                embedding: chunk.embedding.clone(),
625                metadata,
626                source: chunk.source.clone(),
627            });
628        }
629
630        self.db.ingest_chunks(&encrypted_chunks)
631    }
632
633    pub fn search(
634        &self,
635        context: &AccessContext,
636        mut request: SearchRequest,
637    ) -> Result<Vec<SearchResult>> {
638        self.policy
639            .authorize(context, AccessOperation::Query, &context.tenant_id)?;
640
641        if let Some(existing) = request.metadata_filters.get("tenant")
642            && existing != &context.tenant_id
643        {
644            self.audit(
645                context,
646                AccessOperation::Query,
647                false,
648                serde_json::json!({"reason": "tenant filter mismatch"}),
649            )?;
650            return Err(SqlRiteError::AuthorizationDenied(
651                "tenant filter mismatch".to_string(),
652            ));
653        }
654
655        request
656            .metadata_filters
657            .insert("tenant".to_string(), context.tenant_id.clone());
658        let result = self.db.search(request);
659
660        self.audit(
661            context,
662            AccessOperation::Query,
663            result.is_ok(),
664            serde_json::json!({
665                "result_count": result.as_ref().map(|items| items.len()).unwrap_or(0),
666            }),
667        )?;
668        result
669    }
670
671    pub fn delete_tenant_data(&self, context: &AccessContext, tenant_id: &str) -> Result<usize> {
672        self.policy
673            .authorize(context, AccessOperation::DeleteTenant, tenant_id)?;
674
675        let result = self.db.delete_chunks_by_metadata("tenant", tenant_id);
676        self.audit(
677            context,
678            AccessOperation::DeleteTenant,
679            result.is_ok(),
680            serde_json::json!({
681                "target_tenant": tenant_id,
682                "deleted": result.as_ref().copied().unwrap_or(0),
683            }),
684        )?;
685        result
686    }
687
688    pub fn db(&self) -> &SqlRite {
689        &self.db
690    }
691
692    pub fn into_inner(self) -> SqlRite {
693        self.db
694    }
695
696    fn audit(
697        &self,
698        context: &AccessContext,
699        operation: AccessOperation,
700        allowed: bool,
701        detail: Value,
702    ) -> Result<()> {
703        self.audit_logger.log(&AuditEvent {
704            unix_ms: now_unix_ms(),
705            actor_id: context.actor_id.clone(),
706            tenant_id: context.tenant_id.clone(),
707            operation,
708            allowed,
709            detail,
710        })
711    }
712}
713
714pub fn rotate_tenant_encryption_key<R: TenantKeyRegistry>(
715    db: &SqlRite,
716    tenant_id: &str,
717    metadata_field: &str,
718    key_registry: &R,
719    new_key_id: &str,
720) -> Result<usize> {
721    Ok(rotate_tenant_encryption_key_with_report(
722        db,
723        tenant_id,
724        metadata_field,
725        key_registry,
726        new_key_id,
727    )?
728    .rotated_chunks)
729}
730
731pub fn rotate_tenant_encryption_key_with_report<R: TenantKeyRegistry>(
732    db: &SqlRite,
733    tenant_id: &str,
734    metadata_field: &str,
735    key_registry: &R,
736    new_key_id: &str,
737) -> Result<KeyRotationReport> {
738    let new_key = key_registry
739        .key_by_id(tenant_id, new_key_id)
740        .ok_or_else(|| SqlRiteError::UnsupportedOperation("new key not found".to_string()))?;
741
742    let mut updated = 0usize;
743    let mut offset = 0usize;
744    const PAGE_SIZE: usize = 256;
745    let mut scanned_chunks = 0usize;
746    let mut encrypted_chunks = 0usize;
747    let mut target_key_matches = 0usize;
748    let mut stale_key_ids = HashSet::new();
749
750    loop {
751        let page = db.list_chunks_page(offset, PAGE_SIZE, Some(tenant_id))?;
752        if page.is_empty() {
753            break;
754        }
755
756        for chunk in &page {
757            scanned_chunks = scanned_chunks.saturating_add(1);
758            let mut metadata = chunk.metadata.clone();
759            let Some(encrypted_value) = metadata.get(metadata_field).and_then(Value::as_str) else {
760                continue;
761            };
762            let Some((old_key_id, cipher_hex)) = parse_encrypted_value(encrypted_value) else {
763                continue;
764            };
765            encrypted_chunks = encrypted_chunks.saturating_add(1);
766            if old_key_id == new_key_id {
767                target_key_matches = target_key_matches.saturating_add(1);
768                continue;
769            }
770            stale_key_ids.insert(old_key_id.to_string());
771
772            let old_key = key_registry
773                .key_by_id(tenant_id, old_key_id)
774                .ok_or_else(|| {
775                    SqlRiteError::UnsupportedOperation(format!(
776                        "old key `{old_key_id}` not found for tenant `{tenant_id}`"
777                    ))
778                })?;
779
780            let plaintext = decrypt_with_key(cipher_hex, &old_key.material)?;
781            let rotated = encrypt_with_key(&plaintext, tenant_id, &new_key);
782            if let Value::Object(ref mut map) = metadata {
783                map.insert(metadata_field.to_string(), Value::String(rotated));
784                map.insert(
785                    "tenant_key_id".to_string(),
786                    Value::String(new_key.key_id.clone()),
787                );
788            }
789
790            db.update_chunk_metadata(&chunk.id, &metadata)?;
791            updated += 1;
792        }
793
794        offset += page.len();
795    }
796
797    let stale_key_ids = stale_key_ids.into_iter().collect::<Vec<_>>();
798    Ok(KeyRotationReport {
799        tenant_id: tenant_id.to_string(),
800        metadata_field: metadata_field.to_string(),
801        target_key_id: new_key.key_id,
802        scanned_chunks,
803        encrypted_chunks,
804        rotated_chunks: updated,
805        target_key_matches: target_key_matches.saturating_add(updated),
806        stale_key_ids,
807        verified_all_target_key: encrypted_chunks == target_key_matches.saturating_add(updated),
808    })
809}
810
811pub fn inspect_tenant_key_rotation<R: TenantKeyRegistry>(
812    db: &SqlRite,
813    tenant_id: &str,
814    metadata_field: &str,
815    key_registry: &R,
816    target_key_id: &str,
817) -> Result<KeyRotationReport> {
818    let target_key = key_registry
819        .key_by_id(tenant_id, target_key_id)
820        .ok_or_else(|| SqlRiteError::UnsupportedOperation("target key not found".to_string()))?;
821
822    let mut offset = 0usize;
823    const PAGE_SIZE: usize = 256;
824    let mut scanned_chunks = 0usize;
825    let mut encrypted_chunks = 0usize;
826    let mut target_key_matches = 0usize;
827    let mut stale_key_ids = HashSet::new();
828
829    loop {
830        let page = db.list_chunks_page(offset, PAGE_SIZE, Some(tenant_id))?;
831        if page.is_empty() {
832            break;
833        }
834
835        for chunk in &page {
836            scanned_chunks = scanned_chunks.saturating_add(1);
837            let Some(encrypted_value) = chunk.metadata.get(metadata_field).and_then(Value::as_str)
838            else {
839                continue;
840            };
841            let Some((key_id, _)) = parse_encrypted_value(encrypted_value) else {
842                continue;
843            };
844            encrypted_chunks = encrypted_chunks.saturating_add(1);
845            if key_id == target_key.key_id {
846                target_key_matches = target_key_matches.saturating_add(1);
847            } else {
848                stale_key_ids.insert(key_id.to_string());
849            }
850        }
851
852        offset += page.len();
853    }
854
855    Ok(KeyRotationReport {
856        tenant_id: tenant_id.to_string(),
857        metadata_field: metadata_field.to_string(),
858        target_key_id: target_key.key_id,
859        scanned_chunks,
860        encrypted_chunks,
861        rotated_chunks: 0,
862        target_key_matches,
863        stale_key_ids: stale_key_ids.into_iter().collect(),
864        verified_all_target_key: encrypted_chunks == target_key_matches,
865    })
866}
867
868pub fn read_audit_events(path: impl AsRef<Path>) -> Result<Vec<AuditEvent>> {
869    let path = path.as_ref();
870    if !path.exists() {
871        return Ok(Vec::new());
872    }
873    let payload = fs::read_to_string(path)?;
874    payload
875        .lines()
876        .filter(|line| !line.trim().is_empty())
877        .map(|line| {
878            serde_json::from_str::<AuditEvent>(line)
879                .map_err(|e| SqlRiteError::UnsupportedOperation(e.to_string()))
880        })
881        .collect()
882}
883
884pub fn export_audit_events(
885    source_path: impl AsRef<Path>,
886    query: &AuditQuery,
887    output_path: Option<&Path>,
888    format: AuditExportFormat,
889) -> Result<AuditExportReport> {
890    let source_path = source_path.as_ref().to_path_buf();
891    let mut events = read_audit_events(&source_path)?
892        .into_iter()
893        .filter(|event| query.matches(event))
894        .collect::<Vec<_>>();
895
896    if let Some(limit) = query.limit {
897        events.truncate(limit);
898    }
899
900    if let Some(path) = output_path {
901        if let Some(parent) = path.parent()
902            && !parent.as_os_str().is_empty()
903        {
904            fs::create_dir_all(parent)?;
905        }
906        let payload = match format {
907            AuditExportFormat::Json => serde_json::to_string_pretty(&events)?,
908            AuditExportFormat::Jsonl => events
909                .iter()
910                .map(serde_json::to_string)
911                .collect::<std::result::Result<Vec<_>, _>>()?
912                .join("\n"),
913        };
914        fs::write(
915            path,
916            if payload.is_empty() {
917                payload
918            } else {
919                format!("{payload}\n")
920            },
921        )?;
922    }
923
924    Ok(AuditExportReport {
925        source_path,
926        output_path: output_path.map(Path::to_path_buf),
927        matched_events: events.len(),
928        exported_events: events.len(),
929        format: match format {
930            AuditExportFormat::Json => "json".to_string(),
931            AuditExportFormat::Jsonl => "jsonl".to_string(),
932        },
933        filters: query.clone(),
934    })
935}
936
937fn merge_tenant_metadata(metadata: &Value, tenant_id: &str) -> Value {
938    match metadata {
939        Value::Object(map) => {
940            let mut merged = map.clone();
941            merged.insert("tenant".to_string(), Value::String(tenant_id.to_string()));
942            Value::Object(merged)
943        }
944        _ => {
945            let mut merged = serde_json::Map::new();
946            merged.insert("tenant".to_string(), Value::String(tenant_id.to_string()));
947            Value::Object(merged)
948        }
949    }
950}
951
952fn encrypt_metadata_fields(
953    metadata: &mut Value,
954    key: &TenantKey,
955    tenant_id: &str,
956    sensitive_metadata_fields: &[&str],
957) -> Result<()> {
958    let Some(map) = metadata.as_object_mut() else {
959        return Err(SqlRiteError::UnsupportedOperation(
960            "metadata must be a json object for encryption".to_string(),
961        ));
962    };
963
964    for field in sensitive_metadata_fields {
965        let Some(raw) = map.get(*field).and_then(Value::as_str) else {
966            continue;
967        };
968        map.insert(
969            (*field).to_string(),
970            Value::String(encrypt_with_key(raw, tenant_id, key)),
971        );
972    }
973    map.insert(
974        "tenant_key_id".to_string(),
975        Value::String(key.key_id.clone()),
976    );
977    Ok(())
978}
979
980fn encrypt_with_key(plaintext: &str, tenant_id: &str, key: &TenantKey) -> String {
981    let mut scoped = Vec::new();
982    scoped.extend_from_slice(tenant_id.as_bytes());
983    scoped.push(0);
984    scoped.extend_from_slice(plaintext.as_bytes());
985    let cipher = xor_with_key(&scoped, &key.material);
986    format!("enc:v1:{}:{}", key.key_id, hex_encode(&cipher))
987}
988
989fn decrypt_with_key(cipher_hex: &str, key_material: &[u8]) -> Result<String> {
990    let cipher = hex_decode(cipher_hex)?;
991    let plain_scoped = xor_with_key(&cipher, key_material);
992    let Some(separator_idx) = plain_scoped.iter().position(|byte| *byte == 0) else {
993        return Err(SqlRiteError::UnsupportedOperation(
994            "invalid encrypted payload format".to_string(),
995        ));
996    };
997    let plaintext = &plain_scoped[(separator_idx + 1)..];
998    String::from_utf8(plaintext.to_vec()).map_err(|_| {
999        SqlRiteError::UnsupportedOperation("invalid utf8 in decrypted payload".to_string())
1000    })
1001}
1002
1003fn xor_with_key(input: &[u8], key: &[u8]) -> Vec<u8> {
1004    input
1005        .iter()
1006        .enumerate()
1007        .map(|(idx, byte)| byte ^ key[idx % key.len()])
1008        .collect()
1009}
1010
1011fn parse_encrypted_value(value: &str) -> Option<(&str, &str)> {
1012    let mut parts = value.splitn(4, ':');
1013    let marker = parts.next()?;
1014    let version = parts.next()?;
1015    let key_id = parts.next()?;
1016    let payload = parts.next()?;
1017    if marker == "enc" && version == "v1" && !key_id.is_empty() && !payload.is_empty() {
1018        Some((key_id, payload))
1019    } else {
1020        None
1021    }
1022}
1023
1024fn hex_encode(bytes: &[u8]) -> String {
1025    let mut out = String::with_capacity(bytes.len() * 2);
1026    for byte in bytes {
1027        out.push_str(&format!("{byte:02x}"));
1028    }
1029    out
1030}
1031
1032fn hex_decode(value: &str) -> Result<Vec<u8>> {
1033    if !value.len().is_multiple_of(2) {
1034        return Err(SqlRiteError::UnsupportedOperation(
1035            "invalid hex payload length".to_string(),
1036        ));
1037    }
1038    let mut out = Vec::with_capacity(value.len() / 2);
1039    for idx in (0..value.len()).step_by(2) {
1040        let byte = u8::from_str_radix(&value[idx..idx + 2], 16)
1041            .map_err(|_| SqlRiteError::UnsupportedOperation("invalid hex payload".to_string()))?;
1042        out.push(byte);
1043    }
1044    Ok(out)
1045}
1046
1047fn now_unix_ms() -> u64 {
1048    std::time::SystemTime::now()
1049        .duration_since(std::time::UNIX_EPOCH)
1050        .map(|d| d.as_millis() as u64)
1051        .unwrap_or(0)
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056    use super::*;
1057    use crate::{ChunkInput, RuntimeConfig, SearchRequest, SqlRite};
1058    use serde_json::json;
1059    use tempfile::tempdir;
1060
1061    #[test]
1062    fn secure_wrapper_enforces_tenant_filter() -> Result<()> {
1063        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1064        let tmp = tempdir()?;
1065        let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1066        let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1067
1068        let ctx_acme = AccessContext::new("user-1", "acme");
1069        secure.ingest_chunks(
1070            &ctx_acme,
1071            &[ChunkInput {
1072                id: "c1".to_string(),
1073                doc_id: "d1".to_string(),
1074                content: "tenant scoped".to_string(),
1075                embedding: vec![1.0, 0.0],
1076                metadata: json!({}),
1077                source: None,
1078            }],
1079        )?;
1080
1081        let ctx_beta = AccessContext::new("user-2", "beta");
1082        let beta_results = secure.search(
1083            &ctx_beta,
1084            SearchRequest {
1085                query_text: Some("tenant".to_string()),
1086                top_k: 5,
1087                ..Default::default()
1088            },
1089        )?;
1090        assert!(beta_results.is_empty());
1091
1092        let acme_results = secure.search(
1093            &ctx_acme,
1094            SearchRequest {
1095                query_text: Some("tenant".to_string()),
1096                top_k: 5,
1097                ..Default::default()
1098            },
1099        )?;
1100        assert_eq!(acme_results.len(), 1);
1101        Ok(())
1102    }
1103
1104    #[test]
1105    fn non_admin_cannot_delete_other_tenant() -> Result<()> {
1106        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1107        let tmp = tempdir()?;
1108        let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1109        let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1110
1111        let err = secure
1112            .delete_tenant_data(&AccessContext::new("u1", "acme"), "beta")
1113            .expect_err("cross tenant delete should fail");
1114        assert!(matches!(err, SqlRiteError::AuthorizationDenied(_)));
1115        Ok(())
1116    }
1117
1118    #[test]
1119    fn encrypted_ingest_and_key_rotation_workflow() -> Result<()> {
1120        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1121        let tmp = tempdir()?;
1122        let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1123        let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1124
1125        let key_registry = InMemoryTenantKeyRegistry::new();
1126        key_registry.set_active_key("acme", TenantKey::new("k1", b"secret-key-00001")?)?;
1127        key_registry.set_active_key("acme", TenantKey::new("k2", b"secret-key-00002")?)?;
1128
1129        let ctx = AccessContext::new("user-enc", "acme");
1130        secure.ingest_chunks_with_encryption(
1131            &ctx,
1132            &[ChunkInput {
1133                id: "c-sec".to_string(),
1134                doc_id: "d-sec".to_string(),
1135                content: "sensitive chunk".to_string(),
1136                embedding: vec![1.0, 0.0],
1137                metadata: json!({"secret_payload": "highly-sensitive"}),
1138                source: None,
1139            }],
1140            &key_registry,
1141            &["secret_payload"],
1142        )?;
1143
1144        let before = secure
1145            .db()
1146            .list_chunks_page(0, 10, Some("acme"))?
1147            .into_iter()
1148            .next()
1149            .expect("chunk exists");
1150        let before_payload = before
1151            .metadata
1152            .get("secret_payload")
1153            .and_then(Value::as_str)
1154            .unwrap_or_default()
1155            .to_string();
1156        assert!(before_payload.starts_with("enc:v1:"));
1157
1158        let rotated = rotate_tenant_encryption_key(
1159            secure.db(),
1160            "acme",
1161            "secret_payload",
1162            &key_registry,
1163            "k1",
1164        )?;
1165        assert_eq!(rotated, 1);
1166
1167        let after = secure
1168            .db()
1169            .list_chunks_page(0, 10, Some("acme"))?
1170            .into_iter()
1171            .next()
1172            .expect("chunk exists");
1173        let after_payload = after
1174            .metadata
1175            .get("secret_payload")
1176            .and_then(Value::as_str)
1177            .unwrap_or_default()
1178            .to_string();
1179        assert!(after_payload.starts_with("enc:v1:k1:"));
1180        Ok(())
1181    }
1182
1183    #[test]
1184    fn key_registry_persists_to_disk() -> Result<()> {
1185        let tmp = tempdir()?;
1186        let path = tmp.path().join("tenant_keys.json");
1187        let registry = InMemoryTenantKeyRegistry::new();
1188        registry.set_active_key("acme", TenantKey::new("k1", b"material-0000001")?)?;
1189        registry.set_key("acme", TenantKey::new("k2", b"material-0000002")?, false)?;
1190        registry.save_to_json_file(&path)?;
1191
1192        let restored = InMemoryTenantKeyRegistry::load_from_json_file(&path)?;
1193        assert!(restored.active_key("acme").is_some());
1194        assert!(restored.key_by_id("acme", "k2").is_some());
1195        Ok(())
1196    }
1197
1198    #[test]
1199    fn rbac_policy_enforces_role_permissions_and_cross_tenant_rules() -> Result<()> {
1200        let policy = RbacPolicy::default();
1201
1202        let reader = AccessContext::new("reader-1", "acme").with_roles(vec!["reader".to_string()]);
1203        policy.authorize(&reader, AccessOperation::Query, "acme")?;
1204
1205        let ingest_err = policy
1206            .authorize(&reader, AccessOperation::Ingest, "acme")
1207            .expect_err("reader ingest should be denied");
1208        assert!(matches!(ingest_err, SqlRiteError::AuthorizationDenied(_)));
1209
1210        let cross_tenant_err = policy
1211            .authorize(&reader, AccessOperation::Query, "beta")
1212            .expect_err("reader cross-tenant query should be denied");
1213        assert!(matches!(
1214            cross_tenant_err,
1215            SqlRiteError::AuthorizationDenied(_)
1216        ));
1217
1218        let admin = AccessContext::new("admin-1", "acme").with_roles(vec!["admin".to_string()]);
1219        policy.authorize(&admin, AccessOperation::SqlAdmin, "beta")?;
1220        Ok(())
1221    }
1222
1223    #[test]
1224    fn rbac_policy_round_trips_to_disk() -> Result<()> {
1225        let tmp = tempdir()?;
1226        let path = tmp.path().join("rbac-policy.json");
1227        let policy = RbacPolicy::default();
1228        policy.save_to_json_file(&path)?;
1229
1230        let restored = RbacPolicy::load_from_json_file(&path)?;
1231        assert_eq!(restored.role_names(), policy.role_names());
1232        Ok(())
1233    }
1234
1235    #[test]
1236    fn tenant_key_requires_minimum_length() {
1237        let error = TenantKey::new("k-short", b"too-short")
1238            .expect_err("short key material should be rejected");
1239        assert!(matches!(error, SqlRiteError::UnsupportedOperation(_)));
1240    }
1241
1242    #[test]
1243    fn audit_export_filters_and_writes_jsonl() -> Result<()> {
1244        let tmp = tempdir()?;
1245        let path = tmp.path().join("audit.jsonl");
1246        let logger = JsonlAuditLogger::new(&path, Vec::<String>::new())?;
1247        logger.log(&AuditEvent {
1248            unix_ms: 10,
1249            actor_id: "reader-1".to_string(),
1250            tenant_id: "acme".to_string(),
1251            operation: AccessOperation::Query,
1252            allowed: true,
1253            detail: json!({"path":"/v1/query"}),
1254        })?;
1255        logger.log(&AuditEvent {
1256            unix_ms: 20,
1257            actor_id: "admin-1".to_string(),
1258            tenant_id: "acme".to_string(),
1259            operation: AccessOperation::SqlAdmin,
1260            allowed: false,
1261            detail: json!({"path":"/v1/sql"}),
1262        })?;
1263
1264        let output = tmp.path().join("export.jsonl");
1265        let report = export_audit_events(
1266            &path,
1267            &AuditQuery {
1268                actor_id: Some("reader-1".to_string()),
1269                ..AuditQuery::default()
1270            },
1271            Some(&output),
1272            AuditExportFormat::Jsonl,
1273        )?;
1274        assert_eq!(report.matched_events, 1);
1275        let exported = fs::read_to_string(output)?;
1276        assert!(exported.contains("reader-1"));
1277        assert!(!exported.contains("admin-1"));
1278        Ok(())
1279    }
1280
1281    #[test]
1282    fn inspect_rotation_reports_stale_keys() -> Result<()> {
1283        let db = SqlRite::open_in_memory_with_config(RuntimeConfig::default())?;
1284        let tmp = tempdir()?;
1285        let logger = JsonlAuditLogger::new(tmp.path().join("audit.jsonl"), Vec::<String>::new())?;
1286        let secure = SecureSqlRite::from_db(db, AllowAllPolicy, logger);
1287
1288        let key_registry = InMemoryTenantKeyRegistry::new();
1289        key_registry.set_active_key("acme", TenantKey::new("k1", b"secret-key-00001")?)?;
1290        key_registry.set_key("acme", TenantKey::new("k2", b"secret-key-00002")?, false)?;
1291
1292        let ctx = AccessContext::new("user-enc", "acme");
1293        secure.ingest_chunks_with_encryption(
1294            &ctx,
1295            &[ChunkInput {
1296                id: "c-sec".to_string(),
1297                doc_id: "d-sec".to_string(),
1298                content: "sensitive chunk".to_string(),
1299                embedding: vec![1.0, 0.0],
1300                metadata: json!({"secret_payload": "highly-sensitive"}),
1301                source: None,
1302            }],
1303            &key_registry,
1304            &["secret_payload"],
1305        )?;
1306
1307        let report = inspect_tenant_key_rotation(
1308            secure.db(),
1309            "acme",
1310            "secret_payload",
1311            &key_registry,
1312            "k2",
1313        )?;
1314        assert_eq!(report.encrypted_chunks, 1);
1315        assert_eq!(report.target_key_matches, 0);
1316        assert_eq!(report.stale_key_ids, vec!["k1".to_string()]);
1317        assert!(!report.verified_all_target_key);
1318        Ok(())
1319    }
1320}