Skip to main content

ruvector_core/index/
flat.rs

1//! Flat (brute-force) index for baseline and small datasets
2
3use crate::distance::distance;
4use crate::error::Result;
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, SearchResult, VectorId};
7use dashmap::DashMap;
8
9#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
10use rayon::prelude::*;
11
12/// Flat index using brute-force search
13pub struct FlatIndex {
14    vectors: DashMap<VectorId, Vec<f32>>,
15    metric: DistanceMetric,
16    _dimensions: usize,
17}
18
19impl FlatIndex {
20    /// Create a new flat index
21    pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
22        Self {
23            vectors: DashMap::new(),
24            metric,
25            _dimensions: dimensions,
26        }
27    }
28}
29
30impl VectorIndex for FlatIndex {
31    fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
32        self.vectors.insert(id, vector);
33        Ok(())
34    }
35
36    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
37        if k == 0 {
38            return Ok(vec![]);
39        }
40
41        // Distance calculation - parallel on native, sequential on WASM
42        #[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
43        let mut results: Vec<_> = self
44            .vectors
45            .iter()
46            .par_bridge()
47            .map(|entry| {
48                let id = entry.key().clone();
49                let vector = entry.value();
50                let dist = distance(query, vector, self.metric)?;
51                Ok((id, dist))
52            })
53            .collect::<Result<Vec<_>>>()?;
54
55        #[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
56        let mut results: Vec<_> = self
57            .vectors
58            .iter()
59            .map(|entry| {
60                let id = entry.key().clone();
61                let vector = entry.value();
62                let dist = distance(query, vector, self.metric)?;
63                Ok((id, dist))
64            })
65            .collect::<Result<Vec<_>>>()?;
66
67        // Sort by distance (ascending — closest first) and take top k.
68        // Use sort_unstable_by for better performance on large result sets.
69        results.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
70        results.truncate(k);
71
72        Ok(results
73            .into_iter()
74            .map(|(id, score)| SearchResult {
75                id,
76                score,
77                vector: None,
78                metadata: None,
79            })
80            .collect())
81    }
82
83    fn remove(&mut self, id: &VectorId) -> Result<bool> {
84        Ok(self.vectors.remove(id).is_some())
85    }
86
87    fn len(&self) -> usize {
88        self.vectors.len()
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_flat_index() -> Result<()> {
98        let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
99
100        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
101        index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
102        index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
103
104        let query = vec![1.0, 0.0, 0.0];
105        let results = index.search(&query, 2)?;
106
107        assert_eq!(results.len(), 2);
108        assert_eq!(results[0].id, "v1");
109        assert!(results[0].score < 0.01);
110
111        Ok(())
112    }
113
114    #[test]
115    fn test_flat_index_k_zero() -> Result<()> {
116        let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
117        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
118
119        let results = index.search(&[1.0, 0.0, 0.0], 0)?;
120        assert!(results.is_empty(), "k=0 must return empty results");
121
122        Ok(())
123    }
124
125    #[test]
126    fn test_flat_index_results_sorted() -> Result<()> {
127        let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
128
129        // Insert vectors at various distances from origin
130        for i in 1usize..=10 {
131            index.add(format!("v{}", i), vec![i as f32, 0.0, 0.0])?;
132        }
133
134        let query = vec![0.0, 0.0, 0.0];
135        let results = index.search(&query, 5)?;
136
137        assert_eq!(results.len(), 5);
138        for window in results.windows(2) {
139            assert!(
140                window[0].score <= window[1].score,
141                "Results must be sorted ascending by distance"
142            );
143        }
144        // Closest is v1 (distance=1)
145        assert_eq!(results[0].id, "v1");
146
147        Ok(())
148    }
149}