1use std::path::{Path, PathBuf};
2use rusqlite::{Connection, params};
3use crate::types::{Drawer, SearchResult};
4
5pub struct VectorStorage {
6 db_path: PathBuf,
7}
8
9impl VectorStorage {
10 pub fn new(db_path: &Path) -> anyhow::Result<Self> {
11 let conn = open_conn(db_path)?;
12 conn.execute_batch(
13 "PRAGMA journal_mode=WAL;
14 CREATE TABLE IF NOT EXISTS drawers (
15 id TEXT PRIMARY KEY,
16 content TEXT NOT NULL,
17 wing TEXT NOT NULL,
18 room TEXT NOT NULL,
19 source_file TEXT NOT NULL DEFAULT 'mcp',
20 source_mtime INTEGER NOT NULL DEFAULT 0,
21 chunk_index INTEGER NOT NULL DEFAULT 0,
22 added_by TEXT NOT NULL DEFAULT 'mcp',
23 filed_at TEXT NOT NULL,
24 hall TEXT DEFAULT '',
25 topic TEXT DEFAULT '',
26 drawer_type TEXT DEFAULT '',
27 agent TEXT DEFAULT '',
28 date TEXT DEFAULT '',
29 importance REAL DEFAULT 3.0,
30 vector TEXT NOT NULL
31 );",
32 )?;
33 Ok(Self { db_path: db_path.to_path_buf() })
34 }
35
36 pub fn add_drawer(&self, drawer: &Drawer, vector: &[f32]) -> anyhow::Result<()> {
37 let conn = open_conn(&self.db_path)?;
38 let vec_json = serde_json::to_string(vector)?;
39 conn.execute(
40 "INSERT OR REPLACE INTO drawers
41 (id, content, wing, room, source_file, source_mtime, chunk_index, added_by,
42 filed_at, hall, topic, drawer_type, agent, date, importance, vector)
43 VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16)",
44 params![
45 drawer.id, drawer.content, drawer.wing, drawer.room,
46 drawer.source_file, drawer.source_mtime, drawer.chunk_index, drawer.added_by,
47 drawer.filed_at, drawer.hall, drawer.topic, drawer.drawer_type,
48 drawer.agent, drawer.date, drawer.importance, vec_json
49 ],
50 )?;
51 Ok(())
52 }
53
54 pub fn delete_drawer(&self, id: &str) -> anyhow::Result<usize> {
55 let conn = open_conn(&self.db_path)?;
56 let n = conn.execute("DELETE FROM drawers WHERE id = ?1", params![id])?;
57 Ok(n)
58 }
59
60 pub fn search(
61 &self,
62 query_vec: &[f32],
63 limit: usize,
64 wing: Option<&str>,
65 room: Option<&str>,
66 ) -> anyhow::Result<Vec<SearchResult>> {
67 let conn = open_conn(&self.db_path)?;
68 let mut stmt = conn.prepare(
69 "SELECT id, content, wing, room, source_file, source_mtime, chunk_index,
70 added_by, filed_at, hall, topic, drawer_type, agent, date, importance, vector
71 FROM drawers",
72 )?;
73 let rows = stmt.query_map([], |row| {
74 Ok((
75 row_to_drawer(row)?,
76 row.get::<_, String>(15)?,
77 ))
78 })?;
79
80 let mut results: Vec<SearchResult> = Vec::new();
81 for row in rows {
82 let (drawer, vec_json) = row?;
83 if wing.is_some_and(|w| drawer.wing != w) {
84 continue;
85 }
86 if room.is_some_and(|r| drawer.room != r) {
87 continue;
88 }
89 let stored_vec: Vec<f32> = serde_json::from_str(&vec_json).unwrap_or_default();
90 let sim = cosine_similarity(query_vec, &stored_vec);
91 results.push(SearchResult { drawer, similarity: sim });
92 }
93 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
94 results.truncate(limit);
95 Ok(results)
96 }
97
98 pub fn get_all(
99 &self,
100 wing: Option<&str>,
101 room: Option<&str>,
102 limit: usize,
103 ) -> anyhow::Result<Vec<Drawer>> {
104 let conn = open_conn(&self.db_path)?;
105 let base = "SELECT id,content,wing,room,source_file,source_mtime,chunk_index,added_by,\
106 filed_at,hall,topic,drawer_type,agent,date,importance,vector FROM drawers";
107 let mut conditions: Vec<String> = Vec::new();
108 let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
109
110 if let Some(w) = wing {
111 conditions.push(format!("wing=?{}", param_values.len() + 1));
112 param_values.push(Box::new(w.to_string()));
113 }
114 if let Some(r) = room {
115 conditions.push(format!("room=?{}", param_values.len() + 1));
116 param_values.push(Box::new(r.to_string()));
117 }
118 let limit_idx = param_values.len() + 1;
119 param_values.push(Box::new(limit as i64));
120
121 let sql = if conditions.is_empty() {
122 format!("{} ORDER BY importance DESC LIMIT ?{}", base, limit_idx)
123 } else {
124 format!("{} WHERE {} ORDER BY importance DESC LIMIT ?{}", base, conditions.join(" AND "), limit_idx)
125 };
126
127 let mut stmt = conn.prepare(&sql)?;
128 let refs: Vec<&dyn rusqlite::ToSql> = param_values.iter().map(|p| p.as_ref()).collect();
129 let drawers = stmt.query_map(refs.as_slice(), row_to_drawer)?
130 .filter_map(|r| r.ok())
131 .collect();
132 Ok(drawers)
133 }
134
135 pub fn count(&self) -> anyhow::Result<usize> {
136 let conn = open_conn(&self.db_path)?;
137 let n: i64 = conn.query_row("SELECT COUNT(*) FROM drawers", [], |row| row.get(0))?;
138 Ok(n as usize)
139 }
140
141 pub fn list_wings(&self) -> anyhow::Result<Vec<(String, usize)>> {
142 let conn = open_conn(&self.db_path)?;
143 let mut stmt = conn.prepare("SELECT wing, COUNT(*) as cnt FROM drawers GROUP BY wing ORDER BY cnt DESC")?;
144 let rows = stmt.query_map([], |row| {
145 Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)? as usize))
146 })?.filter_map(|r| r.ok()).collect();
147 Ok(rows)
148 }
149
150 pub fn list_rooms(&self, wing: Option<&str>) -> anyhow::Result<Vec<(String, String, usize)>> {
151 let conn = open_conn(&self.db_path)?;
152 let rows = if let Some(w) = wing {
153 let mut stmt = conn.prepare(
154 "SELECT wing, room, COUNT(*) as cnt FROM drawers WHERE wing=?1 GROUP BY wing, room ORDER BY cnt DESC"
155 )?;
156 stmt.query_map(params![w], |row| {
157 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as usize))
158 })?.filter_map(|r| r.ok()).collect()
159 } else {
160 let mut stmt = conn.prepare(
161 "SELECT wing, room, COUNT(*) as cnt FROM drawers GROUP BY wing, room ORDER BY cnt DESC"
162 )?;
163 stmt.query_map([], |row| {
164 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as usize))
165 })?.filter_map(|r| r.ok()).collect()
166 };
167 Ok(rows)
168 }
169
170 pub fn get_taxonomy(&self) -> anyhow::Result<serde_json::Value> {
171 let rooms = self.list_rooms(None)?;
172 let mut map = serde_json::Map::new();
173 for (wing, room, count) in rooms {
174 let wing_entry = map.entry(wing).or_insert_with(|| serde_json::json!({}));
175 if let serde_json::Value::Object(wing_map) = wing_entry {
176 wing_map.insert(room, serde_json::json!(count));
177 }
178 }
179 Ok(serde_json::Value::Object(map))
180 }
181}
182
183fn open_conn(path: &Path) -> anyhow::Result<Connection> {
184 let conn = Connection::open(path)?;
185 conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
186 Ok(conn)
187}
188
189fn row_to_drawer(row: &rusqlite::Row<'_>) -> rusqlite::Result<Drawer> {
190 Ok(Drawer {
191 id: row.get(0)?,
192 content: row.get(1)?,
193 wing: row.get(2)?,
194 room: row.get(3)?,
195 source_file: row.get(4)?,
196 source_mtime: row.get(5)?,
197 chunk_index: row.get(6)?,
198 added_by: row.get(7)?,
199 filed_at: row.get(8)?,
200 hall: row.get(9)?,
201 topic: row.get(10)?,
202 drawer_type: row.get(11)?,
203 agent: row.get(12)?,
204 date: row.get(13)?,
205 importance: row.get(14)?,
206 })
207}
208
209pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
210 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
211 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
212 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
213 if norm_a == 0.0 || norm_b == 0.0 {
214 return 0.0;
215 }
216 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
217}