Skip to main content

steel_memory_lib/storage/
vector.rs

1use std::path::{Path, PathBuf};
2use rusqlite::{Connection, params};
3use crate::types::{Drawer, SearchResult};
4
5pub struct VectorStorage {
6    db_path: PathBuf,
7}
8
9impl VectorStorage {
10    pub fn new(db_path: &Path) -> anyhow::Result<Self> {
11        let conn = open_conn(db_path)?;
12        conn.execute_batch(
13            "PRAGMA journal_mode=WAL;
14             CREATE TABLE IF NOT EXISTS drawers (
15                 id TEXT PRIMARY KEY,
16                 content TEXT NOT NULL,
17                 wing TEXT NOT NULL,
18                 room TEXT NOT NULL,
19                 source_file TEXT NOT NULL DEFAULT 'mcp',
20                 source_mtime INTEGER NOT NULL DEFAULT 0,
21                 chunk_index INTEGER NOT NULL DEFAULT 0,
22                 added_by TEXT NOT NULL DEFAULT 'mcp',
23                 filed_at TEXT NOT NULL,
24                 hall TEXT DEFAULT '',
25                 topic TEXT DEFAULT '',
26                 drawer_type TEXT DEFAULT '',
27                 agent TEXT DEFAULT '',
28                 date TEXT DEFAULT '',
29                 importance REAL DEFAULT 3.0,
30                 vector TEXT NOT NULL
31             );",
32        )?;
33        Ok(Self { db_path: db_path.to_path_buf() })
34    }
35
36    pub fn add_drawer(&self, drawer: &Drawer, vector: &[f32]) -> anyhow::Result<()> {
37        let conn = open_conn(&self.db_path)?;
38        let vec_json = serde_json::to_string(vector)?;
39        conn.execute(
40            "INSERT OR REPLACE INTO drawers
41             (id, content, wing, room, source_file, source_mtime, chunk_index, added_by,
42              filed_at, hall, topic, drawer_type, agent, date, importance, vector)
43             VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16)",
44            params![
45                drawer.id, drawer.content, drawer.wing, drawer.room,
46                drawer.source_file, drawer.source_mtime, drawer.chunk_index, drawer.added_by,
47                drawer.filed_at, drawer.hall, drawer.topic, drawer.drawer_type,
48                drawer.agent, drawer.date, drawer.importance, vec_json
49            ],
50        )?;
51        Ok(())
52    }
53
54    pub fn delete_drawer(&self, id: &str) -> anyhow::Result<usize> {
55        let conn = open_conn(&self.db_path)?;
56        let n = conn.execute("DELETE FROM drawers WHERE id = ?1", params![id])?;
57        Ok(n)
58    }
59
60    pub fn search(
61        &self,
62        query_vec: &[f32],
63        limit: usize,
64        wing: Option<&str>,
65        room: Option<&str>,
66    ) -> anyhow::Result<Vec<SearchResult>> {
67        let conn = open_conn(&self.db_path)?;
68        let mut stmt = conn.prepare(
69            "SELECT id, content, wing, room, source_file, source_mtime, chunk_index,
70                    added_by, filed_at, hall, topic, drawer_type, agent, date, importance, vector
71             FROM drawers",
72        )?;
73        let rows = stmt.query_map([], |row| {
74            Ok((
75                row_to_drawer(row)?,
76                row.get::<_, String>(15)?,
77            ))
78        })?;
79
80        let mut results: Vec<SearchResult> = Vec::new();
81        for row in rows {
82            let (drawer, vec_json) = row?;
83            if wing.is_some_and(|w| drawer.wing != w) {
84                continue;
85            }
86            if room.is_some_and(|r| drawer.room != r) {
87                continue;
88            }
89            let stored_vec: Vec<f32> = serde_json::from_str(&vec_json).unwrap_or_default();
90            let sim = cosine_similarity(query_vec, &stored_vec);
91            results.push(SearchResult { drawer, similarity: sim });
92        }
93        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
94        results.truncate(limit);
95        Ok(results)
96    }
97
98    pub fn get_all(
99        &self,
100        wing: Option<&str>,
101        room: Option<&str>,
102        limit: usize,
103    ) -> anyhow::Result<Vec<Drawer>> {
104        let conn = open_conn(&self.db_path)?;
105        let base = "SELECT id,content,wing,room,source_file,source_mtime,chunk_index,added_by,\
106                    filed_at,hall,topic,drawer_type,agent,date,importance,vector FROM drawers";
107        let mut conditions: Vec<String> = Vec::new();
108        let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
109
110        if let Some(w) = wing {
111            conditions.push(format!("wing=?{}", param_values.len() + 1));
112            param_values.push(Box::new(w.to_string()));
113        }
114        if let Some(r) = room {
115            conditions.push(format!("room=?{}", param_values.len() + 1));
116            param_values.push(Box::new(r.to_string()));
117        }
118        let limit_idx = param_values.len() + 1;
119        param_values.push(Box::new(limit as i64));
120
121        let sql = if conditions.is_empty() {
122            format!("{} ORDER BY importance DESC LIMIT ?{}", base, limit_idx)
123        } else {
124            format!("{} WHERE {} ORDER BY importance DESC LIMIT ?{}", base, conditions.join(" AND "), limit_idx)
125        };
126
127        let mut stmt = conn.prepare(&sql)?;
128        let refs: Vec<&dyn rusqlite::ToSql> = param_values.iter().map(|p| p.as_ref()).collect();
129        let drawers = stmt.query_map(refs.as_slice(), row_to_drawer)?
130            .filter_map(|r| r.ok())
131            .collect();
132        Ok(drawers)
133    }
134
135    pub fn count(&self) -> anyhow::Result<usize> {
136        let conn = open_conn(&self.db_path)?;
137        let n: i64 = conn.query_row("SELECT COUNT(*) FROM drawers", [], |row| row.get(0))?;
138        Ok(n as usize)
139    }
140
141    pub fn list_wings(&self) -> anyhow::Result<Vec<(String, usize)>> {
142        let conn = open_conn(&self.db_path)?;
143        let mut stmt = conn.prepare("SELECT wing, COUNT(*) as cnt FROM drawers GROUP BY wing ORDER BY cnt DESC")?;
144        let rows = stmt.query_map([], |row| {
145            Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)? as usize))
146        })?.filter_map(|r| r.ok()).collect();
147        Ok(rows)
148    }
149
150    pub fn list_rooms(&self, wing: Option<&str>) -> anyhow::Result<Vec<(String, String, usize)>> {
151        let conn = open_conn(&self.db_path)?;
152        let rows = if let Some(w) = wing {
153            let mut stmt = conn.prepare(
154                "SELECT wing, room, COUNT(*) as cnt FROM drawers WHERE wing=?1 GROUP BY wing, room ORDER BY cnt DESC"
155            )?;
156            stmt.query_map(params![w], |row| {
157                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as usize))
158            })?.filter_map(|r| r.ok()).collect()
159        } else {
160            let mut stmt = conn.prepare(
161                "SELECT wing, room, COUNT(*) as cnt FROM drawers GROUP BY wing, room ORDER BY cnt DESC"
162            )?;
163            stmt.query_map([], |row| {
164                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as usize))
165            })?.filter_map(|r| r.ok()).collect()
166        };
167        Ok(rows)
168    }
169
170    pub fn get_taxonomy(&self) -> anyhow::Result<serde_json::Value> {
171        let rooms = self.list_rooms(None)?;
172        let mut map = serde_json::Map::new();
173        for (wing, room, count) in rooms {
174            let wing_entry = map.entry(wing).or_insert_with(|| serde_json::json!({}));
175            if let serde_json::Value::Object(wing_map) = wing_entry {
176                wing_map.insert(room, serde_json::json!(count));
177            }
178        }
179        Ok(serde_json::Value::Object(map))
180    }
181}
182
183fn open_conn(path: &Path) -> anyhow::Result<Connection> {
184    let conn = Connection::open(path)?;
185    conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
186    Ok(conn)
187}
188
189fn row_to_drawer(row: &rusqlite::Row<'_>) -> rusqlite::Result<Drawer> {
190    Ok(Drawer {
191        id: row.get(0)?,
192        content: row.get(1)?,
193        wing: row.get(2)?,
194        room: row.get(3)?,
195        source_file: row.get(4)?,
196        source_mtime: row.get(5)?,
197        chunk_index: row.get(6)?,
198        added_by: row.get(7)?,
199        filed_at: row.get(8)?,
200        hall: row.get(9)?,
201        topic: row.get(10)?,
202        drawer_type: row.get(11)?,
203        agent: row.get(12)?,
204        date: row.get(13)?,
205        importance: row.get(14)?,
206    })
207}
208
209pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
210    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
211    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
212    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
213    if norm_a == 0.0 || norm_b == 0.0 {
214        return 0.0;
215    }
216    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
217}