Skip to main content

uni_query_functions/
similar_to.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `similar_to()` expression function — unified similarity scoring.
5//!
6//! Dispatches to vector cosine similarity, BM25 full-text scoring, or
7//! hybrid fusion based on schema types. Returns a float in `[0, 1]`
8//! where higher means more similar.
9//!
10//! This is a **point computation** (score one bound node), not a search
11//! (top-K index scan). It works in WHERE, RETURN, WITH, ORDER BY, and
12//! Locy rule bodies.
13
14use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20/// Named error types for `similar_to()` validation failures.
21#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23    #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24    NoIndex { label: String, property: String },
25
26    #[error(
27        "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28    )]
29    TypeMismatch { source_index: usize },
30
31    #[error(
32        "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33    )]
34    NoEmbeddingConfig { source_index: usize },
35
36    #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37    WeightsLengthMismatch {
38        weights_len: usize,
39        sources_len: usize,
40    },
41
42    #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43    WeightsNotNormalized { sum: f32 },
44
45    #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46    InvalidMethod { method: String },
47
48    #[error("similar_to: {message}")]
49    InvalidOption { message: String },
50
51    #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52    DimensionMismatch { a: usize, b: usize },
53
54    #[error("similar_to: expected vector or list of numbers, got {actual}")]
55    InvalidVectorValue { actual: String },
56
57    #[error("similar_to: weighted fusion requires 'weights' option")]
58    WeightsRequired,
59
60    #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61    InvalidArity { count: usize },
62
63    #[error("similar_to requires GraphExecutionContext")]
64    NoGraphContext,
65}
66
67/// Fusion method for multi-source scoring.
68#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70    /// Reciprocal Rank Fusion (default). Falls back to equal-weight
71    /// fusion in point-computation context.
72    #[default]
73    Rrf,
74    /// Weighted sum of per-source scores.
75    Weighted,
76}
77
78/// Options for `similar_to()` controlling fusion and scoring behavior.
79#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81    /// Fusion algorithm when multiple sources are present.
82    pub method: FusionMethod,
83    /// Per-source weights for weighted fusion. Must sum to 1.0.
84    pub weights: Option<Vec<f32>>,
85    /// RRF constant k (default 60).
86    pub k: usize,
87    /// BM25 saturation constant for FTS normalization (default 1.0).
88    pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92    fn default() -> Self {
93        Self {
94            method: FusionMethod::Rrf,
95            weights: None,
96            k: 60,
97            fts_k: 1.0,
98        }
99    }
100}
101
102/// Parse options from a `Value::Map`.
103pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104    let map = match value {
105        Value::Map(m) => m,
106        Value::Null => return Ok(SimilarToOptions::default()),
107        _ => {
108            return Err(SimilarToError::InvalidOption {
109                message: format!("options must be a map, got {:?}", value),
110            });
111        }
112    };
113
114    let mut opts = SimilarToOptions::default();
115
116    if let Some(method_val) = map.get("method") {
117        match method_val.as_str() {
118            Some("rrf") => opts.method = FusionMethod::Rrf,
119            Some("weighted") => opts.method = FusionMethod::Weighted,
120            Some(other) => {
121                return Err(SimilarToError::InvalidMethod {
122                    method: other.to_string(),
123                });
124            }
125            None => {
126                return Err(SimilarToError::InvalidOption {
127                    message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128                });
129            }
130        }
131    }
132
133    if let Some(weights_val) = map.get("weights") {
134        match weights_val {
135            Value::List(list) => {
136                let weights: Result<Vec<f32>, SimilarToError> = list
137                    .iter()
138                    .map(|v| {
139                        v.as_f64()
140                            .map(|f| f as f32)
141                            .ok_or_else(|| SimilarToError::InvalidOption {
142                                message: "weight must be a number".to_string(),
143                            })
144                    })
145                    .collect();
146                opts.weights = Some(weights?);
147            }
148            _ => {
149                return Err(SimilarToError::InvalidOption {
150                    message: "'weights' must be a list of numbers".to_string(),
151                });
152            }
153        }
154    }
155
156    if let Some(k_val) = map.get("k") {
157        opts.k = k_val
158            .as_i64()
159            .ok_or_else(|| SimilarToError::InvalidOption {
160                message: "'k' must be an integer".to_string(),
161            })? as usize;
162    }
163
164    if let Some(fts_k_val) = map.get("fts_k") {
165        opts.fts_k = fts_k_val
166            .as_f64()
167            .ok_or_else(|| SimilarToError::InvalidOption {
168                message: "'fts_k' must be a number".to_string(),
169            })? as f32;
170    }
171
172    Ok(opts)
173}
174
175/// What type of source a property represents.
176#[derive(Debug, Clone)]
177pub enum SourceType {
178    /// Vector property with a vector index.
179    Vector {
180        metric: DistanceMetric,
181        has_embedding_config: bool,
182    },
183    /// String property with a full-text index.
184    Fts,
185}
186
187/// Resolve the source type for a property given the schema.
188pub fn resolve_source_type(
189    schema: &Schema,
190    label: &str,
191    property: &str,
192) -> Result<SourceType, SimilarToError> {
193    // Check vector index first
194    if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195        return Ok(SourceType::Vector {
196            metric: vec_config.metric.clone(),
197            has_embedding_config: vec_config.embedding_config.is_some(),
198        });
199    }
200
201    // Check full-text index
202    if schema
203        .fulltext_index_for_property(label, property)
204        .is_some()
205    {
206        return Ok(SourceType::Fts);
207    }
208
209    Err(SimilarToError::NoIndex {
210        label: label.to_string(),
211        property: property.to_string(),
212    })
213}
214
215/// Compute cosine similarity between two vectors, returning a score in [-1, 1].
216///
217/// Arithmetic is performed in f64 (see `cosine_similarity_inner`) and the
218/// clamped result is narrowed to f32.
219pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220    cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223/// Score two vectors using the specified distance metric, returning a similarity
224/// score where higher means more similar.
225///
226/// - **Cosine**: raw cosine similarity in \[-1, 1\] (delegates to [`cosine_similarity`]).
227/// - **L2**: `1 / (1 + d²)` where d² is squared Euclidean distance; range (0, 1\].
228/// - **Dot**: raw dot product (for normalised vectors equals cosine similarity).
229pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230    if a.len() != b.len() {
231        return Err(SimilarToError::DimensionMismatch {
232            a: a.len(),
233            b: b.len(),
234        });
235    }
236    let distance = metric.compute_distance(a, b);
237    match metric {
238        DistanceMetric::Cosine => cosine_similarity(a, b),
239        // compute_distance returns -dot (LanceDB convention: lower = more similar).
240        // Negate to recover the actual dot product as a similarity score.
241        DistanceMetric::Dot => Ok(-distance),
242        // L2 and all other metrics (#[non_exhaustive]): normalise via calculate_score.
243        _ => Ok(calculate_score(distance, metric)),
244    }
245}
246
247/// Compute the MaxSim (late-interaction / ColBERT) score between a query and a
248/// document, each represented as a set of token vectors.
249///
250/// MaxSim is `Σ_i max_j score_vectors(query_i, doc_j, metric)`: for every query
251/// token take its best match across all document tokens, then sum. Per-token
252/// scoring reuses [`score_vectors`], so the metric's normalisation matches the
253/// rest of the engine. Default metric for late interaction is `Cosine`.
254///
255/// An empty query, or a query token with no document tokens to match against,
256/// contributes `0` rather than erroring.
257///
258/// # Errors
259/// Returns [`SimilarToError::DimensionMismatch`] if a query token and a document
260/// token differ in length.
261pub fn maxsim(
262    query: &[Vec<f32>],
263    doc: &[Vec<f32>],
264    metric: &DistanceMetric,
265) -> Result<f32, SimilarToError> {
266    let mut total = 0.0_f32;
267    for q in query {
268        let mut best: Option<f32> = None;
269        for d in doc {
270            let sim = score_vectors(q, d, metric)?;
271            best = Some(best.map_or(sim, |b| b.max(sim)));
272        }
273        // `None` means the document had no tokens: that query token scores 0.
274        total += best.unwrap_or(0.0);
275    }
276    Ok(total)
277}
278
279/// Convert a raw distance value into a normalised similarity score.
280///
281/// The conversion depends on the distance metric:
282/// - **Cosine**: `(2 - d) / 2` (LanceDB cosine distance ranges 0..2)
283/// - **Dot**: pass-through (already a similarity measure)
284/// - **L2** and others: `1 / (1 + d)`
285pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
286    match metric {
287        DistanceMetric::Cosine => (2.0 - distance) / 2.0,
288        DistanceMetric::Dot => distance,
289        _ => 1.0 / (1.0 + distance),
290    }
291}
292
293/// Normalize a BM25 score to [0, 1] using a saturation function.
294///
295/// `normalized = score / (score + fts_k)` where `fts_k` defaults to 1.0.
296pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
297    if score <= 0.0 {
298        return 0.0;
299    }
300    score / (score + fts_k)
301}
302
303/// Compute pure vector-vs-vector similarity (no storage access needed).
304///
305/// Both values must be `Value::List` of numbers or `Value::Vector`.
306///
307/// Uses f64 arithmetic throughout when both inputs are `Value::List`, preserving
308/// full precision for property-based vectors (e.g. in TCK and unit tests). For
309/// `Value::Vector` (pre-indexed f32 data) it falls back to the f32 path.
310pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
311    // NULL propagates: a NULL operand yields NULL, not an error. `values_to_array`
312    // renders `Value::Null` as an Arrow null, matching 3VL semantics.
313    if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
314        return Ok(Value::Null);
315    }
316    // Fast path: at least one input is a List — use f64 to avoid f32 precision loss.
317    let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
318    let f64_vecs = has_list
319        .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
320        .flatten();
321    if let Some((vec1, vec2)) = f64_vecs {
322        let sim = cosine_similarity_inner(&vec1, &vec2)?;
323        return Ok(Value::Float(sim));
324    }
325    // Fallback: f32 path for Value::Vector (indexed data already in f32).
326    let vec1 = value_to_f32_vec(v1)?;
327    let vec2 = value_to_f32_vec(v2)?;
328    let sim = cosine_similarity(&vec1, &vec2)?;
329    Ok(Value::Float(sim as f64))
330}
331
332/// Compute the sparse dot product between two sparse vectors (no storage access).
333///
334/// Each operand may be a `Value::SparseVector` or the structural
335/// `{indices: [...], values: [...]}` `Value::Map` — the latter is how a sparse
336/// param arrives at a scalar UDF, where the Arrow `Struct` is decoded without
337/// schema context. A `NULL` operand propagates to `NULL` (3VL semantics,
338/// matching [`eval_similar_to_pure`]). Non-overlapping term sets yield
339/// `Value::Float(0.0)`.
340///
341/// This is the SPLADE/learned-sparse analogue of [`eval_similar_to_pure`]:
342/// dot product is the canonical sparse similarity, so no normalization is
343/// applied (higher means more similar, unbounded above).
344///
345/// # Errors
346/// Returns an error if either non-null operand is neither a sparse vector nor a
347/// well-formed `{indices, values}` map, or if its weights are non-finite.
348pub fn eval_sparse_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
349    if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
350        return Ok(Value::Null);
351    }
352    let a = value_to_sparse(v1)?;
353    let b = value_to_sparse(v2)?;
354    Ok(Value::Float(f64::from(uni_sparse_vector::ops::sparse_dot(
355        &a, &b,
356    ))))
357}
358
359/// Reconstruct a kernel [`uni_sparse_vector::SparseVector`] from either a
360/// `Value::SparseVector` or a `{indices, values}` `Value::Map`.
361fn value_to_sparse(v: &Value) -> Result<uni_sparse_vector::SparseVector> {
362    match v {
363        Value::SparseVector { indices, values } => {
364            uni_sparse_vector::SparseVector::new(indices.clone(), values.clone())
365                .map_err(|e| anyhow::anyhow!("sparse_similar_to: invalid sparse vector: {e}"))
366        }
367        Value::Map(m) => {
368            let as_list = |key: &str| -> Result<&Vec<Value>> {
369                match m.get(key) {
370                    Some(Value::List(l)) => Ok(l),
371                    _ => Err(anyhow::anyhow!(
372                        "sparse_similar_to: map operand missing '{key}' list"
373                    )),
374                }
375            };
376            let indices: Vec<u32> = as_list("indices")?
377                .iter()
378                .map(|x| x.as_i64().map(|i| i as u32))
379                .collect::<Option<_>>()
380                .ok_or_else(|| anyhow::anyhow!("sparse_similar_to: 'indices' must be integers"))?;
381            let values: Vec<f32> = as_list("values")?
382                .iter()
383                .map(|x| x.as_f64().map(|f| f as f32))
384                .collect::<Option<_>>()
385                .ok_or_else(|| anyhow::anyhow!("sparse_similar_to: 'values' must be numbers"))?;
386            // `from_pairs` canonicalizes (sort + sum) since a map carries no
387            // ascending-order guarantee.
388            uni_sparse_vector::SparseVector::from_pairs(indices.into_iter().zip(values).collect())
389                .map_err(|e| anyhow::anyhow!("sparse_similar_to: invalid sparse map: {e}"))
390        }
391        _ => Err(anyhow::anyhow!(
392            "sparse_similar_to arguments must be sparse vectors or {{indices, values}} maps"
393        )),
394    }
395}
396
397/// Compute the raw cosine similarity (in f64) between two equal-length vectors,
398/// clamped to [-1, 1]. Returns 0 when either vector has zero magnitude.
399///
400/// Generic over the element type so both the f32 and f64 callers share one body;
401/// arithmetic is always performed in f64 to preserve precision.
402fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
403    if a.len() != b.len() {
404        return Err(SimilarToError::DimensionMismatch {
405            a: a.len(),
406            b: b.len(),
407        });
408    }
409    let mut dot = 0.0f64;
410    let mut mag1 = 0.0f64;
411    let mut mag2 = 0.0f64;
412    for (&x, &y) in a.iter().zip(b.iter()) {
413        let (x, y): (f64, f64) = (x.into(), y.into());
414        dot += x * y;
415        mag1 += x * x;
416        mag2 += y * y;
417    }
418    let mag1 = mag1.sqrt();
419    let mag2 = mag2.sqrt();
420    if mag1 == 0.0 || mag2 == 0.0 {
421        return Ok(0.0);
422    }
423    Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
424}
425
426/// Convert a Value to a numeric vector for vector operations.
427///
428/// Generic over the element type: `cast` maps the `f64` source value onto the
429/// target representation (identity for `f64`, narrowing for `f32`). List elements
430/// are read as full-precision `f64` before casting, while `Value::Vector` data is
431/// already `f32` and is widened to `f64` first.
432fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
433    match v {
434        Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
435        Value::List(list) => list
436            .iter()
437            .map(|v| {
438                v.as_f64()
439                    .map(&cast)
440                    .ok_or_else(|| SimilarToError::InvalidOption {
441                        message: "vector element must be a number".to_string(),
442                    })
443            })
444            .collect(),
445        _ => Err(SimilarToError::InvalidVectorValue {
446            actual: format!("{v:?}"),
447        }),
448    }
449}
450
451/// Convert a Value to a `Vec<f64>` for high-precision vector operations.
452fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
453    value_to_vec(v, |f| f)
454}
455
456/// Convert a Value to a `Vec<f32>` for vector operations.
457pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
458    value_to_vec(v, |f| f as f32)
459}
460
461/// Validate options against the number of sources.
462pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
463    if let Some(ref weights) = opts.weights {
464        if weights.len() != num_sources {
465            return Err(SimilarToError::WeightsLengthMismatch {
466                weights_len: weights.len(),
467                sources_len: num_sources,
468            });
469        }
470        let sum: f32 = weights.iter().sum();
471        if (sum - 1.0).abs() > 0.01 {
472            return Err(SimilarToError::WeightsNotNormalized { sum });
473        }
474    }
475    Ok(())
476}
477
478/// Validate per-pair type compatibility.
479///
480/// Returns an error if a Vector query is paired with an FTS source,
481/// or a String query is paired with a Vector source that has no embedding config.
482pub fn validate_pair(
483    source_type: &SourceType,
484    query_is_vector: bool,
485    query_is_string: bool,
486    source_index: usize,
487) -> Result<(), SimilarToError> {
488    match source_type {
489        SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
490        SourceType::Vector {
491            has_embedding_config: false,
492            ..
493        } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
494        _ => Ok(()),
495    }
496}
497
498/// Fuse multiple per-source scores into a single score.
499pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
500    if scores.len() == 1 {
501        return Ok(scores[0]);
502    }
503
504    match opts.method {
505        FusionMethod::Weighted => {
506            let weights = opts
507                .weights
508                .as_ref()
509                .ok_or(SimilarToError::WeightsRequired)?;
510            Ok(fusion::fuse_weighted_multi(scores, weights))
511        }
512        FusionMethod::Rrf => {
513            // In point-computation context, RRF degenerates to equal-weight fusion.
514            // The caller (similar_to_expr.rs) already emits QueryWarning::RrfPointContext
515            // unconditionally when method == Rrf && num_sources > 1.
516            let (score, _) = fusion::fuse_rrf_point(scores);
517            Ok(score)
518        }
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use std::collections::HashMap;
525
526    use super::*;
527
528    #[test]
529    fn test_parse_options_default() {
530        let opts = parse_options(&Value::Null).unwrap();
531        assert_eq!(opts.method, FusionMethod::Rrf);
532        assert_eq!(opts.k, 60);
533        assert!((opts.fts_k - 1.0).abs() < 1e-6);
534        assert!(opts.weights.is_none());
535    }
536
537    #[test]
538    fn test_maxsim_hand_computed() {
539        // query = {e0, e1}; doc = {e0, [0.5,0.5]}. With Dot:
540        //   q0=e0 -> max(e0·e0=1, e0·[.5,.5]=0.5) = 1.0
541        //   q1=e1 -> max(e1·e0=0, e1·[.5,.5]=0.5) = 0.5
542        // MaxSim = 1.5
543        let query = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
544        let doc = vec![vec![1.0_f32, 0.0], vec![0.5_f32, 0.5]];
545        let score = maxsim(&query, &doc, &DistanceMetric::Dot).unwrap();
546        assert!((score - 1.5).abs() < 1e-6, "got {score}");
547    }
548
549    #[test]
550    fn test_maxsim_edge_cases() {
551        let metric = DistanceMetric::Cosine;
552        // Empty query -> 0.
553        assert_eq!(maxsim(&[], &[vec![1.0_f32, 0.0]], &metric).unwrap(), 0.0);
554        // Empty doc -> every query token scores 0 -> 0.
555        let empty_doc: Vec<Vec<f32>> = vec![];
556        assert_eq!(
557            maxsim(&[vec![1.0_f32, 0.0]], &empty_doc, &metric).unwrap(),
558            0.0
559        );
560        // Dimension mismatch between a query token and a doc token -> error.
561        let err = maxsim(&[vec![1.0_f32, 0.0]], &[vec![1.0_f32, 0.0, 0.0]], &metric);
562        assert!(matches!(err, Err(SimilarToError::DimensionMismatch { .. })));
563    }
564
565    #[test]
566    fn test_maxsim_metric_changes_score() {
567        // Non-unit vectors so Dot and Cosine diverge: q=[2,0], d=[3,0].
568        //   Dot    = 2*3 = 6.0
569        //   Cosine = (2*3) / (|2|*|3|) = 1.0
570        let q = vec![vec![2.0_f32, 0.0]];
571        let d = vec![vec![3.0_f32, 0.0]];
572        let dot = maxsim(&q, &d, &DistanceMetric::Dot).unwrap();
573        let cos = maxsim(&q, &d, &DistanceMetric::Cosine).unwrap();
574        assert!((dot - 6.0).abs() < 1e-6, "dot got {dot}");
575        assert!((cos - 1.0).abs() < 1e-6, "cosine got {cos}");
576    }
577
578    #[test]
579    fn test_parse_options_weighted() {
580        let mut map = HashMap::new();
581        map.insert("method".to_string(), Value::String("weighted".to_string()));
582        map.insert(
583            "weights".to_string(),
584            Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
585        );
586        let opts = parse_options(&Value::Map(map)).unwrap();
587        assert_eq!(opts.method, FusionMethod::Weighted);
588        let weights = opts.weights.unwrap();
589        assert!((weights[0] - 0.7).abs() < 1e-6);
590        assert!((weights[1] - 0.3).abs() < 1e-6);
591    }
592
593    #[test]
594    fn test_parse_options_rrf_with_k() {
595        let mut map = HashMap::new();
596        map.insert("method".to_string(), Value::String("rrf".to_string()));
597        map.insert("k".to_string(), Value::Int(30));
598        let opts = parse_options(&Value::Map(map)).unwrap();
599        assert_eq!(opts.method, FusionMethod::Rrf);
600        assert_eq!(opts.k, 30);
601    }
602
603    #[test]
604    fn test_parse_options_fts_k() {
605        let mut map = HashMap::new();
606        map.insert("fts_k".to_string(), Value::Float(2.0));
607        let opts = parse_options(&Value::Map(map)).unwrap();
608        assert!((opts.fts_k - 2.0).abs() < 1e-6);
609    }
610
611    #[test]
612    fn test_parse_options_invalid_method() {
613        let mut map = HashMap::new();
614        map.insert("method".to_string(), Value::String("invalid".to_string()));
615        assert!(parse_options(&Value::Map(map)).is_err());
616    }
617
618    #[test]
619    fn test_cosine_similarity_identical() {
620        let v = vec![1.0, 0.0, 0.0];
621        let sim = cosine_similarity(&v, &v).unwrap();
622        assert!((sim - 1.0).abs() < 1e-6);
623    }
624
625    #[test]
626    fn test_cosine_similarity_orthogonal() {
627        let a = vec![1.0, 0.0];
628        let b = vec![0.0, 1.0];
629        let sim = cosine_similarity(&a, &b).unwrap();
630        assert!((sim - 0.0).abs() < 1e-6);
631    }
632
633    #[test]
634    fn test_cosine_similarity_opposite() {
635        let a = vec![1.0, 0.0];
636        let b = vec![-1.0, 0.0];
637        let sim = cosine_similarity(&a, &b).unwrap();
638        assert!((sim - (-1.0)).abs() < 1e-6);
639    }
640
641    #[test]
642    fn test_cosine_similarity_dimension_mismatch() {
643        let a = vec![1.0, 0.0];
644        let b = vec![1.0, 0.0, 0.0];
645        assert!(cosine_similarity(&a, &b).is_err());
646    }
647
648    #[test]
649    fn test_normalize_bm25() {
650        assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
651        assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
652        assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
653        assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
654    }
655
656    #[test]
657    fn test_normalize_bm25_custom_k() {
658        assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
659    }
660
661    #[test]
662    fn test_eval_similar_to_pure() {
663        let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
664        let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
665        let result = eval_similar_to_pure(&v1, &v2).unwrap();
666        match result {
667            Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
668            _ => panic!("Expected Float"),
669        }
670    }
671
672    #[test]
673    fn test_eval_similar_to_pure_vector_type() {
674        let v1 = Value::Vector(vec![1.0, 0.0]);
675        let v2 = Value::Vector(vec![0.0, 1.0]);
676        let result = eval_similar_to_pure(&v1, &v2).unwrap();
677        match result {
678            Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
679            _ => panic!("Expected Float"),
680        }
681    }
682
683    #[test]
684    fn test_validate_options_weights_length() {
685        let opts = SimilarToOptions {
686            weights: Some(vec![0.5]),
687            ..Default::default()
688        };
689        assert!(validate_options(&opts, 2).is_err());
690    }
691
692    #[test]
693    fn test_validate_options_weights_sum() {
694        let opts = SimilarToOptions {
695            weights: Some(vec![0.5, 0.3]),
696            ..Default::default()
697        };
698        assert!(validate_options(&opts, 2).is_err());
699    }
700
701    #[test]
702    fn test_validate_options_ok() {
703        let opts = SimilarToOptions {
704            weights: Some(vec![0.7, 0.3]),
705            ..Default::default()
706        };
707        assert!(validate_options(&opts, 2).is_ok());
708    }
709
710    #[test]
711    fn test_validate_pair_fts_vector_query() {
712        assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
713    }
714
715    #[test]
716    fn test_validate_pair_vector_string_no_embed() {
717        let st = SourceType::Vector {
718            metric: DistanceMetric::Cosine,
719            has_embedding_config: false,
720        };
721        assert!(validate_pair(&st, false, true, 0).is_err());
722    }
723
724    #[test]
725    fn test_validate_pair_vector_string_with_embed() {
726        let st = SourceType::Vector {
727            metric: DistanceMetric::Cosine,
728            has_embedding_config: true,
729        };
730        assert!(validate_pair(&st, false, true, 0).is_ok());
731    }
732
733    #[test]
734    fn test_validate_pair_vector_vector() {
735        let st = SourceType::Vector {
736            metric: DistanceMetric::Cosine,
737            has_embedding_config: false,
738        };
739        assert!(validate_pair(&st, true, false, 0).is_ok());
740    }
741
742    #[test]
743    fn test_validate_pair_fts_string() {
744        assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
745    }
746
747    #[test]
748    fn test_fuse_scores_single() {
749        let opts = SimilarToOptions::default();
750        let score = fuse_scores(&[0.8], &opts).unwrap();
751        assert!((score - 0.8).abs() < 1e-6);
752    }
753
754    #[test]
755    fn test_fuse_scores_weighted() {
756        let opts = SimilarToOptions {
757            method: FusionMethod::Weighted,
758            weights: Some(vec![0.7, 0.3]),
759            ..Default::default()
760        };
761        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
762        assert!((score - 0.74).abs() < 1e-6);
763    }
764
765    #[test]
766    fn test_fuse_scores_rrf_fallback() {
767        let opts = SimilarToOptions::default();
768        let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
769        // RRF in point context falls back to equal weights: (0.8 + 0.6) / 2 = 0.7
770        assert!((score - 0.7).abs() < 1e-6);
771    }
772
773    // -----------------------------------------------------------------------
774    // score_vectors() tests
775    // -----------------------------------------------------------------------
776
777    #[test]
778    fn test_score_vectors_cosine_identical() {
779        let v = vec![1.0, 0.0, 0.0];
780        let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
781        assert!((score - 1.0).abs() < 1e-6);
782    }
783
784    #[test]
785    fn test_score_vectors_cosine_matches_raw() {
786        // score_vectors with Cosine delegates to cosine_similarity
787        let a = vec![1.0, 0.0, 0.0];
788        let b = vec![0.8, 0.6, 0.0];
789        let raw = cosine_similarity(&a, &b).unwrap();
790        let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
791        assert!((raw - scored).abs() < 1e-6);
792    }
793
794    #[test]
795    fn test_score_vectors_l2() {
796        // [1,0,0] vs [0,1,0]: L2 squared distance = 2, score = 1/(1+2) ≈ 0.333
797        let a = vec![1.0, 0.0, 0.0];
798        let b = vec![0.0, 1.0, 0.0];
799        let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
800        assert!((score - 1.0 / 3.0).abs() < 1e-5);
801    }
802
803    #[test]
804    fn test_score_vectors_l2_identical() {
805        let v = vec![1.0, 0.0, 0.0];
806        let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
807        assert!((score - 1.0).abs() < 1e-6);
808    }
809
810    #[test]
811    fn test_score_vectors_dot() {
812        // [1,0,0] dot [0.8,0.6,0] = 0.8
813        let a = vec![1.0, 0.0, 0.0];
814        let b = vec![0.8, 0.6, 0.0];
815        let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
816        assert!((score - 0.8).abs() < 1e-6);
817    }
818
819    #[test]
820    fn test_score_vectors_dot_identical() {
821        let v = vec![1.0, 0.0, 0.0];
822        let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
823        assert!((score - 1.0).abs() < 1e-6);
824    }
825
826    #[test]
827    fn test_score_vectors_dimension_mismatch() {
828        let a = vec![1.0, 0.0];
829        let b = vec![1.0, 0.0, 0.0];
830        assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
831    }
832}