Skip to main content

uni_query_functions/
similar_to.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `similar_to()` expression function — unified similarity scoring.
5//!
6//! Dispatches to vector cosine similarity, BM25 full-text scoring, or
7//! hybrid fusion based on schema types. Returns a float in `[0, 1]`
8//! where higher means more similar.
9//!
10//! This is a **point computation** (score one bound node), not a search
11//! (top-K index scan). It works in WHERE, RETURN, WITH, ORDER BY, and
12//! Locy rule bodies.
13
14use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20/// Named error types for `similar_to()` validation failures.
21#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23    #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24    NoIndex { label: String, property: String },
25
26    #[error(
27        "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28    )]
29    TypeMismatch { source_index: usize },
30
31    #[error(
32        "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33    )]
34    NoEmbeddingConfig { source_index: usize },
35
36    #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37    WeightsLengthMismatch {
38        weights_len: usize,
39        sources_len: usize,
40    },
41
42    #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43    WeightsNotNormalized { sum: f32 },
44
45    #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46    InvalidMethod { method: String },
47
48    #[error("similar_to: {message}")]
49    InvalidOption { message: String },
50
51    #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52    DimensionMismatch { a: usize, b: usize },
53
54    #[error("similar_to: expected vector or list of numbers, got {actual}")]
55    InvalidVectorValue { actual: String },
56
57    #[error("similar_to: weighted fusion requires 'weights' option")]
58    WeightsRequired,
59
60    #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61    InvalidArity { count: usize },
62
63    #[error("similar_to requires GraphExecutionContext")]
64    NoGraphContext,
65}
66
67/// Fusion method for multi-source scoring.
68#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70    /// Reciprocal Rank Fusion (default). Falls back to equal-weight
71    /// fusion in point-computation context.
72    #[default]
73    Rrf,
74    /// Weighted sum of per-source scores.
75    Weighted,
76}
77
78/// Options for `similar_to()` controlling fusion and scoring behavior.
79#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81    /// Fusion algorithm when multiple sources are present.
82    pub method: FusionMethod,
83    /// Per-source weights for weighted fusion. Must sum to 1.0.
84    pub weights: Option<Vec<f32>>,
85    /// RRF constant k (default 60).
86    pub k: usize,
87    /// BM25 saturation constant for FTS normalization (default 1.0).
88    pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92    fn default() -> Self {
93        Self {
94            method: FusionMethod::Rrf,
95            weights: None,
96            k: 60,
97            fts_k: 1.0,
98        }
99    }
100}
101
102/// Parse options from a `Value::Map`.
103pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104    let map = match value {
105        Value::Map(m) => m,
106        Value::Null => return Ok(SimilarToOptions::default()),
107        _ => {
108            return Err(SimilarToError::InvalidOption {
109                message: format!("options must be a map, got {:?}", value),
110            });
111        }
112    };
113
114    let mut opts = SimilarToOptions::default();
115
116    if let Some(method_val) = map.get("method") {
117        match method_val.as_str() {
118            Some("rrf") => opts.method = FusionMethod::Rrf,
119            Some("weighted") => opts.method = FusionMethod::Weighted,
120            Some(other) => {
121                return Err(SimilarToError::InvalidMethod {
122                    method: other.to_string(),
123                });
124            }
125            None => {
126                return Err(SimilarToError::InvalidOption {
127                    message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128                });
129            }
130        }
131    }
132
133    if let Some(weights_val) = map.get("weights") {
134        match weights_val {
135            Value::List(list) => {
136                let weights: Result<Vec<f32>, SimilarToError> = list
137                    .iter()
138                    .map(|v| {
139                        v.as_f64()
140                            .map(|f| f as f32)
141                            .ok_or_else(|| SimilarToError::InvalidOption {
142                                message: "weight must be a number".to_string(),
143                            })
144                    })
145                    .collect();
146                opts.weights = Some(weights?);
147            }
148            _ => {
149                return Err(SimilarToError::InvalidOption {
150                    message: "'weights' must be a list of numbers".to_string(),
151                });
152            }
153        }
154    }
155
156    if let Some(k_val) = map.get("k") {
157        opts.k = k_val
158            .as_i64()
159            .ok_or_else(|| SimilarToError::InvalidOption {
160                message: "'k' must be an integer".to_string(),
161            })? as usize;
162    }
163
164    if let Some(fts_k_val) = map.get("fts_k") {
165        opts.fts_k = fts_k_val
166            .as_f64()
167            .ok_or_else(|| SimilarToError::InvalidOption {
168                message: "'fts_k' must be a number".to_string(),
169            })? as f32;
170    }
171
172    Ok(opts)
173}
174
175/// What type of source a property represents.
176#[derive(Debug, Clone)]
177pub enum SourceType {
178    /// Vector property with a vector index.
179    Vector {
180        metric: DistanceMetric,
181        has_embedding_config: bool,
182    },
183    /// String property with a full-text index.
184    Fts,
185}
186
187/// Resolve the source type for a property given the schema.
188pub fn resolve_source_type(
189    schema: &Schema,
190    label: &str,
191    property: &str,
192) -> Result<SourceType, SimilarToError> {
193    // Check vector index first
194    if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195        return Ok(SourceType::Vector {
196            metric: vec_config.metric.clone(),
197            has_embedding_config: vec_config.embedding_config.is_some(),
198        });
199    }
200
201    // Check full-text index
202    if schema
203        .fulltext_index_for_property(label, property)
204        .is_some()
205    {
206        return Ok(SourceType::Fts);
207    }
208
209    Err(SimilarToError::NoIndex {
210        label: label.to_string(),
211        property: property.to_string(),
212    })
213}
214
215/// Compute cosine similarity between two vectors, returning a score in [-1, 1].
216///
217/// Arithmetic is performed in f64 (see `cosine_similarity_inner`) and the
218/// clamped result is narrowed to f32.
219pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220    cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223/// Score two vectors using the specified distance metric, returning a similarity
224/// score where higher means more similar.
225///
226/// - **Cosine**: raw cosine similarity in \[-1, 1\] (delegates to [`cosine_similarity`]).
227/// - **L2**: `1 / (1 + d²)` where d² is squared Euclidean distance; range (0, 1\].
228/// - **Dot**: raw dot product (for normalised vectors equals cosine similarity).
229pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230    if a.len() != b.len() {
231        return Err(SimilarToError::DimensionMismatch {
232            a: a.len(),
233            b: b.len(),
234        });
235    }
236    let distance = metric.compute_distance(a, b);
237    match metric {
238        DistanceMetric::Cosine => cosine_similarity(a, b),
239        // compute_distance returns -dot (LanceDB convention: lower = more similar).
240        // Negate to recover the actual dot product as a similarity score.
241        DistanceMetric::Dot => Ok(-distance),
242        // L2 and all other metrics (#[non_exhaustive]): normalise via calculate_score.
243        _ => Ok(calculate_score(distance, metric)),
244    }
245}
246
247/// Convert a raw distance value into a normalised similarity score.
248///
249/// The conversion depends on the distance metric:
250/// - **Cosine**: `(2 - d) / 2` (LanceDB cosine distance ranges 0..2)
251/// - **Dot**: pass-through (already a similarity measure)
252/// - **L2** and others: `1 / (1 + d)`
253pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
254    match metric {
255        DistanceMetric::Cosine => (2.0 - distance) / 2.0,
256        DistanceMetric::Dot => distance,
257        _ => 1.0 / (1.0 + distance),
258    }
259}
260
261/// Normalize a BM25 score to [0, 1] using a saturation function.
262///
263/// `normalized = score / (score + fts_k)` where `fts_k` defaults to 1.0.
264pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
265    if score <= 0.0 {
266        return 0.0;
267    }
268    score / (score + fts_k)
269}
270
271/// Compute pure vector-vs-vector similarity (no storage access needed).
272///
273/// Both values must be `Value::List` of numbers or `Value::Vector`.
274///
275/// Uses f64 arithmetic throughout when both inputs are `Value::List`, preserving
276/// full precision for property-based vectors (e.g. in TCK and unit tests). For
277/// `Value::Vector` (pre-indexed f32 data) it falls back to the f32 path.
278pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
279    // Fast path: at least one input is a List — use f64 to avoid f32 precision loss.
280    let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
281    let f64_vecs = has_list
282        .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
283        .flatten();
284    if let Some((vec1, vec2)) = f64_vecs {
285        let sim = cosine_similarity_inner(&vec1, &vec2)?;
286        return Ok(Value::Float(sim));
287    }
288    // Fallback: f32 path for Value::Vector (indexed data already in f32).
289    let vec1 = value_to_f32_vec(v1)?;
290    let vec2 = value_to_f32_vec(v2)?;
291    let sim = cosine_similarity(&vec1, &vec2)?;
292    Ok(Value::Float(sim as f64))
293}
294
295/// Compute the raw cosine similarity (in f64) between two equal-length vectors,
296/// clamped to [-1, 1]. Returns 0 when either vector has zero magnitude.
297///
298/// Generic over the element type so both the f32 and f64 callers share one body;
299/// arithmetic is always performed in f64 to preserve precision.
300fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
301    if a.len() != b.len() {
302        return Err(SimilarToError::DimensionMismatch {
303            a: a.len(),
304            b: b.len(),
305        });
306    }
307    let mut dot = 0.0f64;
308    let mut mag1 = 0.0f64;
309    let mut mag2 = 0.0f64;
310    for (&x, &y) in a.iter().zip(b.iter()) {
311        let (x, y): (f64, f64) = (x.into(), y.into());
312        dot += x * y;
313        mag1 += x * x;
314        mag2 += y * y;
315    }
316    let mag1 = mag1.sqrt();
317    let mag2 = mag2.sqrt();
318    if mag1 == 0.0 || mag2 == 0.0 {
319        return Ok(0.0);
320    }
321    Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
322}
323
324/// Convert a Value to a numeric vector for vector operations.
325///
326/// Generic over the element type: `cast` maps the `f64` source value onto the
327/// target representation (identity for `f64`, narrowing for `f32`). List elements
328/// are read as full-precision `f64` before casting, while `Value::Vector` data is
329/// already `f32` and is widened to `f64` first.
330fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
331    match v {
332        Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
333        Value::List(list) => list
334            .iter()
335            .map(|v| {
336                v.as_f64()
337                    .map(&cast)
338                    .ok_or_else(|| SimilarToError::InvalidOption {
339                        message: "vector element must be a number".to_string(),
340                    })
341            })
342            .collect(),
343        _ => Err(SimilarToError::InvalidVectorValue {
344            actual: format!("{v:?}"),
345        }),
346    }
347}
348
349/// Convert a Value to a `Vec<f64>` for high-precision vector operations.
350fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
351    value_to_vec(v, |f| f)
352}
353
354/// Convert a Value to a `Vec<f32>` for vector operations.
355pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
356    value_to_vec(v, |f| f as f32)
357}
358
359/// Validate options against the number of sources.
360pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
361    if let Some(ref weights) = opts.weights {
362        if weights.len() != num_sources {
363            return Err(SimilarToError::WeightsLengthMismatch {
364                weights_len: weights.len(),
365                sources_len: num_sources,
366            });
367        }
368        let sum: f32 = weights.iter().sum();
369        if (sum - 1.0).abs() > 0.01 {
370            return Err(SimilarToError::WeightsNotNormalized { sum });
371        }
372    }
373    Ok(())
374}
375
376/// Validate per-pair type compatibility.
377///
378/// Returns an error if a Vector query is paired with an FTS source,
379/// or a String query is paired with a Vector source that has no embedding config.
380pub fn validate_pair(
381    source_type: &SourceType,
382    query_is_vector: bool,
383    query_is_string: bool,
384    source_index: usize,
385) -> Result<(), SimilarToError> {
386    match source_type {
387        SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
388        SourceType::Vector {
389            has_embedding_config: false,
390            ..
391        } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
392        _ => Ok(()),
393    }
394}
395
396/// Fuse multiple per-source scores into a single score.
397pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
398    if scores.len() == 1 {
399        return Ok(scores[0]);
400    }
401
402    match opts.method {
403        FusionMethod::Weighted => {
404            let weights = opts
405                .weights
406                .as_ref()
407                .ok_or(SimilarToError::WeightsRequired)?;
408            Ok(fusion::fuse_weighted_multi(scores, weights))
409        }
410        FusionMethod::Rrf => {
411            // In point-computation context, RRF degenerates to equal-weight fusion.
412            // The caller (similar_to_expr.rs) already emits QueryWarning::RrfPointContext
413            // unconditionally when method == Rrf && num_sources > 1.
414            let (score, _) = fusion::fuse_rrf_point(scores);
415            Ok(score)
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use std::collections::HashMap;
423
424    use super::*;
425
426    #[test]
427    fn test_parse_options_default() {
428        let opts = parse_options(&Value::Null).unwrap();
429        assert_eq!(opts.method, FusionMethod::Rrf);
430        assert_eq!(opts.k, 60);
431        assert!((opts.fts_k - 1.0).abs() < 1e-6);
432        assert!(opts.weights.is_none());
433    }
434
435    #[test]
436    fn test_parse_options_weighted() {
437        let mut map = HashMap::new();
438        map.insert("method".to_string(), Value::String("weighted".to_string()));
439        map.insert(
440            "weights".to_string(),
441            Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
442        );
443        let opts = parse_options(&Value::Map(map)).unwrap();
444        assert_eq!(opts.method, FusionMethod::Weighted);
445        let weights = opts.weights.unwrap();
446        assert!((weights[0] - 0.7).abs() < 1e-6);
447        assert!((weights[1] - 0.3).abs() < 1e-6);
448    }
449
450    #[test]
451    fn test_parse_options_rrf_with_k() {
452        let mut map = HashMap::new();
453        map.insert("method".to_string(), Value::String("rrf".to_string()));
454        map.insert("k".to_string(), Value::Int(30));
455        let opts = parse_options(&Value::Map(map)).unwrap();
456        assert_eq!(opts.method, FusionMethod::Rrf);
457        assert_eq!(opts.k, 30);
458    }
459
460    #[test]
461    fn test_parse_options_fts_k() {
462        let mut map = HashMap::new();
463        map.insert("fts_k".to_string(), Value::Float(2.0));
464        let opts = parse_options(&Value::Map(map)).unwrap();
465        assert!((opts.fts_k - 2.0).abs() < 1e-6);
466    }
467
468    #[test]
469    fn test_parse_options_invalid_method() {
470        let mut map = HashMap::new();
471        map.insert("method".to_string(), Value::String("invalid".to_string()));
472        assert!(parse_options(&Value::Map(map)).is_err());
473    }
474
475    #[test]
476    fn test_cosine_similarity_identical() {
477        let v = vec![1.0, 0.0, 0.0];
478        let sim = cosine_similarity(&v, &v).unwrap();
479        assert!((sim - 1.0).abs() < 1e-6);
480    }
481
482    #[test]
483    fn test_cosine_similarity_orthogonal() {
484        let a = vec![1.0, 0.0];
485        let b = vec![0.0, 1.0];
486        let sim = cosine_similarity(&a, &b).unwrap();
487        assert!((sim - 0.0).abs() < 1e-6);
488    }
489
490    #[test]
491    fn test_cosine_similarity_opposite() {
492        let a = vec![1.0, 0.0];
493        let b = vec![-1.0, 0.0];
494        let sim = cosine_similarity(&a, &b).unwrap();
495        assert!((sim - (-1.0)).abs() < 1e-6);
496    }
497
498    #[test]
499    fn test_cosine_similarity_dimension_mismatch() {
500        let a = vec![1.0, 0.0];
501        let b = vec![1.0, 0.0, 0.0];
502        assert!(cosine_similarity(&a, &b).is_err());
503    }
504
505    #[test]
506    fn test_normalize_bm25() {
507        assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
508        assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
509        assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
510        assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
511    }
512
513    #[test]
514    fn test_normalize_bm25_custom_k() {
515        assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
516    }
517
518    #[test]
519    fn test_eval_similar_to_pure() {
520        let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
521        let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
522        let result = eval_similar_to_pure(&v1, &v2).unwrap();
523        match result {
524            Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
525            _ => panic!("Expected Float"),
526        }
527    }
528
529    #[test]
530    fn test_eval_similar_to_pure_vector_type() {
531        let v1 = Value::Vector(vec![1.0, 0.0]);
532        let v2 = Value::Vector(vec![0.0, 1.0]);
533        let result = eval_similar_to_pure(&v1, &v2).unwrap();
534        match result {
535            Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
536            _ => panic!("Expected Float"),
537        }
538    }
539
540    #[test]
541    fn test_validate_options_weights_length() {
542        let opts = SimilarToOptions {
543            weights: Some(vec![0.5]),
544            ..Default::default()
545        };
546        assert!(validate_options(&opts, 2).is_err());
547    }
548
549    #[test]
550    fn test_validate_options_weights_sum() {
551        let opts = SimilarToOptions {
552            weights: Some(vec![0.5, 0.3]),
553            ..Default::default()
554        };
555        assert!(validate_options(&opts, 2).is_err());
556    }
557
558    #[test]
559    fn test_validate_options_ok() {
560        let opts = SimilarToOptions {
561            weights: Some(vec![0.7, 0.3]),
562            ..Default::default()
563        };
564        assert!(validate_options(&opts, 2).is_ok());
565    }
566
567    #[test]
568    fn test_validate_pair_fts_vector_query() {
569        assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
570    }
571
572    #[test]
573    fn test_validate_pair_vector_string_no_embed() {
574        let st = SourceType::Vector {
575            metric: DistanceMetric::Cosine,
576            has_embedding_config: false,
577        };
578        assert!(validate_pair(&st, false, true, 0).is_err());
579    }
580
581    #[test]
582    fn test_validate_pair_vector_string_with_embed() {
583        let st = SourceType::Vector {
584            metric: DistanceMetric::Cosine,
585            has_embedding_config: true,
586        };
587        assert!(validate_pair(&st, false, true, 0).is_ok());
588    }
589
590    #[test]
591    fn test_validate_pair_vector_vector() {
592        let st = SourceType::Vector {
593            metric: DistanceMetric::Cosine,
594            has_embedding_config: false,
595        };
596        assert!(validate_pair(&st, true, false, 0).is_ok());
597    }
598
599    #[test]
600    fn test_validate_pair_fts_string() {
601        assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
602    }
603
604    #[test]
605    fn test_fuse_scores_single() {
606        let opts = SimilarToOptions::default();
607        let score = fuse_scores(&[0.8], &opts).unwrap();
608        assert!((score - 0.8).abs() < 1e-6);
609    }
610
611    #[test]
612    fn test_fuse_scores_weighted() {
613        let opts = SimilarToOptions {
614            method: FusionMethod::Weighted,
615            weights: Some(vec![0.7, 0.3]),
616            ..Default::default()
617        };
618        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
619        assert!((score - 0.74).abs() < 1e-6);
620    }
621
622    #[test]
623    fn test_fuse_scores_rrf_fallback() {
624        let opts = SimilarToOptions::default();
625        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
626        // RRF in point context falls back to equal weights: (0.8 + 0.6) / 2 = 0.7
627        assert!((score - 0.7).abs() < 1e-6);
628    }
629
630    // -----------------------------------------------------------------------
631    // score_vectors() tests
632    // -----------------------------------------------------------------------
633
634    #[test]
635    fn test_score_vectors_cosine_identical() {
636        let v = vec![1.0, 0.0, 0.0];
637        let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
638        assert!((score - 1.0).abs() < 1e-6);
639    }
640
641    #[test]
642    fn test_score_vectors_cosine_matches_raw() {
643        // score_vectors with Cosine delegates to cosine_similarity
644        let a = vec![1.0, 0.0, 0.0];
645        let b = vec![0.8, 0.6, 0.0];
646        let raw = cosine_similarity(&a, &b).unwrap();
647        let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
648        assert!((raw - scored).abs() < 1e-6);
649    }
650
651    #[test]
652    fn test_score_vectors_l2() {
653        // [1,0,0] vs [0,1,0]: L2 squared distance = 2, score = 1/(1+2) ≈ 0.333
654        let a = vec![1.0, 0.0, 0.0];
655        let b = vec![0.0, 1.0, 0.0];
656        let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
657        assert!((score - 1.0 / 3.0).abs() < 1e-5);
658    }
659
660    #[test]
661    fn test_score_vectors_l2_identical() {
662        let v = vec![1.0, 0.0, 0.0];
663        let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
664        assert!((score - 1.0).abs() < 1e-6);
665    }
666
667    #[test]
668    fn test_score_vectors_dot() {
669        // [1,0,0] dot [0.8,0.6,0] = 0.8
670        let a = vec![1.0, 0.0, 0.0];
671        let b = vec![0.8, 0.6, 0.0];
672        let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
673        assert!((score - 0.8).abs() < 1e-6);
674    }
675
676    #[test]
677    fn test_score_vectors_dot_identical() {
678        let v = vec![1.0, 0.0, 0.0];
679        let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
680        assert!((score - 1.0).abs() < 1e-6);
681    }
682
683    #[test]
684    fn test_score_vectors_dimension_mismatch() {
685        let a = vec![1.0, 0.0];
686        let b = vec![1.0, 0.0, 0.0];
687        assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
688    }
689}