ruvector_core/index/
flat.rs1use crate::distance::distance;
4use crate::error::Result;
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, SearchResult, VectorId};
7use dashmap::DashMap;
8
9#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
10use rayon::prelude::*;
11
12pub struct FlatIndex {
14 vectors: DashMap<VectorId, Vec<f32>>,
15 metric: DistanceMetric,
16 dimensions: usize,
17}
18
19impl FlatIndex {
20 pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
22 Self {
23 vectors: DashMap::new(),
24 metric,
25 dimensions,
26 }
27 }
28}
29
30impl VectorIndex for FlatIndex {
31 fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
32 self.vectors.insert(id, vector);
33 Ok(())
34 }
35
36 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
37 #[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
39 let mut results: Vec<_> = self
40 .vectors
41 .iter()
42 .par_bridge()
43 .map(|entry| {
44 let id = entry.key().clone();
45 let vector = entry.value();
46 let dist = distance(query, vector, self.metric)?;
47 Ok((id, dist))
48 })
49 .collect::<Result<Vec<_>>>()?;
50
51 #[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
52 let mut results: Vec<_> = self
53 .vectors
54 .iter()
55 .map(|entry| {
56 let id = entry.key().clone();
57 let vector = entry.value();
58 let dist = distance(query, vector, self.metric)?;
59 Ok((id, dist))
60 })
61 .collect::<Result<Vec<_>>>()?;
62
63 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
65 results.truncate(k);
66
67 Ok(results
68 .into_iter()
69 .map(|(id, score)| SearchResult {
70 id,
71 score,
72 vector: None,
73 metadata: None,
74 })
75 .collect())
76 }
77
78 fn remove(&mut self, id: &VectorId) -> Result<bool> {
79 Ok(self.vectors.remove(id).is_some())
80 }
81
82 fn len(&self) -> usize {
83 self.vectors.len()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn test_flat_index() -> Result<()> {
93 let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
94
95 index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
96 index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
97 index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
98
99 let query = vec![1.0, 0.0, 0.0];
100 let results = index.search(&query, 2)?;
101
102 assert_eq!(results.len(), 2);
103 assert_eq!(results[0].id, "v1");
104 assert!(results[0].score < 0.01);
105
106 Ok(())
107 }
108}