ruvector_core/index/
flat.rs

1//! Flat (brute-force) index for baseline and small datasets
2
3use crate::distance::distance;
4use crate::error::Result;
5use crate::index::VectorIndex;
6use crate::types::{DistanceMetric, SearchResult, VectorId};
7use dashmap::DashMap;
8
9#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
10use rayon::prelude::*;
11
12/// Flat index using brute-force search
13pub struct FlatIndex {
14    vectors: DashMap<VectorId, Vec<f32>>,
15    metric: DistanceMetric,
16    dimensions: usize,
17}
18
19impl FlatIndex {
20    /// Create a new flat index
21    pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
22        Self {
23            vectors: DashMap::new(),
24            metric,
25            dimensions,
26        }
27    }
28}
29
30impl VectorIndex for FlatIndex {
31    fn add(&mut self, id: VectorId, vector: Vec<f32>) -> Result<()> {
32        self.vectors.insert(id, vector);
33        Ok(())
34    }
35
36    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
37        // Distance calculation - parallel on native, sequential on WASM
38        #[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
39        let mut results: Vec<_> = self
40            .vectors
41            .iter()
42            .par_bridge()
43            .map(|entry| {
44                let id = entry.key().clone();
45                let vector = entry.value();
46                let dist = distance(query, vector, self.metric)?;
47                Ok((id, dist))
48            })
49            .collect::<Result<Vec<_>>>()?;
50
51        #[cfg(any(not(feature = "parallel"), target_arch = "wasm32"))]
52        let mut results: Vec<_> = self
53            .vectors
54            .iter()
55            .map(|entry| {
56                let id = entry.key().clone();
57                let vector = entry.value();
58                let dist = distance(query, vector, self.metric)?;
59                Ok((id, dist))
60            })
61            .collect::<Result<Vec<_>>>()?;
62
63        // Sort by distance and take top k
64        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
65        results.truncate(k);
66
67        Ok(results
68            .into_iter()
69            .map(|(id, score)| SearchResult {
70                id,
71                score,
72                vector: None,
73                metadata: None,
74            })
75            .collect())
76    }
77
78    fn remove(&mut self, id: &VectorId) -> Result<bool> {
79        Ok(self.vectors.remove(id).is_some())
80    }
81
82    fn len(&self) -> usize {
83        self.vectors.len()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn test_flat_index() -> Result<()> {
93        let mut index = FlatIndex::new(3, DistanceMetric::Euclidean);
94
95        index.add("v1".to_string(), vec![1.0, 0.0, 0.0])?;
96        index.add("v2".to_string(), vec![0.0, 1.0, 0.0])?;
97        index.add("v3".to_string(), vec![0.0, 0.0, 1.0])?;
98
99        let query = vec![1.0, 0.0, 0.0];
100        let results = index.search(&query, 2)?;
101
102        assert_eq!(results.len(), 2);
103        assert_eq!(results[0].id, "v1");
104        assert!(results[0].score < 0.01);
105
106        Ok(())
107    }
108}