1use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23 #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24 NoIndex { label: String, property: String },
25
26 #[error(
27 "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28 )]
29 TypeMismatch { source_index: usize },
30
31 #[error(
32 "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33 )]
34 NoEmbeddingConfig { source_index: usize },
35
36 #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37 WeightsLengthMismatch {
38 weights_len: usize,
39 sources_len: usize,
40 },
41
42 #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43 WeightsNotNormalized { sum: f32 },
44
45 #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46 InvalidMethod { method: String },
47
48 #[error("similar_to: {message}")]
49 InvalidOption { message: String },
50
51 #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52 DimensionMismatch { a: usize, b: usize },
53
54 #[error("similar_to: expected vector or list of numbers, got {actual}")]
55 InvalidVectorValue { actual: String },
56
57 #[error("similar_to: weighted fusion requires 'weights' option")]
58 WeightsRequired,
59
60 #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61 InvalidArity { count: usize },
62
63 #[error("similar_to requires GraphExecutionContext")]
64 NoGraphContext,
65}
66
67#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70 #[default]
73 Rrf,
74 Weighted,
76}
77
78#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81 pub method: FusionMethod,
83 pub weights: Option<Vec<f32>>,
85 pub k: usize,
87 pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92 fn default() -> Self {
93 Self {
94 method: FusionMethod::Rrf,
95 weights: None,
96 k: 60,
97 fts_k: 1.0,
98 }
99 }
100}
101
102pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104 let map = match value {
105 Value::Map(m) => m,
106 Value::Null => return Ok(SimilarToOptions::default()),
107 _ => {
108 return Err(SimilarToError::InvalidOption {
109 message: format!("options must be a map, got {:?}", value),
110 });
111 }
112 };
113
114 let mut opts = SimilarToOptions::default();
115
116 if let Some(method_val) = map.get("method") {
117 match method_val.as_str() {
118 Some("rrf") => opts.method = FusionMethod::Rrf,
119 Some("weighted") => opts.method = FusionMethod::Weighted,
120 Some(other) => {
121 return Err(SimilarToError::InvalidMethod {
122 method: other.to_string(),
123 });
124 }
125 None => {
126 return Err(SimilarToError::InvalidOption {
127 message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128 });
129 }
130 }
131 }
132
133 if let Some(weights_val) = map.get("weights") {
134 match weights_val {
135 Value::List(list) => {
136 let weights: Result<Vec<f32>, SimilarToError> = list
137 .iter()
138 .map(|v| {
139 v.as_f64()
140 .map(|f| f as f32)
141 .ok_or_else(|| SimilarToError::InvalidOption {
142 message: "weight must be a number".to_string(),
143 })
144 })
145 .collect();
146 opts.weights = Some(weights?);
147 }
148 _ => {
149 return Err(SimilarToError::InvalidOption {
150 message: "'weights' must be a list of numbers".to_string(),
151 });
152 }
153 }
154 }
155
156 if let Some(k_val) = map.get("k") {
157 opts.k = k_val
158 .as_i64()
159 .ok_or_else(|| SimilarToError::InvalidOption {
160 message: "'k' must be an integer".to_string(),
161 })? as usize;
162 }
163
164 if let Some(fts_k_val) = map.get("fts_k") {
165 opts.fts_k = fts_k_val
166 .as_f64()
167 .ok_or_else(|| SimilarToError::InvalidOption {
168 message: "'fts_k' must be a number".to_string(),
169 })? as f32;
170 }
171
172 Ok(opts)
173}
174
175#[derive(Debug, Clone)]
177pub enum SourceType {
178 Vector {
180 metric: DistanceMetric,
181 has_embedding_config: bool,
182 },
183 Fts,
185}
186
187pub fn resolve_source_type(
189 schema: &Schema,
190 label: &str,
191 property: &str,
192) -> Result<SourceType, SimilarToError> {
193 if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195 return Ok(SourceType::Vector {
196 metric: vec_config.metric.clone(),
197 has_embedding_config: vec_config.embedding_config.is_some(),
198 });
199 }
200
201 if schema
203 .fulltext_index_for_property(label, property)
204 .is_some()
205 {
206 return Ok(SourceType::Fts);
207 }
208
209 Err(SimilarToError::NoIndex {
210 label: label.to_string(),
211 property: property.to_string(),
212 })
213}
214
215pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220 cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230 if a.len() != b.len() {
231 return Err(SimilarToError::DimensionMismatch {
232 a: a.len(),
233 b: b.len(),
234 });
235 }
236 let distance = metric.compute_distance(a, b);
237 match metric {
238 DistanceMetric::Cosine => cosine_similarity(a, b),
239 DistanceMetric::Dot => Ok(-distance),
242 _ => Ok(calculate_score(distance, metric)),
244 }
245}
246
247pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
254 match metric {
255 DistanceMetric::Cosine => (2.0 - distance) / 2.0,
256 DistanceMetric::Dot => distance,
257 _ => 1.0 / (1.0 + distance),
258 }
259}
260
261pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
265 if score <= 0.0 {
266 return 0.0;
267 }
268 score / (score + fts_k)
269}
270
271pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
279 let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
281 let f64_vecs = has_list
282 .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
283 .flatten();
284 if let Some((vec1, vec2)) = f64_vecs {
285 let sim = cosine_similarity_inner(&vec1, &vec2)?;
286 return Ok(Value::Float(sim));
287 }
288 let vec1 = value_to_f32_vec(v1)?;
290 let vec2 = value_to_f32_vec(v2)?;
291 let sim = cosine_similarity(&vec1, &vec2)?;
292 Ok(Value::Float(sim as f64))
293}
294
295fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
301 if a.len() != b.len() {
302 return Err(SimilarToError::DimensionMismatch {
303 a: a.len(),
304 b: b.len(),
305 });
306 }
307 let mut dot = 0.0f64;
308 let mut mag1 = 0.0f64;
309 let mut mag2 = 0.0f64;
310 for (&x, &y) in a.iter().zip(b.iter()) {
311 let (x, y): (f64, f64) = (x.into(), y.into());
312 dot += x * y;
313 mag1 += x * x;
314 mag2 += y * y;
315 }
316 let mag1 = mag1.sqrt();
317 let mag2 = mag2.sqrt();
318 if mag1 == 0.0 || mag2 == 0.0 {
319 return Ok(0.0);
320 }
321 Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
322}
323
324fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
331 match v {
332 Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
333 Value::List(list) => list
334 .iter()
335 .map(|v| {
336 v.as_f64()
337 .map(&cast)
338 .ok_or_else(|| SimilarToError::InvalidOption {
339 message: "vector element must be a number".to_string(),
340 })
341 })
342 .collect(),
343 _ => Err(SimilarToError::InvalidVectorValue {
344 actual: format!("{v:?}"),
345 }),
346 }
347}
348
349fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
351 value_to_vec(v, |f| f)
352}
353
354pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
356 value_to_vec(v, |f| f as f32)
357}
358
359pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
361 if let Some(ref weights) = opts.weights {
362 if weights.len() != num_sources {
363 return Err(SimilarToError::WeightsLengthMismatch {
364 weights_len: weights.len(),
365 sources_len: num_sources,
366 });
367 }
368 let sum: f32 = weights.iter().sum();
369 if (sum - 1.0).abs() > 0.01 {
370 return Err(SimilarToError::WeightsNotNormalized { sum });
371 }
372 }
373 Ok(())
374}
375
376pub fn validate_pair(
381 source_type: &SourceType,
382 query_is_vector: bool,
383 query_is_string: bool,
384 source_index: usize,
385) -> Result<(), SimilarToError> {
386 match source_type {
387 SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
388 SourceType::Vector {
389 has_embedding_config: false,
390 ..
391 } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
392 _ => Ok(()),
393 }
394}
395
396pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
398 if scores.len() == 1 {
399 return Ok(scores[0]);
400 }
401
402 match opts.method {
403 FusionMethod::Weighted => {
404 let weights = opts
405 .weights
406 .as_ref()
407 .ok_or(SimilarToError::WeightsRequired)?;
408 Ok(fusion::fuse_weighted_multi(scores, weights))
409 }
410 FusionMethod::Rrf => {
411 let (score, _) = fusion::fuse_rrf_point(scores);
415 Ok(score)
416 }
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use std::collections::HashMap;
423
424 use super::*;
425
426 #[test]
427 fn test_parse_options_default() {
428 let opts = parse_options(&Value::Null).unwrap();
429 assert_eq!(opts.method, FusionMethod::Rrf);
430 assert_eq!(opts.k, 60);
431 assert!((opts.fts_k - 1.0).abs() < 1e-6);
432 assert!(opts.weights.is_none());
433 }
434
435 #[test]
436 fn test_parse_options_weighted() {
437 let mut map = HashMap::new();
438 map.insert("method".to_string(), Value::String("weighted".to_string()));
439 map.insert(
440 "weights".to_string(),
441 Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
442 );
443 let opts = parse_options(&Value::Map(map)).unwrap();
444 assert_eq!(opts.method, FusionMethod::Weighted);
445 let weights = opts.weights.unwrap();
446 assert!((weights[0] - 0.7).abs() < 1e-6);
447 assert!((weights[1] - 0.3).abs() < 1e-6);
448 }
449
450 #[test]
451 fn test_parse_options_rrf_with_k() {
452 let mut map = HashMap::new();
453 map.insert("method".to_string(), Value::String("rrf".to_string()));
454 map.insert("k".to_string(), Value::Int(30));
455 let opts = parse_options(&Value::Map(map)).unwrap();
456 assert_eq!(opts.method, FusionMethod::Rrf);
457 assert_eq!(opts.k, 30);
458 }
459
460 #[test]
461 fn test_parse_options_fts_k() {
462 let mut map = HashMap::new();
463 map.insert("fts_k".to_string(), Value::Float(2.0));
464 let opts = parse_options(&Value::Map(map)).unwrap();
465 assert!((opts.fts_k - 2.0).abs() < 1e-6);
466 }
467
468 #[test]
469 fn test_parse_options_invalid_method() {
470 let mut map = HashMap::new();
471 map.insert("method".to_string(), Value::String("invalid".to_string()));
472 assert!(parse_options(&Value::Map(map)).is_err());
473 }
474
475 #[test]
476 fn test_cosine_similarity_identical() {
477 let v = vec![1.0, 0.0, 0.0];
478 let sim = cosine_similarity(&v, &v).unwrap();
479 assert!((sim - 1.0).abs() < 1e-6);
480 }
481
482 #[test]
483 fn test_cosine_similarity_orthogonal() {
484 let a = vec![1.0, 0.0];
485 let b = vec![0.0, 1.0];
486 let sim = cosine_similarity(&a, &b).unwrap();
487 assert!((sim - 0.0).abs() < 1e-6);
488 }
489
490 #[test]
491 fn test_cosine_similarity_opposite() {
492 let a = vec![1.0, 0.0];
493 let b = vec![-1.0, 0.0];
494 let sim = cosine_similarity(&a, &b).unwrap();
495 assert!((sim - (-1.0)).abs() < 1e-6);
496 }
497
498 #[test]
499 fn test_cosine_similarity_dimension_mismatch() {
500 let a = vec![1.0, 0.0];
501 let b = vec![1.0, 0.0, 0.0];
502 assert!(cosine_similarity(&a, &b).is_err());
503 }
504
505 #[test]
506 fn test_normalize_bm25() {
507 assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
508 assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
509 assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
510 assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
511 }
512
513 #[test]
514 fn test_normalize_bm25_custom_k() {
515 assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
516 }
517
518 #[test]
519 fn test_eval_similar_to_pure() {
520 let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
521 let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
522 let result = eval_similar_to_pure(&v1, &v2).unwrap();
523 match result {
524 Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
525 _ => panic!("Expected Float"),
526 }
527 }
528
529 #[test]
530 fn test_eval_similar_to_pure_vector_type() {
531 let v1 = Value::Vector(vec![1.0, 0.0]);
532 let v2 = Value::Vector(vec![0.0, 1.0]);
533 let result = eval_similar_to_pure(&v1, &v2).unwrap();
534 match result {
535 Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
536 _ => panic!("Expected Float"),
537 }
538 }
539
540 #[test]
541 fn test_validate_options_weights_length() {
542 let opts = SimilarToOptions {
543 weights: Some(vec![0.5]),
544 ..Default::default()
545 };
546 assert!(validate_options(&opts, 2).is_err());
547 }
548
549 #[test]
550 fn test_validate_options_weights_sum() {
551 let opts = SimilarToOptions {
552 weights: Some(vec![0.5, 0.3]),
553 ..Default::default()
554 };
555 assert!(validate_options(&opts, 2).is_err());
556 }
557
558 #[test]
559 fn test_validate_options_ok() {
560 let opts = SimilarToOptions {
561 weights: Some(vec![0.7, 0.3]),
562 ..Default::default()
563 };
564 assert!(validate_options(&opts, 2).is_ok());
565 }
566
567 #[test]
568 fn test_validate_pair_fts_vector_query() {
569 assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
570 }
571
572 #[test]
573 fn test_validate_pair_vector_string_no_embed() {
574 let st = SourceType::Vector {
575 metric: DistanceMetric::Cosine,
576 has_embedding_config: false,
577 };
578 assert!(validate_pair(&st, false, true, 0).is_err());
579 }
580
581 #[test]
582 fn test_validate_pair_vector_string_with_embed() {
583 let st = SourceType::Vector {
584 metric: DistanceMetric::Cosine,
585 has_embedding_config: true,
586 };
587 assert!(validate_pair(&st, false, true, 0).is_ok());
588 }
589
590 #[test]
591 fn test_validate_pair_vector_vector() {
592 let st = SourceType::Vector {
593 metric: DistanceMetric::Cosine,
594 has_embedding_config: false,
595 };
596 assert!(validate_pair(&st, true, false, 0).is_ok());
597 }
598
599 #[test]
600 fn test_validate_pair_fts_string() {
601 assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
602 }
603
604 #[test]
605 fn test_fuse_scores_single() {
606 let opts = SimilarToOptions::default();
607 let score = fuse_scores(&[0.8], &opts).unwrap();
608 assert!((score - 0.8).abs() < 1e-6);
609 }
610
611 #[test]
612 fn test_fuse_scores_weighted() {
613 let opts = SimilarToOptions {
614 method: FusionMethod::Weighted,
615 weights: Some(vec![0.7, 0.3]),
616 ..Default::default()
617 };
618 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
619 assert!((score - 0.74).abs() < 1e-6);
620 }
621
622 #[test]
623 fn test_fuse_scores_rrf_fallback() {
624 let opts = SimilarToOptions::default();
625 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
626 assert!((score - 0.7).abs() < 1e-6);
628 }
629
630 #[test]
635 fn test_score_vectors_cosine_identical() {
636 let v = vec![1.0, 0.0, 0.0];
637 let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
638 assert!((score - 1.0).abs() < 1e-6);
639 }
640
641 #[test]
642 fn test_score_vectors_cosine_matches_raw() {
643 let a = vec![1.0, 0.0, 0.0];
645 let b = vec![0.8, 0.6, 0.0];
646 let raw = cosine_similarity(&a, &b).unwrap();
647 let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
648 assert!((raw - scored).abs() < 1e-6);
649 }
650
651 #[test]
652 fn test_score_vectors_l2() {
653 let a = vec![1.0, 0.0, 0.0];
655 let b = vec![0.0, 1.0, 0.0];
656 let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
657 assert!((score - 1.0 / 3.0).abs() < 1e-5);
658 }
659
660 #[test]
661 fn test_score_vectors_l2_identical() {
662 let v = vec![1.0, 0.0, 0.0];
663 let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
664 assert!((score - 1.0).abs() < 1e-6);
665 }
666
667 #[test]
668 fn test_score_vectors_dot() {
669 let a = vec![1.0, 0.0, 0.0];
671 let b = vec![0.8, 0.6, 0.0];
672 let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
673 assert!((score - 0.8).abs() < 1e-6);
674 }
675
676 #[test]
677 fn test_score_vectors_dot_identical() {
678 let v = vec![1.0, 0.0, 0.0];
679 let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
680 assert!((score - 1.0).abs() < 1e-6);
681 }
682
683 #[test]
684 fn test_score_vectors_dimension_mismatch() {
685 let a = vec![1.0, 0.0];
686 let b = vec![1.0, 0.0, 0.0];
687 assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
688 }
689}