Skip to main content

uni_query_functions/
similar_to.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `similar_to()` expression function — unified similarity scoring.
5//!
6//! Dispatches to vector cosine similarity, BM25 full-text scoring, or
7//! hybrid fusion based on schema types. Returns a float in `[0, 1]`
8//! where higher means more similar.
9//!
10//! This is a **point computation** (score one bound node), not a search
11//! (top-K index scan). It works in WHERE, RETURN, WITH, ORDER BY, and
12//! Locy rule bodies.
13
14use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20/// Named error types for `similar_to()` validation failures.
21#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23    #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24    NoIndex { label: String, property: String },
25
26    #[error(
27        "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28    )]
29    TypeMismatch { source_index: usize },
30
31    #[error(
32        "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33    )]
34    NoEmbeddingConfig { source_index: usize },
35
36    #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37    WeightsLengthMismatch {
38        weights_len: usize,
39        sources_len: usize,
40    },
41
42    #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43    WeightsNotNormalized { sum: f32 },
44
45    #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46    InvalidMethod { method: String },
47
48    #[error("similar_to: {message}")]
49    InvalidOption { message: String },
50
51    #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52    DimensionMismatch { a: usize, b: usize },
53
54    #[error("similar_to: expected vector or list of numbers, got {actual}")]
55    InvalidVectorValue { actual: String },
56
57    #[error("similar_to: weighted fusion requires 'weights' option")]
58    WeightsRequired,
59
60    #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61    InvalidArity { count: usize },
62
63    #[error("similar_to requires GraphExecutionContext")]
64    NoGraphContext,
65}
66
67/// Fusion method for multi-source scoring.
68#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70    /// Reciprocal Rank Fusion (default). Falls back to equal-weight
71    /// fusion in point-computation context.
72    #[default]
73    Rrf,
74    /// Weighted sum of per-source scores.
75    Weighted,
76}
77
78/// Options for `similar_to()` controlling fusion and scoring behavior.
79#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81    /// Fusion algorithm when multiple sources are present.
82    pub method: FusionMethod,
83    /// Per-source weights for weighted fusion. Must sum to 1.0.
84    pub weights: Option<Vec<f32>>,
85    /// RRF constant k (default 60).
86    pub k: usize,
87    /// BM25 saturation constant for FTS normalization (default 1.0).
88    pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92    fn default() -> Self {
93        Self {
94            method: FusionMethod::Rrf,
95            weights: None,
96            k: 60,
97            fts_k: 1.0,
98        }
99    }
100}
101
102/// Parse options from a `Value::Map`.
103pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104    let map = match value {
105        Value::Map(m) => m,
106        Value::Null => return Ok(SimilarToOptions::default()),
107        _ => {
108            return Err(SimilarToError::InvalidOption {
109                message: format!("options must be a map, got {:?}", value),
110            });
111        }
112    };
113
114    let mut opts = SimilarToOptions::default();
115
116    if let Some(method_val) = map.get("method") {
117        match method_val.as_str() {
118            Some("rrf") => opts.method = FusionMethod::Rrf,
119            Some("weighted") => opts.method = FusionMethod::Weighted,
120            Some(other) => {
121                return Err(SimilarToError::InvalidMethod {
122                    method: other.to_string(),
123                });
124            }
125            None => {
126                return Err(SimilarToError::InvalidOption {
127                    message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128                });
129            }
130        }
131    }
132
133    if let Some(weights_val) = map.get("weights") {
134        match weights_val {
135            Value::List(list) => {
136                let weights: Result<Vec<f32>, SimilarToError> = list
137                    .iter()
138                    .map(|v| {
139                        v.as_f64()
140                            .map(|f| f as f32)
141                            .ok_or_else(|| SimilarToError::InvalidOption {
142                                message: "weight must be a number".to_string(),
143                            })
144                    })
145                    .collect();
146                opts.weights = Some(weights?);
147            }
148            _ => {
149                return Err(SimilarToError::InvalidOption {
150                    message: "'weights' must be a list of numbers".to_string(),
151                });
152            }
153        }
154    }
155
156    if let Some(k_val) = map.get("k") {
157        opts.k = k_val
158            .as_i64()
159            .ok_or_else(|| SimilarToError::InvalidOption {
160                message: "'k' must be an integer".to_string(),
161            })? as usize;
162    }
163
164    if let Some(fts_k_val) = map.get("fts_k") {
165        opts.fts_k = fts_k_val
166            .as_f64()
167            .ok_or_else(|| SimilarToError::InvalidOption {
168                message: "'fts_k' must be a number".to_string(),
169            })? as f32;
170    }
171
172    Ok(opts)
173}
174
175/// What type of source a property represents.
176#[derive(Debug, Clone)]
177pub enum SourceType {
178    /// Vector property with a vector index.
179    Vector {
180        metric: DistanceMetric,
181        has_embedding_config: bool,
182    },
183    /// String property with a full-text index.
184    Fts,
185}
186
187/// Resolve the source type for a property given the schema.
188pub fn resolve_source_type(
189    schema: &Schema,
190    label: &str,
191    property: &str,
192) -> Result<SourceType, SimilarToError> {
193    // Check vector index first
194    if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195        return Ok(SourceType::Vector {
196            metric: vec_config.metric.clone(),
197            has_embedding_config: vec_config.embedding_config.is_some(),
198        });
199    }
200
201    // Check full-text index
202    if schema
203        .fulltext_index_for_property(label, property)
204        .is_some()
205    {
206        return Ok(SourceType::Fts);
207    }
208
209    Err(SimilarToError::NoIndex {
210        label: label.to_string(),
211        property: property.to_string(),
212    })
213}
214
215/// Compute cosine similarity between two vectors, returning a score in [-1, 1].
216///
217/// Arithmetic is performed in f64 (see `cosine_similarity_inner`) and the
218/// clamped result is narrowed to f32.
219pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220    cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223/// Score two vectors using the specified distance metric, returning a similarity
224/// score where higher means more similar.
225///
226/// - **Cosine**: raw cosine similarity in \[-1, 1\] (delegates to [`cosine_similarity`]).
227/// - **L2**: `1 / (1 + d²)` where d² is squared Euclidean distance; range (0, 1\].
228/// - **Dot**: raw dot product (for normalised vectors equals cosine similarity).
229pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230    if a.len() != b.len() {
231        return Err(SimilarToError::DimensionMismatch {
232            a: a.len(),
233            b: b.len(),
234        });
235    }
236    let distance = metric.compute_distance(a, b);
237    match metric {
238        DistanceMetric::Cosine => cosine_similarity(a, b),
239        // compute_distance returns -dot (LanceDB convention: lower = more similar).
240        // Negate to recover the actual dot product as a similarity score.
241        DistanceMetric::Dot => Ok(-distance),
242        // L2 and all other metrics (#[non_exhaustive]): normalise via calculate_score.
243        _ => Ok(calculate_score(distance, metric)),
244    }
245}
246
247/// Compute the MaxSim (late-interaction / ColBERT) score between a query and a
248/// document, each represented as a set of token vectors.
249///
250/// MaxSim is `Σ_i max_j score_vectors(query_i, doc_j, metric)`: for every query
251/// token take its best match across all document tokens, then sum. Per-token
252/// scoring reuses [`score_vectors`], so the metric's normalisation matches the
253/// rest of the engine. Default metric for late interaction is `Cosine`.
254///
255/// An empty query, or a query token with no document tokens to match against,
256/// contributes `0` rather than erroring.
257///
258/// # Errors
259/// Returns [`SimilarToError::DimensionMismatch`] if a query token and a document
260/// token differ in length.
261pub fn maxsim(
262    query: &[Vec<f32>],
263    doc: &[Vec<f32>],
264    metric: &DistanceMetric,
265) -> Result<f32, SimilarToError> {
266    let mut total = 0.0_f32;
267    for q in query {
268        let mut best: Option<f32> = None;
269        for d in doc {
270            let sim = score_vectors(q, d, metric)?;
271            best = Some(best.map_or(sim, |b| b.max(sim)));
272        }
273        // `None` means the document had no tokens: that query token scores 0.
274        total += best.unwrap_or(0.0);
275    }
276    Ok(total)
277}
278
279/// Convert a raw distance value into a normalised similarity score.
280///
281/// The conversion depends on the distance metric:
282/// - **Cosine**: `(2 - d) / 2` (LanceDB cosine distance ranges 0..2)
283/// - **Dot**: pass-through (already a similarity measure)
284/// - **L2** and others: `1 / (1 + d)`
285pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
286    match metric {
287        DistanceMetric::Cosine => (2.0 - distance) / 2.0,
288        DistanceMetric::Dot => distance,
289        _ => 1.0 / (1.0 + distance),
290    }
291}
292
293/// Normalize a BM25 score to [0, 1] using a saturation function.
294///
295/// `normalized = score / (score + fts_k)` where `fts_k` defaults to 1.0.
296pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
297    if score <= 0.0 {
298        return 0.0;
299    }
300    score / (score + fts_k)
301}
302
303/// Compute pure vector-vs-vector similarity (no storage access needed).
304///
305/// Both values must be `Value::List` of numbers or `Value::Vector`.
306///
307/// Uses f64 arithmetic throughout when both inputs are `Value::List`, preserving
308/// full precision for property-based vectors (e.g. in TCK and unit tests). For
309/// `Value::Vector` (pre-indexed f32 data) it falls back to the f32 path.
310pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
311    // NULL propagates: a NULL operand yields NULL, not an error. `values_to_array`
312    // renders `Value::Null` as an Arrow null, matching 3VL semantics.
313    if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
314        return Ok(Value::Null);
315    }
316    // Fast path: at least one input is a List — use f64 to avoid f32 precision loss.
317    let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
318    let f64_vecs = has_list
319        .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
320        .flatten();
321    if let Some((vec1, vec2)) = f64_vecs {
322        let sim = cosine_similarity_inner(&vec1, &vec2)?;
323        return Ok(Value::Float(sim));
324    }
325    // Fallback: f32 path for Value::Vector (indexed data already in f32).
326    let vec1 = value_to_f32_vec(v1)?;
327    let vec2 = value_to_f32_vec(v2)?;
328    let sim = cosine_similarity(&vec1, &vec2)?;
329    Ok(Value::Float(sim as f64))
330}
331
332/// Compute the raw cosine similarity (in f64) between two equal-length vectors,
333/// clamped to [-1, 1]. Returns 0 when either vector has zero magnitude.
334///
335/// Generic over the element type so both the f32 and f64 callers share one body;
336/// arithmetic is always performed in f64 to preserve precision.
337fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
338    if a.len() != b.len() {
339        return Err(SimilarToError::DimensionMismatch {
340            a: a.len(),
341            b: b.len(),
342        });
343    }
344    let mut dot = 0.0f64;
345    let mut mag1 = 0.0f64;
346    let mut mag2 = 0.0f64;
347    for (&x, &y) in a.iter().zip(b.iter()) {
348        let (x, y): (f64, f64) = (x.into(), y.into());
349        dot += x * y;
350        mag1 += x * x;
351        mag2 += y * y;
352    }
353    let mag1 = mag1.sqrt();
354    let mag2 = mag2.sqrt();
355    if mag1 == 0.0 || mag2 == 0.0 {
356        return Ok(0.0);
357    }
358    Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
359}
360
361/// Convert a Value to a numeric vector for vector operations.
362///
363/// Generic over the element type: `cast` maps the `f64` source value onto the
364/// target representation (identity for `f64`, narrowing for `f32`). List elements
365/// are read as full-precision `f64` before casting, while `Value::Vector` data is
366/// already `f32` and is widened to `f64` first.
367fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
368    match v {
369        Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
370        Value::List(list) => list
371            .iter()
372            .map(|v| {
373                v.as_f64()
374                    .map(&cast)
375                    .ok_or_else(|| SimilarToError::InvalidOption {
376                        message: "vector element must be a number".to_string(),
377                    })
378            })
379            .collect(),
380        _ => Err(SimilarToError::InvalidVectorValue {
381            actual: format!("{v:?}"),
382        }),
383    }
384}
385
386/// Convert a Value to a `Vec<f64>` for high-precision vector operations.
387fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
388    value_to_vec(v, |f| f)
389}
390
391/// Convert a Value to a `Vec<f32>` for vector operations.
392pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
393    value_to_vec(v, |f| f as f32)
394}
395
396/// Validate options against the number of sources.
397pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
398    if let Some(ref weights) = opts.weights {
399        if weights.len() != num_sources {
400            return Err(SimilarToError::WeightsLengthMismatch {
401                weights_len: weights.len(),
402                sources_len: num_sources,
403            });
404        }
405        let sum: f32 = weights.iter().sum();
406        if (sum - 1.0).abs() > 0.01 {
407            return Err(SimilarToError::WeightsNotNormalized { sum });
408        }
409    }
410    Ok(())
411}
412
413/// Validate per-pair type compatibility.
414///
415/// Returns an error if a Vector query is paired with an FTS source,
416/// or a String query is paired with a Vector source that has no embedding config.
417pub fn validate_pair(
418    source_type: &SourceType,
419    query_is_vector: bool,
420    query_is_string: bool,
421    source_index: usize,
422) -> Result<(), SimilarToError> {
423    match source_type {
424        SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
425        SourceType::Vector {
426            has_embedding_config: false,
427            ..
428        } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
429        _ => Ok(()),
430    }
431}
432
433/// Fuse multiple per-source scores into a single score.
434pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
435    if scores.len() == 1 {
436        return Ok(scores[0]);
437    }
438
439    match opts.method {
440        FusionMethod::Weighted => {
441            let weights = opts
442                .weights
443                .as_ref()
444                .ok_or(SimilarToError::WeightsRequired)?;
445            Ok(fusion::fuse_weighted_multi(scores, weights))
446        }
447        FusionMethod::Rrf => {
448            // In point-computation context, RRF degenerates to equal-weight fusion.
449            // The caller (similar_to_expr.rs) already emits QueryWarning::RrfPointContext
450            // unconditionally when method == Rrf && num_sources > 1.
451            let (score, _) = fusion::fuse_rrf_point(scores);
452            Ok(score)
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use std::collections::HashMap;
460
461    use super::*;
462
463    #[test]
464    fn test_parse_options_default() {
465        let opts = parse_options(&Value::Null).unwrap();
466        assert_eq!(opts.method, FusionMethod::Rrf);
467        assert_eq!(opts.k, 60);
468        assert!((opts.fts_k - 1.0).abs() < 1e-6);
469        assert!(opts.weights.is_none());
470    }
471
472    #[test]
473    fn test_maxsim_hand_computed() {
474        // query = {e0, e1}; doc = {e0, [0.5,0.5]}. With Dot:
475        //   q0=e0 -> max(e0·e0=1, e0·[.5,.5]=0.5) = 1.0
476        //   q1=e1 -> max(e1·e0=0, e1·[.5,.5]=0.5) = 0.5
477        // MaxSim = 1.5
478        let query = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
479        let doc = vec![vec![1.0_f32, 0.0], vec![0.5_f32, 0.5]];
480        let score = maxsim(&query, &doc, &DistanceMetric::Dot).unwrap();
481        assert!((score - 1.5).abs() < 1e-6, "got {score}");
482    }
483
484    #[test]
485    fn test_maxsim_edge_cases() {
486        let metric = DistanceMetric::Cosine;
487        // Empty query -> 0.
488        assert_eq!(maxsim(&[], &[vec![1.0_f32, 0.0]], &metric).unwrap(), 0.0);
489        // Empty doc -> every query token scores 0 -> 0.
490        let empty_doc: Vec<Vec<f32>> = vec![];
491        assert_eq!(
492            maxsim(&[vec![1.0_f32, 0.0]], &empty_doc, &metric).unwrap(),
493            0.0
494        );
495        // Dimension mismatch between a query token and a doc token -> error.
496        let err = maxsim(&[vec![1.0_f32, 0.0]], &[vec![1.0_f32, 0.0, 0.0]], &metric);
497        assert!(matches!(err, Err(SimilarToError::DimensionMismatch { .. })));
498    }
499
500    #[test]
501    fn test_maxsim_metric_changes_score() {
502        // Non-unit vectors so Dot and Cosine diverge: q=[2,0], d=[3,0].
503        //   Dot    = 2*3 = 6.0
504        //   Cosine = (2*3) / (|2|*|3|) = 1.0
505        let q = vec![vec![2.0_f32, 0.0]];
506        let d = vec![vec![3.0_f32, 0.0]];
507        let dot = maxsim(&q, &d, &DistanceMetric::Dot).unwrap();
508        let cos = maxsim(&q, &d, &DistanceMetric::Cosine).unwrap();
509        assert!((dot - 6.0).abs() < 1e-6, "dot got {dot}");
510        assert!((cos - 1.0).abs() < 1e-6, "cosine got {cos}");
511    }
512
513    #[test]
514    fn test_parse_options_weighted() {
515        let mut map = HashMap::new();
516        map.insert("method".to_string(), Value::String("weighted".to_string()));
517        map.insert(
518            "weights".to_string(),
519            Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
520        );
521        let opts = parse_options(&Value::Map(map)).unwrap();
522        assert_eq!(opts.method, FusionMethod::Weighted);
523        let weights = opts.weights.unwrap();
524        assert!((weights[0] - 0.7).abs() < 1e-6);
525        assert!((weights[1] - 0.3).abs() < 1e-6);
526    }
527
528    #[test]
529    fn test_parse_options_rrf_with_k() {
530        let mut map = HashMap::new();
531        map.insert("method".to_string(), Value::String("rrf".to_string()));
532        map.insert("k".to_string(), Value::Int(30));
533        let opts = parse_options(&Value::Map(map)).unwrap();
534        assert_eq!(opts.method, FusionMethod::Rrf);
535        assert_eq!(opts.k, 30);
536    }
537
538    #[test]
539    fn test_parse_options_fts_k() {
540        let mut map = HashMap::new();
541        map.insert("fts_k".to_string(), Value::Float(2.0));
542        let opts = parse_options(&Value::Map(map)).unwrap();
543        assert!((opts.fts_k - 2.0).abs() < 1e-6);
544    }
545
546    #[test]
547    fn test_parse_options_invalid_method() {
548        let mut map = HashMap::new();
549        map.insert("method".to_string(), Value::String("invalid".to_string()));
550        assert!(parse_options(&Value::Map(map)).is_err());
551    }
552
553    #[test]
554    fn test_cosine_similarity_identical() {
555        let v = vec![1.0, 0.0, 0.0];
556        let sim = cosine_similarity(&v, &v).unwrap();
557        assert!((sim - 1.0).abs() < 1e-6);
558    }
559
560    #[test]
561    fn test_cosine_similarity_orthogonal() {
562        let a = vec![1.0, 0.0];
563        let b = vec![0.0, 1.0];
564        let sim = cosine_similarity(&a, &b).unwrap();
565        assert!((sim - 0.0).abs() < 1e-6);
566    }
567
568    #[test]
569    fn test_cosine_similarity_opposite() {
570        let a = vec![1.0, 0.0];
571        let b = vec![-1.0, 0.0];
572        let sim = cosine_similarity(&a, &b).unwrap();
573        assert!((sim - (-1.0)).abs() < 1e-6);
574    }
575
576    #[test]
577    fn test_cosine_similarity_dimension_mismatch() {
578        let a = vec![1.0, 0.0];
579        let b = vec![1.0, 0.0, 0.0];
580        assert!(cosine_similarity(&a, &b).is_err());
581    }
582
583    #[test]
584    fn test_normalize_bm25() {
585        assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
586        assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
587        assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
588        assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
589    }
590
591    #[test]
592    fn test_normalize_bm25_custom_k() {
593        assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
594    }
595
596    #[test]
597    fn test_eval_similar_to_pure() {
598        let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
599        let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
600        let result = eval_similar_to_pure(&v1, &v2).unwrap();
601        match result {
602            Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
603            _ => panic!("Expected Float"),
604        }
605    }
606
607    #[test]
608    fn test_eval_similar_to_pure_vector_type() {
609        let v1 = Value::Vector(vec![1.0, 0.0]);
610        let v2 = Value::Vector(vec![0.0, 1.0]);
611        let result = eval_similar_to_pure(&v1, &v2).unwrap();
612        match result {
613            Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
614            _ => panic!("Expected Float"),
615        }
616    }
617
618    #[test]
619    fn test_validate_options_weights_length() {
620        let opts = SimilarToOptions {
621            weights: Some(vec![0.5]),
622            ..Default::default()
623        };
624        assert!(validate_options(&opts, 2).is_err());
625    }
626
627    #[test]
628    fn test_validate_options_weights_sum() {
629        let opts = SimilarToOptions {
630            weights: Some(vec![0.5, 0.3]),
631            ..Default::default()
632        };
633        assert!(validate_options(&opts, 2).is_err());
634    }
635
636    #[test]
637    fn test_validate_options_ok() {
638        let opts = SimilarToOptions {
639            weights: Some(vec![0.7, 0.3]),
640            ..Default::default()
641        };
642        assert!(validate_options(&opts, 2).is_ok());
643    }
644
645    #[test]
646    fn test_validate_pair_fts_vector_query() {
647        assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
648    }
649
650    #[test]
651    fn test_validate_pair_vector_string_no_embed() {
652        let st = SourceType::Vector {
653            metric: DistanceMetric::Cosine,
654            has_embedding_config: false,
655        };
656        assert!(validate_pair(&st, false, true, 0).is_err());
657    }
658
659    #[test]
660    fn test_validate_pair_vector_string_with_embed() {
661        let st = SourceType::Vector {
662            metric: DistanceMetric::Cosine,
663            has_embedding_config: true,
664        };
665        assert!(validate_pair(&st, false, true, 0).is_ok());
666    }
667
668    #[test]
669    fn test_validate_pair_vector_vector() {
670        let st = SourceType::Vector {
671            metric: DistanceMetric::Cosine,
672            has_embedding_config: false,
673        };
674        assert!(validate_pair(&st, true, false, 0).is_ok());
675    }
676
677    #[test]
678    fn test_validate_pair_fts_string() {
679        assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
680    }
681
682    #[test]
683    fn test_fuse_scores_single() {
684        let opts = SimilarToOptions::default();
685        let score = fuse_scores(&[0.8], &opts).unwrap();
686        assert!((score - 0.8).abs() < 1e-6);
687    }
688
689    #[test]
690    fn test_fuse_scores_weighted() {
691        let opts = SimilarToOptions {
692            method: FusionMethod::Weighted,
693            weights: Some(vec![0.7, 0.3]),
694            ..Default::default()
695        };
696        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
697        assert!((score - 0.74).abs() < 1e-6);
698    }
699
700    #[test]
701    fn test_fuse_scores_rrf_fallback() {
702        let opts = SimilarToOptions::default();
703        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
704        // RRF in point context falls back to equal weights: (0.8 + 0.6) / 2 = 0.7
705        assert!((score - 0.7).abs() < 1e-6);
706    }
707
708    // -----------------------------------------------------------------------
709    // score_vectors() tests
710    // -----------------------------------------------------------------------
711
712    #[test]
713    fn test_score_vectors_cosine_identical() {
714        let v = vec![1.0, 0.0, 0.0];
715        let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
716        assert!((score - 1.0).abs() < 1e-6);
717    }
718
719    #[test]
720    fn test_score_vectors_cosine_matches_raw() {
721        // score_vectors with Cosine delegates to cosine_similarity
722        let a = vec![1.0, 0.0, 0.0];
723        let b = vec![0.8, 0.6, 0.0];
724        let raw = cosine_similarity(&a, &b).unwrap();
725        let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
726        assert!((raw - scored).abs() < 1e-6);
727    }
728
729    #[test]
730    fn test_score_vectors_l2() {
731        // [1,0,0] vs [0,1,0]: L2 squared distance = 2, score = 1/(1+2) ≈ 0.333
732        let a = vec![1.0, 0.0, 0.0];
733        let b = vec![0.0, 1.0, 0.0];
734        let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
735        assert!((score - 1.0 / 3.0).abs() < 1e-5);
736    }
737
738    #[test]
739    fn test_score_vectors_l2_identical() {
740        let v = vec![1.0, 0.0, 0.0];
741        let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
742        assert!((score - 1.0).abs() < 1e-6);
743    }
744
745    #[test]
746    fn test_score_vectors_dot() {
747        // [1,0,0] dot [0.8,0.6,0] = 0.8
748        let a = vec![1.0, 0.0, 0.0];
749        let b = vec![0.8, 0.6, 0.0];
750        let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
751        assert!((score - 0.8).abs() < 1e-6);
752    }
753
754    #[test]
755    fn test_score_vectors_dot_identical() {
756        let v = vec![1.0, 0.0, 0.0];
757        let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
758        assert!((score - 1.0).abs() < 1e-6);
759    }
760
761    #[test]
762    fn test_score_vectors_dimension_mismatch() {
763        let a = vec![1.0, 0.0];
764        let b = vec![1.0, 0.0, 0.0];
765        assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
766    }
767}