zoey_storage_sql/
vector_search.rs

1//! Vector search implementation using pgvector
2
3use zoey_core::{types::*, Result};
4use sqlx::PgPool;
5
6/// Vector search operations for PostgreSQL with pgvector
7pub struct VectorSearch {
8    pool: PgPool,
9    embedding_dimension: usize,
10}
11
12impl VectorSearch {
13    /// Create a new vector search instance
14    pub fn new(pool: PgPool, embedding_dimension: usize) -> Self {
15        Self {
16            pool,
17            embedding_dimension,
18        }
19    }
20
21    /// Initialize pgvector extension
22    pub async fn initialize(&self) -> Result<()> {
23        // Enable pgvector extension
24        sqlx::query("CREATE EXTENSION IF NOT EXISTS vector")
25            .execute(&self.pool)
26            .await?;
27
28        // Create vector column if not exists (would be part of schema migration)
29        sqlx::query(&format!(
30            "ALTER TABLE memories ADD COLUMN IF NOT EXISTS embedding vector({})",
31            self.embedding_dimension
32        ))
33        .execute(&self.pool)
34        .await?;
35
36        // Create HNSW index for fast similarity search
37        sqlx::query(
38            "CREATE INDEX IF NOT EXISTS memories_embedding_idx 
39             ON memories USING hnsw (embedding vector_cosine_ops)",
40        )
41        .execute(&self.pool)
42        .await?;
43
44        Ok(())
45    }
46
47    /// Search memories by embedding similarity
48    pub async fn search_by_embedding(&self, params: SearchMemoriesParams) -> Result<Vec<Memory>> {
49        let _embedding_json = serde_json::to_string(&params.embedding)?;
50
51        let mut query = format!(
52            "SELECT id, entity_id, agent_id, room_id, content, embedding, metadata, created_at, unique_flag,
53                    1 - (embedding <=> $1::vector) as similarity
54             FROM memories
55             WHERE 1=1"
56        );
57
58        let mut bind_count = 1;
59
60        if params.agent_id.is_some() {
61            bind_count += 1;
62            query.push_str(&format!(" AND agent_id = ${}", bind_count));
63        }
64
65        if params.room_id.is_some() {
66            bind_count += 1;
67            query.push_str(&format!(" AND room_id = ${}", bind_count));
68        }
69
70        if let Some(_threshold) = params.threshold {
71            bind_count += 1;
72            query.push_str(&format!(
73                " AND (1 - (embedding <=> $1::vector)) >= ${}",
74                bind_count
75            ));
76        }
77
78        query.push_str(&format!(
79            " ORDER BY embedding <=> $1::vector LIMIT {}",
80            params.count
81        ));
82
83        // Execute query (simplified - would need proper parameter binding)
84        tracing::debug!("Executing vector search query");
85
86        // In real implementation, would properly bind all parameters
87        Ok(vec![])
88    }
89
90    /// Add embedding to existing memory
91    pub async fn add_embedding(&self, memory_id: UUID, embedding: Vec<f32>) -> Result<()> {
92        let embedding_json = serde_json::to_string(&embedding)?;
93
94        sqlx::query("UPDATE memories SET embedding = $1::vector WHERE id = $2")
95            .bind(embedding_json)
96            .bind(memory_id)
97            .execute(&self.pool)
98            .await?;
99
100        Ok(())
101    }
102
103    /// Get similar memories for a given memory
104    pub async fn get_similar_memories(
105        &self,
106        _memory_id: UUID,
107        count: usize,
108        threshold: Option<f32>,
109    ) -> Result<Vec<Memory>> {
110        let mut query = String::from(
111            "SELECT m2.id, m2.entity_id, m2.agent_id, m2.room_id, m2.content, 
112                    m2.embedding, m2.metadata, m2.created_at, m2.unique_flag,
113                    1 - (m1.embedding <=> m2.embedding) as similarity
114             FROM memories m1, memories m2
115             WHERE m1.id = $1 AND m2.id != $1",
116        );
117
118        if let Some(t) = threshold {
119            query.push_str(&format!(
120                " AND (1 - (m1.embedding <=> m2.embedding)) >= {}",
121                t
122            ));
123        }
124
125        query.push_str(&format!(
126            " ORDER BY m1.embedding <=> m2.embedding LIMIT {}",
127            count
128        ));
129
130        // In real implementation, would execute and return results
131        Ok(vec![])
132    }
133
134    /// Batch insert embeddings
135    pub async fn batch_add_embeddings(&self, embeddings: Vec<(UUID, Vec<f32>)>) -> Result<()> {
136        // Use a transaction for batch operations
137        let mut tx = self.pool.begin().await?;
138
139        for (memory_id, embedding) in embeddings {
140            let embedding_json = serde_json::to_string(&embedding)?;
141
142            sqlx::query("UPDATE memories SET embedding = $1::vector WHERE id = $2")
143                .bind(embedding_json)
144                .bind(memory_id)
145                .execute(&mut *tx)
146                .await?;
147        }
148
149        tx.commit().await?;
150
151        Ok(())
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_vector_search_creation() {
161        // This is a compilation test
162        // Real tests would require PostgreSQL with pgvector
163        assert!(true);
164    }
165}