1use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23 #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24 NoIndex { label: String, property: String },
25
26 #[error(
27 "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28 )]
29 TypeMismatch { source_index: usize },
30
31 #[error(
32 "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33 )]
34 NoEmbeddingConfig { source_index: usize },
35
36 #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37 WeightsLengthMismatch {
38 weights_len: usize,
39 sources_len: usize,
40 },
41
42 #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43 WeightsNotNormalized { sum: f32 },
44
45 #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46 InvalidMethod { method: String },
47
48 #[error("similar_to: {message}")]
49 InvalidOption { message: String },
50
51 #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52 DimensionMismatch { a: usize, b: usize },
53
54 #[error("similar_to: expected vector or list of numbers, got {actual}")]
55 InvalidVectorValue { actual: String },
56
57 #[error("similar_to: weighted fusion requires 'weights' option")]
58 WeightsRequired,
59
60 #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61 InvalidArity { count: usize },
62
63 #[error("similar_to requires GraphExecutionContext")]
64 NoGraphContext,
65}
66
67#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70 #[default]
73 Rrf,
74 Weighted,
76}
77
78#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81 pub method: FusionMethod,
83 pub weights: Option<Vec<f32>>,
85 pub k: usize,
87 pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92 fn default() -> Self {
93 Self {
94 method: FusionMethod::Rrf,
95 weights: None,
96 k: 60,
97 fts_k: 1.0,
98 }
99 }
100}
101
102pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104 let map = match value {
105 Value::Map(m) => m,
106 Value::Null => return Ok(SimilarToOptions::default()),
107 _ => {
108 return Err(SimilarToError::InvalidOption {
109 message: format!("options must be a map, got {:?}", value),
110 });
111 }
112 };
113
114 let mut opts = SimilarToOptions::default();
115
116 if let Some(method_val) = map.get("method") {
117 match method_val.as_str() {
118 Some("rrf") => opts.method = FusionMethod::Rrf,
119 Some("weighted") => opts.method = FusionMethod::Weighted,
120 Some(other) => {
121 return Err(SimilarToError::InvalidMethod {
122 method: other.to_string(),
123 });
124 }
125 None => {
126 return Err(SimilarToError::InvalidOption {
127 message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128 });
129 }
130 }
131 }
132
133 if let Some(weights_val) = map.get("weights") {
134 match weights_val {
135 Value::List(list) => {
136 let weights: Result<Vec<f32>, SimilarToError> = list
137 .iter()
138 .map(|v| {
139 v.as_f64()
140 .map(|f| f as f32)
141 .ok_or_else(|| SimilarToError::InvalidOption {
142 message: "weight must be a number".to_string(),
143 })
144 })
145 .collect();
146 opts.weights = Some(weights?);
147 }
148 _ => {
149 return Err(SimilarToError::InvalidOption {
150 message: "'weights' must be a list of numbers".to_string(),
151 });
152 }
153 }
154 }
155
156 if let Some(k_val) = map.get("k") {
157 opts.k = k_val
158 .as_i64()
159 .ok_or_else(|| SimilarToError::InvalidOption {
160 message: "'k' must be an integer".to_string(),
161 })? as usize;
162 }
163
164 if let Some(fts_k_val) = map.get("fts_k") {
165 opts.fts_k = fts_k_val
166 .as_f64()
167 .ok_or_else(|| SimilarToError::InvalidOption {
168 message: "'fts_k' must be a number".to_string(),
169 })? as f32;
170 }
171
172 Ok(opts)
173}
174
175#[derive(Debug, Clone)]
177pub enum SourceType {
178 Vector {
180 metric: DistanceMetric,
181 has_embedding_config: bool,
182 },
183 Fts,
185}
186
187pub fn resolve_source_type(
189 schema: &Schema,
190 label: &str,
191 property: &str,
192) -> Result<SourceType, SimilarToError> {
193 if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195 return Ok(SourceType::Vector {
196 metric: vec_config.metric.clone(),
197 has_embedding_config: vec_config.embedding_config.is_some(),
198 });
199 }
200
201 if schema
203 .fulltext_index_for_property(label, property)
204 .is_some()
205 {
206 return Ok(SourceType::Fts);
207 }
208
209 Err(SimilarToError::NoIndex {
210 label: label.to_string(),
211 property: property.to_string(),
212 })
213}
214
215pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220 cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230 if a.len() != b.len() {
231 return Err(SimilarToError::DimensionMismatch {
232 a: a.len(),
233 b: b.len(),
234 });
235 }
236 let distance = metric.compute_distance(a, b);
237 match metric {
238 DistanceMetric::Cosine => cosine_similarity(a, b),
239 DistanceMetric::Dot => Ok(-distance),
242 _ => Ok(calculate_score(distance, metric)),
244 }
245}
246
247pub fn maxsim(
262 query: &[Vec<f32>],
263 doc: &[Vec<f32>],
264 metric: &DistanceMetric,
265) -> Result<f32, SimilarToError> {
266 let mut total = 0.0_f32;
267 for q in query {
268 let mut best: Option<f32> = None;
269 for d in doc {
270 let sim = score_vectors(q, d, metric)?;
271 best = Some(best.map_or(sim, |b| b.max(sim)));
272 }
273 total += best.unwrap_or(0.0);
275 }
276 Ok(total)
277}
278
279pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
286 match metric {
287 DistanceMetric::Cosine => (2.0 - distance) / 2.0,
288 DistanceMetric::Dot => distance,
289 _ => 1.0 / (1.0 + distance),
290 }
291}
292
293pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
297 if score <= 0.0 {
298 return 0.0;
299 }
300 score / (score + fts_k)
301}
302
303pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
311 if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
314 return Ok(Value::Null);
315 }
316 let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
318 let f64_vecs = has_list
319 .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
320 .flatten();
321 if let Some((vec1, vec2)) = f64_vecs {
322 let sim = cosine_similarity_inner(&vec1, &vec2)?;
323 return Ok(Value::Float(sim));
324 }
325 let vec1 = value_to_f32_vec(v1)?;
327 let vec2 = value_to_f32_vec(v2)?;
328 let sim = cosine_similarity(&vec1, &vec2)?;
329 Ok(Value::Float(sim as f64))
330}
331
332fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
338 if a.len() != b.len() {
339 return Err(SimilarToError::DimensionMismatch {
340 a: a.len(),
341 b: b.len(),
342 });
343 }
344 let mut dot = 0.0f64;
345 let mut mag1 = 0.0f64;
346 let mut mag2 = 0.0f64;
347 for (&x, &y) in a.iter().zip(b.iter()) {
348 let (x, y): (f64, f64) = (x.into(), y.into());
349 dot += x * y;
350 mag1 += x * x;
351 mag2 += y * y;
352 }
353 let mag1 = mag1.sqrt();
354 let mag2 = mag2.sqrt();
355 if mag1 == 0.0 || mag2 == 0.0 {
356 return Ok(0.0);
357 }
358 Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
359}
360
361fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
368 match v {
369 Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
370 Value::List(list) => list
371 .iter()
372 .map(|v| {
373 v.as_f64()
374 .map(&cast)
375 .ok_or_else(|| SimilarToError::InvalidOption {
376 message: "vector element must be a number".to_string(),
377 })
378 })
379 .collect(),
380 _ => Err(SimilarToError::InvalidVectorValue {
381 actual: format!("{v:?}"),
382 }),
383 }
384}
385
386fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
388 value_to_vec(v, |f| f)
389}
390
391pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
393 value_to_vec(v, |f| f as f32)
394}
395
396pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
398 if let Some(ref weights) = opts.weights {
399 if weights.len() != num_sources {
400 return Err(SimilarToError::WeightsLengthMismatch {
401 weights_len: weights.len(),
402 sources_len: num_sources,
403 });
404 }
405 let sum: f32 = weights.iter().sum();
406 if (sum - 1.0).abs() > 0.01 {
407 return Err(SimilarToError::WeightsNotNormalized { sum });
408 }
409 }
410 Ok(())
411}
412
413pub fn validate_pair(
418 source_type: &SourceType,
419 query_is_vector: bool,
420 query_is_string: bool,
421 source_index: usize,
422) -> Result<(), SimilarToError> {
423 match source_type {
424 SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
425 SourceType::Vector {
426 has_embedding_config: false,
427 ..
428 } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
429 _ => Ok(()),
430 }
431}
432
433pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
435 if scores.len() == 1 {
436 return Ok(scores[0]);
437 }
438
439 match opts.method {
440 FusionMethod::Weighted => {
441 let weights = opts
442 .weights
443 .as_ref()
444 .ok_or(SimilarToError::WeightsRequired)?;
445 Ok(fusion::fuse_weighted_multi(scores, weights))
446 }
447 FusionMethod::Rrf => {
448 let (score, _) = fusion::fuse_rrf_point(scores);
452 Ok(score)
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use std::collections::HashMap;
460
461 use super::*;
462
463 #[test]
464 fn test_parse_options_default() {
465 let opts = parse_options(&Value::Null).unwrap();
466 assert_eq!(opts.method, FusionMethod::Rrf);
467 assert_eq!(opts.k, 60);
468 assert!((opts.fts_k - 1.0).abs() < 1e-6);
469 assert!(opts.weights.is_none());
470 }
471
472 #[test]
473 fn test_maxsim_hand_computed() {
474 let query = vec![vec![1.0_f32, 0.0], vec![0.0_f32, 1.0]];
479 let doc = vec![vec![1.0_f32, 0.0], vec![0.5_f32, 0.5]];
480 let score = maxsim(&query, &doc, &DistanceMetric::Dot).unwrap();
481 assert!((score - 1.5).abs() < 1e-6, "got {score}");
482 }
483
484 #[test]
485 fn test_maxsim_edge_cases() {
486 let metric = DistanceMetric::Cosine;
487 assert_eq!(maxsim(&[], &[vec![1.0_f32, 0.0]], &metric).unwrap(), 0.0);
489 let empty_doc: Vec<Vec<f32>> = vec![];
491 assert_eq!(
492 maxsim(&[vec![1.0_f32, 0.0]], &empty_doc, &metric).unwrap(),
493 0.0
494 );
495 let err = maxsim(&[vec![1.0_f32, 0.0]], &[vec![1.0_f32, 0.0, 0.0]], &metric);
497 assert!(matches!(err, Err(SimilarToError::DimensionMismatch { .. })));
498 }
499
500 #[test]
501 fn test_maxsim_metric_changes_score() {
502 let q = vec![vec![2.0_f32, 0.0]];
506 let d = vec![vec![3.0_f32, 0.0]];
507 let dot = maxsim(&q, &d, &DistanceMetric::Dot).unwrap();
508 let cos = maxsim(&q, &d, &DistanceMetric::Cosine).unwrap();
509 assert!((dot - 6.0).abs() < 1e-6, "dot got {dot}");
510 assert!((cos - 1.0).abs() < 1e-6, "cosine got {cos}");
511 }
512
513 #[test]
514 fn test_parse_options_weighted() {
515 let mut map = HashMap::new();
516 map.insert("method".to_string(), Value::String("weighted".to_string()));
517 map.insert(
518 "weights".to_string(),
519 Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
520 );
521 let opts = parse_options(&Value::Map(map)).unwrap();
522 assert_eq!(opts.method, FusionMethod::Weighted);
523 let weights = opts.weights.unwrap();
524 assert!((weights[0] - 0.7).abs() < 1e-6);
525 assert!((weights[1] - 0.3).abs() < 1e-6);
526 }
527
528 #[test]
529 fn test_parse_options_rrf_with_k() {
530 let mut map = HashMap::new();
531 map.insert("method".to_string(), Value::String("rrf".to_string()));
532 map.insert("k".to_string(), Value::Int(30));
533 let opts = parse_options(&Value::Map(map)).unwrap();
534 assert_eq!(opts.method, FusionMethod::Rrf);
535 assert_eq!(opts.k, 30);
536 }
537
538 #[test]
539 fn test_parse_options_fts_k() {
540 let mut map = HashMap::new();
541 map.insert("fts_k".to_string(), Value::Float(2.0));
542 let opts = parse_options(&Value::Map(map)).unwrap();
543 assert!((opts.fts_k - 2.0).abs() < 1e-6);
544 }
545
546 #[test]
547 fn test_parse_options_invalid_method() {
548 let mut map = HashMap::new();
549 map.insert("method".to_string(), Value::String("invalid".to_string()));
550 assert!(parse_options(&Value::Map(map)).is_err());
551 }
552
553 #[test]
554 fn test_cosine_similarity_identical() {
555 let v = vec![1.0, 0.0, 0.0];
556 let sim = cosine_similarity(&v, &v).unwrap();
557 assert!((sim - 1.0).abs() < 1e-6);
558 }
559
560 #[test]
561 fn test_cosine_similarity_orthogonal() {
562 let a = vec![1.0, 0.0];
563 let b = vec![0.0, 1.0];
564 let sim = cosine_similarity(&a, &b).unwrap();
565 assert!((sim - 0.0).abs() < 1e-6);
566 }
567
568 #[test]
569 fn test_cosine_similarity_opposite() {
570 let a = vec![1.0, 0.0];
571 let b = vec![-1.0, 0.0];
572 let sim = cosine_similarity(&a, &b).unwrap();
573 assert!((sim - (-1.0)).abs() < 1e-6);
574 }
575
576 #[test]
577 fn test_cosine_similarity_dimension_mismatch() {
578 let a = vec![1.0, 0.0];
579 let b = vec![1.0, 0.0, 0.0];
580 assert!(cosine_similarity(&a, &b).is_err());
581 }
582
583 #[test]
584 fn test_normalize_bm25() {
585 assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
586 assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
587 assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
588 assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
589 }
590
591 #[test]
592 fn test_normalize_bm25_custom_k() {
593 assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
594 }
595
596 #[test]
597 fn test_eval_similar_to_pure() {
598 let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
599 let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
600 let result = eval_similar_to_pure(&v1, &v2).unwrap();
601 match result {
602 Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
603 _ => panic!("Expected Float"),
604 }
605 }
606
607 #[test]
608 fn test_eval_similar_to_pure_vector_type() {
609 let v1 = Value::Vector(vec![1.0, 0.0]);
610 let v2 = Value::Vector(vec![0.0, 1.0]);
611 let result = eval_similar_to_pure(&v1, &v2).unwrap();
612 match result {
613 Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
614 _ => panic!("Expected Float"),
615 }
616 }
617
618 #[test]
619 fn test_validate_options_weights_length() {
620 let opts = SimilarToOptions {
621 weights: Some(vec![0.5]),
622 ..Default::default()
623 };
624 assert!(validate_options(&opts, 2).is_err());
625 }
626
627 #[test]
628 fn test_validate_options_weights_sum() {
629 let opts = SimilarToOptions {
630 weights: Some(vec![0.5, 0.3]),
631 ..Default::default()
632 };
633 assert!(validate_options(&opts, 2).is_err());
634 }
635
636 #[test]
637 fn test_validate_options_ok() {
638 let opts = SimilarToOptions {
639 weights: Some(vec![0.7, 0.3]),
640 ..Default::default()
641 };
642 assert!(validate_options(&opts, 2).is_ok());
643 }
644
645 #[test]
646 fn test_validate_pair_fts_vector_query() {
647 assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
648 }
649
650 #[test]
651 fn test_validate_pair_vector_string_no_embed() {
652 let st = SourceType::Vector {
653 metric: DistanceMetric::Cosine,
654 has_embedding_config: false,
655 };
656 assert!(validate_pair(&st, false, true, 0).is_err());
657 }
658
659 #[test]
660 fn test_validate_pair_vector_string_with_embed() {
661 let st = SourceType::Vector {
662 metric: DistanceMetric::Cosine,
663 has_embedding_config: true,
664 };
665 assert!(validate_pair(&st, false, true, 0).is_ok());
666 }
667
668 #[test]
669 fn test_validate_pair_vector_vector() {
670 let st = SourceType::Vector {
671 metric: DistanceMetric::Cosine,
672 has_embedding_config: false,
673 };
674 assert!(validate_pair(&st, true, false, 0).is_ok());
675 }
676
677 #[test]
678 fn test_validate_pair_fts_string() {
679 assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
680 }
681
682 #[test]
683 fn test_fuse_scores_single() {
684 let opts = SimilarToOptions::default();
685 let score = fuse_scores(&[0.8], &opts).unwrap();
686 assert!((score - 0.8).abs() < 1e-6);
687 }
688
689 #[test]
690 fn test_fuse_scores_weighted() {
691 let opts = SimilarToOptions {
692 method: FusionMethod::Weighted,
693 weights: Some(vec![0.7, 0.3]),
694 ..Default::default()
695 };
696 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
697 assert!((score - 0.74).abs() < 1e-6);
698 }
699
700 #[test]
701 fn test_fuse_scores_rrf_fallback() {
702 let opts = SimilarToOptions::default();
703 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
704 assert!((score - 0.7).abs() < 1e-6);
706 }
707
708 #[test]
713 fn test_score_vectors_cosine_identical() {
714 let v = vec![1.0, 0.0, 0.0];
715 let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
716 assert!((score - 1.0).abs() < 1e-6);
717 }
718
719 #[test]
720 fn test_score_vectors_cosine_matches_raw() {
721 let a = vec![1.0, 0.0, 0.0];
723 let b = vec![0.8, 0.6, 0.0];
724 let raw = cosine_similarity(&a, &b).unwrap();
725 let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
726 assert!((raw - scored).abs() < 1e-6);
727 }
728
729 #[test]
730 fn test_score_vectors_l2() {
731 let a = vec![1.0, 0.0, 0.0];
733 let b = vec![0.0, 1.0, 0.0];
734 let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
735 assert!((score - 1.0 / 3.0).abs() < 1e-5);
736 }
737
738 #[test]
739 fn test_score_vectors_l2_identical() {
740 let v = vec![1.0, 0.0, 0.0];
741 let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
742 assert!((score - 1.0).abs() < 1e-6);
743 }
744
745 #[test]
746 fn test_score_vectors_dot() {
747 let a = vec![1.0, 0.0, 0.0];
749 let b = vec![0.8, 0.6, 0.0];
750 let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
751 assert!((score - 0.8).abs() < 1e-6);
752 }
753
754 #[test]
755 fn test_score_vectors_dot_identical() {
756 let v = vec![1.0, 0.0, 0.0];
757 let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
758 assert!((score - 1.0).abs() < 1e-6);
759 }
760
761 #[test]
762 fn test_score_vectors_dimension_mismatch() {
763 let a = vec![1.0, 0.0];
764 let b = vec![1.0, 0.0, 0.0];
765 assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
766 }
767}