Skip to main content

tandem_memory/
db.rs

1#![allow(clippy::all)]
2
3// Database Layer Module
4// SQLite + sqlite-vec for vector storage
5
6use crate::types::{
7    ClearFileIndexResult, GlobalMemoryRecord, GlobalMemorySearchHit, GlobalMemoryWriteResult,
8    KnowledgeCoverageRecord, KnowledgeItemRecord, KnowledgeItemStatus, KnowledgePromotionRequest,
9    KnowledgePromotionResult, KnowledgeSpaceRecord, MemoryChunk, MemoryConfig, MemoryError,
10    MemoryResult, MemoryStats, MemoryTenantScope, MemoryTier, ProjectMemoryStats,
11    SourceObjectLifecycleRecord, SourceObjectLifecycleState, DEFAULT_EMBEDDING_DIMENSION,
12};
13use chrono::{DateTime, Utc};
14use rusqlite::{ffi::sqlite3_auto_extension, params, Connection, OptionalExtension, Row};
15use sqlite_vec::sqlite3_vec_init;
16use std::collections::HashSet;
17use std::path::Path;
18use std::sync::{Arc, LazyLock};
19use std::time::Duration;
20use tokio::sync::Mutex;
21
22type ProjectIndexStatusRow = (
23    Option<String>,
24    Option<i64>,
25    Option<i64>,
26    Option<i64>,
27    Option<i64>,
28    Option<i64>,
29);
30
31/// Database connection manager
32pub struct MemoryDatabase {
33    conn: Arc<Mutex<Connection>>,
34    db_path: std::path::PathBuf,
35    crypto: crate::crypto::MemoryCryptoProvider,
36    strict_tenant_enforcement: std::sync::atomic::AtomicBool,
37}
38
39static SCHEMA_INIT_LOCK: LazyLock<Mutex<()>> = LazyLock::new(|| Mutex::new(()));
40
41/// Process-wide default for strict tenant enforcement, set once at startup by
42/// the host (engine `serve` in hosted/enterprise auth modes) before databases
43/// are opened. New `MemoryDatabase` instances inherit this default, so the
44/// many ad-hoc construction sites in tandem-server stay fail-closed without
45/// each one threading a flag.
46static STRICT_TENANT_ENFORCEMENT_DEFAULT: std::sync::atomic::AtomicBool =
47    std::sync::atomic::AtomicBool::new(false);
48
49/// Enable (or disable) strict tenant enforcement for all `MemoryDatabase`
50/// instances opened after this call. In strict mode, reads and writes carrying
51/// the local-implicit tenant scope are rejected instead of landing in the
52/// shared "local" partition.
53pub fn set_strict_tenant_enforcement_default(enabled: bool) {
54    STRICT_TENANT_ENFORCEMENT_DEFAULT.store(enabled, std::sync::atomic::Ordering::SeqCst);
55}
56
57pub fn strict_tenant_enforcement_default() -> bool {
58    STRICT_TENANT_ENFORCEMENT_DEFAULT.load(std::sync::atomic::Ordering::SeqCst)
59}
60
61include!("memory_database_impl_parts/part01.rs");
62include!("memory_database_impl_parts/part02.rs");
63
64/// Convert a database row to a MemoryChunk
65fn row_to_chunk(
66    row: &Row,
67    tier: MemoryTier,
68    crypto: &crate::crypto::MemoryCryptoProvider,
69) -> Result<MemoryChunk, rusqlite::Error> {
70    let map_decrypt_err = |err: crate::types::MemoryError| {
71        rusqlite::Error::FromSqlConversionFailure(1, rusqlite::types::Type::Text, Box::new(err))
72    };
73    let id: String = row.get(0)?;
74    let content_raw: String = row.get(1)?;
75    let content = crypto
76        .decrypt_field(&content_raw)
77        .map_err(map_decrypt_err)?;
78    let (session_id, project_id, source_idx, created_at_idx, token_count_idx, metadata_idx) =
79        match tier {
80            MemoryTier::Session => (
81                Some(row.get(2)?),
82                row.get(3)?,
83                4usize,
84                5usize,
85                6usize,
86                7usize,
87            ),
88            MemoryTier::Project => (
89                row.get(2)?,
90                Some(row.get(3)?),
91                4usize,
92                5usize,
93                6usize,
94                7usize,
95            ),
96            MemoryTier::Global => (None, None, 2usize, 3usize, 4usize, 5usize),
97        };
98
99    let source: String = row.get(source_idx)?;
100    let created_at_str: String = row.get(created_at_idx)?;
101    let token_count: i64 = row.get(token_count_idx)?;
102    let metadata_raw: Option<String> = row.get(metadata_idx)?;
103    let metadata_str = match metadata_raw {
104        Some(s) if !s.is_empty() => Some(crypto.decrypt_field(&s).map_err(map_decrypt_err)?),
105        other => other,
106    };
107
108    let created_at = DateTime::parse_from_rfc3339(&created_at_str)
109        .map_err(|e| {
110            rusqlite::Error::FromSqlConversionFailure(5, rusqlite::types::Type::Text, Box::new(e))
111        })?
112        .with_timezone(&Utc);
113
114    let metadata = metadata_str
115        .filter(|s| !s.is_empty())
116        .and_then(|s| serde_json::from_str(&s).ok());
117
118    let source_path = row.get::<_, Option<String>>("source_path").ok().flatten();
119    let source_mtime = row.get::<_, Option<i64>>("source_mtime").ok().flatten();
120    let source_size = row.get::<_, Option<i64>>("source_size").ok().flatten();
121    let source_hash = row.get::<_, Option<String>>("source_hash").ok().flatten();
122    let tenant_scope = MemoryTenantScope {
123        org_id: row
124            .get::<_, Option<String>>("tenant_org_id")
125            .ok()
126            .flatten()
127            .filter(|value| !value.trim().is_empty())
128            .unwrap_or_else(|| LOCAL_TENANT_ORG_ID.to_string()),
129        workspace_id: row
130            .get::<_, Option<String>>("tenant_workspace_id")
131            .ok()
132            .flatten()
133            .filter(|value| !value.trim().is_empty())
134            .unwrap_or_else(|| LOCAL_TENANT_WORKSPACE_ID.to_string()),
135        deployment_id: row
136            .get::<_, Option<String>>("tenant_deployment_id")
137            .ok()
138            .flatten()
139            .filter(|value| !value.trim().is_empty()),
140    };
141
142    Ok(MemoryChunk {
143        id,
144        content,
145        tier,
146        session_id,
147        project_id,
148        source,
149        source_path,
150        source_mtime,
151        source_size,
152        source_hash,
153        tenant_scope,
154        created_at,
155        token_count,
156        metadata,
157    })
158}
159
160fn require_scope_id<'a>(tier: MemoryTier, scope: Option<&'a str>) -> MemoryResult<&'a str> {
161    scope
162        .filter(|value| !value.trim().is_empty())
163        .ok_or_else(|| {
164            crate::types::MemoryError::InvalidConfig(match tier {
165                MemoryTier::Session => "tier=session requires session_id".to_string(),
166                MemoryTier::Project => "tier=project requires project_id".to_string(),
167                MemoryTier::Global => "tier=global does not require a scope id".to_string(),
168            })
169        })
170}
171
172const LOCAL_TENANT_ORG_ID: &str = "local";
173const LOCAL_TENANT_WORKSPACE_ID: &str = "local";
174
175fn tenant_scope_matches_sql_clause(prefix: &str, first_param: usize) -> String {
176    format!(
177        "{prefix}.tenant_org_id = ?{first_param} AND {prefix}.tenant_workspace_id = ?{} AND IFNULL({prefix}.tenant_deployment_id, '') = IFNULL(?{}, '')",
178        first_param + 1,
179        first_param + 2
180    )
181}
182
183fn global_memory_record_tenant_scope(
184    record: &GlobalMemoryRecord,
185) -> (String, String, Option<String>) {
186    record
187        .provenance
188        .as_ref()
189        .and_then(|value| value.get("tenant_context"))
190        .and_then(memory_tenant_scope_from_value)
191        .unwrap_or_else(|| {
192            (
193                LOCAL_TENANT_ORG_ID.to_string(),
194                LOCAL_TENANT_WORKSPACE_ID.to_string(),
195                None,
196            )
197        })
198}
199
200fn memory_tenant_scope_from_value(
201    value: &serde_json::Value,
202) -> Option<(String, String, Option<String>)> {
203    let org_id = value.get("org_id")?.as_str()?.to_string();
204    let workspace_id = value.get("workspace_id")?.as_str()?.to_string();
205    let deployment_id = value
206        .get("deployment_id")
207        .and_then(|value| value.as_str())
208        .map(ToString::to_string);
209    Some((org_id, workspace_id, deployment_id))
210}
211
212fn row_to_global_record(row: &Row) -> Result<GlobalMemoryRecord, rusqlite::Error> {
213    let metadata_str: Option<String> = row.get(12)?;
214    let provenance_str: Option<String> = row.get(13)?;
215    Ok(GlobalMemoryRecord {
216        id: row.get(0)?,
217        user_id: row.get(1)?,
218        source_type: row.get(2)?,
219        content: row.get(3)?,
220        content_hash: row.get(4)?,
221        run_id: row.get(5)?,
222        session_id: row.get(6)?,
223        message_id: row.get(7)?,
224        tool_name: row.get(8)?,
225        project_tag: row.get(9)?,
226        channel_tag: row.get(10)?,
227        host_tag: row.get(11)?,
228        metadata: metadata_str
229            .filter(|s| !s.is_empty())
230            .and_then(|s| serde_json::from_str(&s).ok()),
231        provenance: provenance_str
232            .filter(|s| !s.is_empty())
233            .and_then(|s| serde_json::from_str(&s).ok()),
234        redaction_status: row.get(14)?,
235        redaction_count: row.get::<_, i64>(15)? as u32,
236        visibility: row.get(16)?,
237        demoted: row.get::<_, i64>(17)? != 0,
238        score_boost: row.get(18)?,
239        created_at_ms: row.get::<_, i64>(19)? as u64,
240        updated_at_ms: row.get::<_, i64>(20)? as u64,
241        expires_at_ms: row.get::<_, Option<i64>>(21)?.map(|v| v as u64),
242    })
243}
244
245fn row_to_source_object_lifecycle(
246    row: &Row,
247) -> Result<SourceObjectLifecycleRecord, rusqlite::Error> {
248    let metadata_str: Option<String> = row.get("metadata")?;
249    let resource_ref_str: String = row.get("resource_ref")?;
250    let tenant_scope = MemoryTenantScope {
251        org_id: row.get("tenant_org_id")?,
252        workspace_id: row.get("tenant_workspace_id")?,
253        deployment_id: row
254            .get::<_, Option<String>>("tenant_deployment_id")?
255            .filter(|value| !value.is_empty()),
256    };
257    let tier = match row.get::<_, String>("tier")?.as_str() {
258        "session" => MemoryTier::Session,
259        "project" => MemoryTier::Project,
260        _ => MemoryTier::Global,
261    };
262    Ok(SourceObjectLifecycleRecord {
263        source_object_id: row.get("source_object_id")?,
264        tenant_scope,
265        source_binding_id: row.get("source_binding_id")?,
266        connector_id: row.get("connector_id")?,
267        state: SourceObjectLifecycleState::parse(&row.get::<_, String>("state")?),
268        tier,
269        session_id: row.get("session_id")?,
270        project_id: row.get("project_id")?,
271        import_namespace: row.get("import_namespace")?,
272        indexed_path: row.get("indexed_path")?,
273        native_object_id: row.get("native_object_id")?,
274        resource_ref: serde_json::from_str(&resource_ref_str).unwrap_or(serde_json::Value::Null),
275        data_class: row.get("data_class")?,
276        content_hash: row.get("content_hash")?,
277        source_hash: row.get("source_hash")?,
278        first_seen_at_ms: row.get::<_, i64>("first_seen_at_ms")? as u64,
279        last_seen_at_ms: row.get::<_, i64>("last_seen_at_ms")? as u64,
280        tombstoned_at_ms: row
281            .get::<_, Option<i64>>("tombstoned_at_ms")?
282            .map(|value| value as u64),
283        metadata: metadata_str
284            .filter(|value| !value.is_empty())
285            .and_then(|value| serde_json::from_str(&value).ok()),
286    })
287}
288
289impl MemoryDatabase {
290    pub async fn get_node_by_uri(
291        &self,
292        uri: &str,
293    ) -> MemoryResult<Option<crate::types::MemoryNode>> {
294        let conn = self.conn.lock().await;
295        let mut stmt = conn.prepare(
296            "SELECT id, uri, parent_uri, node_type, created_at, updated_at, metadata
297             FROM memory_nodes WHERE uri = ?1",
298        )?;
299
300        let result = stmt.query_row(params![uri], |row| {
301            let node_type_str: String = row.get(3)?;
302            let node_type = node_type_str
303                .parse()
304                .unwrap_or(crate::types::NodeType::File);
305            let metadata_str: Option<String> = row.get(6)?;
306            Ok(crate::types::MemoryNode {
307                id: row.get(0)?,
308                uri: row.get(1)?,
309                parent_uri: row.get(2)?,
310                node_type,
311                created_at: row.get::<_, String>(4)?.parse().unwrap_or_default(),
312                updated_at: row.get::<_, String>(5)?.parse().unwrap_or_default(),
313                metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
314            })
315        });
316
317        match result {
318            Ok(node) => Ok(Some(node)),
319            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
320            Err(e) => Err(MemoryError::Database(e)),
321        }
322    }
323
324    pub async fn create_node(
325        &self,
326        uri: &str,
327        parent_uri: Option<&str>,
328        node_type: crate::types::NodeType,
329        metadata: Option<&serde_json::Value>,
330    ) -> MemoryResult<String> {
331        let id = uuid::Uuid::new_v4().to_string();
332        let now = Utc::now().to_rfc3339();
333        let metadata_str = metadata.map(|m| serde_json::to_string(m)).transpose()?;
334
335        let conn = self.conn.lock().await;
336        conn.execute(
337            "INSERT INTO memory_nodes (id, uri, parent_uri, node_type, created_at, updated_at, metadata)
338             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
339            params![id, uri, parent_uri, node_type.to_string(), now, now, metadata_str],
340        )?;
341
342        Ok(id)
343    }
344
345    pub async fn list_directory(&self, uri: &str) -> MemoryResult<Vec<crate::types::MemoryNode>> {
346        let conn = self.conn.lock().await;
347        let mut stmt = conn.prepare(
348            "SELECT id, uri, parent_uri, node_type, created_at, updated_at, metadata
349             FROM memory_nodes WHERE parent_uri = ?1 ORDER BY node_type DESC, uri ASC",
350        )?;
351
352        let rows = stmt.query_map(params![uri], |row| {
353            let node_type_str: String = row.get(3)?;
354            let node_type = node_type_str
355                .parse()
356                .unwrap_or(crate::types::NodeType::File);
357            let metadata_str: Option<String> = row.get(6)?;
358            Ok(crate::types::MemoryNode {
359                id: row.get(0)?,
360                uri: row.get(1)?,
361                parent_uri: row.get(2)?,
362                node_type,
363                created_at: row.get::<_, String>(4)?.parse().unwrap_or_default(),
364                updated_at: row.get::<_, String>(5)?.parse().unwrap_or_default(),
365                metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
366            })
367        })?;
368
369        rows.collect::<Result<Vec<_>, _>>()
370            .map_err(MemoryError::Database)
371    }
372
373    pub async fn get_layer(
374        &self,
375        node_id: &str,
376        layer_type: crate::types::LayerType,
377    ) -> MemoryResult<Option<crate::types::MemoryLayer>> {
378        let conn = self.conn.lock().await;
379        let mut stmt = conn.prepare(
380            "SELECT id, node_id, layer_type, content, token_count, embedding_id, created_at, source_chunk_id
381             FROM memory_layers WHERE node_id = ?1 AND layer_type = ?2"
382        )?;
383
384        let result = stmt.query_row(params![node_id, layer_type.to_string()], |row| {
385            let layer_type_str: String = row.get(2)?;
386            let layer_type = layer_type_str
387                .parse()
388                .unwrap_or(crate::types::LayerType::L2);
389            Ok(crate::types::MemoryLayer {
390                id: row.get(0)?,
391                node_id: row.get(1)?,
392                layer_type,
393                content: row.get(3)?,
394                token_count: row.get(4)?,
395                embedding_id: row.get(5)?,
396                created_at: row.get::<_, String>(6)?.parse().unwrap_or_default(),
397                source_chunk_id: row.get(7)?,
398            })
399        });
400
401        match result {
402            Ok(mut layer) => {
403                layer.content = self.crypto.decrypt_field(&layer.content)?;
404                Ok(Some(layer))
405            }
406            Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
407            Err(e) => Err(MemoryError::Database(e)),
408        }
409    }
410
411    pub async fn create_layer(
412        &self,
413        node_id: &str,
414        layer_type: crate::types::LayerType,
415        content: &str,
416        token_count: i64,
417        source_chunk_id: Option<&str>,
418    ) -> MemoryResult<String> {
419        let id = uuid::Uuid::new_v4().to_string();
420        let now = Utc::now().to_rfc3339();
421        let content_stored = self.crypto.encrypt_field(content)?;
422
423        let conn = self.conn.lock().await;
424        conn.execute(
425            "INSERT INTO memory_layers (id, node_id, layer_type, content, token_count, created_at, source_chunk_id)
426             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
427            params![id, node_id, layer_type.to_string(), content_stored, token_count, now, source_chunk_id],
428        )?;
429
430        Ok(id)
431    }
432
433    pub async fn get_children_tree(
434        &self,
435        parent_uri: &str,
436        max_depth: usize,
437    ) -> MemoryResult<Vec<crate::types::TreeNode>> {
438        if max_depth == 0 {
439            return Ok(Vec::new());
440        }
441
442        let children = self.list_directory(parent_uri).await?;
443        let mut tree_nodes = Vec::new();
444
445        for child in children {
446            let layer_summary = self.get_layer_summary(&child.id).await?;
447
448            let grandchildren = if child.node_type == crate::types::NodeType::Directory {
449                Box::pin(self.get_children_tree(&child.uri, max_depth.saturating_sub(1))).await?
450            } else {
451                Vec::new()
452            };
453
454            tree_nodes.push(crate::types::TreeNode {
455                node: child,
456                children: grandchildren,
457                layer_summary,
458            });
459        }
460
461        Ok(tree_nodes)
462    }
463
464    async fn get_layer_summary(
465        &self,
466        node_id: &str,
467    ) -> MemoryResult<Option<crate::types::LayerSummary>> {
468        let l0 = self.get_layer(node_id, crate::types::LayerType::L0).await?;
469        let l1 = self.get_layer(node_id, crate::types::LayerType::L1).await?;
470        let has_l2 = self
471            .get_layer(node_id, crate::types::LayerType::L2)
472            .await?
473            .is_some();
474
475        if l0.is_none() && l1.is_none() && !has_l2 {
476            return Ok(None);
477        }
478
479        Ok(Some(crate::types::LayerSummary {
480            l0_preview: l0.map(|l| truncate_string(&l.content, 100)),
481            l1_preview: l1.map(|l| truncate_string(&l.content, 200)),
482            has_l2,
483        }))
484    }
485
486    pub async fn node_exists(&self, uri: &str) -> MemoryResult<bool> {
487        let conn = self.conn.lock().await;
488        let count: i64 = conn.query_row(
489            "SELECT COUNT(*) FROM memory_nodes WHERE uri = ?1",
490            params![uri],
491            |row| row.get(0),
492        )?;
493        Ok(count > 0)
494    }
495}
496
497fn row_to_knowledge_space(row: &Row) -> Result<KnowledgeSpaceRecord, rusqlite::Error> {
498    let scope = row
499        .get::<_, String>(1)?
500        .parse()
501        .unwrap_or(tandem_orchestrator::KnowledgeScope::Project);
502    let trust_level = row
503        .get::<_, String>(6)?
504        .parse()
505        .unwrap_or(tandem_orchestrator::KnowledgeTrustLevel::Promoted);
506    let metadata = row
507        .get::<_, Option<String>>(7)?
508        .and_then(|raw| serde_json::from_str(&raw).ok());
509    Ok(KnowledgeSpaceRecord {
510        id: row.get(0)?,
511        scope,
512        project_id: row.get(2)?,
513        namespace: row.get(3)?,
514        title: row.get(4)?,
515        description: row.get(5)?,
516        trust_level,
517        metadata,
518        created_at_ms: row.get::<_, i64>(8)? as u64,
519        updated_at_ms: row.get::<_, i64>(9)? as u64,
520    })
521}
522
523fn row_to_knowledge_item(row: &Row) -> Result<KnowledgeItemRecord, rusqlite::Error> {
524    let trust_level = row
525        .get::<_, String>(8)?
526        .parse()
527        .unwrap_or(tandem_orchestrator::KnowledgeTrustLevel::Promoted);
528    let status = row
529        .get::<_, String>(9)?
530        .parse()
531        .unwrap_or(KnowledgeItemStatus::Working);
532    let payload = row
533        .get::<_, String>(7)
534        .ok()
535        .and_then(|raw| serde_json::from_str(&raw).ok())
536        .unwrap_or(serde_json::Value::Null);
537    let artifact_refs = row
538        .get::<_, String>(11)
539        .ok()
540        .and_then(|raw| serde_json::from_str(&raw).ok())
541        .unwrap_or_default();
542    let source_memory_ids = row
543        .get::<_, String>(12)
544        .ok()
545        .and_then(|raw| serde_json::from_str(&raw).ok())
546        .unwrap_or_default();
547    let metadata = row
548        .get::<_, Option<String>>(14)?
549        .and_then(|raw| serde_json::from_str(&raw).ok());
550    Ok(KnowledgeItemRecord {
551        id: row.get(0)?,
552        space_id: row.get(1)?,
553        coverage_key: row.get(2)?,
554        dedupe_key: row.get(3)?,
555        item_type: row.get(4)?,
556        title: row.get(5)?,
557        summary: row.get(6)?,
558        payload,
559        trust_level,
560        status,
561        run_id: row.get(10)?,
562        artifact_refs,
563        source_memory_ids,
564        freshness_expires_at_ms: row.get::<_, Option<i64>>(13)?.map(|value| value as u64),
565        metadata,
566        created_at_ms: row.get::<_, i64>(15)? as u64,
567        updated_at_ms: row.get::<_, i64>(16)? as u64,
568    })
569}
570
571fn row_to_knowledge_coverage(row: &Row) -> Result<KnowledgeCoverageRecord, rusqlite::Error> {
572    let metadata = row
573        .get::<_, Option<String>>(7)?
574        .and_then(|raw| serde_json::from_str(&raw).ok());
575    Ok(KnowledgeCoverageRecord {
576        coverage_key: row.get(0)?,
577        space_id: row.get(1)?,
578        latest_item_id: row.get(2)?,
579        latest_dedupe_key: row.get(3)?,
580        last_seen_at_ms: row.get::<_, i64>(4)? as u64,
581        last_promoted_at_ms: row.get::<_, Option<i64>>(5)?.map(|value| value as u64),
582        freshness_expires_at_ms: row.get::<_, Option<i64>>(6)?.map(|value| value as u64),
583        metadata,
584    })
585}
586
587fn truncate_string(s: &str, max_len: usize) -> String {
588    if s.len() <= max_len {
589        s.to_string()
590    } else {
591        format!("{}...", &s[..max_len.saturating_sub(3)])
592    }
593}
594
595fn build_fts_query(query: &str) -> String {
596    let tokens = query
597        .split_whitespace()
598        .filter_map(|tok| {
599            let cleaned =
600                tok.trim_matches(|c: char| !c.is_ascii_alphanumeric() && c != '_' && c != '-');
601            if cleaned.is_empty() {
602                None
603            } else {
604                Some(format!("\"{}\"", cleaned))
605            }
606        })
607        .collect::<Vec<_>>();
608    if tokens.is_empty() {
609        "\"\"".to_string()
610    } else {
611        tokens.join(" OR ")
612    }
613}
614
615include!("memory_database_impl_parts/db_tests.rs");