1use anyhow::{anyhow, Result};
2
3use crate::data::datatable::DataValue;
4use crate::sql::functions::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5
6fn get_vector(value: &DataValue) -> Result<Vec<f64>> {
8 match value {
9 DataValue::Vector(v) => Ok(v.clone()),
10 DataValue::String(s) => parse_vector_string(s),
11 _ => Err(anyhow!("Expected vector, got {:?}", value.data_type())),
12 }
13}
14
15fn parse_vector_string(s: &str) -> Result<Vec<f64>> {
17 let trimmed = s.trim();
18
19 let content = if trimmed.starts_with('[') && trimmed.ends_with(']') {
21 &trimmed[1..trimmed.len() - 1]
22 } else {
23 trimmed
24 };
25
26 let components: Result<Vec<f64>> = if content.contains(',') {
28 content
29 .split(',')
30 .map(|s| {
31 s.trim()
32 .parse::<f64>()
33 .map_err(|e| anyhow!("Failed to parse vector component '{}': {}", s.trim(), e))
34 })
35 .collect()
36 } else {
37 content
38 .split_whitespace()
39 .map(|s| {
40 s.parse::<f64>()
41 .map_err(|e| anyhow!("Failed to parse vector component '{}': {}", s, e))
42 })
43 .collect()
44 };
45
46 components
47}
48
49pub struct VecFunction;
51
52impl SqlFunction for VecFunction {
53 fn signature(&self) -> FunctionSignature {
54 FunctionSignature {
55 name: "VEC",
56 category: FunctionCategory::Mathematical,
57 arg_count: ArgCount::Variadic,
58 description: "Construct a vector from numeric components",
59 returns: "Vector",
60 examples: vec!["SELECT VEC(1, 2, 3)", "SELECT VEC(10, 20)"],
61 }
62 }
63
64 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
65 if args.is_empty() {
66 return Err(anyhow!("VEC() requires at least one argument"));
67 }
68
69 let components: Result<Vec<f64>> = args
70 .iter()
71 .map(|arg| match arg {
72 DataValue::Integer(i) => Ok(*i as f64),
73 DataValue::Float(f) => Ok(*f),
74 DataValue::Null => Err(anyhow!("Cannot create vector with NULL component")),
75 _ => Err(anyhow!(
76 "VEC() requires numeric arguments, got {:?}",
77 arg.data_type()
78 )),
79 })
80 .collect();
81
82 Ok(DataValue::Vector(components?))
83 }
84}
85
86pub struct VecAddFunction;
88
89impl SqlFunction for VecAddFunction {
90 fn signature(&self) -> FunctionSignature {
91 FunctionSignature {
92 name: "VEC_ADD",
93 category: FunctionCategory::Mathematical,
94 arg_count: ArgCount::Fixed(2),
95 description: "Add two vectors element-wise",
96 returns: "Vector",
97 examples: vec![
98 "SELECT VEC_ADD(VEC(1,2,3), VEC(4,5,6))",
99 "SELECT VEC_ADD(position, velocity)",
100 ],
101 }
102 }
103
104 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
105 self.validate_args(args)?;
106
107 let v1 = get_vector(&args[0])?;
108 let v2 = get_vector(&args[1])?;
109
110 if v1.len() != v2.len() {
111 return Err(anyhow!(
112 "Vector dimension mismatch: {} != {}",
113 v1.len(),
114 v2.len()
115 ));
116 }
117
118 let result: Vec<f64> = v1.iter().zip(v2.iter()).map(|(a, b)| a + b).collect();
119 Ok(DataValue::Vector(result))
120 }
121}
122
123pub struct VecSubFunction;
125
126impl SqlFunction for VecSubFunction {
127 fn signature(&self) -> FunctionSignature {
128 FunctionSignature {
129 name: "VEC_SUB",
130 category: FunctionCategory::Mathematical,
131 arg_count: ArgCount::Fixed(2),
132 description: "Subtract two vectors element-wise (v1 - v2)",
133 returns: "Vector",
134 examples: vec![
135 "SELECT VEC_SUB(VEC(10,20,30), VEC(1,2,3))",
136 "SELECT VEC_SUB(position_end, position_start)",
137 ],
138 }
139 }
140
141 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
142 self.validate_args(args)?;
143
144 let v1 = get_vector(&args[0])?;
145 let v2 = get_vector(&args[1])?;
146
147 if v1.len() != v2.len() {
148 return Err(anyhow!(
149 "Vector dimension mismatch: {} != {}",
150 v1.len(),
151 v2.len()
152 ));
153 }
154
155 let result: Vec<f64> = v1.iter().zip(v2.iter()).map(|(a, b)| a - b).collect();
156 Ok(DataValue::Vector(result))
157 }
158}
159
160pub struct VecScaleFunction;
162
163impl SqlFunction for VecScaleFunction {
164 fn signature(&self) -> FunctionSignature {
165 FunctionSignature {
166 name: "VEC_SCALE",
167 category: FunctionCategory::Mathematical,
168 arg_count: ArgCount::Fixed(2),
169 description: "Multiply vector by scalar value",
170 returns: "Vector",
171 examples: vec![
172 "SELECT VEC_SCALE(VEC(1,2,3), 2.5)",
173 "SELECT VEC_SCALE(velocity, time)",
174 ],
175 }
176 }
177
178 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
179 self.validate_args(args)?;
180
181 let v = get_vector(&args[0])?;
182 let scalar = match &args[1] {
183 DataValue::Integer(i) => *i as f64,
184 DataValue::Float(f) => *f,
185 _ => {
186 return Err(anyhow!(
187 "Scalar must be numeric, got {:?}",
188 args[1].data_type()
189 ))
190 }
191 };
192
193 let result: Vec<f64> = v.iter().map(|x| x * scalar).collect();
194 Ok(DataValue::Vector(result))
195 }
196}
197
198pub struct VecDotFunction;
200
201impl SqlFunction for VecDotFunction {
202 fn signature(&self) -> FunctionSignature {
203 FunctionSignature {
204 name: "VEC_DOT",
205 category: FunctionCategory::Mathematical,
206 arg_count: ArgCount::Fixed(2),
207 description: "Compute dot product of two vectors",
208 returns: "Float",
209 examples: vec![
210 "SELECT VEC_DOT(VEC(1,2,3), VEC(4,5,6))",
211 "SELECT VEC_DOT(velocity1, velocity2)",
212 ],
213 }
214 }
215
216 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
217 self.validate_args(args)?;
218
219 let v1 = get_vector(&args[0])?;
220 let v2 = get_vector(&args[1])?;
221
222 if v1.len() != v2.len() {
223 return Err(anyhow!(
224 "Vector dimension mismatch: {} != {}",
225 v1.len(),
226 v2.len()
227 ));
228 }
229
230 let dot_product: f64 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
231 Ok(DataValue::Float(dot_product))
232 }
233}
234
235pub struct VecMagFunction;
237
238impl SqlFunction for VecMagFunction {
239 fn signature(&self) -> FunctionSignature {
240 FunctionSignature {
241 name: "VEC_MAG",
242 category: FunctionCategory::Mathematical,
243 arg_count: ArgCount::Fixed(1),
244 description: "Compute magnitude (length) of a vector",
245 returns: "Float",
246 examples: vec!["SELECT VEC_MAG(VEC(3,4))", "SELECT VEC_MAG(velocity)"],
247 }
248 }
249
250 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
251 self.validate_args(args)?;
252
253 let v = get_vector(&args[0])?;
254 let magnitude = v.iter().map(|x| x * x).sum::<f64>().sqrt();
255 Ok(DataValue::Float(magnitude))
256 }
257}
258
259pub struct VecNormalizeFunction;
261
262impl SqlFunction for VecNormalizeFunction {
263 fn signature(&self) -> FunctionSignature {
264 FunctionSignature {
265 name: "VEC_NORMALIZE",
266 category: FunctionCategory::Mathematical,
267 arg_count: ArgCount::Fixed(1),
268 description: "Normalize vector to unit length",
269 returns: "Vector",
270 examples: vec![
271 "SELECT VEC_NORMALIZE(VEC(3,4))",
272 "SELECT VEC_NORMALIZE(direction)",
273 ],
274 }
275 }
276
277 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
278 self.validate_args(args)?;
279
280 let v = get_vector(&args[0])?;
281 let magnitude = v.iter().map(|x| x * x).sum::<f64>().sqrt();
282
283 if magnitude == 0.0 {
284 return Err(anyhow!("Cannot normalize zero vector"));
285 }
286
287 let normalized: Vec<f64> = v.iter().map(|x| x / magnitude).collect();
288 Ok(DataValue::Vector(normalized))
289 }
290}
291
292pub struct VecDistanceFunction;
294
295impl SqlFunction for VecDistanceFunction {
296 fn signature(&self) -> FunctionSignature {
297 FunctionSignature {
298 name: "VEC_DISTANCE",
299 category: FunctionCategory::Mathematical,
300 arg_count: ArgCount::Fixed(2),
301 description: "Compute Euclidean distance between two vectors",
302 returns: "Float",
303 examples: vec![
304 "SELECT VEC_DISTANCE(VEC(0,0), VEC(3,4))",
305 "SELECT VEC_DISTANCE(position1, position2)",
306 ],
307 }
308 }
309
310 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
311 self.validate_args(args)?;
312
313 let v1 = get_vector(&args[0])?;
314 let v2 = get_vector(&args[1])?;
315
316 if v1.len() != v2.len() {
317 return Err(anyhow!(
318 "Vector dimension mismatch: {} != {}",
319 v1.len(),
320 v2.len()
321 ));
322 }
323
324 let distance = v1
325 .iter()
326 .zip(v2.iter())
327 .map(|(a, b)| (a - b).powi(2))
328 .sum::<f64>()
329 .sqrt();
330
331 Ok(DataValue::Float(distance))
332 }
333}
334
335pub struct VecCrossFunction;
337
338impl SqlFunction for VecCrossFunction {
339 fn signature(&self) -> FunctionSignature {
340 FunctionSignature {
341 name: "VEC_CROSS",
342 category: FunctionCategory::Mathematical,
343 arg_count: ArgCount::Fixed(2),
344 description: "Compute cross product of two 3D vectors",
345 returns: "Vector",
346 examples: vec![
347 "SELECT VEC_CROSS(VEC(1,0,0), VEC(0,1,0))",
348 "SELECT VEC_CROSS(velocity, force)",
349 ],
350 }
351 }
352
353 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
354 self.validate_args(args)?;
355
356 let v1 = get_vector(&args[0])?;
357 let v2 = get_vector(&args[1])?;
358
359 if v1.len() != 3 || v2.len() != 3 {
360 return Err(anyhow!(
361 "VEC_CROSS requires 3D vectors, got dimensions {} and {}",
362 v1.len(),
363 v2.len()
364 ));
365 }
366
367 let cross = vec![
368 v1[1] * v2[2] - v1[2] * v2[1],
369 v1[2] * v2[0] - v1[0] * v2[2],
370 v1[0] * v2[1] - v1[1] * v2[0],
371 ];
372
373 Ok(DataValue::Vector(cross))
374 }
375}
376
377pub struct VecAngleFunction;
379
380impl SqlFunction for VecAngleFunction {
381 fn signature(&self) -> FunctionSignature {
382 FunctionSignature {
383 name: "VEC_ANGLE",
384 category: FunctionCategory::Mathematical,
385 arg_count: ArgCount::Fixed(2),
386 description: "Compute angle between two vectors in radians",
387 returns: "Float",
388 examples: vec![
389 "SELECT VEC_ANGLE(VEC(1,0), VEC(0,1))",
390 "SELECT VEC_ANGLE(direction1, direction2)",
391 ],
392 }
393 }
394
395 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
396 self.validate_args(args)?;
397
398 let v1 = get_vector(&args[0])?;
399 let v2 = get_vector(&args[1])?;
400
401 if v1.len() != v2.len() {
402 return Err(anyhow!(
403 "Vector dimension mismatch: {} != {}",
404 v1.len(),
405 v2.len()
406 ));
407 }
408
409 let dot: f64 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
410 let mag1 = v1.iter().map(|x| x * x).sum::<f64>().sqrt();
411 let mag2 = v2.iter().map(|x| x * x).sum::<f64>().sqrt();
412
413 if mag1 == 0.0 || mag2 == 0.0 {
414 return Err(anyhow!("Cannot compute angle with zero vector"));
415 }
416
417 let cos_angle = dot / (mag1 * mag2);
418 let cos_angle = cos_angle.max(-1.0).min(1.0);
420 let angle = cos_angle.acos();
421
422 Ok(DataValue::Float(angle))
423 }
424}
425
426pub struct LineIntersectFunction;
428
429impl SqlFunction for LineIntersectFunction {
430 fn signature(&self) -> FunctionSignature {
431 FunctionSignature {
432 name: "LINE_INTERSECT",
433 category: FunctionCategory::Mathematical,
434 arg_count: ArgCount::Fixed(4),
435 description: "Find intersection point of two 2D lines (returns NULL if parallel)",
436 returns: "Vector or NULL",
437 examples: vec![
438 "SELECT LINE_INTERSECT(VEC(0,0), VEC(4,4), VEC(0,4), VEC(4,0))",
439 "SELECT LINE_INTERSECT(line1_p1, line1_p2, line2_p1, line2_p2)",
440 ],
441 }
442 }
443
444 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
445 self.validate_args(args)?;
446
447 let p1 = get_vector(&args[0])?;
449 let p2 = get_vector(&args[1])?;
450 let p3 = get_vector(&args[2])?;
451 let p4 = get_vector(&args[3])?;
452
453 if p1.len() != 2 || p2.len() != 2 || p3.len() != 2 || p4.len() != 2 {
454 return Err(anyhow!("LINE_INTERSECT requires 2D points"));
455 }
456
457 let x1 = p1[0];
465 let y1 = p1[1];
466 let x2 = p2[0];
467 let y2 = p2[1];
468 let x3 = p3[0];
469 let y3 = p3[1];
470 let x4 = p4[0];
471 let y4 = p4[1];
472
473 let denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
474
475 if denom.abs() < 1e-10 {
477 return Ok(DataValue::Null);
478 }
479
480 let t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom;
482
483 let intersect_x = x1 + t * (x2 - x1);
484 let intersect_y = y1 + t * (y2 - y1);
485
486 Ok(DataValue::Vector(vec![intersect_x, intersect_y]))
487 }
488}
489
490pub struct SegmentIntersectFunction;
492
493impl SqlFunction for SegmentIntersectFunction {
494 fn signature(&self) -> FunctionSignature {
495 FunctionSignature {
496 name: "SEGMENT_INTERSECT",
497 category: FunctionCategory::Mathematical,
498 arg_count: ArgCount::Fixed(4),
499 description:
500 "Check if two 2D line segments intersect (returns intersection point or NULL)",
501 returns: "Vector or NULL",
502 examples: vec![
503 "SELECT SEGMENT_INTERSECT(VEC(0,0), VEC(2,2), VEC(0,2), VEC(2,0))",
504 "SELECT SEGMENT_INTERSECT(seg1_p1, seg1_p2, seg2_p1, seg2_p2)",
505 ],
506 }
507 }
508
509 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
510 self.validate_args(args)?;
511
512 let p1 = get_vector(&args[0])?;
514 let p2 = get_vector(&args[1])?;
515 let p3 = get_vector(&args[2])?;
516 let p4 = get_vector(&args[3])?;
517
518 if p1.len() != 2 || p2.len() != 2 || p3.len() != 2 || p4.len() != 2 {
519 return Err(anyhow!("SEGMENT_INTERSECT requires 2D points"));
520 }
521
522 let x1 = p1[0];
523 let y1 = p1[1];
524 let x2 = p2[0];
525 let y2 = p2[1];
526 let x3 = p3[0];
527 let y3 = p3[1];
528 let x4 = p4[0];
529 let y4 = p4[1];
530
531 let denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
532
533 if denom.abs() < 1e-10 {
535 return Ok(DataValue::Null);
536 }
537
538 let t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom;
540 let s = ((x1 - x3) * (y1 - y2) - (y1 - y3) * (x1 - x2)) / denom;
541
542 if t >= 0.0 && t <= 1.0 && s >= 0.0 && s <= 1.0 {
544 let intersect_x = x1 + t * (x2 - x1);
545 let intersect_y = y1 + t * (y2 - y1);
546 Ok(DataValue::Vector(vec![intersect_x, intersect_y]))
547 } else {
548 Ok(DataValue::Null)
550 }
551 }
552}
553
554pub struct ClosestPointOnLineFunction;
556
557impl SqlFunction for ClosestPointOnLineFunction {
558 fn signature(&self) -> FunctionSignature {
559 FunctionSignature {
560 name: "CLOSEST_POINT_ON_LINE",
561 category: FunctionCategory::Mathematical,
562 arg_count: ArgCount::Fixed(3),
563 description: "Find closest point on a line to a given point (projection)",
564 returns: "Vector",
565 examples: vec![
566 "SELECT CLOSEST_POINT_ON_LINE(VEC(2,2), VEC(0,0), VEC(1,0))",
567 "SELECT CLOSEST_POINT_ON_LINE(point, line_start, line_direction)",
568 ],
569 }
570 }
571
572 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
573 self.validate_args(args)?;
574
575 let point = get_vector(&args[0])?;
576 let line_point = get_vector(&args[1])?;
577 let line_dir = get_vector(&args[2])?;
578
579 if point.len() != line_point.len() || point.len() != line_dir.len() {
580 return Err(anyhow!(
581 "All vectors must have same dimension, got {}, {}, {}",
582 point.len(),
583 line_point.len(),
584 line_dir.len()
585 ));
586 }
587
588 let to_point: Vec<f64> = point
590 .iter()
591 .zip(line_point.iter())
592 .map(|(p, lp)| p - lp)
593 .collect();
594
595 let dot_product: f64 = to_point
597 .iter()
598 .zip(line_dir.iter())
599 .map(|(a, b)| a * b)
600 .sum();
601 let dir_mag_sq: f64 = line_dir.iter().map(|x| x * x).sum();
602
603 if dir_mag_sq < 1e-10 {
604 return Err(anyhow!("Line direction vector cannot be zero"));
605 }
606
607 let t = dot_product / dir_mag_sq;
608
609 let closest: Vec<f64> = line_point
611 .iter()
612 .zip(line_dir.iter())
613 .map(|(lp, ld)| lp + t * ld)
614 .collect();
615
616 Ok(DataValue::Vector(closest))
617 }
618}
619
620pub struct PointLineDistanceFunction;
622
623impl SqlFunction for PointLineDistanceFunction {
624 fn signature(&self) -> FunctionSignature {
625 FunctionSignature {
626 name: "POINT_LINE_DISTANCE",
627 category: FunctionCategory::Mathematical,
628 arg_count: ArgCount::Fixed(3),
629 description: "Compute perpendicular distance from point to line",
630 returns: "Float",
631 examples: vec![
632 "SELECT POINT_LINE_DISTANCE(VEC(2,2), VEC(0,0), VEC(1,0))",
633 "SELECT POINT_LINE_DISTANCE(point, line_start, line_direction)",
634 ],
635 }
636 }
637
638 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
639 self.validate_args(args)?;
640
641 let point = get_vector(&args[0])?;
642 let line_point = get_vector(&args[1])?;
643 let line_dir = get_vector(&args[2])?;
644
645 if point.len() == 2 {
647 let point_3d = vec![point[0], point[1], 0.0];
649 let line_point_3d = vec![line_point[0], line_point[1], 0.0];
650 let line_dir_3d = vec![line_dir[0], line_dir[1], 0.0];
651
652 let to_point: Vec<f64> = point_3d
653 .iter()
654 .zip(line_point_3d.iter())
655 .map(|(p, lp)| p - lp)
656 .collect();
657
658 let cross_x = to_point[1] * line_dir_3d[2] - to_point[2] * line_dir_3d[1];
660 let cross_y = to_point[2] * line_dir_3d[0] - to_point[0] * line_dir_3d[2];
661 let cross_z = to_point[0] * line_dir_3d[1] - to_point[1] * line_dir_3d[0];
662
663 let cross_mag = (cross_x * cross_x + cross_y * cross_y + cross_z * cross_z).sqrt();
664 let dir_mag = (line_dir[0] * line_dir[0] + line_dir[1] * line_dir[1]).sqrt();
665
666 if dir_mag < 1e-10 {
667 return Err(anyhow!("Line direction cannot be zero"));
668 }
669
670 Ok(DataValue::Float(cross_mag / dir_mag))
671 } else if point.len() == 3 {
672 let to_point: Vec<f64> = point
673 .iter()
674 .zip(line_point.iter())
675 .map(|(p, lp)| p - lp)
676 .collect();
677
678 let cross_x = to_point[1] * line_dir[2] - to_point[2] * line_dir[1];
680 let cross_y = to_point[2] * line_dir[0] - to_point[0] * line_dir[2];
681 let cross_z = to_point[0] * line_dir[1] - to_point[1] * line_dir[0];
682
683 let cross_mag = (cross_x * cross_x + cross_y * cross_y + cross_z * cross_z).sqrt();
684 let dir_mag = line_dir.iter().map(|x| x * x).sum::<f64>().sqrt();
685
686 if dir_mag < 1e-10 {
687 return Err(anyhow!("Line direction cannot be zero"));
688 }
689
690 Ok(DataValue::Float(cross_mag / dir_mag))
691 } else {
692 Err(anyhow!(
693 "POINT_LINE_DISTANCE only supports 2D and 3D, got {}D",
694 point.len()
695 ))
696 }
697 }
698}
699
700pub struct LineReflectPointFunction;
702
703impl SqlFunction for LineReflectPointFunction {
704 fn signature(&self) -> FunctionSignature {
705 FunctionSignature {
706 name: "LINE_REFLECT_POINT",
707 category: FunctionCategory::Mathematical,
708 arg_count: ArgCount::Fixed(3),
709 description: "Reflect a point across a line",
710 returns: "Vector",
711 examples: vec![
712 "SELECT LINE_REFLECT_POINT(VEC(2,2), VEC(0,0), VEC(1,0))",
713 "SELECT LINE_REFLECT_POINT(point, line_start, line_direction)",
714 ],
715 }
716 }
717
718 fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
719 self.validate_args(args)?;
720
721 let point = get_vector(&args[0])?;
722 let line_point = get_vector(&args[1])?;
723 let line_dir = get_vector(&args[2])?;
724
725 if point.len() != line_point.len() || point.len() != line_dir.len() {
726 return Err(anyhow!("All vectors must have same dimension"));
727 }
728
729 let to_point: Vec<f64> = point
731 .iter()
732 .zip(line_point.iter())
733 .map(|(p, lp)| p - lp)
734 .collect();
735
736 let dot_product: f64 = to_point
737 .iter()
738 .zip(line_dir.iter())
739 .map(|(a, b)| a * b)
740 .sum();
741 let dir_mag_sq: f64 = line_dir.iter().map(|x| x * x).sum();
742
743 if dir_mag_sq < 1e-10 {
744 return Err(anyhow!("Line direction vector cannot be zero"));
745 }
746
747 let t = dot_product / dir_mag_sq;
748
749 let closest: Vec<f64> = line_point
750 .iter()
751 .zip(line_dir.iter())
752 .map(|(lp, ld)| lp + t * ld)
753 .collect();
754
755 let reflected: Vec<f64> = closest
757 .iter()
758 .zip(point.iter())
759 .map(|(c, p)| 2.0 * c - p)
760 .collect();
761
762 Ok(DataValue::Vector(reflected))
763 }
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769
770 #[test]
771 fn test_parse_vector_string() {
772 assert_eq!(parse_vector_string("[1,2,3]").unwrap(), vec![1.0, 2.0, 3.0]);
773 assert_eq!(parse_vector_string("1 2 3").unwrap(), vec![1.0, 2.0, 3.0]);
774 assert_eq!(
775 parse_vector_string("1.5, 2.5, 3.5").unwrap(),
776 vec![1.5, 2.5, 3.5]
777 );
778 }
779
780 #[test]
781 fn test_vec_function() {
782 let func = VecFunction;
783 let args = vec![
784 DataValue::Integer(1),
785 DataValue::Integer(2),
786 DataValue::Integer(3),
787 ];
788 let result = func.evaluate(&args).unwrap();
789 assert_eq!(result, DataValue::Vector(vec![1.0, 2.0, 3.0]));
790 }
791
792 #[test]
793 fn test_vec_add() {
794 let func = VecAddFunction;
795 let args = vec![
796 DataValue::Vector(vec![1.0, 2.0, 3.0]),
797 DataValue::Vector(vec![4.0, 5.0, 6.0]),
798 ];
799 let result = func.evaluate(&args).unwrap();
800 assert_eq!(result, DataValue::Vector(vec![5.0, 7.0, 9.0]));
801 }
802
803 #[test]
804 fn test_vec_mag() {
805 let func = VecMagFunction;
806 let args = vec![DataValue::Vector(vec![3.0, 4.0])];
807 let result = func.evaluate(&args).unwrap();
808 assert_eq!(result, DataValue::Float(5.0));
809 }
810
811 #[test]
812 fn test_vec_dot() {
813 let func = VecDotFunction;
814 let args = vec![
815 DataValue::Vector(vec![1.0, 2.0, 3.0]),
816 DataValue::Vector(vec![4.0, 5.0, 6.0]),
817 ];
818 let result = func.evaluate(&args).unwrap();
819 assert_eq!(result, DataValue::Float(32.0)); }
821
822 #[test]
823 fn test_vec_cross() {
824 let func = VecCrossFunction;
825 let args = vec![
826 DataValue::Vector(vec![1.0, 0.0, 0.0]),
827 DataValue::Vector(vec![0.0, 1.0, 0.0]),
828 ];
829 let result = func.evaluate(&args).unwrap();
830 assert_eq!(result, DataValue::Vector(vec![0.0, 0.0, 1.0]));
831 }
832}