1use arrow_array::builder::{FixedSizeListBuilder, Float32Builder, StringBuilder, UInt64Builder};
17use arrow_array::{Array, ArrayRef, FixedSizeListArray, Float32Array, Int64Array, RecordBatch};
18use arrow_schema::{DataType, Field, Schema, SchemaRef};
19use datafusion::common::Result as DFResult;
20use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
21use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
22use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23use futures::Stream;
24use std::any::Any;
25use std::collections::HashMap;
26use std::fmt;
27use std::pin::Pin;
28use std::sync::Arc;
29use std::task::{Context, Poll};
30use uni_common::Value;
31use uni_common::core::id::Vid;
32use uni_common::core::schema::{DistanceMetric, PropertyMeta};
33use uni_cypher::ast::Expr;
34use uni_plugin::traits::index::{IndexHandle, IndexKind};
35
36use crate::query::df_graph::GraphExecutionContext;
37use crate::query::df_graph::common::{
38 arrow_err, calculate_score, compute_plan_properties, evaluate_simple_expr, labels_data_type,
39};
40use crate::query::df_graph::scan::{property_field, resolve_property_type};
41
42#[derive(Clone)]
60pub(crate) enum VectorSource {
61 Native,
63 Plugin {
65 #[allow(dead_code)]
68 kind: IndexKind,
69 handle: Arc<dyn IndexHandle>,
71 },
72}
73
74impl fmt::Debug for VectorSource {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 match self {
77 Self::Native => f.write_str("Native"),
78 Self::Plugin { kind, .. } => f.debug_struct("Plugin").field("kind", kind).finish(),
79 }
80 }
81}
82
83pub struct GraphVectorKnnExec {
88 graph_ctx: Arc<GraphExecutionContext>,
90
91 label_id: u16,
93
94 label_name: String,
96
97 variable: String,
99
100 property: String,
102
103 query_expr: Expr,
105
106 k: usize,
108
109 threshold: Option<f32>,
111
112 params: HashMap<String, Value>,
114
115 target_properties: Vec<String>,
117
118 schema: SchemaRef,
120
121 properties: Arc<PlanProperties>,
123
124 source: VectorSource,
128
129 metrics: ExecutionPlanMetricsSet,
131}
132
133impl fmt::Debug for GraphVectorKnnExec {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.debug_struct("GraphVectorKnnExec")
136 .field("label_id", &self.label_id)
137 .field("variable", &self.variable)
138 .field("property", &self.property)
139 .field("k", &self.k)
140 .field("threshold", &self.threshold)
141 .finish()
142 }
143}
144
145impl GraphVectorKnnExec {
146 #[expect(clippy::too_many_arguments)]
160 pub fn new(
161 graph_ctx: Arc<GraphExecutionContext>,
162 label_id: u16,
163 label_name: impl Into<String>,
164 variable: impl Into<String>,
165 property: impl Into<String>,
166 query_expr: Expr,
167 k: usize,
168 threshold: Option<f32>,
169 params: HashMap<String, Value>,
170 target_properties: Vec<String>,
171 ) -> Self {
172 let variable = variable.into();
173 let property = property.into();
174 let label_name = label_name.into();
175
176 let uni_schema = graph_ctx.storage().schema_manager().schema();
178 let label_props = uni_schema.properties.get(label_name.as_str());
179
180 let schema = Self::build_schema(&variable, &target_properties, label_props);
181 let properties = compute_plan_properties(schema.clone());
182
183 Self {
184 graph_ctx,
185 label_id,
186 label_name,
187 variable,
188 property,
189 query_expr,
190 k,
191 threshold,
192 params,
193 target_properties,
194 schema,
195 properties,
196 source: VectorSource::Native,
197 metrics: ExecutionPlanMetricsSet::new(),
198 }
199 }
200
201 #[expect(clippy::too_many_arguments)]
208 pub fn with_plugin_source(
209 graph_ctx: Arc<GraphExecutionContext>,
210 label_id: u16,
211 label_name: impl Into<String>,
212 variable: impl Into<String>,
213 property: impl Into<String>,
214 query_expr: Expr,
215 k: usize,
216 threshold: Option<f32>,
217 params: HashMap<String, Value>,
218 target_properties: Vec<String>,
219 kind: IndexKind,
220 handle: Arc<dyn IndexHandle>,
221 ) -> Self {
222 let mut exec = Self::new(
223 graph_ctx,
224 label_id,
225 label_name,
226 variable,
227 property,
228 query_expr,
229 k,
230 threshold,
231 params,
232 target_properties,
233 );
234 exec.source = VectorSource::Plugin { kind, handle };
235 exec
236 }
237
238 fn build_schema(
246 variable: &str,
247 target_properties: &[String],
248 label_props: Option<&HashMap<String, PropertyMeta>>,
249 ) -> SchemaRef {
250 let mut fields = vec![
251 Field::new(format!("{}._vid", variable), DataType::UInt64, false),
252 Field::new(variable, DataType::Utf8, false),
253 Field::new(format!("{}._labels", variable), labels_data_type(), true),
254 Field::new(format!("{}._score", variable), DataType::Float32, true),
255 ];
256
257 for prop_name in target_properties {
259 let col_name = format!("{}.{}", variable, prop_name);
260 let arrow_type = resolve_property_type(prop_name, label_props);
261 let uni_type = label_props
262 .and_then(|p| p.get(prop_name))
263 .map(|m| &m.r#type);
264 fields.push(property_field(&col_name, arrow_type, uni_type));
265 }
266
267 Arc::new(Schema::new(fields))
268 }
269
270 fn evaluate_query_vector(&self) -> DFResult<Vec<f32>> {
272 let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
273
274 match value {
275 Value::Vector(vec) => Ok(vec),
276 Value::List(arr) => {
277 let mut vec = Vec::with_capacity(arr.len());
278 for v in arr {
279 if let Some(f) = v.as_f64() {
280 vec.push(f as f32);
281 } else {
282 return Err(datafusion::error::DataFusionError::Execution(
283 "Query vector must contain numbers".to_string(),
284 ));
285 }
286 }
287 Ok(vec)
288 }
289 _ => Err(datafusion::error::DataFusionError::Execution(
290 "Query vector must be a list or vector".to_string(),
291 )),
292 }
293 }
294
295 fn evaluate_query_multivector(&self) -> DFResult<Vec<Vec<f32>>> {
297 let value = evaluate_simple_expr(&self.query_expr, &self.params, &HashMap::new())?;
298 let Value::List(tokens) = value else {
299 return Err(datafusion::error::DataFusionError::Execution(
300 "Multi-vector query must be a list of vectors".to_string(),
301 ));
302 };
303 tokens
304 .into_iter()
305 .map(|tok| match tok {
306 Value::Vector(v) => Ok(v),
307 Value::List(inner) => inner
308 .iter()
309 .map(|x| {
310 x.as_f64().map(|f| f as f32).ok_or_else(|| {
311 datafusion::error::DataFusionError::Execution(
312 "Multi-vector query token must contain numbers".to_string(),
313 )
314 })
315 })
316 .collect(),
317 _ => Err(datafusion::error::DataFusionError::Execution(
318 "Multi-vector query must be a list of vectors".to_string(),
319 )),
320 })
321 .collect()
322 }
323
324 fn is_multivector_property(&self) -> bool {
326 let uni_schema = self.graph_ctx.storage().schema_manager().schema();
327 let label_props = uni_schema.properties.get(self.label_name.as_str());
328 matches!(
329 resolve_property_type(&self.property, label_props),
330 DataType::List(ref inner)
331 if matches!(inner.data_type(), DataType::FixedSizeList(_, _))
332 )
333 }
334}
335
336impl DisplayAs for GraphVectorKnnExec {
337 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 write!(
339 f,
340 "GraphVectorKnnExec: label={}, property={}, k={}, variable={}",
341 self.label_name, self.property, self.k, self.variable
342 )
343 }
344}
345
346impl ExecutionPlan for GraphVectorKnnExec {
347 fn name(&self) -> &str {
348 "GraphVectorKnnExec"
349 }
350
351 fn as_any(&self) -> &dyn Any {
352 self
353 }
354
355 fn schema(&self) -> SchemaRef {
356 self.schema.clone()
357 }
358
359 fn properties(&self) -> &Arc<PlanProperties> {
360 &self.properties
361 }
362
363 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
364 vec![]
365 }
366
367 fn with_new_children(
368 self: Arc<Self>,
369 children: Vec<Arc<dyn ExecutionPlan>>,
370 ) -> DFResult<Arc<dyn ExecutionPlan>> {
371 if !children.is_empty() {
372 return Err(datafusion::error::DataFusionError::Internal(
373 "GraphVectorKnnExec has no children".to_string(),
374 ));
375 }
376 Ok(self)
377 }
378
379 fn execute(
380 &self,
381 partition: usize,
382 _context: Arc<TaskContext>,
383 ) -> DFResult<SendableRecordBatchStream> {
384 let metrics = BaselineMetrics::new(&self.metrics, partition);
385
386 let (query_vector, multivec_query) = if self.is_multivector_property() {
390 (Vec::new(), Some(self.evaluate_query_multivector()?))
391 } else {
392 (self.evaluate_query_vector()?, None)
393 };
394
395 Ok(Box::pin(VectorKnnStream::new(
396 self.graph_ctx.clone(),
397 self.label_name.clone(),
398 self.variable.clone(),
399 self.property.clone(),
400 query_vector,
401 multivec_query,
402 self.k,
403 self.threshold,
404 self.target_properties.clone(),
405 self.schema.clone(),
406 self.source.clone(),
407 metrics,
408 )))
409 }
410
411 fn metrics(&self) -> Option<MetricsSet> {
412 Some(self.metrics.clone_inner())
413 }
414}
415
416enum VectorKnnState {
418 Init,
420 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
422 Done,
424}
425
426struct VectorKnnStream {
428 graph_ctx: Arc<GraphExecutionContext>,
430
431 label_name: String,
433
434 variable: String,
436
437 property: String,
439
440 query_vector: Vec<f32>,
442
443 multivec_query: Option<Vec<Vec<f32>>>,
446
447 k: usize,
449
450 threshold: Option<f32>,
452
453 target_properties: Vec<String>,
455
456 schema: SchemaRef,
458
459 source: VectorSource,
461
462 state: VectorKnnState,
464
465 metrics: BaselineMetrics,
467}
468
469impl VectorKnnStream {
470 #[expect(clippy::too_many_arguments)]
471 fn new(
472 graph_ctx: Arc<GraphExecutionContext>,
473 label_name: String,
474 variable: String,
475 property: String,
476 query_vector: Vec<f32>,
477 multivec_query: Option<Vec<Vec<f32>>>,
478 k: usize,
479 threshold: Option<f32>,
480 target_properties: Vec<String>,
481 schema: SchemaRef,
482 source: VectorSource,
483 metrics: BaselineMetrics,
484 ) -> Self {
485 Self {
486 graph_ctx,
487 label_name,
488 variable,
489 property,
490 query_vector,
491 multivec_query,
492 k,
493 threshold,
494 target_properties,
495 schema,
496 source,
497 state: VectorKnnState::Init,
498 metrics,
499 }
500 }
501}
502
503impl Stream for VectorKnnStream {
504 type Item = DFResult<RecordBatch>;
505
506 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
507 let metrics = self.metrics.clone();
508 let _timer = metrics.elapsed_compute().timer();
509 loop {
510 let state = std::mem::replace(&mut self.state, VectorKnnState::Done);
511
512 match state {
513 VectorKnnState::Init => {
514 let graph_ctx = self.graph_ctx.clone();
516 let label_name = self.label_name.clone();
517 let variable = self.variable.clone();
518 let property = self.property.clone();
519 let query_vector = self.query_vector.clone();
520 let multivec_query = self.multivec_query.clone();
521 let k = self.k;
522 let threshold = self.threshold;
523 let target_properties = self.target_properties.clone();
524 let schema = self.schema.clone();
525 let source = self.source.clone();
526
527 let fut = async move {
528 graph_ctx.check_timeout().map_err(|e| {
530 datafusion::error::DataFusionError::Execution(e.to_string())
531 })?;
532
533 execute_vector_search(
534 &graph_ctx,
535 &label_name,
536 &variable,
537 &property,
538 &query_vector,
539 multivec_query.as_deref(),
540 k,
541 threshold,
542 &target_properties,
543 &schema,
544 &source,
545 )
546 .await
547 };
548
549 self.state = VectorKnnState::Executing(Box::pin(fut));
550 }
552 VectorKnnState::Executing(mut fut) => match fut.as_mut().poll(cx) {
553 Poll::Ready(Ok(batch)) => {
554 self.state = VectorKnnState::Done;
555 self.metrics
556 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
557 return Poll::Ready(batch.map(Ok));
558 }
559 Poll::Ready(Err(e)) => {
560 self.state = VectorKnnState::Done;
561 return Poll::Ready(Some(Err(e)));
562 }
563 Poll::Pending => {
564 self.state = VectorKnnState::Executing(fut);
565 return Poll::Pending;
566 }
567 },
568 VectorKnnState::Done => {
569 return Poll::Ready(None);
570 }
571 }
572 }
573 }
574}
575
576impl RecordBatchStream for VectorKnnStream {
577 fn schema(&self) -> SchemaRef {
578 self.schema.clone()
579 }
580}
581
582#[expect(clippy::too_many_arguments)]
584async fn execute_vector_search(
585 graph_ctx: &GraphExecutionContext,
586 label_name: &str,
587 variable: &str,
588 property: &str,
589 query_vector: &[f32],
590 multivec_query: Option<&[Vec<f32>]>,
591 k: usize,
592 threshold: Option<f32>,
593 target_properties: &[String],
594 schema: &SchemaRef,
595 source: &VectorSource,
596) -> DFResult<Option<RecordBatch>> {
597 let storage = graph_ctx.storage();
598
599 let results = retrieve_vid_scores(
601 graph_ctx,
602 label_name,
603 property,
604 query_vector,
605 multivec_query,
606 k,
607 source,
608 )
609 .await?;
610
611 let default_metric = if multivec_query.is_some() {
615 DistanceMetric::Cosine
616 } else {
617 DistanceMetric::L2
618 };
619 let metric = storage
620 .schema_manager()
621 .schema()
622 .vector_index_for_property(label_name, property)
623 .map(|cfg| cfg.metric.clone())
624 .unwrap_or(default_metric);
625
626 let mut vids = Vec::new();
628 let mut scores = Vec::new();
629
630 for (vid, value) in results {
631 let similarity = if multivec_query.is_some() {
634 value
635 } else {
636 calculate_score(value, &metric)
637 };
638
639 if let Some(thresh) = threshold
640 && similarity < thresh
641 {
642 continue;
643 }
644
645 vids.push(vid);
646 scores.push(similarity);
647 }
648
649 if vids.is_empty() {
650 return Ok(Some(RecordBatch::new_empty(schema.clone())));
651 }
652
653 let batch = build_result_batch(
655 &vids,
656 &scores,
657 variable,
658 target_properties,
659 label_name,
660 graph_ctx,
661 schema,
662 )
663 .await?;
664 Ok(Some(batch))
665}
666
667async fn retrieve_vid_scores(
679 graph_ctx: &GraphExecutionContext,
680 label_name: &str,
681 property: &str,
682 query_vector: &[f32],
683 multivec_query: Option<&[Vec<f32>]>,
684 k: usize,
685 source: &VectorSource,
686) -> DFResult<Vec<(Vid, f32)>> {
687 match source {
688 VectorSource::Native => {
689 let storage = graph_ctx.storage();
690 let query_ctx = graph_ctx.query_context();
691 if let Some(mv) = multivec_query {
698 let property_manager = graph_ctx.property_manager();
699 let metric = storage
700 .schema_manager()
701 .schema()
702 .vector_index_for_property(label_name, property)
703 .map(|cfg| cfg.metric.clone())
704 .unwrap_or(DistanceMetric::Cosine);
705 let retrieval_k = k
706 .saturating_mul(
707 crate::query::df_graph::search_procedures::MULTIVECTOR_OVER_FETCH,
708 )
709 .max(k);
710 let (ranked, _props) =
711 crate::query::df_graph::search_procedures::multivector_rerank(
712 storage,
713 property_manager,
714 &query_ctx,
715 label_name,
716 property,
717 mv,
718 k,
719 retrieval_k,
720 None,
721 uni_store::VectorQueryOpts::default(),
722 &metric,
723 )
724 .await?;
725 return Ok(ranked);
726 }
727 storage
728 .vector_search(
729 label_name,
730 property,
731 query_vector,
732 k,
733 None,
734 uni_store::VectorQueryOpts::default(),
735 Some(&query_ctx),
736 )
737 .await
738 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))
739 }
740 VectorSource::Plugin { handle, .. } => {
741 let dim = i32::try_from(query_vector.len()).map_err(|_| {
744 datafusion::error::DataFusionError::Execution(
745 "query vector exceeds i32::MAX dimensions".to_string(),
746 )
747 })?;
748 let item_field = Arc::new(Field::new("item", DataType::Float32, true));
749 let mut fsl_builder =
750 FixedSizeListBuilder::new(Float32Builder::with_capacity(query_vector.len()), dim)
751 .with_field(Arc::clone(&item_field));
752 for &v in query_vector {
753 fsl_builder.values().append_value(v);
754 }
755 fsl_builder.append(true);
756 let fsl: FixedSizeListArray = fsl_builder.finish();
757
758 let query_schema = Arc::new(Schema::new(vec![Field::new(
759 "vector",
760 DataType::FixedSizeList(item_field, dim),
761 false,
762 )]));
763 let query_batch =
764 RecordBatch::try_new(query_schema, vec![Arc::new(fsl)]).map_err(arrow_err)?;
765
766 let result = handle.probe(&query_batch, k).map_err(|e| {
767 datafusion::error::DataFusionError::Execution(format!(
768 "IndexHandle::probe failed: {e:?}"
769 ))
770 })?;
771
772 let vid_col = result
775 .column_by_name("vid")
776 .ok_or_else(|| {
777 datafusion::error::DataFusionError::Execution(
778 "IndexHandle::probe result missing `vid` column".to_string(),
779 )
780 })?
781 .as_any()
782 .downcast_ref::<Int64Array>()
783 .ok_or_else(|| {
784 datafusion::error::DataFusionError::Execution(
785 "IndexHandle::probe result `vid` column is not Int64".to_string(),
786 )
787 })?;
788 let dist_col = result
789 .column_by_name("distance")
790 .ok_or_else(|| {
791 datafusion::error::DataFusionError::Execution(
792 "IndexHandle::probe result missing `distance` column".to_string(),
793 )
794 })?
795 .as_any()
796 .downcast_ref::<Float32Array>()
797 .ok_or_else(|| {
798 datafusion::error::DataFusionError::Execution(
799 "IndexHandle::probe result `distance` column is not Float32".to_string(),
800 )
801 })?;
802
803 let mut pairs = Vec::with_capacity(result.num_rows());
804 for i in 0..result.num_rows() {
805 if vid_col.is_null(i) {
806 continue;
807 }
808 let vid_i64 = vid_col.value(i);
809 let dist = if dist_col.is_null(i) {
810 f32::INFINITY
811 } else {
812 dist_col.value(i)
813 };
814 pairs.push((Vid::from(vid_i64 as u64), dist));
815 }
816 Ok(pairs)
817 }
818 }
819}
820
821async fn build_result_batch(
823 vids: &[Vid],
824 scores: &[f32],
825 _variable: &str,
826 target_properties: &[String],
827 label_name: &str,
828 graph_ctx: &GraphExecutionContext,
829 schema: &SchemaRef,
830) -> DFResult<RecordBatch> {
831 let num_rows = vids.len();
832
833 let mut vid_builder = UInt64Builder::with_capacity(num_rows);
835 for vid in vids {
836 vid_builder.append_value(vid.as_u64());
837 }
838
839 let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
841 for vid in vids {
842 var_builder.append_value(vid.to_string());
843 }
844
845 let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
847 for _vid in vids {
848 labels_builder.values().append_value(label_name);
849 labels_builder.append(true);
850 }
851
852 let mut score_builder = Float32Builder::with_capacity(num_rows);
854 for &score in scores {
855 score_builder.append_value(score);
856 }
857
858 let mut columns: Vec<ArrayRef> = vec![
859 Arc::new(vid_builder.finish()),
860 Arc::new(var_builder.finish()),
861 Arc::new(labels_builder.finish()),
862 Arc::new(score_builder.finish()),
863 ];
864
865 if !target_properties.is_empty() {
867 let property_manager = graph_ctx.property_manager();
868 let query_ctx = graph_ctx.query_context();
869
870 let props_map = property_manager
871 .get_batch_vertex_props_for_label(vids, label_name, Some(&query_ctx))
872 .await
873 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
874
875 let uni_schema = graph_ctx.storage().schema_manager().schema();
876 let label_props = uni_schema.properties.get(label_name);
877
878 for prop_name in target_properties {
879 let data_type = resolve_property_type(prop_name, label_props);
880 let column = crate::query::df_graph::scan::build_property_column_static(
881 vids, &props_map, prop_name, &data_type,
882 )?;
883 columns.push(column);
884 }
885 }
886
887 RecordBatch::try_new(schema.clone(), columns).map_err(arrow_err)
888}
889
890#[cfg(test)]
891mod tests {
892 use super::*;
893 use uni_cypher::ast::CypherLiteral;
894
895 #[test]
896 fn test_build_schema() {
897 let schema = GraphVectorKnnExec::build_schema("n", &[], None);
898
899 assert_eq!(schema.fields().len(), 4);
900 assert_eq!(schema.field(0).name(), "n._vid");
901 assert_eq!(schema.field(1).name(), "n");
902 assert_eq!(schema.field(2).name(), "n._labels");
903 assert_eq!(schema.field(3).name(), "n._score");
904 }
905
906 #[test]
907 fn test_evaluate_literal_list() {
908 let expr = Expr::List(vec![
909 Expr::Literal(CypherLiteral::Float(0.1)),
910 Expr::Literal(CypherLiteral::Float(0.2)),
911 Expr::Literal(CypherLiteral::Float(0.3)),
912 ]);
913
914 let result = evaluate_simple_expr(&expr, &HashMap::new(), &HashMap::new()).unwrap();
915 match result {
916 Value::List(arr) => {
917 assert_eq!(arr.len(), 3);
918 }
919 _ => panic!("Expected list"),
920 }
921 }
922
923 #[test]
924 fn test_evaluate_parameter() {
925 let expr = Expr::Parameter("query".to_string());
926 let mut params = HashMap::new();
927 params.insert(
928 "query".to_string(),
929 Value::List(vec![Value::Float(0.1), Value::Float(0.2)]),
930 );
931
932 let result = evaluate_simple_expr(&expr, ¶ms, &HashMap::new()).unwrap();
933 match result {
934 Value::List(arr) => {
935 assert_eq!(arr.len(), 2);
936 }
937 _ => panic!("Expected list"),
938 }
939 }
940
941 #[test]
942 fn test_build_schema_with_extra_properties() {
943 let extra_props = vec!["name".to_string(), "embedding".to_string()];
944 let schema = GraphVectorKnnExec::build_schema("doc", &extra_props, None);
945
946 assert!(schema.field_with_name("doc._vid").is_ok());
948 assert!(schema.field_with_name("doc").is_ok());
949 assert!(schema.field_with_name("doc._score").is_ok());
950 assert!(
951 schema.field_with_name("doc.name").is_ok(),
952 "Extra property 'name' should be in schema"
953 );
954 assert!(
955 schema.field_with_name("doc.embedding").is_ok(),
956 "Extra property 'embedding' should be in schema"
957 );
958 }
959
960 #[test]
961 fn test_evaluate_variable() {
962 let expr = Expr::Variable("x".to_string());
964 let mut variables = HashMap::new();
965 variables.insert(
966 "x".to_string(),
967 Value::List(vec![Value::Float(0.5), Value::Float(0.6)]),
968 );
969
970 let result = evaluate_simple_expr(&expr, &HashMap::new(), &variables).unwrap();
971 match result {
972 Value::List(arr) => {
973 assert_eq!(arr.len(), 2);
974 }
975 _ => panic!("Expected list, got {:?}", result),
976 }
977 }
978}