Skip to main content

shodh_redb/
vector_ops.rs

1use alloc::collections::BinaryHeap;
2use alloc::vec;
3use alloc::vec::Vec;
4use core::fmt::{self, Debug};
5
6use crate::vector::SQVec;
7
8#[cfg(all(target_arch = "x86_64", feature = "std"))]
9mod simd_x86;
10
11/// Portable `f32::sqrt` that works in `no_std` on wasm32.
12///
13/// On targets with `std`, delegates to the hardware/libm-backed `f32::sqrt()`.
14/// On `no_std` (e.g. wasm32 without wasi), uses a bit-level Newton's method
15/// implementation that converges in a fixed number of iterations.
16#[inline]
17fn sqrt_f32(x: f32) -> f32 {
18    #[cfg(feature = "std")]
19    {
20        x.sqrt()
21    }
22    #[cfg(not(feature = "std"))]
23    {
24        if x < 0.0 || x.is_nan() {
25            return f32::NAN;
26        }
27        if x == 0.0 || x.is_infinite() {
28            return x;
29        }
30        // Subnormals (x < MIN_POSITIVE) have a biased exponent of 0, which
31        // makes the bit-manipulation initial guess wildly inaccurate. Scale
32        // into the normal range, compute sqrt there, then scale back.
33        if x < f32::MIN_POSITIVE {
34            // 2^24 = 16777216.0; sqrt(2^24) = 4096.0
35            return sqrt_f32(x * 16_777_216.0) / 4096.0;
36        }
37        // Initial estimate via bit manipulation (fast inverse sqrt trick variant)
38        let bits = x.to_bits();
39        let guess_bits = (bits >> 1) + 0x1FC0_0000;
40        let mut guess = f32::from_bits(guess_bits);
41        // Newton-Raphson iterations (5 iterations for full f32 precision)
42        guess = 0.5 * (guess + x / guess);
43        guess = 0.5 * (guess + x / guess);
44        guess = 0.5 * (guess + x / guess);
45        guess = 0.5 * (guess + x / guess);
46        guess = 0.5 * (guess + x / guess);
47        guess
48    }
49}
50
51// ---------------------------------------------------------------------------
52// Distance metric enum
53// ---------------------------------------------------------------------------
54
55/// Specifies the distance metric for vector similarity search.
56///
57/// Lower distance values indicate more similar vectors for all metrics.
58///
59/// # Usage
60///
61/// ```rust,ignore
62/// use shodh_redb::{DistanceMetric, FixedVec, TableDefinition, ReadableTable};
63///
64/// let query = [1.0f32, 0.0, 0.0];
65/// let metric = DistanceMetric::Cosine;
66///
67/// // Scan and rank vectors
68/// for entry in table.iter()? {
69///     let (key, guard) = entry?;
70///     let vec = guard.value();
71///     let dist = metric.compute(&query, &vec);
72///     println!("{}: {}", key.value(), dist);
73/// }
74/// ```
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76pub enum DistanceMetric {
77    /// Cosine distance: `1.0 - cosine_similarity(a, b)`. Range `[0.0, 2.0]`.
78    Cosine,
79    /// Squared Euclidean distance: `sum((a_i - b_i)^2)`. Range `[0.0, inf)`.
80    EuclideanSq,
81    /// Dot product distance: `-dot_product(a, b)`. Negate so lower = more similar.
82    /// Use with L2-normalized vectors for equivalent cosine ranking without sqrt.
83    DotProduct,
84    /// Manhattan (L1) distance: `sum(|a_i - b_i|)`. Range `[0.0, inf)`.
85    Manhattan,
86}
87
88impl DistanceMetric {
89    /// Computes the distance between two f32 vectors using this metric.
90    ///
91    /// Lower values indicate more similar vectors for all metrics.
92    ///
93    /// Returns [`f32::MAX`] when the vectors have mismatched dimensions
94    /// (truncated or corrupted data) to prevent garbage results from being
95    /// promoted to the top of nearest-neighbor heaps. Also returns
96    /// [`f32::MAX`] when the computed distance is NaN (e.g. due to NaN
97    /// elements in the input vectors) to avoid silent NaN propagation
98    /// through search results.
99    ///
100    /// NaN distances are replaced with `f32::MAX` (treated as maximally
101    /// distant) so that `BinaryHeap` ordering is never corrupted.
102    #[inline]
103    pub fn compute(&self, a: &[f32], b: &[f32]) -> f32 {
104        if a.len() != b.len() {
105            return f32::MAX;
106        }
107        let d = match self {
108            Self::Cosine => cosine_distance(a, b),
109            Self::EuclideanSq => euclidean_distance_sq(a, b),
110            Self::DotProduct => -dot_product(a, b),
111            Self::Manhattan => manhattan_distance(a, b),
112        };
113        if d.is_nan() { f32::MAX } else { d }
114    }
115}
116
117impl fmt::Display for DistanceMetric {
118    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
119        match self {
120            Self::Cosine => f.write_str("cosine"),
121            Self::EuclideanSq => f.write_str("euclidean_sq"),
122            Self::DotProduct => f.write_str("dot_product"),
123            Self::Manhattan => f.write_str("manhattan"),
124        }
125    }
126}
127
128// ---------------------------------------------------------------------------
129// Core distance functions
130// ---------------------------------------------------------------------------
131
132/// Computes the dot product of two f32 slices.
133///
134/// # Panics
135///
136/// Panics if `a.len() != b.len()`. Callers that need graceful mismatch
137/// handling should use [`DistanceMetric::compute`] instead.
138#[inline]
139pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
140    assert_eq!(a.len(), b.len(), "dot_product: dimension mismatch");
141    #[cfg(all(target_arch = "x86_64", feature = "std"))]
142    {
143        if is_x86_feature_detected!("avx2") {
144            // SAFETY: AVX2 detected; slices have equal length (asserted above).
145            return unsafe { simd_x86::dot_product_avx2(a, b) };
146        }
147    }
148    dot_product_scalar(a, b)
149}
150
151#[inline]
152fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
153    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
154}
155
156/// Computes the squared Euclidean distance between two f32 slices.
157///
158/// Returns the sum of squared element-wise differences. Take the square root
159/// for the actual Euclidean distance, but the squared form is sufficient for
160/// nearest-neighbor comparisons and avoids the sqrt cost.
161///
162/// # Panics
163///
164/// Panics if `a.len() != b.len()`.
165#[inline]
166pub fn euclidean_distance_sq(a: &[f32], b: &[f32]) -> f32 {
167    assert_eq!(
168        a.len(),
169        b.len(),
170        "euclidean_distance_sq: dimension mismatch"
171    );
172    #[cfg(all(target_arch = "x86_64", feature = "std"))]
173    {
174        if is_x86_feature_detected!("avx2") {
175            // SAFETY: AVX2 detected; slices have equal length (asserted above).
176            return unsafe { simd_x86::euclidean_distance_sq_avx2(a, b) };
177        }
178    }
179    euclidean_distance_sq_scalar(a, b)
180}
181
182#[inline]
183fn euclidean_distance_sq_scalar(a: &[f32], b: &[f32]) -> f32 {
184    a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
185}
186
187/// Computes the cosine similarity between two f32 slices.
188///
189/// Returns a value in `[-1.0, 1.0]` where 1.0 means identical direction,
190/// 0.0 means orthogonal, and -1.0 means opposite direction.
191///
192/// Returns 0.0 if either vector has zero magnitude.
193///
194/// # Panics
195///
196/// Panics if `a.len() != b.len()`.
197#[inline]
198pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199    assert_eq!(a.len(), b.len(), "cosine_similarity: dimension mismatch");
200    #[cfg(all(target_arch = "x86_64", feature = "std"))]
201    {
202        if is_x86_feature_detected!("avx2") {
203            // SAFETY: AVX2 detected; slices have equal length (asserted above).
204            return unsafe { simd_x86::cosine_similarity_avx2(a, b) };
205        }
206    }
207    cosine_similarity_scalar(a, b)
208}
209
210#[inline]
211fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
212    let mut dot = 0.0f32;
213    let mut norm_a = 0.0f32;
214    let mut norm_b = 0.0f32;
215    for i in 0..a.len() {
216        let x = a[i];
217        let y = b[i];
218        dot += x * y;
219        norm_a += x * x;
220        norm_b += y * y;
221    }
222    let denom = sqrt_f32(norm_a) * sqrt_f32(norm_b);
223    if denom == 0.0 {
224        0.0
225    } else {
226        (dot / denom).clamp(-1.0, 1.0)
227    }
228}
229
230/// Computes the cosine distance between two f32 slices.
231///
232/// Defined as `1.0 - cosine_similarity(a, b)`, returning a value in `[0.0, 2.0]`
233/// where 0.0 means identical direction and 2.0 means opposite direction.
234///
235/// Returns 1.0 if either vector has zero magnitude.
236///
237/// # Panics
238///
239/// Panics if `a.len() != b.len()`.
240#[inline]
241pub fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
242    1.0 - cosine_similarity(a, b)
243}
244
245/// Computes the Manhattan (L1) distance between two f32 slices.
246///
247/// Returns the sum of absolute element-wise differences.
248///
249/// # Panics
250///
251/// Panics if `a.len() != b.len()`.
252#[inline]
253pub fn manhattan_distance(a: &[f32], b: &[f32]) -> f32 {
254    assert_eq!(a.len(), b.len(), "manhattan_distance: dimension mismatch");
255    #[cfg(all(target_arch = "x86_64", feature = "std"))]
256    {
257        if is_x86_feature_detected!("avx2") {
258            // SAFETY: AVX2 detected; slices have equal length (asserted above).
259            return unsafe { simd_x86::manhattan_distance_avx2(a, b) };
260        }
261    }
262    manhattan_distance_scalar(a, b)
263}
264
265#[inline]
266fn manhattan_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
267    a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
268}
269
270/// Computes the Hamming distance between two byte slices interpreted as binary vectors.
271///
272/// Counts the number of bits that differ between `a` and `b`. Useful for binary
273/// embeddings (e.g., Cohere binary, Matryoshka quantized vectors).
274///
275/// If lengths differ, computes over the shorter length.
276#[inline]
277pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
278    let len = a.len().min(b.len());
279    let a = &a[..len];
280    let b = &b[..len];
281    #[cfg(all(target_arch = "x86_64", feature = "std"))]
282    {
283        if is_x86_feature_detected!("avx2") {
284            // SAFETY: AVX2 detected; slices are trimmed to equal length above.
285            return unsafe { simd_x86::hamming_distance_avx2(a, b) };
286        }
287    }
288    hamming_distance_scalar(a, b)
289}
290
291#[inline]
292fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
293    a.iter()
294        .zip(b.iter())
295        .map(|(x, y)| (x ^ y).count_ones())
296        .sum()
297}
298
299// ---------------------------------------------------------------------------
300// Normalization
301// ---------------------------------------------------------------------------
302
303/// Computes the L2 (Euclidean) norm of a vector.
304///
305/// Returns `sqrt(sum(x_i^2))`.
306#[inline]
307pub fn l2_norm(v: &[f32]) -> f32 {
308    sqrt_f32(v.iter().map(|x| x * x).sum::<f32>())
309}
310
311/// Normalizes a vector to unit length (L2 norm = 1.0) in place.
312///
313/// After normalization, `dot_product(v, v) ~= 1.0` and `cosine_similarity`
314/// reduces to a simple `dot_product`, which is significantly faster.
315///
316/// If the vector has zero magnitude, it is left unchanged. For vectors with
317/// extremely large elements (where the raw norm overflows to infinity), the
318/// vector is first scaled down by the maximum absolute element value before
319/// computing the norm to avoid producing a zero vector.
320#[inline]
321pub fn l2_normalize(v: &mut [f32]) {
322    let norm = l2_norm(v);
323    if norm.is_finite() && norm > 0.0 {
324        let inv = 1.0 / norm;
325        for x in v.iter_mut() {
326            *x *= inv;
327        }
328    } else if !norm.is_finite() {
329        // Norm overflowed to Inf. Scale down by max absolute value first,
330        // then normalize the scaled vector to get correct unit direction.
331        let max_abs = v.iter().fold(0.0f32, |acc, &x| {
332            let a = x.abs();
333            if a > acc { a } else { acc }
334        });
335        if max_abs == 0.0 || !max_abs.is_finite() {
336            return;
337        }
338        let inv_max = 1.0 / max_abs;
339        for x in v.iter_mut() {
340            *x *= inv_max;
341        }
342        let scaled_norm = l2_norm(v);
343        if scaled_norm.is_finite() && scaled_norm > 0.0 {
344            let inv = 1.0 / scaled_norm;
345            for x in v.iter_mut() {
346                *x *= inv;
347            }
348        }
349    }
350}
351
352/// Returns a new L2-normalized copy of the input vector.
353///
354/// If the input has zero magnitude, returns a zero vector of the same length.
355#[inline]
356pub fn l2_normalized(v: &[f32]) -> Vec<f32> {
357    let mut out = v.to_vec();
358    l2_normalize(&mut out);
359    out
360}
361
362// ---------------------------------------------------------------------------
363// Quantization
364// ---------------------------------------------------------------------------
365
366/// Converts an f32 vector to a binary quantized representation.
367///
368/// Each f32 dimension is mapped to a single bit: 1 if positive, 0 otherwise.
369/// The result is packed into bytes (MSB-first within each byte), with the
370/// output length equal to `ceil(input.len() / 8)`.
371///
372/// This gives 32x compression over f32 storage. Use
373/// [`hamming_distance`] to compare binary vectors.
374///
375/// # Example
376///
377/// ```rust,ignore
378/// let v = [1.0f32, -0.5, 0.3, -0.1, 0.0, 0.7, -0.2, 0.9];
379/// let bq = shodh_redb::quantize_binary(&v);
380/// // bit pattern: [1,0,1,0, 0,1,0,1] = 0b10100101 = 0xA5
381/// assert_eq!(bq, vec![0xA5]);
382/// ```
383pub fn quantize_binary(v: &[f32]) -> Vec<u8> {
384    let byte_count = v.len().div_ceil(8);
385    let mut result = vec![0u8; byte_count];
386    for (i, &val) in v.iter().enumerate() {
387        if val > 0.0 {
388            let byte_idx = i / 8;
389            let bit_idx = 7 - (i % 8); // MSB-first
390            result[byte_idx] |= 1 << bit_idx;
391        }
392    }
393    result
394}
395
396/// Scalar-quantizes an f32 vector to u8 codes with min/max scale factors.
397///
398/// Maps each f32 value linearly to the `[0, 255]` range based on the vector's
399/// min and max values. Returns an [`SQVec`] containing the scale factors and codes.
400///
401/// This gives approximately 4x compression over f32 storage with bounded
402/// quantization error of `(max - min) / 510` per dimension.
403///
404/// # Example
405///
406/// ```rust,ignore
407/// let v = [0.0f32, 0.5, 1.0, 0.25];
408/// let sq: SQVec<4> = shodh_redb::quantize_scalar(&v);
409/// assert_eq!(sq.min_val, 0.0);
410/// assert_eq!(sq.max_val, 1.0);
411/// assert_eq!(sq.codes[2], 255); // max maps to 255
412/// ```
413pub fn quantize_scalar<const N: usize>(v: &[f32; N]) -> SQVec<N> {
414    let mut min_val = f32::INFINITY;
415    let mut max_val = f32::NEG_INFINITY;
416    for &x in v {
417        if x < min_val {
418            min_val = x;
419        }
420        if x > max_val {
421            max_val = x;
422        }
423    }
424
425    let mut codes = [0u8; N];
426    let range = max_val - min_val;
427    if !range.is_finite() {
428        // Input contained NaN or Inf -- quantization is meaningless.
429        // Return zero codes with clamped min/max so dequantize produces 0.0.
430        return SQVec {
431            min_val: 0.0,
432            max_val: 0.0,
433            codes,
434        };
435    }
436    if range >= f32::MIN_POSITIVE {
437        let inv_range = 255.0 / range;
438        if inv_range.is_finite() {
439            for (i, &x) in v.iter().enumerate() {
440                // Quantize to [0, 255]: value is guaranteed non-negative and <= 255.5
441                // because x is clamped within [min_val, max_val].
442                #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
443                let q = ((x - min_val) * inv_range + 0.5) as u8;
444                codes[i] = q;
445            }
446        }
447        // else: inv_range is Inf/NaN from a subnormal range; treat as zero-range
448    }
449    // If range < MIN_POSITIVE all codes stay 0, and dequantize returns min_val for all
450
451    SQVec {
452        min_val,
453        max_val,
454        codes,
455    }
456}
457
458/// Dequantizes an [`SQVec`] back to an array of f32 values.
459///
460/// This is a convenience wrapper around [`SQVec::dequantize`].
461#[inline]
462pub fn dequantize_scalar<const N: usize>(sq: &SQVec<N>) -> [f32; N] {
463    sq.dequantize()
464}
465
466/// Computes approximate squared Euclidean distance between an f32 query and
467/// a scalar-quantized vector, dequantizing on the fly.
468///
469/// This avoids materializing the full f32 vector when doing distance
470/// comparisons during search, reducing memory bandwidth.
471///
472/// Returns [`f32::MAX`] if the scale factors are non-finite or the result
473/// is NaN, preventing silent corruption of search rankings.
474#[inline]
475pub fn sq_euclidean_distance_sq<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
476    let range = sq.max_val - sq.min_val;
477    if !range.is_finite() {
478        return f32::MAX;
479    }
480    if range == 0.0 {
481        // All codes dequantize to min_val -- compute exact per-dimension distances.
482        let d: f32 = query
483            .iter()
484            .map(|&q| {
485                let diff = q - sq.min_val;
486                diff * diff
487            })
488            .sum();
489        return if d.is_nan() { f32::MAX } else { d };
490    }
491    let scale = range / 255.0;
492    let mut sum = 0.0f32;
493    for (i, &q) in query.iter().enumerate() {
494        let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
495        let diff = q - dequant;
496        sum += diff * diff;
497    }
498    if sum.is_nan() { f32::MAX } else { sum }
499}
500
501/// Computes approximate dot product between an f32 query and a scalar-quantized
502/// vector, dequantizing on the fly.
503///
504/// Returns `0.0` if the scale factors are non-finite. Returns the raw dot
505/// product otherwise (callers negate for distance ranking via
506/// [`DistanceMetric::DotProduct`]).
507#[inline]
508pub fn sq_dot_product<const N: usize>(query: &[f32; N], sq: &SQVec<N>) -> f32 {
509    let range = sq.max_val - sq.min_val;
510    if !range.is_finite() {
511        return 0.0;
512    }
513    if range == 0.0 {
514        let d = query.iter().sum::<f32>() * sq.min_val;
515        return if d.is_nan() { 0.0 } else { d };
516    }
517    let scale = range / 255.0;
518    let mut sum = 0.0f32;
519    for (i, &q) in query.iter().enumerate() {
520        let dequant = sq.min_val + f32::from(sq.codes[i]) * scale;
521        sum += q * dequant;
522    }
523    if sum.is_nan() { 0.0 } else { sum }
524}
525
526// ---------------------------------------------------------------------------
527// Top-K scan
528// ---------------------------------------------------------------------------
529
530/// A scored result from a nearest-neighbor search.
531#[derive(Debug, Clone)]
532pub struct Neighbor<K> {
533    /// The key of the matching row.
534    pub key: K,
535    /// The distance from the query vector (lower = more similar).
536    pub distance: f32,
537}
538
539impl<K> PartialEq for Neighbor<K> {
540    fn eq(&self, other: &Self) -> bool {
541        // Treat NaN as equal to NaN for heap consistency (IEEE NaN != NaN breaks Eq).
542        self.distance.to_bits() == other.distance.to_bits()
543    }
544}
545
546impl<K> Eq for Neighbor<K> {}
547
548impl<K> PartialOrd for Neighbor<K> {
549    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
550        Some(self.cmp(other))
551    }
552}
553
554impl<K> Ord for Neighbor<K> {
555    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
556        // BinaryHeap is a max-heap; we want the *largest* distance at the top
557        // so we can efficiently evict it. NaN sorts as greater-than everything
558        // so NaN entries sit at the heap root and are evicted first.
559        self.distance.total_cmp(&other.distance)
560    }
561}
562
563/// Brute-force top-k nearest neighbor scan over an iterator of `(key, vector)` pairs.
564///
565/// Returns up to `k` nearest neighbors sorted by ascending distance (closest first).
566/// The distance function should return lower values for more similar vectors.
567///
568/// This is the fundamental building block for vector search. Higher-level index
569/// structures (IVF, HNSW) use this for scanning candidate shortlists.
570///
571/// # Example
572///
573/// ```rust,ignore
574/// use shodh_redb::{nearest_k, DistanceMetric, FixedVec, ReadableTable};
575///
576/// let query = [1.0f32, 0.0, 0.0, 0.0];
577/// let metric = DistanceMetric::Cosine;
578///
579/// let results = nearest_k(
580///     table.iter()?.map(|r| {
581///         let (k, v) = r.unwrap();
582///         (k.value(), v.value())
583///     }),
584///     &query,
585///     10,
586///     |a, b| metric.compute(a, b),
587/// );
588///
589/// for neighbor in &results {
590///     println!("key={}, distance={}", neighbor.key, neighbor.distance);
591/// }
592/// ```
593pub fn nearest_k<K, I, F>(iter: I, query: &[f32], k: usize, distance_fn: F) -> Vec<Neighbor<K>>
594where
595    I: Iterator<Item = (K, Vec<f32>)>,
596    F: Fn(&[f32], &[f32]) -> f32,
597{
598    if k == 0 {
599        return Vec::new();
600    }
601
602    // Max-heap of size k: the root is the worst (largest distance) candidate.
603    // When we find something better, we pop the worst and push the new one.
604    let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
605
606    for (key, vec) in iter {
607        let dist = distance_fn(query, &vec);
608        if heap.len() < k {
609            heap.push(Neighbor {
610                key,
611                distance: dist,
612            });
613        } else if heap
614            .peek()
615            .is_some_and(|worst| dist.total_cmp(&worst.distance).is_lt())
616        {
617            heap.pop();
618            heap.push(Neighbor {
619                key,
620                distance: dist,
621            });
622        }
623    }
624
625    let mut results: Vec<Neighbor<K>> = heap.into_vec();
626    results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
627    results
628}
629
630/// Brute-force top-k scan with a fixed-size array query (zero-copy variant).
631///
632/// Same as [`nearest_k`] but takes `[f32; N]` vectors from the iterator,
633/// avoiding the `Vec<f32>` allocation overhead for fixed-dimension tables.
634pub fn nearest_k_fixed<K, I, F, const N: usize>(
635    iter: I,
636    query: &[f32; N],
637    k: usize,
638    distance_fn: F,
639) -> Vec<Neighbor<K>>
640where
641    I: Iterator<Item = (K, [f32; N])>,
642    F: Fn(&[f32], &[f32]) -> f32,
643{
644    if k == 0 {
645        return Vec::new();
646    }
647
648    let mut heap: BinaryHeap<Neighbor<K>> = BinaryHeap::with_capacity(k + 1);
649
650    for (key, vec) in iter {
651        let dist = distance_fn(query.as_slice(), vec.as_slice());
652        if heap.len() < k {
653            heap.push(Neighbor {
654                key,
655                distance: dist,
656            });
657        } else if heap
658            .peek()
659            .is_some_and(|worst| dist.total_cmp(&worst.distance).is_lt())
660        {
661            heap.pop();
662            heap.push(Neighbor {
663                key,
664                distance: dist,
665            });
666        }
667    }
668
669    let mut results: Vec<Neighbor<K>> = heap.into_vec();
670    results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
671
672    results
673}
674
675// ---------------------------------------------------------------------------
676// LE byte helpers
677// ---------------------------------------------------------------------------
678
679/// Writes a slice of f32 values as little-endian bytes into a destination buffer.
680///
681/// Useful for populating `insert_reserve` buffers when using `FixedVec<N>`.
682///
683/// If the buffer is smaller than `values.len() * 4`, writes only as many
684/// complete f32 values as fit. If `values` is shorter, only those values
685/// are written.
686#[inline]
687pub fn write_f32_le(dest: &mut [u8], values: &[f32]) {
688    let count = (dest.len() / 4).min(values.len());
689    #[cfg(target_endian = "little")]
690    {
691        let byte_len = count * 4;
692        // SAFETY: On LE targets, f32 memory layout matches LE bytes.
693        // `count` ensures we don't read past `values` or write past `dest`.
694        unsafe {
695            core::ptr::copy_nonoverlapping(
696                values.as_ptr().cast::<u8>(),
697                dest.as_mut_ptr(),
698                byte_len,
699            );
700        }
701    }
702    #[cfg(not(target_endian = "little"))]
703    {
704        for (i, val) in values.iter().enumerate().take(count) {
705            let start = i * 4;
706            dest[start..start + 4].copy_from_slice(&val.to_le_bytes());
707        }
708    }
709}
710
711/// Reads little-endian f32 values from a byte slice.
712///
713/// If `src.len()` is not a multiple of 4, trailing bytes are ignored.
714#[inline]
715pub fn read_f32_le(src: &[u8]) -> Vec<f32> {
716    let usable = src.len() - (src.len() % 4);
717    let count = usable / 4;
718    #[cfg(target_endian = "little")]
719    {
720        let mut result = vec![0.0f32; count];
721        // SAFETY: On LE targets, f32 byte representation matches memory layout.
722        // `usable` = count * 4, both buffers are valid for that length.
723        unsafe {
724            core::ptr::copy_nonoverlapping(src.as_ptr(), result.as_mut_ptr().cast::<u8>(), usable);
725        }
726        result
727    }
728    #[cfg(not(target_endian = "little"))]
729    {
730        let mut result = Vec::with_capacity(count);
731        for i in 0..count {
732            let start = i * 4;
733            // Infallible: start + 4 <= usable <= src.len() by loop bounds.
734            let bytes: [u8; 4] = src[start..start + 4].try_into().unwrap_or([0u8; 4]);
735            result.push(f32::from_le_bytes(bytes));
736        }
737        result
738    }
739}
740
741#[cfg(test)]
742#[allow(
743    clippy::float_cmp,
744    clippy::cast_precision_loss,
745    clippy::cast_possible_truncation
746)]
747mod tests {
748    use super::*;
749
750    /// Dimensions that exercise all tail-loop edge cases:
751    /// 1 (pure tail), 7 (tail=7), 8 (exact chunk), 15, 16, 31, 32, 128, 384, 768
752    const DIMS: &[usize] = &[1, 7, 8, 15, 16, 31, 32, 128, 384, 768];
753
754    fn make_vecs(dim: usize) -> (Vec<f32>, Vec<f32>) {
755        let a: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.1 - 5.0).collect();
756        let b: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.2 + 1.0).collect();
757        (a, b)
758    }
759
760    fn assert_close(actual: f32, expected: f32, tol: f32, label: &str, dim: usize) {
761        let diff = (actual - expected).abs();
762        let scale = expected.abs().max(1.0);
763        assert!(
764            diff < tol * scale,
765            "{label} dim={dim}: expected={expected}, actual={actual}, diff={diff}"
766        );
767    }
768
769    #[test]
770    fn dot_product_matches_scalar() {
771        for &dim in DIMS {
772            let (a, b) = make_vecs(dim);
773            let scalar = dot_product_scalar(&a, &b);
774            let result = dot_product(&a, &b);
775            assert_close(result, scalar, 1e-5, "dot_product", dim);
776        }
777    }
778
779    #[test]
780    fn euclidean_distance_sq_matches_scalar() {
781        for &dim in DIMS {
782            let (a, b) = make_vecs(dim);
783            let scalar = euclidean_distance_sq_scalar(&a, &b);
784            let result = euclidean_distance_sq(&a, &b);
785            assert_close(result, scalar, 1e-5, "euclidean_distance_sq", dim);
786        }
787    }
788
789    #[test]
790    fn cosine_similarity_matches_scalar() {
791        for &dim in DIMS {
792            let (a, b) = make_vecs(dim);
793            let scalar = cosine_similarity_scalar(&a, &b);
794            let result = cosine_similarity(&a, &b);
795            assert_close(result, scalar, 1e-5, "cosine_similarity", dim);
796        }
797    }
798
799    #[test]
800    fn manhattan_distance_matches_scalar() {
801        for &dim in DIMS {
802            let (a, b) = make_vecs(dim);
803            let scalar = manhattan_distance_scalar(&a, &b);
804            let result = manhattan_distance(&a, &b);
805            assert_close(result, scalar, 1e-5, "manhattan_distance", dim);
806        }
807    }
808
809    #[test]
810    fn hamming_distance_matches_scalar() {
811        for dim in [1usize, 7, 8, 15, 16, 31, 32, 64, 128, 256] {
812            let a: Vec<u8> = (0..dim).map(|i| (i * 37 + 13) as u8).collect();
813            let b: Vec<u8> = (0..dim).map(|i| (i * 53 + 7) as u8).collect();
814            let scalar = hamming_distance_scalar(&a, &b);
815            let result = hamming_distance(&a, &b);
816            assert_eq!(
817                result, scalar,
818                "hamming_distance dim={dim}: scalar={scalar}, simd={result}"
819            );
820        }
821    }
822
823    #[test]
824    fn dot_product_zero_vectors() {
825        let a = vec![0.0f32; 128];
826        let b = vec![0.0f32; 128];
827        assert_eq!(dot_product(&a, &b), 0.0);
828    }
829
830    #[test]
831    fn cosine_similarity_zero_vector() {
832        let a = vec![0.0f32; 32];
833        let b = vec![1.0f32; 32];
834        assert_eq!(cosine_similarity(&a, &b), 0.0);
835    }
836
837    #[test]
838    fn cosine_similarity_identical() {
839        let a: Vec<f32> = (0..64).map(|i| (i as f32) * 0.3 + 0.1).collect();
840        let result = cosine_similarity(&a, &a);
841        assert!(
842            (result - 1.0).abs() < 1e-6,
843            "identical vectors: sim={result}"
844        );
845    }
846
847    #[test]
848    fn cosine_similarity_opposite() {
849        let a: Vec<f32> = (0..64).map(|i| (i as f32) * 0.3 + 0.1).collect();
850        let b: Vec<f32> = a.iter().map(|x| -x).collect();
851        let result = cosine_similarity(&a, &b);
852        assert!(
853            (result - (-1.0)).abs() < 1e-6,
854            "opposite vectors: sim={result}"
855        );
856    }
857
858    #[test]
859    fn hamming_distance_known_pattern() {
860        // 0xFF ^ 0x00 = 0xFF -> 8 bits per byte
861        let a = vec![0xFF_u8; 32];
862        let b = vec![0x00_u8; 32];
863        assert_eq!(hamming_distance(&a, &b), 32 * 8);
864    }
865
866    #[test]
867    fn hamming_distance_identical() {
868        let a = vec![0xAB_u8; 64];
869        assert_eq!(hamming_distance(&a, &a), 0);
870    }
871
872    #[test]
873    fn euclidean_distance_sq_identical() {
874        let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
875        assert_eq!(euclidean_distance_sq(&a, &a), 0.0);
876    }
877
878    #[test]
879    fn manhattan_distance_identical() {
880        let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
881        assert_eq!(manhattan_distance(&a, &a), 0.0);
882    }
883
884    #[test]
885    #[should_panic(expected = "dimension mismatch")]
886    fn dot_product_dimension_mismatch_panics() {
887        let a = vec![1.0f32; 10];
888        let b = vec![1.0f32; 11];
889        dot_product(&a, &b);
890    }
891
892    #[test]
893    #[should_panic(expected = "dimension mismatch")]
894    fn euclidean_dimension_mismatch_panics() {
895        let a = vec![1.0f32; 10];
896        let b = vec![1.0f32; 11];
897        euclidean_distance_sq(&a, &b);
898    }
899
900    #[test]
901    fn distance_metric_nan_returns_max() {
902        let a = [1.0f32, f32::NAN, 3.0];
903        let b = [4.0f32, 5.0, 6.0];
904        let d = DistanceMetric::EuclideanSq.compute(&a, &b);
905        assert_eq!(d, f32::MAX);
906    }
907
908    #[test]
909    fn distance_metric_mismatch_returns_max() {
910        let a = [1.0f32, 2.0];
911        let b = [1.0f32, 2.0, 3.0];
912        let d = DistanceMetric::Cosine.compute(&a, &b);
913        assert_eq!(d, f32::MAX);
914    }
915
916    #[test]
917    fn write_read_f32_le_roundtrip() {
918        let values: Vec<f32> = (0..100).map(|i| (i as f32) * 0.123 - 6.0).collect();
919        let mut buf = vec![0u8; values.len() * 4];
920        write_f32_le(&mut buf, &values);
921        let decoded = read_f32_le(&buf);
922        assert_eq!(decoded, values);
923    }
924}