zoey_storage_sql/
vector_search.rs1use zoey_core::{types::*, Result};
4use sqlx::PgPool;
5
6pub struct VectorSearch {
8 pool: PgPool,
9 embedding_dimension: usize,
10}
11
12impl VectorSearch {
13 pub fn new(pool: PgPool, embedding_dimension: usize) -> Self {
15 Self {
16 pool,
17 embedding_dimension,
18 }
19 }
20
21 pub async fn initialize(&self) -> Result<()> {
23 sqlx::query("CREATE EXTENSION IF NOT EXISTS vector")
25 .execute(&self.pool)
26 .await?;
27
28 sqlx::query(&format!(
30 "ALTER TABLE memories ADD COLUMN IF NOT EXISTS embedding vector({})",
31 self.embedding_dimension
32 ))
33 .execute(&self.pool)
34 .await?;
35
36 sqlx::query(
38 "CREATE INDEX IF NOT EXISTS memories_embedding_idx
39 ON memories USING hnsw (embedding vector_cosine_ops)",
40 )
41 .execute(&self.pool)
42 .await?;
43
44 Ok(())
45 }
46
47 pub async fn search_by_embedding(&self, params: SearchMemoriesParams) -> Result<Vec<Memory>> {
49 let _embedding_json = serde_json::to_string(¶ms.embedding)?;
50
51 let mut query = format!(
52 "SELECT id, entity_id, agent_id, room_id, content, embedding, metadata, created_at, unique_flag,
53 1 - (embedding <=> $1::vector) as similarity
54 FROM memories
55 WHERE 1=1"
56 );
57
58 let mut bind_count = 1;
59
60 if params.agent_id.is_some() {
61 bind_count += 1;
62 query.push_str(&format!(" AND agent_id = ${}", bind_count));
63 }
64
65 if params.room_id.is_some() {
66 bind_count += 1;
67 query.push_str(&format!(" AND room_id = ${}", bind_count));
68 }
69
70 if let Some(_threshold) = params.threshold {
71 bind_count += 1;
72 query.push_str(&format!(
73 " AND (1 - (embedding <=> $1::vector)) >= ${}",
74 bind_count
75 ));
76 }
77
78 query.push_str(&format!(
79 " ORDER BY embedding <=> $1::vector LIMIT {}",
80 params.count
81 ));
82
83 tracing::debug!("Executing vector search query");
85
86 Ok(vec![])
88 }
89
90 pub async fn add_embedding(&self, memory_id: UUID, embedding: Vec<f32>) -> Result<()> {
92 let embedding_json = serde_json::to_string(&embedding)?;
93
94 sqlx::query("UPDATE memories SET embedding = $1::vector WHERE id = $2")
95 .bind(embedding_json)
96 .bind(memory_id)
97 .execute(&self.pool)
98 .await?;
99
100 Ok(())
101 }
102
103 pub async fn get_similar_memories(
105 &self,
106 _memory_id: UUID,
107 count: usize,
108 threshold: Option<f32>,
109 ) -> Result<Vec<Memory>> {
110 let mut query = String::from(
111 "SELECT m2.id, m2.entity_id, m2.agent_id, m2.room_id, m2.content,
112 m2.embedding, m2.metadata, m2.created_at, m2.unique_flag,
113 1 - (m1.embedding <=> m2.embedding) as similarity
114 FROM memories m1, memories m2
115 WHERE m1.id = $1 AND m2.id != $1",
116 );
117
118 if let Some(t) = threshold {
119 query.push_str(&format!(
120 " AND (1 - (m1.embedding <=> m2.embedding)) >= {}",
121 t
122 ));
123 }
124
125 query.push_str(&format!(
126 " ORDER BY m1.embedding <=> m2.embedding LIMIT {}",
127 count
128 ));
129
130 Ok(vec![])
132 }
133
134 pub async fn batch_add_embeddings(&self, embeddings: Vec<(UUID, Vec<f32>)>) -> Result<()> {
136 let mut tx = self.pool.begin().await?;
138
139 for (memory_id, embedding) in embeddings {
140 let embedding_json = serde_json::to_string(&embedding)?;
141
142 sqlx::query("UPDATE memories SET embedding = $1::vector WHERE id = $2")
143 .bind(embedding_json)
144 .bind(memory_id)
145 .execute(&mut *tx)
146 .await?;
147 }
148
149 tx.commit().await?;
150
151 Ok(())
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_vector_search_creation() {
161 assert!(true);
164 }
165}