ruvector_core/index/
flat.rs1use crate::distance::distance;
4use crate::error::Result;
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, SearchResult, VectorId};
7use dashmap::DashMap;
8
9#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
10use rayon::prelude::*;
11
12pub struct FlatIndex {
14 vectors: DashMap<VectorId, Vec<f32>>,
15 metric: DistanceMetric,
16 _dimensions: usize,
17}
18
19impl FlatIndex {
20 pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
22 Self {
23 vectors: DashMap::new(),
24 metric,
25 _dimensions: dimensions,
26 }
27 }
28}
29
30impl VectorIndex for FlatIndex {
31 fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
32 self.vectors.insert(id, vector);
33 Ok(())
34 }
35
36 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
37 if k == 0 {
38 return Ok(vec![]);
39 }
40
41 #[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
43 let mut results: Vec<_> = self
44 .vectors
45 .iter()
46 .par_bridge()
47 .map(|entry| {
48 let id = entry.key().clone();
49 let vector = entry.value();
50 let dist = distance(query, vector, self.metric)?;
51 Ok((id, dist))
52 })
53 .collect::<Result<Vec<_>>>()?;
54
55 #[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
56 let mut results: Vec<_> = self
57 .vectors
58 .iter()
59 .map(|entry| {
60 let id = entry.key().clone();
61 let vector = entry.value();
62 let dist = distance(query, vector, self.metric)?;
63 Ok((id, dist))
64 })
65 .collect::<Result<Vec<_>>>()?;
66
67 results.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
70 results.truncate(k);
71
72 Ok(results
73 .into_iter()
74 .map(|(id, score)| SearchResult {
75 id,
76 score,
77 vector: None,
78 metadata: None,
79 })
80 .collect())
81 }
82
83 fn remove(&mut self, id: &VectorId) -> Result<bool> {
84 Ok(self.vectors.remove(id).is_some())
85 }
86
87 fn len(&self) -> usize {
88 self.vectors.len()
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_flat_index() -> Result<()> {
98 let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
99
100 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
101 index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
102 index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
103
104 let query = vec![1.0, 0.0, 0.0];
105 let results = index.search(&query, 2)?;
106
107 assert_eq!(results.len(), 2);
108 assert_eq!(results[0].id, "v1");
109 assert!(results[0].score < 0.01);
110
111 Ok(())
112 }
113
114 #[test]
115 fn test_flat_index_k_zero() -> Result<()> {
116 let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
117 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
118
119 let results = index.search(&[1.0, 0.0, 0.0], 0)?;
120 assert!(results.is_empty(), "k=0 must return empty results");
121
122 Ok(())
123 }
124
125 #[test]
126 fn test_flat_index_results_sorted() -> Result<()> {
127 let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
128
129 for i in 1usize..=10 {
131 index.add(format!("v{}", i), vec![i as f32, 0.0, 0.0])?;
132 }
133
134 let query = vec![0.0, 0.0, 0.0];
135 let results = index.search(&query, 5)?;
136
137 assert_eq!(results.len(), 5);
138 for window in results.windows(2) {
139 assert!(
140 window[0].score <= window[1].score,
141 "Results must be sorted ascending by distance"
142 );
143 }
144 assert_eq!(results[0].id, "v1");
146
147 Ok(())
148 }
149}