sql_rs/vector/
collection.rs1use super::{DistanceMetric, HnswIndex, SearchResult, Vector};
2use crate::{Result, SqlRsError};
3use std::collections::HashMap;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct VectorCollection {
8 name: String,
9 dimension: usize,
10 metric: DistanceMetric,
11 index: HnswIndex,
12 vectors: HashMap<String, Vector>,
13 id_mapping: HashMap<usize, String>,
14}
15
16impl VectorCollection {
17 pub fn new(name: impl Into<String>, dimension: usize, metric: DistanceMetric) -> Self {
18 Self {
19 name: name.into(),
20 dimension,
21 metric,
22 index: HnswIndex::new(metric),
23 vectors: HashMap::new(),
24 id_mapping: HashMap::new(),
25 }
26 }
27
28 pub fn add(&mut self, vector: Vector) -> Result<()> {
29 if vector.dimension() != self.dimension {
30 return Err(SqlRsError::Vector(format!(
31 "Expected dimension {}, got {}",
32 self.dimension,
33 vector.dimension()
34 )));
35 }
36
37 let internal_id = self.index.add(vector.embedding.clone());
38 self.id_mapping.insert(internal_id, vector.id.clone());
39 self.vectors.insert(vector.id.clone(), vector);
40
41 Ok(())
42 }
43
44 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
45 if query.len() != self.dimension {
46 return Err(SqlRsError::Vector(format!(
47 "Expected dimension {}, got {}",
48 self.dimension,
49 query.len()
50 )));
51 }
52
53 let results = self.index.search(query, k);
54
55 let mut search_results = Vec::new();
56 for (internal_id, distance) in results {
57 if let Some(vector_id) = self.id_mapping.get(&internal_id) {
58 if let Some(vector) = self.vectors.get(vector_id) {
59 search_results.push(SearchResult {
60 id: vector.id.clone(),
61 distance,
62 metadata: vector.metadata.clone(),
63 });
64 }
65 }
66 }
67
68 Ok(search_results)
69 }
70
71 pub fn get(&self, id: &str) -> Option<&Vector> {
72 self.vectors.get(id)
73 }
74
75 pub fn len(&self) -> usize {
76 self.vectors.len()
77 }
78
79 pub fn is_empty(&self) -> bool {
80 self.vectors.is_empty()
81 }
82
83 pub fn name(&self) -> &str {
84 &self.name
85 }
86
87 pub fn dimension(&self) -> usize {
88 self.dimension
89 }
90
91 pub fn metric(&self) -> DistanceMetric {
92 self.metric
93 }
94}