Skip to main content

uni_query/query/
similar_to.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `similar_to()` expression function — unified similarity scoring.
5//!
6//! Dispatches to vector cosine similarity, BM25 full-text scoring, or
7//! hybrid fusion based on schema types. Returns a float in `[0, 1]`
8//! where higher means more similar.
9//!
10//! This is a **point computation** (score one bound node), not a search
11//! (top-K index scan). It works in WHERE, RETURN, WITH, ORDER BY, and
12//! Locy rule bodies.
13
14use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::query::df_graph::common::calculate_score;
19use crate::query::fusion;
20
21/// Named error types for `similar_to()` validation failures.
22#[derive(Debug, thiserror::Error)]
23pub enum SimilarToError {
24    #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
25    NoIndex { label: String, property: String },
26
27    #[error(
28        "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
29    )]
30    TypeMismatch { source_index: usize },
31
32    #[error(
33        "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
34    )]
35    NoEmbeddingConfig { source_index: usize },
36
37    #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
38    WeightsLengthMismatch {
39        weights_len: usize,
40        sources_len: usize,
41    },
42
43    #[error("similar_to: weights must sum to 1.0 (got {sum})")]
44    WeightsNotNormalized { sum: f32 },
45
46    #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
47    InvalidMethod { method: String },
48
49    #[error("similar_to: {message}")]
50    InvalidOption { message: String },
51
52    #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
53    DimensionMismatch { a: usize, b: usize },
54
55    #[error("similar_to: expected vector or list of numbers, got {actual}")]
56    InvalidVectorValue { actual: String },
57
58    #[error("similar_to: weighted fusion requires 'weights' option")]
59    WeightsRequired,
60
61    #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
62    InvalidArity { count: usize },
63
64    #[error("similar_to requires GraphExecutionContext")]
65    NoGraphContext,
66}
67
68/// Fusion method for multi-source scoring.
69#[derive(Debug, Clone, Default, PartialEq)]
70pub enum FusionMethod {
71    /// Reciprocal Rank Fusion (default). Falls back to equal-weight
72    /// fusion in point-computation context.
73    #[default]
74    Rrf,
75    /// Weighted sum of per-source scores.
76    Weighted,
77}
78
79/// Options for `similar_to()` controlling fusion and scoring behavior.
80#[derive(Debug, Clone)]
81pub struct SimilarToOptions {
82    /// Fusion algorithm when multiple sources are present.
83    pub method: FusionMethod,
84    /// Per-source weights for weighted fusion. Must sum to 1.0.
85    pub weights: Option<Vec<f32>>,
86    /// RRF constant k (default 60).
87    pub k: usize,
88    /// BM25 saturation constant for FTS normalization (default 1.0).
89    pub fts_k: f32,
90}
91
92impl Default for SimilarToOptions {
93    fn default() -> Self {
94        Self {
95            method: FusionMethod::Rrf,
96            weights: None,
97            k: 60,
98            fts_k: 1.0,
99        }
100    }
101}
102
103/// Parse options from a `Value::Map`.
104pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
105    let map = match value {
106        Value::Map(m) => m,
107        Value::Null => return Ok(SimilarToOptions::default()),
108        _ => {
109            return Err(SimilarToError::InvalidOption {
110                message: format!("options must be a map, got {:?}", value),
111            });
112        }
113    };
114
115    let mut opts = SimilarToOptions::default();
116
117    if let Some(method_val) = map.get("method") {
118        match method_val.as_str() {
119            Some("rrf") => opts.method = FusionMethod::Rrf,
120            Some("weighted") => opts.method = FusionMethod::Weighted,
121            Some(other) => {
122                return Err(SimilarToError::InvalidMethod {
123                    method: other.to_string(),
124                });
125            }
126            None => {
127                return Err(SimilarToError::InvalidOption {
128                    message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
129                });
130            }
131        }
132    }
133
134    if let Some(weights_val) = map.get("weights") {
135        match weights_val {
136            Value::List(list) => {
137                let weights: Result<Vec<f32>, SimilarToError> = list
138                    .iter()
139                    .map(|v| {
140                        v.as_f64()
141                            .map(|f| f as f32)
142                            .ok_or_else(|| SimilarToError::InvalidOption {
143                                message: "weight must be a number".to_string(),
144                            })
145                    })
146                    .collect();
147                opts.weights = Some(weights?);
148            }
149            _ => {
150                return Err(SimilarToError::InvalidOption {
151                    message: "'weights' must be a list of numbers".to_string(),
152                });
153            }
154        }
155    }
156
157    if let Some(k_val) = map.get("k") {
158        opts.k = k_val
159            .as_i64()
160            .ok_or_else(|| SimilarToError::InvalidOption {
161                message: "'k' must be an integer".to_string(),
162            })? as usize;
163    }
164
165    if let Some(fts_k_val) = map.get("fts_k") {
166        opts.fts_k = fts_k_val
167            .as_f64()
168            .ok_or_else(|| SimilarToError::InvalidOption {
169                message: "'fts_k' must be a number".to_string(),
170            })? as f32;
171    }
172
173    Ok(opts)
174}
175
176/// What type of source a property represents.
177#[derive(Debug, Clone)]
178pub enum SourceType {
179    /// Vector property with a vector index.
180    Vector {
181        metric: DistanceMetric,
182        has_embedding_config: bool,
183    },
184    /// String property with a full-text index.
185    Fts,
186}
187
188/// Resolve the source type for a property given the schema.
189pub fn resolve_source_type(
190    schema: &Schema,
191    label: &str,
192    property: &str,
193) -> Result<SourceType, SimilarToError> {
194    // Check vector index first
195    if let Some(vec_config) = schema.vector_index_for_property(label, property) {
196        return Ok(SourceType::Vector {
197            metric: vec_config.metric.clone(),
198            has_embedding_config: vec_config.embedding_config.is_some(),
199        });
200    }
201
202    // Check full-text index
203    if schema
204        .fulltext_index_for_property(label, property)
205        .is_some()
206    {
207        return Ok(SourceType::Fts);
208    }
209
210    Err(SimilarToError::NoIndex {
211        label: label.to_string(),
212        property: property.to_string(),
213    })
214}
215
216/// Compute cosine similarity between two vectors, returning a score in [0, 1].
217pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
218    if a.len() != b.len() {
219        return Err(SimilarToError::DimensionMismatch {
220            a: a.len(),
221            b: b.len(),
222        });
223    }
224
225    let mut dot = 0.0f64;
226    let mut mag1 = 0.0f64;
227    let mut mag2 = 0.0f64;
228    for (x, y) in a.iter().zip(b.iter()) {
229        let x = *x as f64;
230        let y = *y as f64;
231        dot += x * y;
232        mag1 += x * x;
233        mag2 += y * y;
234    }
235    let mag1 = mag1.sqrt();
236    let mag2 = mag2.sqrt();
237
238    if mag1 == 0.0 || mag2 == 0.0 {
239        return Ok(0.0);
240    }
241
242    // Cosine similarity in [-1, 1], map to [0, 1]
243    let sim = (dot / (mag1 * mag2)) as f32;
244    Ok(sim.clamp(-1.0, 1.0))
245}
246
247/// Score two vectors using the specified distance metric, returning a similarity
248/// score where higher means more similar.
249///
250/// - **Cosine**: raw cosine similarity in \[-1, 1\] (delegates to [`cosine_similarity`]).
251/// - **L2**: `1 / (1 + d²)` where d² is squared Euclidean distance; range (0, 1\].
252/// - **Dot**: raw dot product (for normalised vectors equals cosine similarity).
253pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
254    if a.len() != b.len() {
255        return Err(SimilarToError::DimensionMismatch {
256            a: a.len(),
257            b: b.len(),
258        });
259    }
260    let distance = metric.compute_distance(a, b);
261    match metric {
262        DistanceMetric::Cosine => cosine_similarity(a, b),
263        // compute_distance returns -dot (LanceDB convention: lower = more similar).
264        // Negate to recover the actual dot product as a similarity score.
265        DistanceMetric::Dot => Ok(-distance),
266        // L2 and all other metrics (#[non_exhaustive]): normalise via calculate_score.
267        _ => Ok(calculate_score(distance, metric)),
268    }
269}
270
271/// Normalize a BM25 score to [0, 1] using a saturation function.
272///
273/// `normalized = score / (score + fts_k)` where `fts_k` defaults to 1.0.
274pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
275    if score <= 0.0 {
276        return 0.0;
277    }
278    score / (score + fts_k)
279}
280
281/// Compute pure vector-vs-vector similarity (no storage access needed).
282///
283/// Both values must be `Value::List` of numbers or `Value::Vector`.
284///
285/// Uses f64 arithmetic throughout when both inputs are `Value::List`, preserving
286/// full precision for property-based vectors (e.g. in TCK and unit tests). For
287/// `Value::Vector` (pre-indexed f32 data) it falls back to the f32 path.
288pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
289    // Fast path: at least one input is a List — use f64 to avoid f32 precision loss.
290    let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
291    let f64_vecs = has_list
292        .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
293        .flatten();
294    if let Some((vec1, vec2)) = f64_vecs {
295        let sim = cosine_similarity_f64(&vec1, &vec2)?;
296        return Ok(Value::Float(sim));
297    }
298    // Fallback: f32 path for Value::Vector (indexed data already in f32).
299    let vec1 = value_to_f32_vec(v1)?;
300    let vec2 = value_to_f32_vec(v2)?;
301    let sim = cosine_similarity(&vec1, &vec2)?;
302    Ok(Value::Float(sim as f64))
303}
304
305/// Compute cosine similarity between two f64 vectors, returning a score in [-1, 1].
306fn cosine_similarity_f64(a: &[f64], b: &[f64]) -> Result<f64, SimilarToError> {
307    if a.len() != b.len() {
308        return Err(SimilarToError::DimensionMismatch {
309            a: a.len(),
310            b: b.len(),
311        });
312    }
313    let mut dot = 0.0f64;
314    let mut mag1 = 0.0f64;
315    let mut mag2 = 0.0f64;
316    for (x, y) in a.iter().zip(b.iter()) {
317        dot += x * y;
318        mag1 += x * x;
319        mag2 += y * y;
320    }
321    let mag1 = mag1.sqrt();
322    let mag2 = mag2.sqrt();
323    if mag1 == 0.0 || mag2 == 0.0 {
324        return Ok(0.0);
325    }
326    Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
327}
328
329/// Convert a Value to a `Vec<f64>` for high-precision vector operations.
330fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
331    match v {
332        Value::Vector(vec) => Ok(vec.iter().map(|&x| x as f64).collect()),
333        Value::List(list) => list
334            .iter()
335            .map(|v| {
336                v.as_f64().ok_or_else(|| SimilarToError::InvalidOption {
337                    message: "vector element must be a number".to_string(),
338                })
339            })
340            .collect(),
341        _ => Err(SimilarToError::InvalidVectorValue {
342            actual: format!("{v:?}"),
343        }),
344    }
345}
346
347/// Convert a Value to a `Vec<f32>` for vector operations.
348pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
349    match v {
350        Value::Vector(vec) => Ok(vec.clone()),
351        Value::List(list) => list
352            .iter()
353            .map(|v| {
354                v.as_f64()
355                    .map(|f| f as f32)
356                    .ok_or_else(|| SimilarToError::InvalidOption {
357                        message: "vector element must be a number".to_string(),
358                    })
359            })
360            .collect(),
361        _ => Err(SimilarToError::InvalidVectorValue {
362            actual: format!("{v:?}"),
363        }),
364    }
365}
366
367/// Validate options against the number of sources.
368pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
369    if let Some(ref weights) = opts.weights {
370        if weights.len() != num_sources {
371            return Err(SimilarToError::WeightsLengthMismatch {
372                weights_len: weights.len(),
373                sources_len: num_sources,
374            });
375        }
376        let sum: f32 = weights.iter().sum();
377        if (sum - 1.0).abs() > 0.01 {
378            return Err(SimilarToError::WeightsNotNormalized { sum });
379        }
380    }
381    Ok(())
382}
383
384/// Validate per-pair type compatibility.
385///
386/// Returns an error if a Vector query is paired with an FTS source,
387/// or a String query is paired with a Vector source that has no embedding config.
388pub fn validate_pair(
389    source_type: &SourceType,
390    query_is_vector: bool,
391    query_is_string: bool,
392    source_index: usize,
393) -> Result<(), SimilarToError> {
394    match source_type {
395        SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
396        SourceType::Vector {
397            has_embedding_config: false,
398            ..
399        } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
400        _ => Ok(()),
401    }
402}
403
404/// Fuse multiple per-source scores into a single score.
405pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
406    if scores.len() == 1 {
407        return Ok(scores[0]);
408    }
409
410    match opts.method {
411        FusionMethod::Weighted => {
412            let weights = opts
413                .weights
414                .as_ref()
415                .ok_or(SimilarToError::WeightsRequired)?;
416            Ok(fusion::fuse_weighted_multi(scores, weights))
417        }
418        FusionMethod::Rrf => {
419            // In point-computation context, RRF degenerates to equal-weight fusion.
420            // The caller (similar_to_expr.rs) already emits QueryWarning::RrfPointContext
421            // unconditionally when method == Rrf && num_sources > 1.
422            let (score, _) = fusion::fuse_rrf_point(scores);
423            Ok(score)
424        }
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use std::collections::HashMap;
431
432    use super::*;
433
434    #[test]
435    fn test_parse_options_default() {
436        let opts = parse_options(&Value::Null).unwrap();
437        assert_eq!(opts.method, FusionMethod::Rrf);
438        assert_eq!(opts.k, 60);
439        assert!((opts.fts_k - 1.0).abs() < 1e-6);
440        assert!(opts.weights.is_none());
441    }
442
443    #[test]
444    fn test_parse_options_weighted() {
445        let mut map = HashMap::new();
446        map.insert("method".to_string(), Value::String("weighted".to_string()));
447        map.insert(
448            "weights".to_string(),
449            Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
450        );
451        let opts = parse_options(&Value::Map(map)).unwrap();
452        assert_eq!(opts.method, FusionMethod::Weighted);
453        let weights = opts.weights.unwrap();
454        assert!((weights[0] - 0.7).abs() < 1e-6);
455        assert!((weights[1] - 0.3).abs() < 1e-6);
456    }
457
458    #[test]
459    fn test_parse_options_rrf_with_k() {
460        let mut map = HashMap::new();
461        map.insert("method".to_string(), Value::String("rrf".to_string()));
462        map.insert("k".to_string(), Value::Int(30));
463        let opts = parse_options(&Value::Map(map)).unwrap();
464        assert_eq!(opts.method, FusionMethod::Rrf);
465        assert_eq!(opts.k, 30);
466    }
467
468    #[test]
469    fn test_parse_options_fts_k() {
470        let mut map = HashMap::new();
471        map.insert("fts_k".to_string(), Value::Float(2.0));
472        let opts = parse_options(&Value::Map(map)).unwrap();
473        assert!((opts.fts_k - 2.0).abs() < 1e-6);
474    }
475
476    #[test]
477    fn test_parse_options_invalid_method() {
478        let mut map = HashMap::new();
479        map.insert("method".to_string(), Value::String("invalid".to_string()));
480        assert!(parse_options(&Value::Map(map)).is_err());
481    }
482
483    #[test]
484    fn test_cosine_similarity_identical() {
485        let v = vec![1.0, 0.0, 0.0];
486        let sim = cosine_similarity(&v, &v).unwrap();
487        assert!((sim - 1.0).abs() < 1e-6);
488    }
489
490    #[test]
491    fn test_cosine_similarity_orthogonal() {
492        let a = vec![1.0, 0.0];
493        let b = vec![0.0, 1.0];
494        let sim = cosine_similarity(&a, &b).unwrap();
495        assert!((sim - 0.0).abs() < 1e-6);
496    }
497
498    #[test]
499    fn test_cosine_similarity_opposite() {
500        let a = vec![1.0, 0.0];
501        let b = vec![-1.0, 0.0];
502        let sim = cosine_similarity(&a, &b).unwrap();
503        assert!((sim - (-1.0)).abs() < 1e-6);
504    }
505
506    #[test]
507    fn test_cosine_similarity_dimension_mismatch() {
508        let a = vec![1.0, 0.0];
509        let b = vec![1.0, 0.0, 0.0];
510        assert!(cosine_similarity(&a, &b).is_err());
511    }
512
513    #[test]
514    fn test_normalize_bm25() {
515        assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
516        assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
517        assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
518        assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
519    }
520
521    #[test]
522    fn test_normalize_bm25_custom_k() {
523        assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
524    }
525
526    #[test]
527    fn test_eval_similar_to_pure() {
528        let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
529        let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
530        let result = eval_similar_to_pure(&v1, &v2).unwrap();
531        match result {
532            Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
533            _ => panic!("Expected Float"),
534        }
535    }
536
537    #[test]
538    fn test_eval_similar_to_pure_vector_type() {
539        let v1 = Value::Vector(vec![1.0, 0.0]);
540        let v2 = Value::Vector(vec![0.0, 1.0]);
541        let result = eval_similar_to_pure(&v1, &v2).unwrap();
542        match result {
543            Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
544            _ => panic!("Expected Float"),
545        }
546    }
547
548    #[test]
549    fn test_validate_options_weights_length() {
550        let opts = SimilarToOptions {
551            weights: Some(vec![0.5]),
552            ..Default::default()
553        };
554        assert!(validate_options(&opts, 2).is_err());
555    }
556
557    #[test]
558    fn test_validate_options_weights_sum() {
559        let opts = SimilarToOptions {
560            weights: Some(vec![0.5, 0.3]),
561            ..Default::default()
562        };
563        assert!(validate_options(&opts, 2).is_err());
564    }
565
566    #[test]
567    fn test_validate_options_ok() {
568        let opts = SimilarToOptions {
569            weights: Some(vec![0.7, 0.3]),
570            ..Default::default()
571        };
572        assert!(validate_options(&opts, 2).is_ok());
573    }
574
575    #[test]
576    fn test_validate_pair_fts_vector_query() {
577        assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
578    }
579
580    #[test]
581    fn test_validate_pair_vector_string_no_embed() {
582        let st = SourceType::Vector {
583            metric: DistanceMetric::Cosine,
584            has_embedding_config: false,
585        };
586        assert!(validate_pair(&st, false, true, 0).is_err());
587    }
588
589    #[test]
590    fn test_validate_pair_vector_string_with_embed() {
591        let st = SourceType::Vector {
592            metric: DistanceMetric::Cosine,
593            has_embedding_config: true,
594        };
595        assert!(validate_pair(&st, false, true, 0).is_ok());
596    }
597
598    #[test]
599    fn test_validate_pair_vector_vector() {
600        let st = SourceType::Vector {
601            metric: DistanceMetric::Cosine,
602            has_embedding_config: false,
603        };
604        assert!(validate_pair(&st, true, false, 0).is_ok());
605    }
606
607    #[test]
608    fn test_validate_pair_fts_string() {
609        assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
610    }
611
612    #[test]
613    fn test_fuse_scores_single() {
614        let opts = SimilarToOptions::default();
615        let score = fuse_scores(&[0.8], &opts).unwrap();
616        assert!((score - 0.8).abs() < 1e-6);
617    }
618
619    #[test]
620    fn test_fuse_scores_weighted() {
621        let opts = SimilarToOptions {
622            method: FusionMethod::Weighted,
623            weights: Some(vec![0.7, 0.3]),
624            ..Default::default()
625        };
626        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
627        assert!((score - 0.74).abs() < 1e-6);
628    }
629
630    #[test]
631    fn test_fuse_scores_rrf_fallback() {
632        let opts = SimilarToOptions::default();
633        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
634        // RRF in point context falls back to equal weights: (0.8 + 0.6) / 2 = 0.7
635        assert!((score - 0.7).abs() < 1e-6);
636    }
637
638    // -----------------------------------------------------------------------
639    // score_vectors() tests
640    // -----------------------------------------------------------------------
641
642    #[test]
643    fn test_score_vectors_cosine_identical() {
644        let v = vec![1.0, 0.0, 0.0];
645        let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
646        assert!((score - 1.0).abs() < 1e-6);
647    }
648
649    #[test]
650    fn test_score_vectors_cosine_matches_raw() {
651        // score_vectors with Cosine delegates to cosine_similarity
652        let a = vec![1.0, 0.0, 0.0];
653        let b = vec![0.8, 0.6, 0.0];
654        let raw = cosine_similarity(&a, &b).unwrap();
655        let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
656        assert!((raw - scored).abs() < 1e-6);
657    }
658
659    #[test]
660    fn test_score_vectors_l2() {
661        // [1,0,0] vs [0,1,0]: L2 squared distance = 2, score = 1/(1+2) ≈ 0.333
662        let a = vec![1.0, 0.0, 0.0];
663        let b = vec![0.0, 1.0, 0.0];
664        let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
665        assert!((score - 1.0 / 3.0).abs() < 1e-5);
666    }
667
668    #[test]
669    fn test_score_vectors_l2_identical() {
670        let v = vec![1.0, 0.0, 0.0];
671        let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
672        assert!((score - 1.0).abs() < 1e-6);
673    }
674
675    #[test]
676    fn test_score_vectors_dot() {
677        // [1,0,0] dot [0.8,0.6,0] = 0.8
678        let a = vec![1.0, 0.0, 0.0];
679        let b = vec![0.8, 0.6, 0.0];
680        let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
681        assert!((score - 0.8).abs() < 1e-6);
682    }
683
684    #[test]
685    fn test_score_vectors_dot_identical() {
686        let v = vec![1.0, 0.0, 0.0];
687        let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
688        assert!((score - 1.0).abs() < 1e-6);
689    }
690
691    #[test]
692    fn test_score_vectors_dimension_mismatch() {
693        let a = vec![1.0, 0.0];
694        let b = vec![1.0, 0.0, 0.0];
695        assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
696    }
697}