1use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::fusion;
19
20#[derive(Debug, thiserror::Error)]
22pub enum SimilarToError {
23 #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
24 NoIndex { label: String, property: String },
25
26 #[error(
27 "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
28 )]
29 TypeMismatch { source_index: usize },
30
31 #[error(
32 "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
33 )]
34 NoEmbeddingConfig { source_index: usize },
35
36 #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
37 WeightsLengthMismatch {
38 weights_len: usize,
39 sources_len: usize,
40 },
41
42 #[error("similar_to: weights must sum to 1.0 (got {sum})")]
43 WeightsNotNormalized { sum: f32 },
44
45 #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
46 InvalidMethod { method: String },
47
48 #[error("similar_to: {message}")]
49 InvalidOption { message: String },
50
51 #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
52 DimensionMismatch { a: usize, b: usize },
53
54 #[error("similar_to: expected vector or list of numbers, got {actual}")]
55 InvalidVectorValue { actual: String },
56
57 #[error("similar_to: weighted fusion requires 'weights' option")]
58 WeightsRequired,
59
60 #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
61 InvalidArity { count: usize },
62
63 #[error("similar_to requires GraphExecutionContext")]
64 NoGraphContext,
65}
66
67#[derive(Debug, Clone, Default, PartialEq)]
69pub enum FusionMethod {
70 #[default]
73 Rrf,
74 Weighted,
76}
77
78#[derive(Debug, Clone)]
80pub struct SimilarToOptions {
81 pub method: FusionMethod,
83 pub weights: Option<Vec<f32>>,
85 pub k: usize,
87 pub fts_k: f32,
89}
90
91impl Default for SimilarToOptions {
92 fn default() -> Self {
93 Self {
94 method: FusionMethod::Rrf,
95 weights: None,
96 k: 60,
97 fts_k: 1.0,
98 }
99 }
100}
101
102pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
104 let map = match value {
105 Value::Map(m) => m,
106 Value::Null => return Ok(SimilarToOptions::default()),
107 _ => {
108 return Err(SimilarToError::InvalidOption {
109 message: format!("options must be a map, got {:?}", value),
110 });
111 }
112 };
113
114 let mut opts = SimilarToOptions::default();
115
116 if let Some(method_val) = map.get("method") {
117 match method_val.as_str() {
118 Some("rrf") => opts.method = FusionMethod::Rrf,
119 Some("weighted") => opts.method = FusionMethod::Weighted,
120 Some(other) => {
121 return Err(SimilarToError::InvalidMethod {
122 method: other.to_string(),
123 });
124 }
125 None => {
126 return Err(SimilarToError::InvalidOption {
127 message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
128 });
129 }
130 }
131 }
132
133 if let Some(weights_val) = map.get("weights") {
134 match weights_val {
135 Value::List(list) => {
136 let weights: Result<Vec<f32>, SimilarToError> = list
137 .iter()
138 .map(|v| {
139 v.as_f64()
140 .map(|f| f as f32)
141 .ok_or_else(|| SimilarToError::InvalidOption {
142 message: "weight must be a number".to_string(),
143 })
144 })
145 .collect();
146 opts.weights = Some(weights?);
147 }
148 _ => {
149 return Err(SimilarToError::InvalidOption {
150 message: "'weights' must be a list of numbers".to_string(),
151 });
152 }
153 }
154 }
155
156 if let Some(k_val) = map.get("k") {
157 opts.k = k_val
158 .as_i64()
159 .ok_or_else(|| SimilarToError::InvalidOption {
160 message: "'k' must be an integer".to_string(),
161 })? as usize;
162 }
163
164 if let Some(fts_k_val) = map.get("fts_k") {
165 opts.fts_k = fts_k_val
166 .as_f64()
167 .ok_or_else(|| SimilarToError::InvalidOption {
168 message: "'fts_k' must be a number".to_string(),
169 })? as f32;
170 }
171
172 Ok(opts)
173}
174
175#[derive(Debug, Clone)]
177pub enum SourceType {
178 Vector {
180 metric: DistanceMetric,
181 has_embedding_config: bool,
182 },
183 Fts,
185}
186
187pub fn resolve_source_type(
189 schema: &Schema,
190 label: &str,
191 property: &str,
192) -> Result<SourceType, SimilarToError> {
193 if let Some(vec_config) = schema.vector_index_for_property(label, property) {
195 return Ok(SourceType::Vector {
196 metric: vec_config.metric.clone(),
197 has_embedding_config: vec_config.embedding_config.is_some(),
198 });
199 }
200
201 if schema
203 .fulltext_index_for_property(label, property)
204 .is_some()
205 {
206 return Ok(SourceType::Fts);
207 }
208
209 Err(SimilarToError::NoIndex {
210 label: label.to_string(),
211 property: property.to_string(),
212 })
213}
214
215pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
220 cosine_similarity_inner(a, b).map(|sim| sim as f32)
221}
222
223pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
230 if a.len() != b.len() {
231 return Err(SimilarToError::DimensionMismatch {
232 a: a.len(),
233 b: b.len(),
234 });
235 }
236 let distance = metric.compute_distance(a, b);
237 match metric {
238 DistanceMetric::Cosine => cosine_similarity(a, b),
239 DistanceMetric::Dot => Ok(-distance),
242 _ => Ok(calculate_score(distance, metric)),
244 }
245}
246
247pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
254 match metric {
255 DistanceMetric::Cosine => (2.0 - distance) / 2.0,
256 DistanceMetric::Dot => distance,
257 _ => 1.0 / (1.0 + distance),
258 }
259}
260
261pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
265 if score <= 0.0 {
266 return 0.0;
267 }
268 score / (score + fts_k)
269}
270
271pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
279 if matches!(v1, Value::Null) || matches!(v2, Value::Null) {
282 return Ok(Value::Null);
283 }
284 let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
286 let f64_vecs = has_list
287 .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
288 .flatten();
289 if let Some((vec1, vec2)) = f64_vecs {
290 let sim = cosine_similarity_inner(&vec1, &vec2)?;
291 return Ok(Value::Float(sim));
292 }
293 let vec1 = value_to_f32_vec(v1)?;
295 let vec2 = value_to_f32_vec(v2)?;
296 let sim = cosine_similarity(&vec1, &vec2)?;
297 Ok(Value::Float(sim as f64))
298}
299
300fn cosine_similarity_inner<T: Copy + Into<f64>>(a: &[T], b: &[T]) -> Result<f64, SimilarToError> {
306 if a.len() != b.len() {
307 return Err(SimilarToError::DimensionMismatch {
308 a: a.len(),
309 b: b.len(),
310 });
311 }
312 let mut dot = 0.0f64;
313 let mut mag1 = 0.0f64;
314 let mut mag2 = 0.0f64;
315 for (&x, &y) in a.iter().zip(b.iter()) {
316 let (x, y): (f64, f64) = (x.into(), y.into());
317 dot += x * y;
318 mag1 += x * x;
319 mag2 += y * y;
320 }
321 let mag1 = mag1.sqrt();
322 let mag2 = mag2.sqrt();
323 if mag1 == 0.0 || mag2 == 0.0 {
324 return Ok(0.0);
325 }
326 Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
327}
328
329fn value_to_vec<T>(v: &Value, cast: impl Fn(f64) -> T) -> Result<Vec<T>, SimilarToError> {
336 match v {
337 Value::Vector(vec) => Ok(vec.iter().map(|&x| cast(x as f64)).collect()),
338 Value::List(list) => list
339 .iter()
340 .map(|v| {
341 v.as_f64()
342 .map(&cast)
343 .ok_or_else(|| SimilarToError::InvalidOption {
344 message: "vector element must be a number".to_string(),
345 })
346 })
347 .collect(),
348 _ => Err(SimilarToError::InvalidVectorValue {
349 actual: format!("{v:?}"),
350 }),
351 }
352}
353
354fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
356 value_to_vec(v, |f| f)
357}
358
359pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
361 value_to_vec(v, |f| f as f32)
362}
363
364pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
366 if let Some(ref weights) = opts.weights {
367 if weights.len() != num_sources {
368 return Err(SimilarToError::WeightsLengthMismatch {
369 weights_len: weights.len(),
370 sources_len: num_sources,
371 });
372 }
373 let sum: f32 = weights.iter().sum();
374 if (sum - 1.0).abs() > 0.01 {
375 return Err(SimilarToError::WeightsNotNormalized { sum });
376 }
377 }
378 Ok(())
379}
380
381pub fn validate_pair(
386 source_type: &SourceType,
387 query_is_vector: bool,
388 query_is_string: bool,
389 source_index: usize,
390) -> Result<(), SimilarToError> {
391 match source_type {
392 SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
393 SourceType::Vector {
394 has_embedding_config: false,
395 ..
396 } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
397 _ => Ok(()),
398 }
399}
400
401pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
403 if scores.len() == 1 {
404 return Ok(scores[0]);
405 }
406
407 match opts.method {
408 FusionMethod::Weighted => {
409 let weights = opts
410 .weights
411 .as_ref()
412 .ok_or(SimilarToError::WeightsRequired)?;
413 Ok(fusion::fuse_weighted_multi(scores, weights))
414 }
415 FusionMethod::Rrf => {
416 let (score, _) = fusion::fuse_rrf_point(scores);
420 Ok(score)
421 }
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use std::collections::HashMap;
428
429 use super::*;
430
431 #[test]
432 fn test_parse_options_default() {
433 let opts = parse_options(&Value::Null).unwrap();
434 assert_eq!(opts.method, FusionMethod::Rrf);
435 assert_eq!(opts.k, 60);
436 assert!((opts.fts_k - 1.0).abs() < 1e-6);
437 assert!(opts.weights.is_none());
438 }
439
440 #[test]
441 fn test_parse_options_weighted() {
442 let mut map = HashMap::new();
443 map.insert("method".to_string(), Value::String("weighted".to_string()));
444 map.insert(
445 "weights".to_string(),
446 Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
447 );
448 let opts = parse_options(&Value::Map(map)).unwrap();
449 assert_eq!(opts.method, FusionMethod::Weighted);
450 let weights = opts.weights.unwrap();
451 assert!((weights[0] - 0.7).abs() < 1e-6);
452 assert!((weights[1] - 0.3).abs() < 1e-6);
453 }
454
455 #[test]
456 fn test_parse_options_rrf_with_k() {
457 let mut map = HashMap::new();
458 map.insert("method".to_string(), Value::String("rrf".to_string()));
459 map.insert("k".to_string(), Value::Int(30));
460 let opts = parse_options(&Value::Map(map)).unwrap();
461 assert_eq!(opts.method, FusionMethod::Rrf);
462 assert_eq!(opts.k, 30);
463 }
464
465 #[test]
466 fn test_parse_options_fts_k() {
467 let mut map = HashMap::new();
468 map.insert("fts_k".to_string(), Value::Float(2.0));
469 let opts = parse_options(&Value::Map(map)).unwrap();
470 assert!((opts.fts_k - 2.0).abs() < 1e-6);
471 }
472
473 #[test]
474 fn test_parse_options_invalid_method() {
475 let mut map = HashMap::new();
476 map.insert("method".to_string(), Value::String("invalid".to_string()));
477 assert!(parse_options(&Value::Map(map)).is_err());
478 }
479
480 #[test]
481 fn test_cosine_similarity_identical() {
482 let v = vec![1.0, 0.0, 0.0];
483 let sim = cosine_similarity(&v, &v).unwrap();
484 assert!((sim - 1.0).abs() < 1e-6);
485 }
486
487 #[test]
488 fn test_cosine_similarity_orthogonal() {
489 let a = vec![1.0, 0.0];
490 let b = vec![0.0, 1.0];
491 let sim = cosine_similarity(&a, &b).unwrap();
492 assert!((sim - 0.0).abs() < 1e-6);
493 }
494
495 #[test]
496 fn test_cosine_similarity_opposite() {
497 let a = vec![1.0, 0.0];
498 let b = vec![-1.0, 0.0];
499 let sim = cosine_similarity(&a, &b).unwrap();
500 assert!((sim - (-1.0)).abs() < 1e-6);
501 }
502
503 #[test]
504 fn test_cosine_similarity_dimension_mismatch() {
505 let a = vec![1.0, 0.0];
506 let b = vec![1.0, 0.0, 0.0];
507 assert!(cosine_similarity(&a, &b).is_err());
508 }
509
510 #[test]
511 fn test_normalize_bm25() {
512 assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
513 assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
514 assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
515 assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
516 }
517
518 #[test]
519 fn test_normalize_bm25_custom_k() {
520 assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
521 }
522
523 #[test]
524 fn test_eval_similar_to_pure() {
525 let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
526 let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
527 let result = eval_similar_to_pure(&v1, &v2).unwrap();
528 match result {
529 Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
530 _ => panic!("Expected Float"),
531 }
532 }
533
534 #[test]
535 fn test_eval_similar_to_pure_vector_type() {
536 let v1 = Value::Vector(vec![1.0, 0.0]);
537 let v2 = Value::Vector(vec![0.0, 1.0]);
538 let result = eval_similar_to_pure(&v1, &v2).unwrap();
539 match result {
540 Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
541 _ => panic!("Expected Float"),
542 }
543 }
544
545 #[test]
546 fn test_validate_options_weights_length() {
547 let opts = SimilarToOptions {
548 weights: Some(vec![0.5]),
549 ..Default::default()
550 };
551 assert!(validate_options(&opts, 2).is_err());
552 }
553
554 #[test]
555 fn test_validate_options_weights_sum() {
556 let opts = SimilarToOptions {
557 weights: Some(vec![0.5, 0.3]),
558 ..Default::default()
559 };
560 assert!(validate_options(&opts, 2).is_err());
561 }
562
563 #[test]
564 fn test_validate_options_ok() {
565 let opts = SimilarToOptions {
566 weights: Some(vec![0.7, 0.3]),
567 ..Default::default()
568 };
569 assert!(validate_options(&opts, 2).is_ok());
570 }
571
572 #[test]
573 fn test_validate_pair_fts_vector_query() {
574 assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
575 }
576
577 #[test]
578 fn test_validate_pair_vector_string_no_embed() {
579 let st = SourceType::Vector {
580 metric: DistanceMetric::Cosine,
581 has_embedding_config: false,
582 };
583 assert!(validate_pair(&st, false, true, 0).is_err());
584 }
585
586 #[test]
587 fn test_validate_pair_vector_string_with_embed() {
588 let st = SourceType::Vector {
589 metric: DistanceMetric::Cosine,
590 has_embedding_config: true,
591 };
592 assert!(validate_pair(&st, false, true, 0).is_ok());
593 }
594
595 #[test]
596 fn test_validate_pair_vector_vector() {
597 let st = SourceType::Vector {
598 metric: DistanceMetric::Cosine,
599 has_embedding_config: false,
600 };
601 assert!(validate_pair(&st, true, false, 0).is_ok());
602 }
603
604 #[test]
605 fn test_validate_pair_fts_string() {
606 assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
607 }
608
609 #[test]
610 fn test_fuse_scores_single() {
611 let opts = SimilarToOptions::default();
612 let score = fuse_scores(&[0.8], &opts).unwrap();
613 assert!((score - 0.8).abs() < 1e-6);
614 }
615
616 #[test]
617 fn test_fuse_scores_weighted() {
618 let opts = SimilarToOptions {
619 method: FusionMethod::Weighted,
620 weights: Some(vec![0.7, 0.3]),
621 ..Default::default()
622 };
623 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
624 assert!((score - 0.74).abs() < 1e-6);
625 }
626
627 #[test]
628 fn test_fuse_scores_rrf_fallback() {
629 let opts = SimilarToOptions::default();
630 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
631 assert!((score - 0.7).abs() < 1e-6);
633 }
634
635 #[test]
640 fn test_score_vectors_cosine_identical() {
641 let v = vec![1.0, 0.0, 0.0];
642 let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
643 assert!((score - 1.0).abs() < 1e-6);
644 }
645
646 #[test]
647 fn test_score_vectors_cosine_matches_raw() {
648 let a = vec![1.0, 0.0, 0.0];
650 let b = vec![0.8, 0.6, 0.0];
651 let raw = cosine_similarity(&a, &b).unwrap();
652 let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
653 assert!((raw - scored).abs() < 1e-6);
654 }
655
656 #[test]
657 fn test_score_vectors_l2() {
658 let a = vec![1.0, 0.0, 0.0];
660 let b = vec![0.0, 1.0, 0.0];
661 let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
662 assert!((score - 1.0 / 3.0).abs() < 1e-5);
663 }
664
665 #[test]
666 fn test_score_vectors_l2_identical() {
667 let v = vec![1.0, 0.0, 0.0];
668 let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
669 assert!((score - 1.0).abs() < 1e-6);
670 }
671
672 #[test]
673 fn test_score_vectors_dot() {
674 let a = vec![1.0, 0.0, 0.0];
676 let b = vec![0.8, 0.6, 0.0];
677 let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
678 assert!((score - 0.8).abs() < 1e-6);
679 }
680
681 #[test]
682 fn test_score_vectors_dot_identical() {
683 let v = vec![1.0, 0.0, 0.0];
684 let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
685 assert!((score - 1.0).abs() < 1e-6);
686 }
687
688 #[test]
689 fn test_score_vectors_dimension_mismatch() {
690 let a = vec![1.0, 0.0];
691 let b = vec![1.0, 0.0, 0.0];
692 assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
693 }
694}