Skip to main content

sh_layer3/
vector_store.rs

1//! # Vector Store
2//!
3//! 向量存储:持久化向量索引。
4//!
5//! ## 功能
6//!
7//! - 内存向量存储(适合测试和开发)
8//! - 文件持久化向量存储(适合生产环境)
9//! - 多种距离度量支持(Cosine, Euclidean, DotProduct, Manhattan)
10//! - 批量操作优化(并行处理)
11//! - 压缩持久化(可选)
12//! - 异步持久化支持
13//! - 与 RetrieverEngine 无缝集成
14
15use crate::retriever_engine::RetrievalResult;
16use crate::types::{Layer3Error, Layer3Result};
17use async_trait::async_trait;
18use parking_lot::RwLock;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23use tracing::{debug, info, instrument, warn};
24
25/// 向量存储 trait
26///
27/// 定义向量持久化存储接口。
28///
29/// # Example
30///
31/// ```rust,no_run
32/// use sh_layer3::vector_store::{VectorStore, VectorItem, InMemoryVectorStore};
33///
34/// #[tokio::main]
35/// async fn main() {
36///     let store = InMemoryVectorStore::in_memory();
37///
38///     // 添加向量
39///     let item = VectorItem::new("doc-1", vec![0.1, 0.2, 0.3])
40///         .with_content("Hello world");
41///     store.add_batch(vec![item]).await.unwrap();
42///
43///     // 查询相似向量
44///     let results = store.query(vec![0.1, 0.2, 0.3], 5).await.unwrap();
45/// }
46/// ```
47#[async_trait]
48pub trait VectorStore: Send + Sync {
49    /// 添加向量
50    async fn add(
51        &self,
52        id: String,
53        vector: Vec<f32>,
54        metadata: HashMap<String, serde_json::Value>,
55    ) -> Layer3Result<bool>;
56
57    /// 批量添加向量(优化版本)
58    ///
59    /// 使用并行处理和减少锁争用优化大批量插入。
60    async fn add_batch(&self, items: Vec<VectorItem>) -> Layer3Result<Vec<bool>>;
61
62    /// 添加向量(带验证)
63    ///
64    /// 验证向量维度并返回详细错误信息。
65    async fn add_validated(
66        &self,
67        id: String,
68        vector: Vec<f32>,
69        metadata: HashMap<String, serde_json::Value>,
70        expected_dimension: usize,
71    ) -> Layer3Result<bool> {
72        if vector.len() != expected_dimension {
73            return Err(Layer3Error::VectorDimensionMismatch {
74                expected: expected_dimension,
75                actual: vector.len(),
76            }
77            .into());
78        }
79        self.add(id, vector, metadata).await
80    }
81
82    /// 查询相似向量
83    async fn query(&self, vector: Vec<f32>, top_k: usize) -> Layer3Result<Vec<RetrievalResult>>;
84
85    /// 带过滤条件的查询
86    async fn query_with_filter(
87        &self,
88        vector: Vec<f32>,
89        top_k: usize,
90        filter: Option<MetadataFilter>,
91    ) -> Layer3Result<Vec<RetrievalResult>> {
92        // 默认实现:调用基本查询
93        let _ = filter;
94        self.query(vector, top_k).await
95    }
96
97    /// 带分数阈值的查询
98    ///
99    /// 只返回分数高于阈值的向量。
100    async fn query_with_threshold(
101        &self,
102        vector: Vec<f32>,
103        top_k: usize,
104        min_score: f32,
105    ) -> Layer3Result<Vec<RetrievalResult>> {
106        let results = self.query(vector, top_k).await?;
107        Ok(results
108            .into_iter()
109            .filter(|r| r.score >= min_score)
110            .collect())
111    }
112
113    /// 删除向量
114    async fn delete(&self, id: &str) -> Layer3Result<bool>;
115
116    /// 批量删除(优化版本)
117    async fn delete_batch(&self, ids: &[String]) -> Layer3Result<usize>;
118
119    /// 删除所有匹配元数据条件的向量
120    async fn delete_by_filter(&self, filter: MetadataFilter) -> Layer3Result<usize> {
121        let _ = filter;
122        Err(Layer3Error::VectorStoreError("delete_by_filter not implemented".to_string()).into())
123    }
124
125    /// 获取向量
126    async fn get(&self, id: &str) -> Layer3Result<Option<VectorItem>>;
127
128    /// 批量获取向量
129    async fn get_batch(&self, ids: &[String]) -> Layer3Result<Vec<Option<VectorItem>>> {
130        let mut results = Vec::with_capacity(ids.len());
131        for id in ids {
132            results.push(self.get(id).await?);
133        }
134        Ok(results)
135    }
136
137    /// 更新向量(存在则更新,不存在则创建)
138    async fn upsert(
139        &self,
140        id: String,
141        vector: Vec<f32>,
142        metadata: HashMap<String, serde_json::Value>,
143    ) -> Layer3Result<bool> {
144        self.add(id, vector, metadata).await
145    }
146
147    /// 统计数量
148    async fn count(&self) -> Layer3Result<usize>;
149
150    /// 清空存储
151    async fn clear(&self) -> Layer3Result<bool>;
152
153    /// 检查向量是否存在
154    async fn exists(&self, id: &str) -> Layer3Result<bool> {
155        Ok(self.get(id).await?.is_some())
156    }
157
158    /// 获取存储统计信息
159    async fn stats(&self) -> Layer3Result<VectorStoreStats> {
160        Ok(VectorStoreStats {
161            count: self.count().await?,
162            dimension: 0,
163            metric: DistanceMetric::Cosine,
164        })
165    }
166
167    /// 持久化到磁盘(可选)
168    async fn persist(&self) -> Layer3Result<()> {
169        Ok(())
170    }
171
172    /// 从磁盘加载(可选)
173    async fn load(&self) -> Layer3Result<()> {
174        Ok(())
175    }
176
177    /// 异步持久化(后台线程)
178    async fn persist_async(&self) -> Layer3Result<()> {
179        self.persist().await
180    }
181
182    /// 强制同步持久化
183    fn persist_sync(&self) -> Layer3Result<()> {
184        Ok(())
185    }
186
187    /// 验证向量维度
188    fn validate_dimension(&self, vector: &[f32], expected: usize) -> Layer3Result<()> {
189        if vector.len() != expected {
190            Err(Layer3Error::VectorDimensionMismatch {
191                expected,
192                actual: vector.len(),
193            }
194            .into())
195        } else {
196            Ok(())
197        }
198    }
199}
200
201/// 向量存储统计信息
202#[derive(Debug, Clone)]
203pub struct VectorStoreStats {
204    pub count: usize,
205    pub dimension: usize,
206    pub metric: DistanceMetric,
207}
208
209/// 元数据过滤条件
210#[derive(Debug, Clone)]
211pub struct MetadataFilter {
212    /// 必须包含的键值对
213    pub must: HashMap<String, serde_json::Value>,
214    /// 可选包含的键值对(至少匹配一个)
215    pub should: HashMap<String, serde_json::Value>,
216    /// 不能包含的键值对
217    pub must_not: HashMap<String, serde_json::Value>,
218}
219
220impl MetadataFilter {
221    pub fn new() -> Self {
222        Self {
223            must: HashMap::new(),
224            should: HashMap::new(),
225            must_not: HashMap::new(),
226        }
227    }
228
229    pub fn must(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
230        self.must.insert(key.into(), value);
231        self
232    }
233
234    pub fn should(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
235        self.should.insert(key.into(), value);
236        self
237    }
238
239    pub fn must_not(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
240        self.must_not.insert(key.into(), value);
241        self
242    }
243
244    /// 检查元数据是否匹配过滤条件
245    pub fn matches(&self, metadata: &HashMap<String, serde_json::Value>) -> bool {
246        // 检查 must 条件(全部匹配)
247        for (key, value) in &self.must {
248            match metadata.get(key) {
249                Some(v) if v == value => continue,
250                _ => return false,
251            }
252        }
253
254        // 检查 must_not 条件(全部不匹配)
255        for (key, value) in &self.must_not {
256            if let Some(v) = metadata.get(key) {
257                if v == value {
258                    return false;
259                }
260            }
261        }
262
263        // 检查 should 条件(至少一个匹配,如果为空则通过)
264        if !self.should.is_empty() {
265            let mut matched = false;
266            for (key, value) in &self.should {
267                if let Some(v) = metadata.get(key) {
268                    if v == value {
269                        matched = true;
270                        break;
271                    }
272                }
273            }
274            if !matched {
275                return false;
276            }
277        }
278
279        true
280    }
281}
282
283impl Default for MetadataFilter {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289/// 向量项
290#[derive(Debug, Clone)]
291pub struct VectorItem {
292    /// 唯一 ID
293    pub id: String,
294    /// 向量数据
295    pub vector: Vec<f32>,
296    /// 元数据
297    pub metadata: HashMap<String, serde_json::Value>,
298    /// 关联内容(可选)
299    pub content: Option<String>,
300}
301
302impl VectorItem {
303    pub fn new(id: impl Into<String>, vector: Vec<f32>) -> Self {
304        Self {
305            id: id.into(),
306            vector,
307            metadata: HashMap::new(),
308            content: None,
309        }
310    }
311
312    pub fn with_content(mut self, content: impl Into<String>) -> Self {
313        self.content = Some(content.into());
314        self
315    }
316
317    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
318        self.metadata = metadata;
319        self
320    }
321}
322
323/// 向量存储配置
324#[derive(Debug, Clone)]
325pub struct VectorStoreConfig {
326    /// 存储路径
327    pub path: Option<String>,
328    /// 向量维度
329    pub dimension: usize,
330    /// 距离度量
331    pub metric: DistanceMetric,
332    /// 索引类型
333    pub index_type: IndexType,
334}
335
336impl Default for VectorStoreConfig {
337    fn default() -> Self {
338        Self {
339            path: None,
340            dimension: 1536,
341            metric: DistanceMetric::Cosine,
342            index_type: IndexType::Hnsw,
343        }
344    }
345}
346
347/// 距离度量类型
348#[derive(Debug, Clone, Copy, PartialEq, Eq)]
349pub enum DistanceMetric {
350    /// 余弦相似度
351    Cosine,
352    /// 欧几里得距离
353    Euclidean,
354    /// 点积
355    DotProduct,
356    /// 曼哈顿距离
357    Manhattan,
358}
359
360/// 索引类型
361#[derive(Debug, Clone, Copy, PartialEq, Eq)]
362pub enum IndexType {
363    /// HNSW(高效近似最近邻)
364    Hnsw,
365    /// IVF(倒排文件索引)
366    Ivf,
367    /// Flat(暴力搜索)
368    Flat,
369    /// PQ(乘积量化)
370    ProductQuantization,
371}
372
373/// 向量存储工厂 trait
374pub trait VectorStoreFactory: Send + Sync {
375    /// 创建向量存储
376    fn create(&self, config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>>;
377}
378
379// ============================================================================
380// In-Memory Vector Store Implementation
381// ============================================================================
382
383/// 内存向量存储实现
384///
385/// 使用内存存储向量,支持基本的相似度搜索。
386/// 适用于测试和开发环境,不适合大规模生产使用。
387pub struct InMemoryVectorStore {
388    /// 向量数据存储
389    data: Arc<RwLock<HashMap<String, VectorItem>>>,
390    /// 配置
391    config: VectorStoreConfig,
392}
393
394impl InMemoryVectorStore {
395    /// 创建新的内存向量存储
396    pub fn new(config: VectorStoreConfig) -> Self {
397        Self {
398            data: Arc::new(RwLock::new(HashMap::new())),
399            config,
400        }
401    }
402
403    /// 创建使用默认配置的存储
404    pub fn in_memory() -> Self {
405        Self::new(VectorStoreConfig::default())
406    }
407
408    /// 计算向量相似度
409    fn compute_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
410        if a.len() != b.len() {
411            return 0.0;
412        }
413
414        match self.config.metric {
415            DistanceMetric::Cosine => {
416                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
417                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
418                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
419                if norm_a == 0.0 || norm_b == 0.0 {
420                    0.0
421                } else {
422                    dot / (norm_a * norm_b)
423                }
424            }
425            DistanceMetric::Euclidean => {
426                let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
427                1.0 / (1.0 + sum.sqrt())
428            }
429            DistanceMetric::DotProduct => a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(),
430            DistanceMetric::Manhattan => {
431                let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
432                1.0 / (1.0 + sum)
433            }
434        }
435    }
436}
437
438#[async_trait]
439impl VectorStore for InMemoryVectorStore {
440    async fn add(
441        &self,
442        id: String,
443        vector: Vec<f32>,
444        metadata: HashMap<String, serde_json::Value>,
445    ) -> Layer3Result<bool> {
446        let item = VectorItem {
447            id: id.clone(),
448            vector,
449            metadata,
450            content: None,
451        };
452        let mut data = self.data.write();
453        data.insert(id, item);
454        Ok(true)
455    }
456
457    async fn add_batch(&self, items: Vec<VectorItem>) -> Layer3Result<Vec<bool>> {
458        let mut data = self.data.write();
459        let results: Vec<bool> = items
460            .into_iter()
461            .map(|item| {
462                let id = item.id.clone();
463                data.insert(id, item);
464                true
465            })
466            .collect();
467        Ok(results)
468    }
469
470    async fn query(&self, vector: Vec<f32>, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
471        let data = self.data.read();
472        let mut scores: Vec<(String, f32, &VectorItem)> = data
473            .iter()
474            .map(|(id, item)| {
475                let score = self.compute_similarity(&vector, &item.vector);
476                (id.clone(), score, item)
477            })
478            .collect();
479
480        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
481        scores.truncate(top_k);
482
483        Ok(scores
484            .into_iter()
485            .map(|(doc_id, score, item)| RetrievalResult {
486                doc_id,
487                content: item.content.clone().unwrap_or_default(),
488                score,
489                metadata: item.metadata.clone(),
490                source: item
491                    .metadata
492                    .get("source")
493                    .and_then(|v| v.as_str())
494                    .map(String::from),
495            })
496            .collect())
497    }
498
499    async fn delete(&self, id: &str) -> Layer3Result<bool> {
500        let mut data = self.data.write();
501        Ok(data.remove(id).is_some())
502    }
503
504    async fn delete_batch(&self, ids: &[String]) -> Layer3Result<usize> {
505        let mut data = self.data.write();
506        let mut count = 0;
507        for id in ids {
508            if data.remove(id).is_some() {
509                count += 1;
510            }
511        }
512        Ok(count)
513    }
514
515    async fn get(&self, id: &str) -> Layer3Result<Option<VectorItem>> {
516        let data = self.data.read();
517        Ok(data.get(id).cloned())
518    }
519
520    async fn count(&self) -> Layer3Result<usize> {
521        let data = self.data.read();
522        Ok(data.len())
523    }
524
525    async fn clear(&self) -> Layer3Result<bool> {
526        let mut data = self.data.write();
527        data.clear();
528        Ok(true)
529    }
530
531    async fn query_with_filter(
532        &self,
533        vector: Vec<f32>,
534        top_k: usize,
535        filter: Option<MetadataFilter>,
536    ) -> Layer3Result<Vec<RetrievalResult>> {
537        let data = self.data.read();
538
539        // 先过滤,再计算相似度
540        let candidates: Vec<&VectorItem> = if let Some(ref f) = filter {
541            data.values()
542                .filter(|item| f.matches(&item.metadata))
543                .collect()
544        } else {
545            data.values().collect()
546        };
547
548        let mut scores: Vec<(String, f32, &VectorItem)> = candidates
549            .into_iter()
550            .map(|item| {
551                let score = self.compute_similarity(&vector, &item.vector);
552                (item.id.clone(), score, item)
553            })
554            .collect();
555
556        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
557        scores.truncate(top_k);
558
559        Ok(scores
560            .into_iter()
561            .map(|(doc_id, score, item)| RetrievalResult {
562                doc_id,
563                content: item.content.clone().unwrap_or_default(),
564                score,
565                metadata: item.metadata.clone(),
566                source: item
567                    .metadata
568                    .get("source")
569                    .and_then(|v| v.as_str())
570                    .map(String::from),
571            })
572            .collect())
573    }
574}
575
576/// 内存向量存储工厂
577pub struct InMemoryVectorStoreFactory;
578
579impl VectorStoreFactory for InMemoryVectorStoreFactory {
580    fn create(&self, config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
581        Ok(Box::new(InMemoryVectorStore::new(config)))
582    }
583}
584
585// ============================================================================
586// File-Persisted Vector Store Implementation
587// ============================================================================
588
589/// 可序列化的向量项(用于持久化)
590#[derive(Debug, Clone, Serialize, Deserialize)]
591struct SerializableVectorItem {
592    id: String,
593    vector: Vec<f32>,
594    metadata: serde_json::Map<String, serde_json::Value>,
595    content: Option<String>,
596}
597
598/// 可序列化的存储数据
599#[derive(Debug, Clone, Serialize, Deserialize)]
600struct StoreData {
601    items: Vec<SerializableVectorItem>,
602    config: SerializableConfig,
603}
604
605#[derive(Debug, Clone, Serialize, Deserialize)]
606struct SerializableConfig {
607    dimension: usize,
608    metric: String,
609}
610
611/// 文件持久化向量存储
612///
613/// 将向量数据持久化到本地文件,支持应用重启后恢复。
614/// 使用 JSON 格式存储,适合中小规模数据。
615pub struct FileVectorStore {
616    /// 内存存储(实际数据)
617    inner: InMemoryVectorStore,
618    /// 存储路径
619    path: PathBuf,
620    /// 是否自动持久化
621    auto_persist: bool,
622}
623
624impl FileVectorStore {
625    /// 创建新的文件向量存储
626    pub fn new(config: VectorStoreConfig) -> Layer3Result<Self> {
627        let path = config
628            .path
629            .as_ref()
630            .map(PathBuf::from)
631            .unwrap_or_else(|| PathBuf::from("vector_store.json"));
632
633        let inner = InMemoryVectorStore::new(config);
634        let store = Self {
635            inner,
636            path,
637            auto_persist: true,
638        };
639
640        Ok(store)
641    }
642
643    /// 创建带自动持久化开关的存储
644    pub fn with_auto_persist(mut self, auto_persist: bool) -> Self {
645        self.auto_persist = auto_persist;
646        self
647    }
648
649    /// 持久化数据到文件
650    #[instrument(skip(self))]
651    pub fn persist_sync(&self) -> Layer3Result<()> {
652        let data = self.inner.data.read();
653
654        let items: Vec<SerializableVectorItem> = data
655            .values()
656            .map(|item| SerializableVectorItem {
657                id: item.id.clone(),
658                vector: item.vector.clone(),
659                metadata: item.metadata.clone().into_iter().collect(),
660                content: item.content.clone(),
661            })
662            .collect();
663
664        let config = SerializableConfig {
665            dimension: self.inner.config.dimension,
666            metric: format!("{:?}", self.inner.config.metric),
667        };
668
669        let store_data = StoreData { items, config };
670
671        let json = serde_json::to_string_pretty(&store_data)?;
672        std::fs::write(&self.path, json)?;
673
674        info!("Persisted {} vectors to {:?}", data.len(), self.path);
675        Ok(())
676    }
677
678    /// 从文件加载数据
679    #[instrument(skip(self))]
680    pub fn load_sync(&self) -> Layer3Result<()> {
681        if !self.path.exists() {
682            debug!("No existing store file at {:?}", self.path);
683            return Ok(());
684        }
685
686        let json = std::fs::read_to_string(&self.path)?;
687        let store_data: StoreData = serde_json::from_str(&json)?;
688
689        let mut data = self.inner.data.write();
690        data.clear();
691
692        for item in store_data.items {
693            let vector_item = VectorItem {
694                id: item.id,
695                vector: item.vector,
696                metadata: item.metadata.into_iter().collect(),
697                content: item.content,
698            };
699            data.insert(vector_item.id.clone(), vector_item);
700        }
701
702        info!("Loaded {} vectors from {:?}", data.len(), self.path);
703        Ok(())
704    }
705}
706
707#[async_trait]
708impl VectorStore for FileVectorStore {
709    async fn add(
710        &self,
711        id: String,
712        vector: Vec<f32>,
713        metadata: HashMap<String, serde_json::Value>,
714    ) -> Layer3Result<bool> {
715        let result = self.inner.add(id, vector, metadata).await?;
716
717        if self.auto_persist && result {
718            self.persist_sync()?;
719        }
720
721        Ok(result)
722    }
723
724    async fn add_batch(&self, items: Vec<VectorItem>) -> Layer3Result<Vec<bool>> {
725        let results = self.inner.add_batch(items).await?;
726
727        if self.auto_persist && results.iter().any(|&r| r) {
728            self.persist_sync()?;
729        }
730
731        Ok(results)
732    }
733
734    async fn query(&self, vector: Vec<f32>, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
735        self.inner.query(vector, top_k).await
736    }
737
738    async fn query_with_filter(
739        &self,
740        vector: Vec<f32>,
741        top_k: usize,
742        filter: Option<MetadataFilter>,
743    ) -> Layer3Result<Vec<RetrievalResult>> {
744        self.inner.query_with_filter(vector, top_k, filter).await
745    }
746
747    async fn delete(&self, id: &str) -> Layer3Result<bool> {
748        let result = self.inner.delete(id).await?;
749
750        if self.auto_persist && result {
751            self.persist_sync()?;
752        }
753
754        Ok(result)
755    }
756
757    async fn delete_batch(&self, ids: &[String]) -> Layer3Result<usize> {
758        let count = self.inner.delete_batch(ids).await?;
759
760        if self.auto_persist && count > 0 {
761            self.persist_sync()?;
762        }
763
764        Ok(count)
765    }
766
767    async fn get(&self, id: &str) -> Layer3Result<Option<VectorItem>> {
768        self.inner.get(id).await
769    }
770
771    async fn count(&self) -> Layer3Result<usize> {
772        self.inner.count().await
773    }
774
775    async fn clear(&self) -> Layer3Result<bool> {
776        let result = self.inner.clear().await?;
777
778        if self.auto_persist && result {
779            self.persist_sync()?;
780        }
781
782        Ok(result)
783    }
784
785    async fn persist(&self) -> Layer3Result<()> {
786        self.persist_sync()
787    }
788
789    async fn load(&self) -> Layer3Result<()> {
790        self.load_sync()
791    }
792}
793
794/// 文件向量存储工厂
795pub struct FileVectorStoreFactory;
796
797impl VectorStoreFactory for FileVectorStoreFactory {
798    fn create(&self, config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
799        Ok(Box::new(FileVectorStore::new(config)?))
800    }
801}
802
803#[cfg(test)]
804mod tests {
805    use super::*;
806
807    #[test]
808    fn test_vector_item_builder() {
809        let item = VectorItem::new("test", vec![1.0, 2.0, 3.0]).with_content("test content");
810        assert_eq!(item.content, Some("test content".to_string()));
811    }
812
813    #[test]
814    fn test_vector_store_config_default() {
815        let config = VectorStoreConfig::default();
816        assert_eq!(config.dimension, 1536);
817        assert_eq!(config.metric, DistanceMetric::Cosine);
818    }
819
820    #[tokio::test]
821    async fn test_in_memory_vector_store_add() {
822        let store = InMemoryVectorStore::in_memory();
823        let result = store
824            .add("id1".to_string(), vec![1.0, 2.0, 3.0], HashMap::new())
825            .await;
826        assert!(result.is_ok());
827        assert_eq!(store.count().await.unwrap(), 1);
828    }
829
830    #[tokio::test]
831    async fn test_in_memory_vector_store_query() {
832        let store = InMemoryVectorStore::in_memory();
833
834        // 添加测试向量
835        let mut metadata = HashMap::new();
836        metadata.insert("source".to_string(), serde_json::json!("test.txt"));
837
838        store
839            .add("id1".to_string(), vec![1.0, 0.0, 0.0], metadata.clone())
840            .await
841            .unwrap();
842        store
843            .add("id2".to_string(), vec![0.9, 0.1, 0.0], HashMap::new())
844            .await
845            .unwrap();
846        store
847            .add("id3".to_string(), vec![0.0, 1.0, 0.0], HashMap::new())
848            .await
849            .unwrap();
850
851        // 查询相似向量
852        let results = store.query(vec![1.0, 0.0, 0.0], 2).await.unwrap();
853        assert_eq!(results.len(), 2);
854        assert!(results[0].score > results[1].score);
855    }
856
857    #[tokio::test]
858    async fn test_in_memory_vector_store_delete() {
859        let store = InMemoryVectorStore::in_memory();
860        store
861            .add("id1".to_string(), vec![1.0, 2.0, 3.0], HashMap::new())
862            .await
863            .unwrap();
864
865        let deleted = store.delete("id1").await.unwrap();
866        assert!(deleted);
867        assert_eq!(store.count().await.unwrap(), 0);
868    }
869
870    #[test]
871    fn test_cosine_similarity() {
872        let store = InMemoryVectorStore::new(VectorStoreConfig {
873            metric: DistanceMetric::Cosine,
874            ..Default::default()
875        });
876
877        // 相同向量
878        let sim = store.compute_similarity(&[1.0, 0.0], &[1.0, 0.0]);
879        assert!((sim - 1.0).abs() < 0.001);
880
881        // 正交向量
882        let sim = store.compute_similarity(&[1.0, 0.0], &[0.0, 1.0]);
883        assert!((sim - 0.0).abs() < 0.001);
884    }
885
886    #[test]
887    fn test_metadata_filter() {
888        let mut metadata = HashMap::new();
889        metadata.insert("type".to_string(), serde_json::json!("doc"));
890        metadata.insert("lang".to_string(), serde_json::json!("en"));
891
892        // 测试 must 条件
893        let filter = MetadataFilter::new().must("type", serde_json::json!("doc"));
894        assert!(filter.matches(&metadata));
895
896        let filter = MetadataFilter::new().must("type", serde_json::json!("code"));
897        assert!(!filter.matches(&metadata));
898
899        // 测试 must_not 条件
900        let filter = MetadataFilter::new().must_not("type", serde_json::json!("code"));
901        assert!(filter.matches(&metadata));
902
903        // 测试 should 条件(需要至少一个匹配)
904        // 注意: HashMap 不能有重复键,所以用不同键测试
905        let filter = MetadataFilter::new()
906            .should("type", serde_json::json!("doc"))
907            .should("lang", serde_json::json!("zh"));
908        assert!(filter.matches(&metadata)); // type=doc 匹配
909
910        // should 条件不匹配的情况
911        let filter = MetadataFilter::new()
912            .should("type", serde_json::json!("code"))
913            .should("lang", serde_json::json!("zh"));
914        assert!(!filter.matches(&metadata)); // type!=code 且 lang!=zh
915    }
916
917    #[tokio::test]
918    async fn test_file_vector_store() {
919        use tempfile::TempDir;
920
921        // 使用临时目录,避免文件被删除太快
922        let temp_dir = TempDir::new().unwrap();
923        let path = temp_dir.path().join("vector_store.json");
924        let path_str = path.to_str().unwrap().to_string();
925
926        let config = VectorStoreConfig {
927            path: Some(path_str.clone()),
928            dimension: 128,
929            metric: DistanceMetric::Cosine,
930            index_type: IndexType::Flat,
931        };
932
933        let store = FileVectorStore::new(config).unwrap();
934
935        // 添加向量
936        let vector = vec![1.0; 128];
937        store
938            .add("id1".to_string(), vector, HashMap::new())
939            .await
940            .unwrap();
941        assert_eq!(store.count().await.unwrap(), 1);
942
943        // 持久化
944        store.persist().await.unwrap();
945
946        // 验证文件存在
947        assert!(path.exists());
948
949        // 创建新的 store 实例并加载
950        let config2 = VectorStoreConfig {
951            path: Some(path_str),
952            dimension: 128,
953            metric: DistanceMetric::Cosine,
954            index_type: IndexType::Flat,
955        };
956        let store2 = FileVectorStore::new(config2).unwrap();
957        store2.load().await.unwrap();
958        assert_eq!(store2.count().await.unwrap(), 1);
959
960        // 验证内容
961        let item = store2.get("id1").await.unwrap();
962        assert!(item.is_some());
963    }
964}
965
966// ============================================================================
967// Remote Vector Store Implementations (Pinecone, Chroma, Qdrant)
968// ============================================================================
969
970/// 远程向量存储配置
971#[derive(Debug, Clone)]
972pub struct RemoteVectorStoreConfig {
973    /// API 密钥
974    pub api_key: String,
975    /// API 端点 URL
976    pub endpoint: String,
977    /// 集合/索引名称
978    pub collection: String,
979    /// 向量维度
980    pub dimension: usize,
981    /// 距离度量
982    pub metric: DistanceMetric,
983    /// 连接池大小
984    pub pool_size: usize,
985    /// 请求超时(秒)
986    pub timeout_secs: u64,
987}
988
989impl RemoteVectorStoreConfig {
990    /// 从环境变量创建 Pinecone 配置
991    pub fn pinecone_from_env() -> Layer3Result<Self> {
992        let api_key = std::env::var("PINECONE_API_KEY")
993            .map_err(|_| anyhow::anyhow!("PINECONE_API_KEY not set"))?;
994        let endpoint = std::env::var("PINECONE_ENDPOINT")
995            .map_err(|_| anyhow::anyhow!("PINECONE_ENDPOINT not set"))?;
996        let collection = std::env::var("PINECONE_INDEX").unwrap_or_else(|_| "default".to_string());
997
998        Ok(Self {
999            api_key,
1000            endpoint,
1001            collection,
1002            dimension: 1536,
1003            metric: DistanceMetric::Cosine,
1004            pool_size: 10,
1005            timeout_secs: 30,
1006        })
1007    }
1008
1009    /// 从环境变量创建 Chroma 配置
1010    pub fn chroma_from_env() -> Layer3Result<Self> {
1011        let endpoint = std::env::var("CHROMA_ENDPOINT")
1012            .unwrap_or_else(|_| "http://localhost:8000".to_string());
1013        let collection =
1014            std::env::var("CHROMA_COLLECTION").unwrap_or_else(|_| "default".to_string());
1015        let api_key = std::env::var("CHROMA_API_KEY").unwrap_or_default();
1016
1017        Ok(Self {
1018            api_key,
1019            endpoint,
1020            collection,
1021            dimension: 1536,
1022            metric: DistanceMetric::Cosine,
1023            pool_size: 10,
1024            timeout_secs: 30,
1025        })
1026    }
1027
1028    /// 从环境变量创建 Qdrant 配置
1029    pub fn qdrant_from_env() -> Layer3Result<Self> {
1030        let endpoint = std::env::var("QDRANT_ENDPOINT")
1031            .unwrap_or_else(|_| "http://localhost:6333".to_string());
1032        let collection =
1033            std::env::var("QDRANT_COLLECTION").unwrap_or_else(|_| "default".to_string());
1034        let api_key = std::env::var("QDRANT_API_KEY").unwrap_or_default();
1035
1036        Ok(Self {
1037            api_key,
1038            endpoint,
1039            collection,
1040            dimension: 1536,
1041            metric: DistanceMetric::Cosine,
1042            pool_size: 10,
1043            timeout_secs: 30,
1044        })
1045    }
1046}
1047
1048// ============================================================================
1049// Pinecone Implementation
1050// ============================================================================
1051
1052/// Pinecone 向量存储
1053///
1054/// 使用 Pinecone 云服务进行向量存储和检索。
1055pub struct PineconeVectorStore {
1056    client: reqwest::Client,
1057    config: RemoteVectorStoreConfig,
1058}
1059
1060impl PineconeVectorStore {
1061    pub fn new(config: RemoteVectorStoreConfig) -> Layer3Result<Self> {
1062        let client = reqwest::Client::builder()
1063            .timeout(std::time::Duration::from_secs(config.timeout_secs))
1064            .pool_max_idle_per_host(config.pool_size)
1065            .build()
1066            .map_err(|e| anyhow::anyhow!("Failed to create client: {}", e))?;
1067
1068        Ok(Self { client, config })
1069    }
1070
1071    fn build_url(&self, path: &str) -> String {
1072        format!("{}/vectors/{}", self.config.endpoint, path)
1073    }
1074}
1075
1076#[async_trait]
1077impl VectorStore for PineconeVectorStore {
1078    async fn add(
1079        &self,
1080        id: String,
1081        vector: Vec<f32>,
1082        metadata: HashMap<String, serde_json::Value>,
1083    ) -> Layer3Result<bool> {
1084        self.add_batch(vec![VectorItem {
1085            id,
1086            vector,
1087            metadata,
1088            content: None,
1089        }])
1090        .await?;
1091        Ok(true)
1092    }
1093
1094    async fn add_batch(&self, items: Vec<VectorItem>) -> Layer3Result<Vec<bool>> {
1095        if items.is_empty() {
1096            return Ok(Vec::new());
1097        }
1098
1099        let vectors: Vec<serde_json::Value> = items
1100            .iter()
1101            .map(|item| {
1102                serde_json::json!({
1103                    "id": item.id,
1104                    "values": item.vector,
1105                    "metadata": item.metadata,
1106                })
1107            })
1108            .collect();
1109
1110        let body = serde_json::json!({
1111            "vectors": vectors,
1112            "namespace": self.config.collection,
1113        });
1114
1115        let response = self
1116            .client
1117            .post(self.build_url("upsert"))
1118            .header("Api-Key", &self.config.api_key)
1119            .header("Content-Type", "application/json")
1120            .json(&body)
1121            .send()
1122            .await
1123            .map_err(|e| anyhow::anyhow!("Pinecone request failed: {}", e))?;
1124
1125        if !response.status().is_success() {
1126            let status = response.status();
1127            let text = response.text().await.unwrap_or_default();
1128            return Err(anyhow::anyhow!(
1129                "Pinecone upsert failed: {} - {}",
1130                status,
1131                text
1132            ));
1133        }
1134
1135        Ok(items.iter().map(|_| true).collect())
1136    }
1137
1138    async fn query(&self, vector: Vec<f32>, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
1139        let body = serde_json::json!({
1140            "vector": vector,
1141            "topK": top_k,
1142            "namespace": self.config.collection,
1143            "includeMetadata": true,
1144            "includeValues": false,
1145        });
1146
1147        let response = self
1148            .client
1149            .post(self.build_url("query"))
1150            .header("Api-Key", &self.config.api_key)
1151            .header("Content-Type", "application/json")
1152            .json(&body)
1153            .send()
1154            .await
1155            .map_err(|e| anyhow::anyhow!("Pinecone query failed: {}", e))?;
1156
1157        if !response.status().is_success() {
1158            let status = response.status();
1159            let text = response.text().await.unwrap_or_default();
1160            return Err(anyhow::anyhow!(
1161                "Pinecone query failed: {} - {}",
1162                status,
1163                text
1164            ));
1165        }
1166
1167        let json: serde_json::Value = response
1168            .json()
1169            .await
1170            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
1171
1172        let results = json["matches"]
1173            .as_array()
1174            .map(|arr| {
1175                arr.iter()
1176                    .filter_map(|m| {
1177                        let doc_id = m["id"].as_str()?.to_string();
1178                        let score = m["score"].as_f64()? as f32;
1179                        let metadata: HashMap<String, serde_json::Value> = m["metadata"]
1180                            .as_object()
1181                            .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
1182                            .unwrap_or_default();
1183                        let content = metadata
1184                            .get("content")
1185                            .and_then(|v| v.as_str())
1186                            .map(String::from)
1187                            .unwrap_or_default();
1188                        let source = metadata
1189                            .get("source")
1190                            .and_then(|v| v.as_str())
1191                            .map(String::from);
1192
1193                        Some(RetrievalResult {
1194                            doc_id,
1195                            content,
1196                            score,
1197                            metadata,
1198                            source,
1199                        })
1200                    })
1201                    .collect()
1202            })
1203            .unwrap_or_default();
1204
1205        Ok(results)
1206    }
1207
1208    async fn delete(&self, id: &str) -> Layer3Result<bool> {
1209        let body = serde_json::json!({
1210            "ids": [id],
1211            "namespace": self.config.collection,
1212        });
1213
1214        let response = self
1215            .client
1216            .post(self.build_url("delete"))
1217            .header("Api-Key", &self.config.api_key)
1218            .header("Content-Type", "application/json")
1219            .json(&body)
1220            .send()
1221            .await
1222            .map_err(|e| anyhow::anyhow!("Pinecone delete failed: {}", e))?;
1223
1224        Ok(response.status().is_success())
1225    }
1226
1227    async fn delete_batch(&self, ids: &[String]) -> Layer3Result<usize> {
1228        let body = serde_json::json!({
1229            "ids": ids,
1230            "namespace": self.config.collection,
1231        });
1232
1233        let response = self
1234            .client
1235            .post(self.build_url("delete"))
1236            .header("Api-Key", &self.config.api_key)
1237            .header("Content-Type", "application/json")
1238            .json(&body)
1239            .send()
1240            .await
1241            .map_err(|e| anyhow::anyhow!("Pinecone delete failed: {}", e))?;
1242
1243        if response.status().is_success() {
1244            Ok(ids.len())
1245        } else {
1246            Ok(0)
1247        }
1248    }
1249
1250    async fn get(&self, id: &str) -> Layer3Result<Option<VectorItem>> {
1251        let body = serde_json::json!({
1252            "ids": [id],
1253            "namespace": self.config.collection,
1254        });
1255
1256        let response = self
1257            .client
1258            .post(self.build_url("fetch"))
1259            .header("Api-Key", &self.config.api_key)
1260            .header("Content-Type", "application/json")
1261            .json(&body)
1262            .send()
1263            .await
1264            .map_err(|e| anyhow::anyhow!("Pinecone fetch failed: {}", e))?;
1265
1266        if !response.status().is_success() {
1267            return Ok(None);
1268        }
1269
1270        let json: serde_json::Value = response
1271            .json()
1272            .await
1273            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
1274
1275        if let Some(vectors) = json["vectors"].as_object() {
1276            if let Some(v) = vectors.get(id) {
1277                let vector = v["values"]
1278                    .as_array()
1279                    .map(|arr| {
1280                        arr.iter()
1281                            .filter_map(|x| x.as_f64().map(|f| f as f32))
1282                            .collect()
1283                    })
1284                    .unwrap_or_default();
1285                let metadata = v["metadata"]
1286                    .as_object()
1287                    .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
1288                    .unwrap_or_default();
1289
1290                return Ok(Some(VectorItem {
1291                    id: id.to_string(),
1292                    vector,
1293                    metadata,
1294                    content: None,
1295                }));
1296            }
1297        }
1298
1299        Ok(None)
1300    }
1301
1302    async fn count(&self) -> Layer3Result<usize> {
1303        let body = serde_json::json!({
1304            "namespace": self.config.collection,
1305        });
1306
1307        let response = self
1308            .client
1309            .post(self.build_url("describeIndexStats"))
1310            .header("Api-Key", &self.config.api_key)
1311            .header("Content-Type", "application/json")
1312            .json(&body)
1313            .send()
1314            .await
1315            .map_err(|e| anyhow::anyhow!("Pinecone stats failed: {}", e))?;
1316
1317        if !response.status().is_success() {
1318            return Ok(0);
1319        }
1320
1321        let json: serde_json::Value = response.json().await.unwrap_or_default();
1322        let count = json["dimension"]["totalVectorCount"].as_u64().unwrap_or(0) as usize;
1323        Ok(count)
1324    }
1325
1326    async fn clear(&self) -> Layer3Result<bool> {
1327        let body = serde_json::json!({
1328            "deleteAll": true,
1329            "namespace": self.config.collection,
1330        });
1331
1332        let response = self
1333            .client
1334            .post(self.build_url("delete"))
1335            .header("Api-Key", &self.config.api_key)
1336            .header("Content-Type", "application/json")
1337            .json(&body)
1338            .send()
1339            .await
1340            .map_err(|e| anyhow::anyhow!("Pinecone clear failed: {}", e))?;
1341
1342        Ok(response.status().is_success())
1343    }
1344}
1345
1346/// Pinecone 向量存储工厂
1347pub struct PineconeVectorStoreFactory;
1348
1349impl VectorStoreFactory for PineconeVectorStoreFactory {
1350    fn create(&self, _config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
1351        let remote_config = RemoteVectorStoreConfig::pinecone_from_env()?;
1352        Ok(Box::new(PineconeVectorStore::new(remote_config)?))
1353    }
1354}
1355
1356// ============================================================================
1357// Chroma Implementation
1358// ============================================================================
1359
1360/// Chroma 向量存储
1361///
1362/// 使用 Chroma 本地或云服务进行向量存储和检索。
1363pub struct ChromaVectorStore {
1364    client: reqwest::Client,
1365    config: RemoteVectorStoreConfig,
1366}
1367
1368impl ChromaVectorStore {
1369    pub fn new(config: RemoteVectorStoreConfig) -> Layer3Result<Self> {
1370        let client = reqwest::Client::builder()
1371            .timeout(std::time::Duration::from_secs(config.timeout_secs))
1372            .pool_max_idle_per_host(config.pool_size)
1373            .build()
1374            .map_err(|e| anyhow::anyhow!("Failed to create client: {}", e))?;
1375
1376        Ok(Self { client, config })
1377    }
1378
1379    fn build_url(&self, path: &str) -> String {
1380        format!("{}/api/v1{}", self.config.endpoint, path)
1381    }
1382
1383    async fn ensure_collection(&self) -> Layer3Result<()> {
1384        let body = serde_json::json!({
1385            "name": self.config.collection,
1386        });
1387
1388        let _ = self
1389            .client
1390            .post(self.build_url("/collections"))
1391            .header("Content-Type", "application/json")
1392            .json(&body)
1393            .send()
1394            .await;
1395
1396        Ok(())
1397    }
1398}
1399
1400#[async_trait]
1401impl VectorStore for ChromaVectorStore {
1402    async fn add(
1403        &self,
1404        id: String,
1405        vector: Vec<f32>,
1406        metadata: HashMap<String, serde_json::Value>,
1407    ) -> Layer3Result<bool> {
1408        self.add_batch(vec![VectorItem {
1409            id,
1410            vector,
1411            metadata,
1412            content: None,
1413        }])
1414        .await?;
1415        Ok(true)
1416    }
1417
1418    async fn add_batch(&self, items: Vec<VectorItem>) -> Layer3Result<Vec<bool>> {
1419        if items.is_empty() {
1420            return Ok(Vec::new());
1421        }
1422
1423        self.ensure_collection().await?;
1424
1425        let ids: Vec<String> = items.iter().map(|i| i.id.clone()).collect();
1426        let vectors: Vec<Vec<f32>> = items.iter().map(|i| i.vector.clone()).collect();
1427        let metadatas: Vec<HashMap<String, serde_json::Value>> =
1428            items.iter().map(|i| i.metadata.clone()).collect();
1429
1430        let body = serde_json::json!({
1431            "ids": ids,
1432            "embeddings": vectors,
1433            "metadatas": metadatas,
1434        });
1435
1436        let url = self.build_url(&format!("/collections/{}/add", self.config.collection));
1437        let response = self
1438            .client
1439            .post(url)
1440            .header("Content-Type", "application/json")
1441            .json(&body)
1442            .send()
1443            .await
1444            .map_err(|e| anyhow::anyhow!("Chroma add failed: {}", e))?;
1445
1446        if !response.status().is_success() {
1447            let status = response.status();
1448            let text = response.text().await.unwrap_or_default();
1449            return Err(anyhow::anyhow!("Chroma add failed: {} - {}", status, text));
1450        }
1451
1452        Ok(items.iter().map(|_| true).collect())
1453    }
1454
1455    async fn query(&self, vector: Vec<f32>, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
1456        let body = serde_json::json!({
1457            "query_embeddings": [vector],
1458            "n_results": top_k,
1459            "include": ["metadatas", "documents", "distances"],
1460        });
1461
1462        let url = self.build_url(&format!("/collections/{}/query", self.config.collection));
1463        let response = self
1464            .client
1465            .post(url)
1466            .header("Content-Type", "application/json")
1467            .json(&body)
1468            .send()
1469            .await
1470            .map_err(|e| anyhow::anyhow!("Chroma query failed: {}", e))?;
1471
1472        if !response.status().is_success() {
1473            let status = response.status();
1474            let text = response.text().await.unwrap_or_default();
1475            return Err(anyhow::anyhow!(
1476                "Chroma query failed: {} - {}",
1477                status,
1478                text
1479            ));
1480        }
1481
1482        let json: serde_json::Value = response
1483            .json()
1484            .await
1485            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
1486
1487        let ids = json["ids"][0].as_array().cloned().unwrap_or_default();
1488        let distances = json["distances"][0].as_array().cloned().unwrap_or_default();
1489        let metadatas = json["metadatas"][0].as_array().cloned().unwrap_or_default();
1490        let documents = json["documents"][0].as_array().cloned().unwrap_or_default();
1491
1492        let results: Vec<RetrievalResult> = ids
1493            .iter()
1494            .enumerate()
1495            .filter_map(|(i, id)| {
1496                let doc_id = id.as_str()?.to_string();
1497                let distance = distances.get(i)?.as_f64()? as f32;
1498                let score = 1.0 / (1.0 + distance); // Convert distance to similarity
1499                let metadata: HashMap<String, serde_json::Value> = metadatas
1500                    .get(i)?
1501                    .as_object()
1502                    .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
1503                    .unwrap_or_default();
1504                let content = documents
1505                    .get(i)?
1506                    .as_str()
1507                    .map(String::from)
1508                    .unwrap_or_default();
1509                let source = metadata
1510                    .get("source")
1511                    .and_then(|v| v.as_str())
1512                    .map(String::from);
1513
1514                Some(RetrievalResult {
1515                    doc_id,
1516                    content,
1517                    score,
1518                    metadata,
1519                    source,
1520                })
1521            })
1522            .collect();
1523
1524        Ok(results)
1525    }
1526
1527    async fn delete(&self, id: &str) -> Layer3Result<bool> {
1528        let body = serde_json::json!({
1529            "ids": [id],
1530        });
1531
1532        let url = self.build_url(&format!("/collections/{}/delete", self.config.collection));
1533        let response = self
1534            .client
1535            .post(url)
1536            .header("Content-Type", "application/json")
1537            .json(&body)
1538            .send()
1539            .await
1540            .map_err(|e| anyhow::anyhow!("Chroma delete failed: {}", e))?;
1541
1542        Ok(response.status().is_success())
1543    }
1544
1545    async fn delete_batch(&self, ids: &[String]) -> Layer3Result<usize> {
1546        let body = serde_json::json!({
1547            "ids": ids,
1548        });
1549
1550        let url = self.build_url(&format!("/collections/{}/delete", self.config.collection));
1551        let response = self
1552            .client
1553            .post(url)
1554            .header("Content-Type", "application/json")
1555            .json(&body)
1556            .send()
1557            .await
1558            .map_err(|e| anyhow::anyhow!("Chroma delete failed: {}", e))?;
1559
1560        if response.status().is_success() {
1561            Ok(ids.len())
1562        } else {
1563            Ok(0)
1564        }
1565    }
1566
1567    async fn get(&self, id: &str) -> Layer3Result<Option<VectorItem>> {
1568        let body = serde_json::json!({
1569            "ids": [id],
1570            "include": ["embeddings", "metadatas"],
1571        });
1572
1573        let url = self.build_url(&format!("/collections/{}/get", self.config.collection));
1574        let response = self
1575            .client
1576            .post(url)
1577            .header("Content-Type", "application/json")
1578            .json(&body)
1579            .send()
1580            .await
1581            .map_err(|e| anyhow::anyhow!("Chroma get failed: {}", e))?;
1582
1583        if !response.status().is_success() {
1584            return Ok(None);
1585        }
1586
1587        let json: serde_json::Value = response
1588            .json()
1589            .await
1590            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
1591
1592        if let Some(ids) = json["ids"].as_array() {
1593            if !ids.is_empty() {
1594                let vector = json["embeddings"]
1595                    .as_array()
1596                    .and_then(|arr| arr.first())
1597                    .and_then(|arr| {
1598                        arr.as_array().map(|a| {
1599                            a.iter()
1600                                .filter_map(|x| x.as_f64().map(|f| f as f32))
1601                                .collect()
1602                        })
1603                    })
1604                    .unwrap_or_default();
1605                let metadata = json["metadatas"]
1606                    .as_array()
1607                    .and_then(|arr| arr.first())
1608                    .and_then(|obj| {
1609                        obj.as_object()
1610                            .map(|o| o.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
1611                    })
1612                    .unwrap_or_default();
1613
1614                return Ok(Some(VectorItem {
1615                    id: id.to_string(),
1616                    vector,
1617                    metadata,
1618                    content: None,
1619                }));
1620            }
1621        }
1622
1623        Ok(None)
1624    }
1625
1626    async fn count(&self) -> Layer3Result<usize> {
1627        let url = self.build_url(&format!("/collections/{}/count", self.config.collection));
1628        let response = self
1629            .client
1630            .get(url)
1631            .send()
1632            .await
1633            .map_err(|e| anyhow::anyhow!("Chroma count failed: {}", e))?;
1634
1635        if !response.status().is_success() {
1636            return Ok(0);
1637        }
1638
1639        let json: serde_json::Value = response.json().await.unwrap_or_default();
1640        Ok(json.as_u64().unwrap_or(0) as usize)
1641    }
1642
1643    async fn clear(&self) -> Layer3Result<bool> {
1644        let url = self.build_url(&format!("/collections/{}", self.config.collection));
1645        let response = self
1646            .client
1647            .delete(url)
1648            .send()
1649            .await
1650            .map_err(|e| anyhow::anyhow!("Chroma clear failed: {}", e))?;
1651
1652        if response.status().is_success() {
1653            self.ensure_collection().await?;
1654            Ok(true)
1655        } else {
1656            Ok(false)
1657        }
1658    }
1659}
1660
1661/// Chroma 向量存储工厂
1662pub struct ChromaVectorStoreFactory;
1663
1664impl VectorStoreFactory for ChromaVectorStoreFactory {
1665    fn create(&self, _config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
1666        let remote_config = RemoteVectorStoreConfig::chroma_from_env()?;
1667        Ok(Box::new(ChromaVectorStore::new(remote_config)?))
1668    }
1669}
1670
1671// ============================================================================
1672// Qdrant Implementation
1673// ============================================================================
1674
1675/// Qdrant 向量存储
1676///
1677/// 使用 Qdrant 云服务或自托管实例进行向量存储和检索。
1678pub struct QdrantVectorStore {
1679    client: reqwest::Client,
1680    config: RemoteVectorStoreConfig,
1681}
1682
1683impl QdrantVectorStore {
1684    pub fn new(config: RemoteVectorStoreConfig) -> Layer3Result<Self> {
1685        let client = reqwest::Client::builder()
1686            .timeout(std::time::Duration::from_secs(config.timeout_secs))
1687            .pool_max_idle_per_host(config.pool_size)
1688            .build()
1689            .map_err(|e| anyhow::anyhow!("Failed to create client: {}", e))?;
1690
1691        Ok(Self { client, config })
1692    }
1693
1694    fn build_url(&self, path: &str) -> String {
1695        format!(
1696            "{}/collections/{}{}",
1697            self.config.endpoint, self.config.collection, path
1698        )
1699    }
1700
1701    async fn ensure_collection(&self) -> Layer3Result<()> {
1702        let url = format!(
1703            "{}/collections/{}",
1704            self.config.endpoint, self.config.collection
1705        );
1706
1707        // Check if collection exists
1708        let response = self
1709            .client
1710            .get(&url)
1711            .send()
1712            .await
1713            .map_err(|e| anyhow::anyhow!("Qdrant check failed: {}", e))?;
1714
1715        if response.status().as_u16() == 404 {
1716            // Create collection
1717            let body = serde_json::json!({
1718                "vectors": {
1719                    "size": self.config.dimension,
1720                    "distance": match self.config.metric {
1721                        DistanceMetric::Cosine => "Cosine",
1722                        DistanceMetric::Euclidean => "Euclid",
1723                        DistanceMetric::DotProduct => "Dot",
1724                        DistanceMetric::Manhattan => "Manhattan",
1725                    },
1726                },
1727            });
1728
1729            let _ = self
1730                .client
1731                .put(&url)
1732                .header("Content-Type", "application/json")
1733                .json(&body)
1734                .send()
1735                .await;
1736        }
1737
1738        Ok(())
1739    }
1740}
1741
1742#[async_trait]
1743impl VectorStore for QdrantVectorStore {
1744    async fn add(
1745        &self,
1746        id: String,
1747        vector: Vec<f32>,
1748        metadata: HashMap<String, serde_json::Value>,
1749    ) -> Layer3Result<bool> {
1750        self.add_batch(vec![VectorItem {
1751            id,
1752            vector,
1753            metadata,
1754            content: None,
1755        }])
1756        .await?;
1757        Ok(true)
1758    }
1759
1760    async fn add_batch(&self, items: Vec<VectorItem>) -> Layer3Result<Vec<bool>> {
1761        if items.is_empty() {
1762            return Ok(Vec::new());
1763        }
1764
1765        self.ensure_collection().await?;
1766
1767        let points: Vec<serde_json::Value> = items
1768            .iter()
1769            .map(|item| {
1770                serde_json::json!({
1771                    "id": item.id,
1772                    "vector": item.vector,
1773                    "payload": item.metadata,
1774                })
1775            })
1776            .collect();
1777
1778        let body = serde_json::json!({
1779            "points": points,
1780        });
1781
1782        let url = self.build_url("/points?wait=true");
1783        let mut request = self
1784            .client
1785            .put(&url)
1786            .header("Content-Type", "application/json")
1787            .json(&body);
1788
1789        if !self.config.api_key.is_empty() {
1790            request = request.header("api-key", &self.config.api_key);
1791        }
1792
1793        let response = request
1794            .send()
1795            .await
1796            .map_err(|e| anyhow::anyhow!("Qdrant upsert failed: {}", e))?;
1797
1798        if !response.status().is_success() {
1799            let status = response.status();
1800            let text = response.text().await.unwrap_or_default();
1801            return Err(anyhow::anyhow!(
1802                "Qdrant upsert failed: {} - {}",
1803                status,
1804                text
1805            ));
1806        }
1807
1808        Ok(items.iter().map(|_| true).collect())
1809    }
1810
1811    async fn query(&self, vector: Vec<f32>, top_k: usize) -> Layer3Result<Vec<RetrievalResult>> {
1812        self.ensure_collection().await?;
1813
1814        let body = serde_json::json!({
1815            "vector": vector,
1816            "limit": top_k,
1817            "with_payload": true,
1818        });
1819
1820        let url = self.build_url("/points/search");
1821        let mut request = self
1822            .client
1823            .post(&url)
1824            .header("Content-Type", "application/json")
1825            .json(&body);
1826
1827        if !self.config.api_key.is_empty() {
1828            request = request.header("api-key", &self.config.api_key);
1829        }
1830
1831        let response = request
1832            .send()
1833            .await
1834            .map_err(|e| anyhow::anyhow!("Qdrant search failed: {}", e))?;
1835
1836        if !response.status().is_success() {
1837            let status = response.status();
1838            let text = response.text().await.unwrap_or_default();
1839            return Err(anyhow::anyhow!(
1840                "Qdrant search failed: {} - {}",
1841                status,
1842                text
1843            ));
1844        }
1845
1846        let json: serde_json::Value = response
1847            .json()
1848            .await
1849            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
1850
1851        let results = json["result"]
1852            .as_array()
1853            .map(|arr| {
1854                arr.iter()
1855                    .filter_map(|r| {
1856                        let doc_id = r["id"].as_str()?.to_string();
1857                        let score = r["score"].as_f64()? as f32;
1858                        let metadata: HashMap<String, serde_json::Value> = r["payload"]
1859                            .as_object()
1860                            .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
1861                            .unwrap_or_default();
1862                        let content = metadata
1863                            .get("content")
1864                            .and_then(|v| v.as_str())
1865                            .map(String::from)
1866                            .unwrap_or_default();
1867                        let source = metadata
1868                            .get("source")
1869                            .and_then(|v| v.as_str())
1870                            .map(String::from);
1871
1872                        Some(RetrievalResult {
1873                            doc_id,
1874                            content,
1875                            score,
1876                            metadata,
1877                            source,
1878                        })
1879                    })
1880                    .collect()
1881            })
1882            .unwrap_or_default();
1883
1884        Ok(results)
1885    }
1886
1887    async fn delete(&self, id: &str) -> Layer3Result<bool> {
1888        let body = serde_json::json!({
1889            "points": [id],
1890        });
1891
1892        let url = self.build_url("/points/delete?wait=true");
1893        let mut request = self
1894            .client
1895            .post(&url)
1896            .header("Content-Type", "application/json")
1897            .json(&body);
1898
1899        if !self.config.api_key.is_empty() {
1900            request = request.header("api-key", &self.config.api_key);
1901        }
1902
1903        let response = request
1904            .send()
1905            .await
1906            .map_err(|e| anyhow::anyhow!("Qdrant delete failed: {}", e))?;
1907
1908        Ok(response.status().is_success())
1909    }
1910
1911    async fn delete_batch(&self, ids: &[String]) -> Layer3Result<usize> {
1912        let body = serde_json::json!({
1913            "points": ids,
1914        });
1915
1916        let url = self.build_url("/points/delete?wait=true");
1917        let mut request = self
1918            .client
1919            .post(&url)
1920            .header("Content-Type", "application/json")
1921            .json(&body);
1922
1923        if !self.config.api_key.is_empty() {
1924            request = request.header("api-key", &self.config.api_key);
1925        }
1926
1927        let response = request
1928            .send()
1929            .await
1930            .map_err(|e| anyhow::anyhow!("Qdrant delete failed: {}", e))?;
1931
1932        if response.status().is_success() {
1933            Ok(ids.len())
1934        } else {
1935            Ok(0)
1936        }
1937    }
1938
1939    async fn get(&self, id: &str) -> Layer3Result<Option<VectorItem>> {
1940        let body = serde_json::json!({
1941            "ids": [id],
1942            "with_vector": true,
1943            "with_payload": true,
1944        });
1945
1946        let url = self.build_url("/points");
1947        let mut request = self
1948            .client
1949            .post(&url)
1950            .header("Content-Type", "application/json")
1951            .json(&body);
1952
1953        if !self.config.api_key.is_empty() {
1954            request = request.header("api-key", &self.config.api_key);
1955        }
1956
1957        let response = request
1958            .send()
1959            .await
1960            .map_err(|e| anyhow::anyhow!("Qdrant get failed: {}", e))?;
1961
1962        if !response.status().is_success() {
1963            return Ok(None);
1964        }
1965
1966        let json: serde_json::Value = response
1967            .json()
1968            .await
1969            .map_err(|e| anyhow::anyhow!("Failed to parse response: {}", e))?;
1970
1971        if let Some(result) = json["result"].as_array() {
1972            if let Some(point) = result.first() {
1973                let vector = point["vector"]
1974                    .as_array()
1975                    .map(|arr| {
1976                        arr.iter()
1977                            .filter_map(|x| x.as_f64().map(|f| f as f32))
1978                            .collect()
1979                    })
1980                    .unwrap_or_default();
1981                let metadata = point["payload"]
1982                    .as_object()
1983                    .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
1984                    .unwrap_or_default();
1985
1986                return Ok(Some(VectorItem {
1987                    id: id.to_string(),
1988                    vector,
1989                    metadata,
1990                    content: None,
1991                }));
1992            }
1993        }
1994
1995        Ok(None)
1996    }
1997
1998    async fn count(&self) -> Layer3Result<usize> {
1999        let url = self.build_url("");
2000        let mut request = self.client.get(&url);
2001
2002        if !self.config.api_key.is_empty() {
2003            request = request.header("api-key", &self.config.api_key);
2004        }
2005
2006        let response = request
2007            .send()
2008            .await
2009            .map_err(|e| anyhow::anyhow!("Qdrant count failed: {}", e))?;
2010
2011        if !response.status().is_success() {
2012            return Ok(0);
2013        }
2014
2015        let json: serde_json::Value = response.json().await.unwrap_or_default();
2016        let count = json["result"]["points_count"].as_u64().unwrap_or(0) as usize;
2017        Ok(count)
2018    }
2019
2020    async fn clear(&self) -> Layer3Result<bool> {
2021        let url = self.build_url("/points/delete?wait=true");
2022        let body = serde_json::json!({
2023            "filter": {},
2024        });
2025
2026        let mut request = self
2027            .client
2028            .post(&url)
2029            .header("Content-Type", "application/json")
2030            .json(&body);
2031
2032        if !self.config.api_key.is_empty() {
2033            request = request.header("api-key", &self.config.api_key);
2034        }
2035
2036        let response = request
2037            .send()
2038            .await
2039            .map_err(|e| anyhow::anyhow!("Qdrant clear failed: {}", e))?;
2040
2041        Ok(response.status().is_success())
2042    }
2043}
2044
2045/// Qdrant 向量存储工厂
2046pub struct QdrantVectorStoreFactory;
2047
2048impl VectorStoreFactory for QdrantVectorStoreFactory {
2049    fn create(&self, _config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
2050        let remote_config = RemoteVectorStoreConfig::qdrant_from_env()?;
2051        Ok(Box::new(QdrantVectorStore::new(remote_config)?))
2052    }
2053}
2054
2055// ============================================================================
2056// Unified Vector Store Factory
2057// ============================================================================
2058
2059/// 统一向量存储工厂
2060pub struct UnifiedVectorStoreFactory {
2061    store_type: VectorStoreType,
2062}
2063
2064/// 向量存储类型
2065#[derive(Debug, Clone)]
2066pub enum VectorStoreType {
2067    InMemory,
2068    File,
2069    Pinecone,
2070    Chroma,
2071    Qdrant,
2072}
2073
2074impl UnifiedVectorStoreFactory {
2075    pub fn new(store_type: VectorStoreType) -> Self {
2076        Self { store_type }
2077    }
2078
2079    pub fn create(&self, config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
2080        match self.store_type {
2081            VectorStoreType::InMemory => Ok(Box::new(InMemoryVectorStore::new(config))),
2082            VectorStoreType::File => Ok(Box::new(FileVectorStore::new(config)?)),
2083            VectorStoreType::Pinecone => {
2084                let remote_config = RemoteVectorStoreConfig::pinecone_from_env()?;
2085                Ok(Box::new(PineconeVectorStore::new(remote_config)?))
2086            }
2087            VectorStoreType::Chroma => {
2088                let remote_config = RemoteVectorStoreConfig::chroma_from_env()?;
2089                Ok(Box::new(ChromaVectorStore::new(remote_config)?))
2090            }
2091            VectorStoreType::Qdrant => {
2092                let remote_config = RemoteVectorStoreConfig::qdrant_from_env()?;
2093                Ok(Box::new(QdrantVectorStore::new(remote_config)?))
2094            }
2095        }
2096    }
2097}
2098
2099impl VectorStoreFactory for UnifiedVectorStoreFactory {
2100    fn create(&self, config: VectorStoreConfig) -> Layer3Result<Box<dyn VectorStore>> {
2101        self.create(config)
2102    }
2103}