1use anyhow::Result;
15use uni_common::Value;
16use uni_common::core::schema::{DistanceMetric, Schema};
17
18use crate::query::df_graph::common::calculate_score;
19use crate::query::fusion;
20
21#[derive(Debug, thiserror::Error)]
23pub enum SimilarToError {
24 #[error("similar_to: property '{label}.{property}' has no vector or full-text index")]
25 NoIndex { label: String, property: String },
26
27 #[error(
28 "similar_to: source {source_index} is FTS-indexed but query is a vector (FTS cannot score against vectors)"
29 )]
30 TypeMismatch { source_index: usize },
31
32 #[error(
33 "similar_to: source {source_index} is a vector property but query is a string, and the index has no embedding config for auto-embedding"
34 )]
35 NoEmbeddingConfig { source_index: usize },
36
37 #[error("similar_to: weights length ({weights_len}) != sources length ({sources_len})")]
38 WeightsLengthMismatch {
39 weights_len: usize,
40 sources_len: usize,
41 },
42
43 #[error("similar_to: weights must sum to 1.0 (got {sum})")]
44 WeightsNotNormalized { sum: f32 },
45
46 #[error("similar_to: unknown method '{method}', expected 'rrf' or 'weighted'")]
47 InvalidMethod { method: String },
48
49 #[error("similar_to: {message}")]
50 InvalidOption { message: String },
51
52 #[error("similar_to: vector dimensions mismatch: {a} vs {b}")]
53 DimensionMismatch { a: usize, b: usize },
54
55 #[error("similar_to: expected vector or list of numbers, got {actual}")]
56 InvalidVectorValue { actual: String },
57
58 #[error("similar_to: weighted fusion requires 'weights' option")]
59 WeightsRequired,
60
61 #[error("similar_to takes 2 or 3 arguments (sources, queries [, options]), got {count}")]
62 InvalidArity { count: usize },
63
64 #[error("similar_to requires GraphExecutionContext")]
65 NoGraphContext,
66}
67
68#[derive(Debug, Clone, Default, PartialEq)]
70pub enum FusionMethod {
71 #[default]
74 Rrf,
75 Weighted,
77}
78
79#[derive(Debug, Clone)]
81pub struct SimilarToOptions {
82 pub method: FusionMethod,
84 pub weights: Option<Vec<f32>>,
86 pub k: usize,
88 pub fts_k: f32,
90}
91
92impl Default for SimilarToOptions {
93 fn default() -> Self {
94 Self {
95 method: FusionMethod::Rrf,
96 weights: None,
97 k: 60,
98 fts_k: 1.0,
99 }
100 }
101}
102
103pub fn parse_options(value: &Value) -> Result<SimilarToOptions, SimilarToError> {
105 let map = match value {
106 Value::Map(m) => m,
107 Value::Null => return Ok(SimilarToOptions::default()),
108 _ => {
109 return Err(SimilarToError::InvalidOption {
110 message: format!("options must be a map, got {:?}", value),
111 });
112 }
113 };
114
115 let mut opts = SimilarToOptions::default();
116
117 if let Some(method_val) = map.get("method") {
118 match method_val.as_str() {
119 Some("rrf") => opts.method = FusionMethod::Rrf,
120 Some("weighted") => opts.method = FusionMethod::Weighted,
121 Some(other) => {
122 return Err(SimilarToError::InvalidMethod {
123 method: other.to_string(),
124 });
125 }
126 None => {
127 return Err(SimilarToError::InvalidOption {
128 message: "'method' must be a string ('rrf' or 'weighted')".to_string(),
129 });
130 }
131 }
132 }
133
134 if let Some(weights_val) = map.get("weights") {
135 match weights_val {
136 Value::List(list) => {
137 let weights: Result<Vec<f32>, SimilarToError> = list
138 .iter()
139 .map(|v| {
140 v.as_f64()
141 .map(|f| f as f32)
142 .ok_or_else(|| SimilarToError::InvalidOption {
143 message: "weight must be a number".to_string(),
144 })
145 })
146 .collect();
147 opts.weights = Some(weights?);
148 }
149 _ => {
150 return Err(SimilarToError::InvalidOption {
151 message: "'weights' must be a list of numbers".to_string(),
152 });
153 }
154 }
155 }
156
157 if let Some(k_val) = map.get("k") {
158 opts.k = k_val
159 .as_i64()
160 .ok_or_else(|| SimilarToError::InvalidOption {
161 message: "'k' must be an integer".to_string(),
162 })? as usize;
163 }
164
165 if let Some(fts_k_val) = map.get("fts_k") {
166 opts.fts_k = fts_k_val
167 .as_f64()
168 .ok_or_else(|| SimilarToError::InvalidOption {
169 message: "'fts_k' must be a number".to_string(),
170 })? as f32;
171 }
172
173 Ok(opts)
174}
175
176#[derive(Debug, Clone)]
178pub enum SourceType {
179 Vector {
181 metric: DistanceMetric,
182 has_embedding_config: bool,
183 },
184 Fts,
186}
187
188pub fn resolve_source_type(
190 schema: &Schema,
191 label: &str,
192 property: &str,
193) -> Result<SourceType, SimilarToError> {
194 if let Some(vec_config) = schema.vector_index_for_property(label, property) {
196 return Ok(SourceType::Vector {
197 metric: vec_config.metric.clone(),
198 has_embedding_config: vec_config.embedding_config.is_some(),
199 });
200 }
201
202 if schema
204 .fulltext_index_for_property(label, property)
205 .is_some()
206 {
207 return Ok(SourceType::Fts);
208 }
209
210 Err(SimilarToError::NoIndex {
211 label: label.to_string(),
212 property: property.to_string(),
213 })
214}
215
216pub fn cosine_similarity(a: &[f32], b: &[f32]) -> Result<f32, SimilarToError> {
218 if a.len() != b.len() {
219 return Err(SimilarToError::DimensionMismatch {
220 a: a.len(),
221 b: b.len(),
222 });
223 }
224
225 let mut dot = 0.0f64;
226 let mut mag1 = 0.0f64;
227 let mut mag2 = 0.0f64;
228 for (x, y) in a.iter().zip(b.iter()) {
229 let x = *x as f64;
230 let y = *y as f64;
231 dot += x * y;
232 mag1 += x * x;
233 mag2 += y * y;
234 }
235 let mag1 = mag1.sqrt();
236 let mag2 = mag2.sqrt();
237
238 if mag1 == 0.0 || mag2 == 0.0 {
239 return Ok(0.0);
240 }
241
242 let sim = (dot / (mag1 * mag2)) as f32;
244 Ok(sim.clamp(-1.0, 1.0))
245}
246
247pub fn score_vectors(a: &[f32], b: &[f32], metric: &DistanceMetric) -> Result<f32, SimilarToError> {
254 if a.len() != b.len() {
255 return Err(SimilarToError::DimensionMismatch {
256 a: a.len(),
257 b: b.len(),
258 });
259 }
260 let distance = metric.compute_distance(a, b);
261 match metric {
262 DistanceMetric::Cosine => cosine_similarity(a, b),
263 DistanceMetric::Dot => Ok(-distance),
266 _ => Ok(calculate_score(distance, metric)),
268 }
269}
270
271pub fn normalize_bm25(score: f32, fts_k: f32) -> f32 {
275 if score <= 0.0 {
276 return 0.0;
277 }
278 score / (score + fts_k)
279}
280
281pub fn eval_similar_to_pure(v1: &Value, v2: &Value) -> Result<Value> {
289 let has_list = matches!(v1, Value::List(_)) || matches!(v2, Value::List(_));
291 let f64_vecs = has_list
292 .then(|| value_to_f64_vec(v1).ok().zip(value_to_f64_vec(v2).ok()))
293 .flatten();
294 if let Some((vec1, vec2)) = f64_vecs {
295 let sim = cosine_similarity_f64(&vec1, &vec2)?;
296 return Ok(Value::Float(sim));
297 }
298 let vec1 = value_to_f32_vec(v1)?;
300 let vec2 = value_to_f32_vec(v2)?;
301 let sim = cosine_similarity(&vec1, &vec2)?;
302 Ok(Value::Float(sim as f64))
303}
304
305fn cosine_similarity_f64(a: &[f64], b: &[f64]) -> Result<f64, SimilarToError> {
307 if a.len() != b.len() {
308 return Err(SimilarToError::DimensionMismatch {
309 a: a.len(),
310 b: b.len(),
311 });
312 }
313 let mut dot = 0.0f64;
314 let mut mag1 = 0.0f64;
315 let mut mag2 = 0.0f64;
316 for (x, y) in a.iter().zip(b.iter()) {
317 dot += x * y;
318 mag1 += x * x;
319 mag2 += y * y;
320 }
321 let mag1 = mag1.sqrt();
322 let mag2 = mag2.sqrt();
323 if mag1 == 0.0 || mag2 == 0.0 {
324 return Ok(0.0);
325 }
326 Ok((dot / (mag1 * mag2)).clamp(-1.0, 1.0))
327}
328
329fn value_to_f64_vec(v: &Value) -> Result<Vec<f64>, SimilarToError> {
331 match v {
332 Value::Vector(vec) => Ok(vec.iter().map(|&x| x as f64).collect()),
333 Value::List(list) => list
334 .iter()
335 .map(|v| {
336 v.as_f64().ok_or_else(|| SimilarToError::InvalidOption {
337 message: "vector element must be a number".to_string(),
338 })
339 })
340 .collect(),
341 _ => Err(SimilarToError::InvalidVectorValue {
342 actual: format!("{v:?}"),
343 }),
344 }
345}
346
347pub fn value_to_f32_vec(v: &Value) -> Result<Vec<f32>, SimilarToError> {
349 match v {
350 Value::Vector(vec) => Ok(vec.clone()),
351 Value::List(list) => list
352 .iter()
353 .map(|v| {
354 v.as_f64()
355 .map(|f| f as f32)
356 .ok_or_else(|| SimilarToError::InvalidOption {
357 message: "vector element must be a number".to_string(),
358 })
359 })
360 .collect(),
361 _ => Err(SimilarToError::InvalidVectorValue {
362 actual: format!("{v:?}"),
363 }),
364 }
365}
366
367pub fn validate_options(opts: &SimilarToOptions, num_sources: usize) -> Result<(), SimilarToError> {
369 if let Some(ref weights) = opts.weights {
370 if weights.len() != num_sources {
371 return Err(SimilarToError::WeightsLengthMismatch {
372 weights_len: weights.len(),
373 sources_len: num_sources,
374 });
375 }
376 let sum: f32 = weights.iter().sum();
377 if (sum - 1.0).abs() > 0.01 {
378 return Err(SimilarToError::WeightsNotNormalized { sum });
379 }
380 }
381 Ok(())
382}
383
384pub fn validate_pair(
389 source_type: &SourceType,
390 query_is_vector: bool,
391 query_is_string: bool,
392 source_index: usize,
393) -> Result<(), SimilarToError> {
394 match source_type {
395 SourceType::Fts if query_is_vector => Err(SimilarToError::TypeMismatch { source_index }),
396 SourceType::Vector {
397 has_embedding_config: false,
398 ..
399 } if query_is_string => Err(SimilarToError::NoEmbeddingConfig { source_index }),
400 _ => Ok(()),
401 }
402}
403
404pub fn fuse_scores(scores: &[f32], opts: &SimilarToOptions) -> Result<f32, SimilarToError> {
406 if scores.len() == 1 {
407 return Ok(scores[0]);
408 }
409
410 match opts.method {
411 FusionMethod::Weighted => {
412 let weights = opts
413 .weights
414 .as_ref()
415 .ok_or(SimilarToError::WeightsRequired)?;
416 Ok(fusion::fuse_weighted_multi(scores, weights))
417 }
418 FusionMethod::Rrf => {
419 let (score, _) = fusion::fuse_rrf_point(scores);
423 Ok(score)
424 }
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use std::collections::HashMap;
431
432 use super::*;
433
434 #[test]
435 fn test_parse_options_default() {
436 let opts = parse_options(&Value::Null).unwrap();
437 assert_eq!(opts.method, FusionMethod::Rrf);
438 assert_eq!(opts.k, 60);
439 assert!((opts.fts_k - 1.0).abs() < 1e-6);
440 assert!(opts.weights.is_none());
441 }
442
443 #[test]
444 fn test_parse_options_weighted() {
445 let mut map = HashMap::new();
446 map.insert("method".to_string(), Value::String("weighted".to_string()));
447 map.insert(
448 "weights".to_string(),
449 Value::List(vec![Value::Float(0.7), Value::Float(0.3)]),
450 );
451 let opts = parse_options(&Value::Map(map)).unwrap();
452 assert_eq!(opts.method, FusionMethod::Weighted);
453 let weights = opts.weights.unwrap();
454 assert!((weights[0] - 0.7).abs() < 1e-6);
455 assert!((weights[1] - 0.3).abs() < 1e-6);
456 }
457
458 #[test]
459 fn test_parse_options_rrf_with_k() {
460 let mut map = HashMap::new();
461 map.insert("method".to_string(), Value::String("rrf".to_string()));
462 map.insert("k".to_string(), Value::Int(30));
463 let opts = parse_options(&Value::Map(map)).unwrap();
464 assert_eq!(opts.method, FusionMethod::Rrf);
465 assert_eq!(opts.k, 30);
466 }
467
468 #[test]
469 fn test_parse_options_fts_k() {
470 let mut map = HashMap::new();
471 map.insert("fts_k".to_string(), Value::Float(2.0));
472 let opts = parse_options(&Value::Map(map)).unwrap();
473 assert!((opts.fts_k - 2.0).abs() < 1e-6);
474 }
475
476 #[test]
477 fn test_parse_options_invalid_method() {
478 let mut map = HashMap::new();
479 map.insert("method".to_string(), Value::String("invalid".to_string()));
480 assert!(parse_options(&Value::Map(map)).is_err());
481 }
482
483 #[test]
484 fn test_cosine_similarity_identical() {
485 let v = vec![1.0, 0.0, 0.0];
486 let sim = cosine_similarity(&v, &v).unwrap();
487 assert!((sim - 1.0).abs() < 1e-6);
488 }
489
490 #[test]
491 fn test_cosine_similarity_orthogonal() {
492 let a = vec![1.0, 0.0];
493 let b = vec![0.0, 1.0];
494 let sim = cosine_similarity(&a, &b).unwrap();
495 assert!((sim - 0.0).abs() < 1e-6);
496 }
497
498 #[test]
499 fn test_cosine_similarity_opposite() {
500 let a = vec![1.0, 0.0];
501 let b = vec![-1.0, 0.0];
502 let sim = cosine_similarity(&a, &b).unwrap();
503 assert!((sim - (-1.0)).abs() < 1e-6);
504 }
505
506 #[test]
507 fn test_cosine_similarity_dimension_mismatch() {
508 let a = vec![1.0, 0.0];
509 let b = vec![1.0, 0.0, 0.0];
510 assert!(cosine_similarity(&a, &b).is_err());
511 }
512
513 #[test]
514 fn test_normalize_bm25() {
515 assert!((normalize_bm25(0.0, 1.0) - 0.0).abs() < 1e-6);
516 assert!((normalize_bm25(1.0, 1.0) - 0.5).abs() < 1e-6);
517 assert!((normalize_bm25(9.0, 1.0) - 0.9).abs() < 1e-6);
518 assert!((normalize_bm25(99.0, 1.0) - 0.99).abs() < 1e-4);
519 }
520
521 #[test]
522 fn test_normalize_bm25_custom_k() {
523 assert!((normalize_bm25(2.0, 2.0) - 0.5).abs() < 1e-6);
524 }
525
526 #[test]
527 fn test_eval_similar_to_pure() {
528 let v1 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
529 let v2 = Value::List(vec![Value::Float(1.0), Value::Float(0.0)]);
530 let result = eval_similar_to_pure(&v1, &v2).unwrap();
531 match result {
532 Value::Float(f) => assert!((f - 1.0).abs() < 1e-6),
533 _ => panic!("Expected Float"),
534 }
535 }
536
537 #[test]
538 fn test_eval_similar_to_pure_vector_type() {
539 let v1 = Value::Vector(vec![1.0, 0.0]);
540 let v2 = Value::Vector(vec![0.0, 1.0]);
541 let result = eval_similar_to_pure(&v1, &v2).unwrap();
542 match result {
543 Value::Float(f) => assert!((f - 0.0).abs() < 1e-6),
544 _ => panic!("Expected Float"),
545 }
546 }
547
548 #[test]
549 fn test_validate_options_weights_length() {
550 let opts = SimilarToOptions {
551 weights: Some(vec![0.5]),
552 ..Default::default()
553 };
554 assert!(validate_options(&opts, 2).is_err());
555 }
556
557 #[test]
558 fn test_validate_options_weights_sum() {
559 let opts = SimilarToOptions {
560 weights: Some(vec![0.5, 0.3]),
561 ..Default::default()
562 };
563 assert!(validate_options(&opts, 2).is_err());
564 }
565
566 #[test]
567 fn test_validate_options_ok() {
568 let opts = SimilarToOptions {
569 weights: Some(vec![0.7, 0.3]),
570 ..Default::default()
571 };
572 assert!(validate_options(&opts, 2).is_ok());
573 }
574
575 #[test]
576 fn test_validate_pair_fts_vector_query() {
577 assert!(validate_pair(&SourceType::Fts, true, false, 0).is_err());
578 }
579
580 #[test]
581 fn test_validate_pair_vector_string_no_embed() {
582 let st = SourceType::Vector {
583 metric: DistanceMetric::Cosine,
584 has_embedding_config: false,
585 };
586 assert!(validate_pair(&st, false, true, 0).is_err());
587 }
588
589 #[test]
590 fn test_validate_pair_vector_string_with_embed() {
591 let st = SourceType::Vector {
592 metric: DistanceMetric::Cosine,
593 has_embedding_config: true,
594 };
595 assert!(validate_pair(&st, false, true, 0).is_ok());
596 }
597
598 #[test]
599 fn test_validate_pair_vector_vector() {
600 let st = SourceType::Vector {
601 metric: DistanceMetric::Cosine,
602 has_embedding_config: false,
603 };
604 assert!(validate_pair(&st, true, false, 0).is_ok());
605 }
606
607 #[test]
608 fn test_validate_pair_fts_string() {
609 assert!(validate_pair(&SourceType::Fts, false, true, 0).is_ok());
610 }
611
612 #[test]
613 fn test_fuse_scores_single() {
614 let opts = SimilarToOptions::default();
615 let score = fuse_scores(&[0.8], &opts).unwrap();
616 assert!((score - 0.8).abs() < 1e-6);
617 }
618
619 #[test]
620 fn test_fuse_scores_weighted() {
621 let opts = SimilarToOptions {
622 method: FusionMethod::Weighted,
623 weights: Some(vec![0.7, 0.3]),
624 ..Default::default()
625 };
626 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
627 assert!((score - 0.74).abs() < 1e-6);
628 }
629
630 #[test]
631 fn test_fuse_scores_rrf_fallback() {
632 let opts = SimilarToOptions::default();
633 let score = fuse_scores(&[0.8, 0.6], &opts).unwrap();
634 assert!((score - 0.7).abs() < 1e-6);
636 }
637
638 #[test]
643 fn test_score_vectors_cosine_identical() {
644 let v = vec![1.0, 0.0, 0.0];
645 let score = score_vectors(&v, &v, &DistanceMetric::Cosine).unwrap();
646 assert!((score - 1.0).abs() < 1e-6);
647 }
648
649 #[test]
650 fn test_score_vectors_cosine_matches_raw() {
651 let a = vec![1.0, 0.0, 0.0];
653 let b = vec![0.8, 0.6, 0.0];
654 let raw = cosine_similarity(&a, &b).unwrap();
655 let scored = score_vectors(&a, &b, &DistanceMetric::Cosine).unwrap();
656 assert!((raw - scored).abs() < 1e-6);
657 }
658
659 #[test]
660 fn test_score_vectors_l2() {
661 let a = vec![1.0, 0.0, 0.0];
663 let b = vec![0.0, 1.0, 0.0];
664 let score = score_vectors(&a, &b, &DistanceMetric::L2).unwrap();
665 assert!((score - 1.0 / 3.0).abs() < 1e-5);
666 }
667
668 #[test]
669 fn test_score_vectors_l2_identical() {
670 let v = vec![1.0, 0.0, 0.0];
671 let score = score_vectors(&v, &v, &DistanceMetric::L2).unwrap();
672 assert!((score - 1.0).abs() < 1e-6);
673 }
674
675 #[test]
676 fn test_score_vectors_dot() {
677 let a = vec![1.0, 0.0, 0.0];
679 let b = vec![0.8, 0.6, 0.0];
680 let score = score_vectors(&a, &b, &DistanceMetric::Dot).unwrap();
681 assert!((score - 0.8).abs() < 1e-6);
682 }
683
684 #[test]
685 fn test_score_vectors_dot_identical() {
686 let v = vec![1.0, 0.0, 0.0];
687 let score = score_vectors(&v, &v, &DistanceMetric::Dot).unwrap();
688 assert!((score - 1.0).abs() < 1e-6);
689 }
690
691 #[test]
692 fn test_score_vectors_dimension_mismatch() {
693 let a = vec![1.0, 0.0];
694 let b = vec![1.0, 0.0, 0.0];
695 assert!(score_vectors(&a, &b, &DistanceMetric::Cosine).is_err());
696 }
697}