1use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23 #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24 NoIndex { label: String, property: String },
25
26 #[error(
27 "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28 )]
29 TypeMismatch { source_index: usize },
30
31 #[error(
32 "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33 )]
34 NoEmbeddingConfig { source_index: usize },
35
36 #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37 WeightsLengthMismatch {
38 weights_len: usize,
39 sources_len: usize,
40 },
41
42 #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43 WeightsNotNormalized { sum: f32 },
44
45 #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46 InvalidMethod { method: String },
47
48 #[error("similar_to: {message}")]
49 InvalidOption { message: String },
50
51 #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52 DimensionMismatch { a: usize, b: usize },
53
54 #[error("similar_to: expected vector or list of numbers, got {actual}")]
55 InvalidVectorValue { actual: String },
56
57 #[error("similar_to: weighted fusion requires 'weights' option")]
58 WeightsRequired,
59
60 #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61 InvalidArity { count: usize },
62
63 #[error("similar_to requires GraphExecutionContext")]
64 NoGraphContext,
65}
66
67#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70 #[default]
73 Rrf,
74 Weighted,
76}
77
78#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81 pub method: FusionMethod,
83 pub weights: Option<Vec<f32>>,
85 pub k: usize,
87 pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92 fn default() -> Self {
93 Self {
94 method: FusionMethod::Rrf,
95 weights: None,
96 k: 60,
97 fts_k: 1.0,
98 }
99 }
100}
101
102pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104 let map = match value {
105 Value::Map(m) => m,
106 Value::Null => return Ok(SimilarToOptions::default()),
107 _ => {
108 return Err(SimilarToError::InvalidOption {
109 message: format!("options must be a map, got {:?}", value),
110 });
111 }
112 };
113
114 let mut opts = SimilarToOptions::default();
115
116 if let Some(method_val) = map.get("method") {
117 match method_val.as_str() {
118 Some("rrf") => opts.method = FusionMethod::Rrf,
119 Some("weighted") => opts.method = FusionMethod::Weighted,
120 Some(other) => {
121 return Err(SimilarToError::InvalidMethod {
122 method: other.to_string(),
123 });
124 }
125 None => {
126 return Err(SimilarToError::InvalidOption {
127 message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128 });
129 }
130 }
131 }
132
133 if let Some(weights_val) = map.get("weights") {
134 match weights_val {
135 Value::List(list) => {
136 let weights: Result<Vec<f32>, SimilarToError> = list
137 .iter()
138 .map(|v| {
139 v.as_f64()
140 .map(|f| f as f32)
141 .ok_or_else(|| SimilarToError::InvalidOption {
142 message: "weight must be a number".to_string(),
143 })
144 })
145 .collect();
146 opts.weights = Some(weights?);
147 }
148 _ => {
149 return Err(SimilarToError::InvalidOption {
150 message: "'weights' must be a list of numbers".to_string(),
151 });
152 }
153 }
154 }
155
156 if let Some(k_val) = map.get("k") {
157 opts.k = k_val
158 .as_i64()
159 .ok_or_else(|| SimilarToError::InvalidOption {
160 message: "'k' must be an integer".to_string(),
161 })? as usize;
162 }
163
164 if let Some(fts_k_val) = map.get("fts_k") {
165 opts.fts_k = fts_k_val
166 .as_f64()
167 .ok_or_else(|| SimilarToError::InvalidOption {
168 message: "'fts_k' must be a number".to_string(),
169 })? as f32;
170 }
171
172 Ok(opts)
173}
174
175#[derive(Debug, Clone)]
177pub enum SourceType {
178 Vector {
180 metric: DistanceMetric,
181 has_embedding_config: bool,
182 },
183 Fts,
185}
186
187pub fn resolve_source_type(
189 schema: &Schema,
190 label: &str,
191 property: &str,
192) -> Result<SourceType, SimilarToError> {
193 if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195 return Ok(SourceType::Vector {
196 metric: vec_config.metric.clone(),
197 has_embedding_config: vec_config.embedding_config.is_some(),
198 });
199 }
200
201 if schema
203 .fulltext_index_for_property(label, property)
204 .is_some()
205 {
206 return Ok(SourceType::Fts);
207 }
208
209 Err(SimilarToError::NoIndex {
210 label: label.to_string(),
211 property: property.to_string(),
212 })
213}
214
215pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220 cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230 if a.len() != b.len() {
231 return Err(SimilarToError::DimensionMismatch {
232 a: a.len(),
233 b: b.len(),
234 });
235 }
236 let distance = metric.compute_distance(a, b);
237 match metric {
238 DistanceMetric::Cosine => cosine_similarity(a, b),
239 DistanceMetric::Dot => Ok(-distance),
242 _ => Ok(calculate_score(distance, metric)),
244 }
245}
246
247pub fn maxsim(
262 query: &[Vec<f32>],
263 doc: &[Vec<f32>],
264 metric: &DistanceMetric,
265) -> Result<f32, SimilarToError> {
266 let mut total = 0.0_f32;
267 for q in query {
268 let mut best: Option<f32> = None;
269 for d in doc {
270 let sim = score_vectors(q, d, metric)?;
271 best = Some(best.map_or(sim, |b| b.max(sim)));
272 }
273 total += best.unwrap_or(0.0);
275 }
276 Ok(total)
277}
278
279pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
286 match metric {
287 DistanceMetric::Cosine => (2.0 - distance) / 2.0,
288 DistanceMetric::Dot => distance,
289 _ => 1.0 / (1.0 + distance),
290 }
291}
292
293pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
297 if score <= 0.0 {
298 return 0.0;
299 }
300 score / (score + fts_k)
301}
302
303pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
311 if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
314 return Ok(Value::Null);
315 }
316 let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
318 let f64_vecs = has_list
319 .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
320 .flatten();
321 if let Some((vec1, vec2)) = f64_vecs {
322 let sim = cosine_similarity_inner(&vec1, &vec2)?;
323 return Ok(Value::Float(sim));
324 }
325 let vec1 = value_to_f32_vec(v1)?;
327 let vec2 = value_to_f32_vec(v2)?;
328 let sim = cosine_similarity(&vec1, &vec2)?;
329 Ok(Value::Float(sim as f64))
330}
331
332pub fn eval_sparse_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
349 if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
350 return Ok(Value::Null);
351 }
352 let a = value_to_sparse(v1)?;
353 let b = value_to_sparse(v2)?;
354 Ok(Value::Float(f64::from(uni_sparse_vector::ops::sparse_dot(
355 &a, &b,
356 ))))
357}
358
359fn value_to_sparse(v: &Value) -> Result<uni_sparse_vector::SparseVector> {
362 match v {
363 Value::SparseVector { indices, values } => {
364 uni_sparse_vector::SparseVector::new(indices.clone(), values.clone())
365 .map_err(|e| anyhow::anyhow!("sparse_similar_to: invalid sparse vector: {e}"))
366 }
367 Value::Map(m) => {
368 let as_list = |key: &str| -> Result<&Vec<Value>> {
369 match m.get(key) {
370 Some(Value::List(l)) => Ok(l),
371 _ => Err(anyhow::anyhow!(
372 "sparse_similar_to: map operand missing '{key}' list"
373 )),
374 }
375 };
376 let indices: Vec<u32> = as_list("indices")?
377 .iter()
378 .map(|x| x.as_i64().map(|i| i as u32))
379 .collect::<Option<_>>()
380 .ok_or_else(|| anyhow::anyhow!("sparse_similar_to: 'indices' must be integers"))?;
381 let values: Vec<f32> = as_list("values")?
382 .iter()
383 .map(|x| x.as_f64().map(|f| f as f32))
384 .collect::<Option<_>>()
385 .ok_or_else(|| anyhow::anyhow!("sparse_similar_to: 'values' must be numbers"))?;
386 uni_sparse_vector::SparseVector::from_pairs(indices.into_iter().zip(values).collect())
389 .map_err(|e| anyhow::anyhow!("sparse_similar_to: invalid sparse map: {e}"))
390 }
391 _ => Err(anyhow::anyhow!(
392 "sparse_similar_to arguments must be sparse vectors or {{indices, values}} maps"
393 )),
394 }
395}
396
397fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
403 if a.len() != b.len() {
404 return Err(SimilarToError::DimensionMismatch {
405 a: a.len(),
406 b: b.len(),
407 });
408 }
409 let mut dot = 0.0f64;
410 let mut mag1 = 0.0f64;
411 let mut mag2 = 0.0f64;
412 for (&x, &y) in a.iter().zip(b.iter()) {
413 let (x, y): (f64, f64) = (x.into(), y.into());
414 dot += x * y;
415 mag1 += x * x;
416 mag2 += y * y;
417 }
418 let mag1 = mag1.sqrt();
419 let mag2 = mag2.sqrt();
420 if mag1 == 0.0 || mag2 == 0.0 {
421 return Ok(0.0);
422 }
423 Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
424}
425
426fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
433 match v {
434 Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
435 Value::List(list) => list
436 .iter()
437 .map(|v| {
438 v.as_f64()
439 .map(&cast)
440 .ok_or_else(|| SimilarToError::InvalidOption {
441 message: "vector element must be a number".to_string(),
442 })
443 })
444 .collect(),
445 _ => Err(SimilarToError::InvalidVectorValue {
446 actual: format!("{v:?}"),
447 }),
448 }
449}
450
451fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
453 value_to_vec(v, |f| f)
454}
455
456pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
458 value_to_vec(v, |f| f as f32)
459}
460
461pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
463 if let Some(ref weights) = opts.weights {
464 if weights.len() != num_sources {
465 return Err(SimilarToError::WeightsLengthMismatch {
466 weights_len: weights.len(),
467 sources_len: num_sources,
468 });
469 }
470 let sum: f32 = weights.iter().sum();
471 if (sum - 1.0).abs() > 0.01 {
472 return Err(SimilarToError::WeightsNotNormalized { sum });
473 }
474 }
475 Ok(())
476}
477
478pub fn validate_pair(
483 source_type: &SourceType,
484 query_is_vector: bool,
485 query_is_string: bool,
486 source_index: usize,
487) -> Result<(), SimilarToError> {
488 match source_type {
489 SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
490 SourceType::Vector {
491 has_embedding_config: false,
492 ..
493 } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
494 _ => Ok(()),
495 }
496}
497
498pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
500 if scores.len() == 1 {
501 return Ok(scores[0]);
502 }
503
504 match opts.method {
505 FusionMethod::Weighted => {
506 let weights = opts
507 .weights
508 .as_ref()
509 .ok_or(SimilarToError::WeightsRequired)?;
510 Ok(fusion::fuse_weighted_multi(scores, weights))
511 }
512 FusionMethod::Rrf => {
513 let (score, _) = fusion::fuse_rrf_point(scores);
517 Ok(score)
518 }
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use std::collections::HashMap;
525
526 use super::*;
527
528 #[test]
529 fn test_parse_options_default() {
530 let opts = parse_options(&Value::Null).unwrap();
531 assert_eq!(opts.method, FusionMethod::Rrf);
532 assert_eq!(opts.k, 60);
533 assert!((opts.fts_k - 1.0).abs() < 1e-6);
534 assert!(opts.weights.is_none());
535 }
536
537 #[test]
538 fn test_maxsim_hand_computed() {
539 let query = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
544 let doc = vec![vec![1.0_f32, 0.0], vec![0.5_f32, 0.5]];
545 let score = maxsim(&query, &doc, &DistanceMetric::Dot).unwrap();
546 assert!((score - 1.5).abs() < 1e-6, "got {score}");
547 }
548
549 #[test]
550 fn test_maxsim_edge_cases() {
551 let metric = DistanceMetric::Cosine;
552 assert_eq!(maxsim(&[], &[vec![1.0_f32, 0.0]], &metric).unwrap(), 0.0);
554 let empty_doc: Vec<Vec<f32>> = vec![];
556 assert_eq!(
557 maxsim(&[vec![1.0_f32, 0.0]], &empty_doc, &metric).unwrap(),
558 0.0
559 );
560 let err = maxsim(&[vec![1.0_f32, 0.0]], &[vec![1.0_f32, 0.0, 0.0]], &metric);
562 assert!(matches!(err, Err(SimilarToError::DimensionMismatch { .. })));
563 }
564
565 #[test]
566 fn test_maxsim_metric_changes_score() {
567 let q = vec![vec![2.0_f32, 0.0]];
571 let d = vec![vec![3.0_f32, 0.0]];
572 let dot = maxsim(&q, &d, &DistanceMetric::Dot).unwrap();
573 let cos = maxsim(&q, &d, &DistanceMetric::Cosine).unwrap();
574 assert!((dot - 6.0).abs() < 1e-6, "dot got {dot}");
575 assert!((cos - 1.0).abs() < 1e-6, "cosine got {cos}");
576 }
577
578 #[test]
579 fn test_parse_options_weighted() {
580 let mut map = HashMap::new();
581 map.insert("method".to_string(), Value::String("weighted".to_string()));
582 map.insert(
583 "weights".to_string(),
584 Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
585 );
586 let opts = parse_options(&Value::Map(map)).unwrap();
587 assert_eq!(opts.method, FusionMethod::Weighted);
588 let weights = opts.weights.unwrap();
589 assert!((weights[0] - 0.7).abs() < 1e-6);
590 assert!((weights[1] - 0.3).abs() < 1e-6);
591 }
592
593 #[test]
594 fn test_parse_options_rrf_with_k() {
595 let mut map = HashMap::new();
596 map.insert("method".to_string(), Value::String("rrf".to_string()));
597 map.insert("k".to_string(), Value::Int(30));
598 let opts = parse_options(&Value::Map(map)).unwrap();
599 assert_eq!(opts.method, FusionMethod::Rrf);
600 assert_eq!(opts.k, 30);
601 }
602
603 #[test]
604 fn test_parse_options_fts_k() {
605 let mut map = HashMap::new();
606 map.insert("fts_k".to_string(), Value::Float(2.0));
607 let opts = parse_options(&Value::Map(map)).unwrap();
608 assert!((opts.fts_k - 2.0).abs() < 1e-6);
609 }
610
611 #[test]
612 fn test_parse_options_invalid_method() {
613 let mut map = HashMap::new();
614 map.insert("method".to_string(), Value::String("invalid".to_string()));
615 assert!(parse_options(&Value::Map(map)).is_err());
616 }
617
618 #[test]
619 fn test_cosine_similarity_identical() {
620 let v = vec![1.0, 0.0, 0.0];
621 let sim = cosine_similarity(&v, &v).unwrap();
622 assert!((sim - 1.0).abs() < 1e-6);
623 }
624
625 #[test]
626 fn test_cosine_similarity_orthogonal() {
627 let a = vec![1.0, 0.0];
628 let b = vec![0.0, 1.0];
629 let sim = cosine_similarity(&a, &b).unwrap();
630 assert!((sim - 0.0).abs() < 1e-6);
631 }
632
633 #[test]
634 fn test_cosine_similarity_opposite() {
635 let a = vec![1.0, 0.0];
636 let b = vec![-1.0, 0.0];
637 let sim = cosine_similarity(&a, &b).unwrap();
638 assert!((sim - (-1.0)).abs() < 1e-6);
639 }
640
641 #[test]
642 fn test_cosine_similarity_dimension_mismatch() {
643 let a = vec![1.0, 0.0];
644 let b = vec![1.0, 0.0, 0.0];
645 assert!(cosine_similarity(&a, &b).is_err());
646 }
647
648 #[test]
649 fn test_normalize_bm25() {
650 assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
651 assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
652 assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
653 assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
654 }
655
656 #[test]
657 fn test_normalize_bm25_custom_k() {
658 assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
659 }
660
661 #[test]
662 fn test_eval_similar_to_pure() {
663 let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
664 let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
665 let result = eval_similar_to_pure(&v1, &v2).unwrap();
666 match result {
667 Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
668 _ => panic!("Expected Float"),
669 }
670 }
671
672 #[test]
673 fn test_eval_similar_to_pure_vector_type() {
674 let v1 = Value::Vector(vec![1.0, 0.0]);
675 let v2 = Value::Vector(vec![0.0, 1.0]);
676 let result = eval_similar_to_pure(&v1, &v2).unwrap();
677 match result {
678 Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
679 _ => panic!("Expected Float"),
680 }
681 }
682
683 #[test]
684 fn test_validate_options_weights_length() {
685 let opts = SimilarToOptions {
686 weights: Some(vec![0.5]),
687 ..Default::default()
688 };
689 assert!(validate_options(&opts, 2).is_err());
690 }
691
692 #[test]
693 fn test_validate_options_weights_sum() {
694 let opts = SimilarToOptions {
695 weights: Some(vec![0.5, 0.3]),
696 ..Default::default()
697 };
698 assert!(validate_options(&opts, 2).is_err());
699 }
700
701 #[test]
702 fn test_validate_options_ok() {
703 let opts = SimilarToOptions {
704 weights: Some(vec![0.7, 0.3]),
705 ..Default::default()
706 };
707 assert!(validate_options(&opts, 2).is_ok());
708 }
709
710 #[test]
711 fn test_validate_pair_fts_vector_query() {
712 assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
713 }
714
715 #[test]
716 fn test_validate_pair_vector_string_no_embed() {
717 let st = SourceType::Vector {
718 metric: DistanceMetric::Cosine,
719 has_embedding_config: false,
720 };
721 assert!(validate_pair(&st, false, true, 0).is_err());
722 }
723
724 #[test]
725 fn test_validate_pair_vector_string_with_embed() {
726 let st = SourceType::Vector {
727 metric: DistanceMetric::Cosine,
728 has_embedding_config: true,
729 };
730 assert!(validate_pair(&st, false, true, 0).is_ok());
731 }
732
733 #[test]
734 fn test_validate_pair_vector_vector() {
735 let st = SourceType::Vector {
736 metric: DistanceMetric::Cosine,
737 has_embedding_config: false,
738 };
739 assert!(validate_pair(&st, true, false, 0).is_ok());
740 }
741
742 #[test]
743 fn test_validate_pair_fts_string() {
744 assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
745 }
746
747 #[test]
748 fn test_fuse_scores_single() {
749 let opts = SimilarToOptions::default();
750 let score = fuse_scores(&[0.8], &opts).unwrap();
751 assert!((score - 0.8).abs() < 1e-6);
752 }
753
754 #[test]
755 fn test_fuse_scores_weighted() {
756 let opts = SimilarToOptions {
757 method: FusionMethod::Weighted,
758 weights: Some(vec![0.7, 0.3]),
759 ..Default::default()
760 };
761 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
762 assert!((score - 0.74).abs() < 1e-6);
763 }
764
765 #[test]
766 fn test_fuse_scores_rrf_fallback() {
767 let opts = SimilarToOptions::default();
768 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
769 assert!((score - 0.7).abs() < 1e-6);
771 }
772
773 #[test]
778 fn test_score_vectors_cosine_identical() {
779 let v = vec![1.0, 0.0, 0.0];
780 let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
781 assert!((score - 1.0).abs() < 1e-6);
782 }
783
784 #[test]
785 fn test_score_vectors_cosine_matches_raw() {
786 let a = vec![1.0, 0.0, 0.0];
788 let b = vec![0.8, 0.6, 0.0];
789 let raw = cosine_similarity(&a, &b).unwrap();
790 let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
791 assert!((raw - scored).abs() < 1e-6);
792 }
793
794 #[test]
795 fn test_score_vectors_l2() {
796 let a = vec![1.0, 0.0, 0.0];
798 let b = vec![0.0, 1.0, 0.0];
799 let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
800 assert!((score - 1.0 / 3.0).abs() < 1e-5);
801 }
802
803 #[test]
804 fn test_score_vectors_l2_identical() {
805 let v = vec![1.0, 0.0, 0.0];
806 let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
807 assert!((score - 1.0).abs() < 1e-6);
808 }
809
810 #[test]
811 fn test_score_vectors_dot() {
812 let a = vec![1.0, 0.0, 0.0];
814 let b = vec![0.8, 0.6, 0.0];
815 let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
816 assert!((score - 0.8).abs() < 1e-6);
817 }
818
819 #[test]
820 fn test_score_vectors_dot_identical() {
821 let v = vec![1.0, 0.0, 0.0];
822 let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
823 assert!((score - 1.0).abs() < 1e-6);
824 }
825
826 #[test]
827 fn test_score_vectors_dimension_mismatch() {
828 let a = vec![1.0, 0.0];
829 let b = vec![1.0, 0.0, 0.0];
830 assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
831 }
832}