Skip to main content

uni_query_functions/
similar_to.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `similar_to()` expression function — unified similarity scoring.
5//!
6//! Dispatches to vector cosine similarity, BM25 full-text scoring, or
7//! hybrid fusion based on schema types. Returns a float in `[0, 1]`
8//! where higher means more similar.
9//!
10//! This is a **point computation** (score one bound node), not a search
11//! (top-K index scan). It works in WHERE, RETURN, WITH, ORDER BY, and
12//! Locy rule bodies.
13
14use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20/// Named error types for `similar_to()` validation failures.
21#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23    #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24    NoIndex { label: String, property: String },
25
26    #[error(
27        "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28    )]
29    TypeMismatch { source_index: usize },
30
31    #[error(
32        "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33    )]
34    NoEmbeddingConfig { source_index: usize },
35
36    #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37    WeightsLengthMismatch {
38        weights_len: usize,
39        sources_len: usize,
40    },
41
42    #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43    WeightsNotNormalized { sum: f32 },
44
45    #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46    InvalidMethod { method: String },
47
48    #[error("similar_to: {message}")]
49    InvalidOption { message: String },
50
51    #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52    DimensionMismatch { a: usize, b: usize },
53
54    #[error("similar_to: expected vector or list of numbers, got {actual}")]
55    InvalidVectorValue { actual: String },
56
57    #[error("similar_to: weighted fusion requires 'weights' option")]
58    WeightsRequired,
59
60    #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61    InvalidArity { count: usize },
62
63    #[error("similar_to requires GraphExecutionContext")]
64    NoGraphContext,
65}
66
67/// Fusion method for multi-source scoring.
68#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70    /// Reciprocal Rank Fusion (default). Falls back to equal-weight
71    /// fusion in point-computation context.
72    #[default]
73    Rrf,
74    /// Weighted sum of per-source scores.
75    Weighted,
76}
77
78/// Options for `similar_to()` controlling fusion and scoring behavior.
79#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81    /// Fusion algorithm when multiple sources are present.
82    pub method: FusionMethod,
83    /// Per-source weights for weighted fusion. Must sum to 1.0.
84    pub weights: Option<Vec<f32>>,
85    /// RRF constant k (default 60).
86    pub k: usize,
87    /// BM25 saturation constant for FTS normalization (default 1.0).
88    pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92    fn default() -> Self {
93        Self {
94            method: FusionMethod::Rrf,
95            weights: None,
96            k: 60,
97            fts_k: 1.0,
98        }
99    }
100}
101
102/// Parse options from a `Value::Map`.
103pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104    let map = match value {
105        Value::Map(m) => m,
106        Value::Null => return Ok(SimilarToOptions::default()),
107        _ => {
108            return Err(SimilarToError::InvalidOption {
109                message: format!("options must be a map, got {:?}", value),
110            });
111        }
112    };
113
114    let mut opts = SimilarToOptions::default();
115
116    if let Some(method_val) = map.get("method") {
117        match method_val.as_str() {
118            Some("rrf") => opts.method = FusionMethod::Rrf,
119            Some("weighted") => opts.method = FusionMethod::Weighted,
120            Some(other) => {
121                return Err(SimilarToError::InvalidMethod {
122                    method: other.to_string(),
123                });
124            }
125            None => {
126                return Err(SimilarToError::InvalidOption {
127                    message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128                });
129            }
130        }
131    }
132
133    if let Some(weights_val) = map.get("weights") {
134        match weights_val {
135            Value::List(list) => {
136                let weights: Result<Vec<f32>, SimilarToError> = list
137                    .iter()
138                    .map(|v| {
139                        v.as_f64()
140                            .map(|f| f as f32)
141                            .ok_or_else(|| SimilarToError::InvalidOption {
142                                message: "weight must be a number".to_string(),
143                            })
144                    })
145                    .collect();
146                opts.weights = Some(weights?);
147            }
148            _ => {
149                return Err(SimilarToError::InvalidOption {
150                    message: "'weights' must be a list of numbers".to_string(),
151                });
152            }
153        }
154    }
155
156    if let Some(k_val) = map.get("k") {
157        opts.k = k_val
158            .as_i64()
159            .ok_or_else(|| SimilarToError::InvalidOption {
160                message: "'k' must be an integer".to_string(),
161            })? as usize;
162    }
163
164    if let Some(fts_k_val) = map.get("fts_k") {
165        opts.fts_k = fts_k_val
166            .as_f64()
167            .ok_or_else(|| SimilarToError::InvalidOption {
168                message: "'fts_k' must be a number".to_string(),
169            })? as f32;
170    }
171
172    Ok(opts)
173}
174
175/// What type of source a property represents.
176#[derive(Debug, Clone)]
177pub enum SourceType {
178    /// Vector property with a vector index.
179    Vector {
180        metric: DistanceMetric,
181        has_embedding_config: bool,
182    },
183    /// String property with a full-text index.
184    Fts,
185}
186
187/// Resolve the source type for a property given the schema.
188pub fn resolve_source_type(
189    schema: &Schema,
190    label: &str,
191    property: &str,
192) -> Result<SourceType, SimilarToError> {
193    // Check vector index first
194    if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195        return Ok(SourceType::Vector {
196            metric: vec_config.metric.clone(),
197            has_embedding_config: vec_config.embedding_config.is_some(),
198        });
199    }
200
201    // Check full-text index
202    if schema
203        .fulltext_index_for_property(label, property)
204        .is_some()
205    {
206        return Ok(SourceType::Fts);
207    }
208
209    Err(SimilarToError::NoIndex {
210        label: label.to_string(),
211        property: property.to_string(),
212    })
213}
214
215/// Compute cosine similarity between two vectors, returning a score in [-1, 1].
216///
217/// Arithmetic is performed in f64 (see `cosine_similarity_inner`) and the
218/// clamped result is narrowed to f32.
219pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220    cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223/// Score two vectors using the specified distance metric, returning a similarity
224/// score where higher means more similar.
225///
226/// - **Cosine**: raw cosine similarity in \[-1, 1\] (delegates to [`cosine_similarity`]).
227/// - **L2**: `1 / (1 + d²)` where d² is squared Euclidean distance; range (0, 1\].
228/// - **Dot**: raw dot product (for normalised vectors equals cosine similarity).
229pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230    if a.len() != b.len() {
231        return Err(SimilarToError::DimensionMismatch {
232            a: a.len(),
233            b: b.len(),
234        });
235    }
236    let distance = metric.compute_distance(a, b);
237    match metric {
238        DistanceMetric::Cosine => cosine_similarity(a, b),
239        // compute_distance returns -dot (LanceDB convention: lower = more similar).
240        // Negate to recover the actual dot product as a similarity score.
241        DistanceMetric::Dot => Ok(-distance),
242        // L2 and all other metrics (#[non_exhaustive]): normalise via calculate_score.
243        _ => Ok(calculate_score(distance, metric)),
244    }
245}
246
247/// Convert a raw distance value into a normalised similarity score.
248///
249/// The conversion depends on the distance metric:
250/// - **Cosine**: `(2 - d) / 2` (LanceDB cosine distance ranges 0..2)
251/// - **Dot**: pass-through (already a similarity measure)
252/// - **L2** and others: `1 / (1 + d)`
253pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
254    match metric {
255        DistanceMetric::Cosine => (2.0 - distance) / 2.0,
256        DistanceMetric::Dot => distance,
257        _ => 1.0 / (1.0 + distance),
258    }
259}
260
261/// Normalize a BM25 score to [0, 1] using a saturation function.
262///
263/// `normalized = score / (score + fts_k)` where `fts_k` defaults to 1.0.
264pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
265    if score <= 0.0 {
266        return 0.0;
267    }
268    score / (score + fts_k)
269}
270
271/// Compute pure vector-vs-vector similarity (no storage access needed).
272///
273/// Both values must be `Value::List` of numbers or `Value::Vector`.
274///
275/// Uses f64 arithmetic throughout when both inputs are `Value::List`, preserving
276/// full precision for property-based vectors (e.g. in TCK and unit tests). For
277/// `Value::Vector` (pre-indexed f32 data) it falls back to the f32 path.
278pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
279    // NULL propagates: a NULL operand yields NULL, not an error. `values_to_array`
280    // renders `Value::Null` as an Arrow null, matching 3VL semantics.
281    if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
282        return Ok(Value::Null);
283    }
284    // Fast path: at least one input is a List — use f64 to avoid f32 precision loss.
285    let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
286    let f64_vecs = has_list
287        .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
288        .flatten();
289    if let Some((vec1, vec2)) = f64_vecs {
290        let sim = cosine_similarity_inner(&vec1, &vec2)?;
291        return Ok(Value::Float(sim));
292    }
293    // Fallback: f32 path for Value::Vector (indexed data already in f32).
294    let vec1 = value_to_f32_vec(v1)?;
295    let vec2 = value_to_f32_vec(v2)?;
296    let sim = cosine_similarity(&vec1, &vec2)?;
297    Ok(Value::Float(sim as f64))
298}
299
300/// Compute the raw cosine similarity (in f64) between two equal-length vectors,
301/// clamped to [-1, 1]. Returns 0 when either vector has zero magnitude.
302///
303/// Generic over the element type so both the f32 and f64 callers share one body;
304/// arithmetic is always performed in f64 to preserve precision.
305fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
306    if a.len() != b.len() {
307        return Err(SimilarToError::DimensionMismatch {
308            a: a.len(),
309            b: b.len(),
310        });
311    }
312    let mut dot = 0.0f64;
313    let mut mag1 = 0.0f64;
314    let mut mag2 = 0.0f64;
315    for (&x, &y) in a.iter().zip(b.iter()) {
316        let (x, y): (f64, f64) = (x.into(), y.into());
317        dot += x * y;
318        mag1 += x * x;
319        mag2 += y * y;
320    }
321    let mag1 = mag1.sqrt();
322    let mag2 = mag2.sqrt();
323    if mag1 == 0.0 || mag2 == 0.0 {
324        return Ok(0.0);
325    }
326    Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
327}
328
329/// Convert a Value to a numeric vector for vector operations.
330///
331/// Generic over the element type: `cast` maps the `f64` source value onto the
332/// target representation (identity for `f64`, narrowing for `f32`). List elements
333/// are read as full-precision `f64` before casting, while `Value::Vector` data is
334/// already `f32` and is widened to `f64` first.
335fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
336    match v {
337        Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
338        Value::List(list) => list
339            .iter()
340            .map(|v| {
341                v.as_f64()
342                    .map(&cast)
343                    .ok_or_else(|| SimilarToError::InvalidOption {
344                        message: "vector element must be a number".to_string(),
345                    })
346            })
347            .collect(),
348        _ => Err(SimilarToError::InvalidVectorValue {
349            actual: format!("{v:?}"),
350        }),
351    }
352}
353
354/// Convert a Value to a `Vec<f64>` for high-precision vector operations.
355fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
356    value_to_vec(v, |f| f)
357}
358
359/// Convert a Value to a `Vec<f32>` for vector operations.
360pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
361    value_to_vec(v, |f| f as f32)
362}
363
364/// Validate options against the number of sources.
365pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
366    if let Some(ref weights) = opts.weights {
367        if weights.len() != num_sources {
368            return Err(SimilarToError::WeightsLengthMismatch {
369                weights_len: weights.len(),
370                sources_len: num_sources,
371            });
372        }
373        let sum: f32 = weights.iter().sum();
374        if (sum - 1.0).abs() > 0.01 {
375            return Err(SimilarToError::WeightsNotNormalized { sum });
376        }
377    }
378    Ok(())
379}
380
381/// Validate per-pair type compatibility.
382///
383/// Returns an error if a Vector query is paired with an FTS source,
384/// or a String query is paired with a Vector source that has no embedding config.
385pub fn validate_pair(
386    source_type: &SourceType,
387    query_is_vector: bool,
388    query_is_string: bool,
389    source_index: usize,
390) -> Result<(), SimilarToError> {
391    match source_type {
392        SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
393        SourceType::Vector {
394            has_embedding_config: false,
395            ..
396        } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
397        _ => Ok(()),
398    }
399}
400
401/// Fuse multiple per-source scores into a single score.
402pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
403    if scores.len() == 1 {
404        return Ok(scores[0]);
405    }
406
407    match opts.method {
408        FusionMethod::Weighted => {
409            let weights = opts
410                .weights
411                .as_ref()
412                .ok_or(SimilarToError::WeightsRequired)?;
413            Ok(fusion::fuse_weighted_multi(scores, weights))
414        }
415        FusionMethod::Rrf => {
416            // In point-computation context, RRF degenerates to equal-weight fusion.
417            // The caller (similar_to_expr.rs) already emits QueryWarning::RrfPointContext
418            // unconditionally when method == Rrf && num_sources > 1.
419            let (score, _) = fusion::fuse_rrf_point(scores);
420            Ok(score)
421        }
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use std::collections::HashMap;
428
429    use super::*;
430
431    #[test]
432    fn test_parse_options_default() {
433        let opts = parse_options(&Value::Null).unwrap();
434        assert_eq!(opts.method, FusionMethod::Rrf);
435        assert_eq!(opts.k, 60);
436        assert!((opts.fts_k - 1.0).abs() < 1e-6);
437        assert!(opts.weights.is_none());
438    }
439
440    #[test]
441    fn test_parse_options_weighted() {
442        let mut map = HashMap::new();
443        map.insert("method".to_string(), Value::String("weighted".to_string()));
444        map.insert(
445            "weights".to_string(),
446            Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
447        );
448        let opts = parse_options(&Value::Map(map)).unwrap();
449        assert_eq!(opts.method, FusionMethod::Weighted);
450        let weights = opts.weights.unwrap();
451        assert!((weights[0] - 0.7).abs() < 1e-6);
452        assert!((weights[1] - 0.3).abs() < 1e-6);
453    }
454
455    #[test]
456    fn test_parse_options_rrf_with_k() {
457        let mut map = HashMap::new();
458        map.insert("method".to_string(), Value::String("rrf".to_string()));
459        map.insert("k".to_string(), Value::Int(30));
460        let opts = parse_options(&Value::Map(map)).unwrap();
461        assert_eq!(opts.method, FusionMethod::Rrf);
462        assert_eq!(opts.k, 30);
463    }
464
465    #[test]
466    fn test_parse_options_fts_k() {
467        let mut map = HashMap::new();
468        map.insert("fts_k".to_string(), Value::Float(2.0));
469        let opts = parse_options(&Value::Map(map)).unwrap();
470        assert!((opts.fts_k - 2.0).abs() < 1e-6);
471    }
472
473    #[test]
474    fn test_parse_options_invalid_method() {
475        let mut map = HashMap::new();
476        map.insert("method".to_string(), Value::String("invalid".to_string()));
477        assert!(parse_options(&Value::Map(map)).is_err());
478    }
479
480    #[test]
481    fn test_cosine_similarity_identical() {
482        let v = vec![1.0, 0.0, 0.0];
483        let sim = cosine_similarity(&v, &v).unwrap();
484        assert!((sim - 1.0).abs() < 1e-6);
485    }
486
487    #[test]
488    fn test_cosine_similarity_orthogonal() {
489        let a = vec![1.0, 0.0];
490        let b = vec![0.0, 1.0];
491        let sim = cosine_similarity(&a, &b).unwrap();
492        assert!((sim - 0.0).abs() < 1e-6);
493    }
494
495    #[test]
496    fn test_cosine_similarity_opposite() {
497        let a = vec![1.0, 0.0];
498        let b = vec![-1.0, 0.0];
499        let sim = cosine_similarity(&a, &b).unwrap();
500        assert!((sim - (-1.0)).abs() < 1e-6);
501    }
502
503    #[test]
504    fn test_cosine_similarity_dimension_mismatch() {
505        let a = vec![1.0, 0.0];
506        let b = vec![1.0, 0.0, 0.0];
507        assert!(cosine_similarity(&a, &b).is_err());
508    }
509
510    #[test]
511    fn test_normalize_bm25() {
512        assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
513        assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
514        assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
515        assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
516    }
517
518    #[test]
519    fn test_normalize_bm25_custom_k() {
520        assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
521    }
522
523    #[test]
524    fn test_eval_similar_to_pure() {
525        let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
526        let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
527        let result = eval_similar_to_pure(&v1, &v2).unwrap();
528        match result {
529            Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
530            _ => panic!("Expected Float"),
531        }
532    }
533
534    #[test]
535    fn test_eval_similar_to_pure_vector_type() {
536        let v1 = Value::Vector(vec![1.0, 0.0]);
537        let v2 = Value::Vector(vec![0.0, 1.0]);
538        let result = eval_similar_to_pure(&v1, &v2).unwrap();
539        match result {
540            Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
541            _ => panic!("Expected Float"),
542        }
543    }
544
545    #[test]
546    fn test_validate_options_weights_length() {
547        let opts = SimilarToOptions {
548            weights: Some(vec![0.5]),
549            ..Default::default()
550        };
551        assert!(validate_options(&opts, 2).is_err());
552    }
553
554    #[test]
555    fn test_validate_options_weights_sum() {
556        let opts = SimilarToOptions {
557            weights: Some(vec![0.5, 0.3]),
558            ..Default::default()
559        };
560        assert!(validate_options(&opts, 2).is_err());
561    }
562
563    #[test]
564    fn test_validate_options_ok() {
565        let opts = SimilarToOptions {
566            weights: Some(vec![0.7, 0.3]),
567            ..Default::default()
568        };
569        assert!(validate_options(&opts, 2).is_ok());
570    }
571
572    #[test]
573    fn test_validate_pair_fts_vector_query() {
574        assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
575    }
576
577    #[test]
578    fn test_validate_pair_vector_string_no_embed() {
579        let st = SourceType::Vector {
580            metric: DistanceMetric::Cosine,
581            has_embedding_config: false,
582        };
583        assert!(validate_pair(&st, false, true, 0).is_err());
584    }
585
586    #[test]
587    fn test_validate_pair_vector_string_with_embed() {
588        let st = SourceType::Vector {
589            metric: DistanceMetric::Cosine,
590            has_embedding_config: true,
591        };
592        assert!(validate_pair(&st, false, true, 0).is_ok());
593    }
594
595    #[test]
596    fn test_validate_pair_vector_vector() {
597        let st = SourceType::Vector {
598            metric: DistanceMetric::Cosine,
599            has_embedding_config: false,
600        };
601        assert!(validate_pair(&st, true, false, 0).is_ok());
602    }
603
604    #[test]
605    fn test_validate_pair_fts_string() {
606        assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
607    }
608
609    #[test]
610    fn test_fuse_scores_single() {
611        let opts = SimilarToOptions::default();
612        let score = fuse_scores(&[0.8], &opts).unwrap();
613        assert!((score - 0.8).abs() < 1e-6);
614    }
615
616    #[test]
617    fn test_fuse_scores_weighted() {
618        let opts = SimilarToOptions {
619            method: FusionMethod::Weighted,
620            weights: Some(vec![0.7, 0.3]),
621            ..Default::default()
622        };
623        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
624        assert!((score - 0.74).abs() < 1e-6);
625    }
626
627    #[test]
628    fn test_fuse_scores_rrf_fallback() {
629        let opts = SimilarToOptions::default();
630        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
631        // RRF in point context falls back to equal weights: (0.8 + 0.6) / 2 = 0.7
632        assert!((score - 0.7).abs() < 1e-6);
633    }
634
635    // -----------------------------------------------------------------------
636    // score_vectors() tests
637    // -----------------------------------------------------------------------
638
639    #[test]
640    fn test_score_vectors_cosine_identical() {
641        let v = vec![1.0, 0.0, 0.0];
642        let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
643        assert!((score - 1.0).abs() < 1e-6);
644    }
645
646    #[test]
647    fn test_score_vectors_cosine_matches_raw() {
648        // score_vectors with Cosine delegates to cosine_similarity
649        let a = vec![1.0, 0.0, 0.0];
650        let b = vec![0.8, 0.6, 0.0];
651        let raw = cosine_similarity(&a, &b).unwrap();
652        let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
653        assert!((raw - scored).abs() < 1e-6);
654    }
655
656    #[test]
657    fn test_score_vectors_l2() {
658        // [1,0,0] vs [0,1,0]: L2 squared distance = 2, score = 1/(1+2) ≈ 0.333
659        let a = vec![1.0, 0.0, 0.0];
660        let b = vec![0.0, 1.0, 0.0];
661        let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
662        assert!((score - 1.0 / 3.0).abs() < 1e-5);
663    }
664
665    #[test]
666    fn test_score_vectors_l2_identical() {
667        let v = vec![1.0, 0.0, 0.0];
668        let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
669        assert!((score - 1.0).abs() < 1e-6);
670    }
671
672    #[test]
673    fn test_score_vectors_dot() {
674        // [1,0,0] dot [0.8,0.6,0] = 0.8
675        let a = vec![1.0, 0.0, 0.0];
676        let b = vec![0.8, 0.6, 0.0];
677        let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
678        assert!((score - 0.8).abs() < 1e-6);
679    }
680
681    #[test]
682    fn test_score_vectors_dot_identical() {
683        let v = vec![1.0, 0.0, 0.0];
684        let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
685        assert!((score - 1.0).abs() < 1e-6);
686    }
687
688    #[test]
689    fn test_score_vectors_dimension_mismatch() {
690        let a = vec![1.0, 0.0];
691        let b = vec![1.0, 0.0, 0.0];
692        assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
693    }
694}