sql_cli/sql/functions/
vector.rs

1use anyhow::{anyhow, Result};
2
3use crate::data::datatable::DataValue;
4use crate::sql::functions::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5
6/// Helper function to extract vector from DataValue
7fn get_vector(value: &DataValue) -> Result<Vec<f64>> {
8    match value {
9        DataValue::Vector(v) => Ok(v.clone()),
10        DataValue::String(s) => parse_vector_string(s),
11        _ => Err(anyhow!("Expected vector, got {:?}", value.data_type())),
12    }
13}
14
15/// Parse vector from string representation: "[1,2,3]" or "1 2 3"
16fn parse_vector_string(s: &str) -> Result<Vec<f64>> {
17    let trimmed = s.trim();
18
19    // Handle "[1,2,3]" format
20    let content = if trimmed.starts_with('[') && trimmed.ends_with(']') {
21        &trimmed[1..trimmed.len() - 1]
22    } else {
23        trimmed
24    };
25
26    // Parse components (comma or space separated)
27    let components: Result<Vec<f64>> = if content.contains(',') {
28        content
29            .split(',')
30            .map(|s| {
31                s.trim()
32                    .parse::<f64>()
33                    .map_err(|e| anyhow!("Failed to parse vector component '{}': {}", s.trim(), e))
34            })
35            .collect()
36    } else {
37        content
38            .split_whitespace()
39            .map(|s| {
40                s.parse::<f64>()
41                    .map_err(|e| anyhow!("Failed to parse vector component '{}': {}", s, e))
42            })
43            .collect()
44    };
45
46    components
47}
48
49/// VEC(x, y, z, ...) - Construct a vector from components
50pub struct VecFunction;
51
52impl SqlFunction for VecFunction {
53    fn signature(&self) -> FunctionSignature {
54        FunctionSignature {
55            name: "VEC",
56            category: FunctionCategory::Mathematical,
57            arg_count: ArgCount::Variadic,
58            description: "Construct a vector from numeric components",
59            returns: "Vector",
60            examples: vec!["SELECT VEC(1, 2, 3)", "SELECT VEC(10, 20)"],
61        }
62    }
63
64    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
65        if args.is_empty() {
66            return Err(anyhow!("VEC() requires at least one argument"));
67        }
68
69        let components: Result<Vec<f64>> = args
70            .iter()
71            .map(|arg| match arg {
72                DataValue::Integer(i) => Ok(*i as f64),
73                DataValue::Float(f) => Ok(*f),
74                DataValue::Null => Err(anyhow!("Cannot create vector with NULL component")),
75                _ => Err(anyhow!(
76                    "VEC() requires numeric arguments, got {:?}",
77                    arg.data_type()
78                )),
79            })
80            .collect();
81
82        Ok(DataValue::Vector(components?))
83    }
84}
85
86/// VEC_ADD(v1, v2) - Add two vectors element-wise
87pub struct VecAddFunction;
88
89impl SqlFunction for VecAddFunction {
90    fn signature(&self) -> FunctionSignature {
91        FunctionSignature {
92            name: "VEC_ADD",
93            category: FunctionCategory::Mathematical,
94            arg_count: ArgCount::Fixed(2),
95            description: "Add two vectors element-wise",
96            returns: "Vector",
97            examples: vec![
98                "SELECT VEC_ADD(VEC(1,2,3), VEC(4,5,6))",
99                "SELECT VEC_ADD(position, velocity)",
100            ],
101        }
102    }
103
104    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
105        self.validate_args(args)?;
106
107        let v1 = get_vector(&args[0])?;
108        let v2 = get_vector(&args[1])?;
109
110        if v1.len() != v2.len() {
111            return Err(anyhow!(
112                "Vector dimension mismatch: {} != {}",
113                v1.len(),
114                v2.len()
115            ));
116        }
117
118        let result: Vec<f64> = v1.iter().zip(v2.iter()).map(|(a, b)| a + b).collect();
119        Ok(DataValue::Vector(result))
120    }
121}
122
123/// VEC_SUB(v1, v2) - Subtract two vectors element-wise
124pub struct VecSubFunction;
125
126impl SqlFunction for VecSubFunction {
127    fn signature(&self) -> FunctionSignature {
128        FunctionSignature {
129            name: "VEC_SUB",
130            category: FunctionCategory::Mathematical,
131            arg_count: ArgCount::Fixed(2),
132            description: "Subtract two vectors element-wise (v1 - v2)",
133            returns: "Vector",
134            examples: vec![
135                "SELECT VEC_SUB(VEC(10,20,30), VEC(1,2,3))",
136                "SELECT VEC_SUB(position_end, position_start)",
137            ],
138        }
139    }
140
141    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
142        self.validate_args(args)?;
143
144        let v1 = get_vector(&args[0])?;
145        let v2 = get_vector(&args[1])?;
146
147        if v1.len() != v2.len() {
148            return Err(anyhow!(
149                "Vector dimension mismatch: {} != {}",
150                v1.len(),
151                v2.len()
152            ));
153        }
154
155        let result: Vec<f64> = v1.iter().zip(v2.iter()).map(|(a, b)| a - b).collect();
156        Ok(DataValue::Vector(result))
157    }
158}
159
160/// VEC_SCALE(vector, scalar) - Multiply vector by scalar
161pub struct VecScaleFunction;
162
163impl SqlFunction for VecScaleFunction {
164    fn signature(&self) -> FunctionSignature {
165        FunctionSignature {
166            name: "VEC_SCALE",
167            category: FunctionCategory::Mathematical,
168            arg_count: ArgCount::Fixed(2),
169            description: "Multiply vector by scalar value",
170            returns: "Vector",
171            examples: vec![
172                "SELECT VEC_SCALE(VEC(1,2,3), 2.5)",
173                "SELECT VEC_SCALE(velocity, time)",
174            ],
175        }
176    }
177
178    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
179        self.validate_args(args)?;
180
181        let v = get_vector(&args[0])?;
182        let scalar = match &args[1] {
183            DataValue::Integer(i) => *i as f64,
184            DataValue::Float(f) => *f,
185            _ => {
186                return Err(anyhow!(
187                    "Scalar must be numeric, got {:?}",
188                    args[1].data_type()
189                ))
190            }
191        };
192
193        let result: Vec<f64> = v.iter().map(|x| x * scalar).collect();
194        Ok(DataValue::Vector(result))
195    }
196}
197
198/// VEC_DOT(v1, v2) - Compute dot product
199pub struct VecDotFunction;
200
201impl SqlFunction for VecDotFunction {
202    fn signature(&self) -> FunctionSignature {
203        FunctionSignature {
204            name: "VEC_DOT",
205            category: FunctionCategory::Mathematical,
206            arg_count: ArgCount::Fixed(2),
207            description: "Compute dot product of two vectors",
208            returns: "Float",
209            examples: vec![
210                "SELECT VEC_DOT(VEC(1,2,3), VEC(4,5,6))",
211                "SELECT VEC_DOT(velocity1, velocity2)",
212            ],
213        }
214    }
215
216    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
217        self.validate_args(args)?;
218
219        let v1 = get_vector(&args[0])?;
220        let v2 = get_vector(&args[1])?;
221
222        if v1.len() != v2.len() {
223            return Err(anyhow!(
224                "Vector dimension mismatch: {} != {}",
225                v1.len(),
226                v2.len()
227            ));
228        }
229
230        let dot_product: f64 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
231        Ok(DataValue::Float(dot_product))
232    }
233}
234
235/// VEC_MAG(vector) - Compute magnitude (length) of vector
236pub struct VecMagFunction;
237
238impl SqlFunction for VecMagFunction {
239    fn signature(&self) -> FunctionSignature {
240        FunctionSignature {
241            name: "VEC_MAG",
242            category: FunctionCategory::Mathematical,
243            arg_count: ArgCount::Fixed(1),
244            description: "Compute magnitude (length) of a vector",
245            returns: "Float",
246            examples: vec!["SELECT VEC_MAG(VEC(3,4))", "SELECT VEC_MAG(velocity)"],
247        }
248    }
249
250    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
251        self.validate_args(args)?;
252
253        let v = get_vector(&args[0])?;
254        let magnitude = v.iter().map(|x| x * x).sum::<f64>().sqrt();
255        Ok(DataValue::Float(magnitude))
256    }
257}
258
259/// VEC_NORMALIZE(vector) - Normalize vector to unit length
260pub struct VecNormalizeFunction;
261
262impl SqlFunction for VecNormalizeFunction {
263    fn signature(&self) -> FunctionSignature {
264        FunctionSignature {
265            name: "VEC_NORMALIZE",
266            category: FunctionCategory::Mathematical,
267            arg_count: ArgCount::Fixed(1),
268            description: "Normalize vector to unit length",
269            returns: "Vector",
270            examples: vec![
271                "SELECT VEC_NORMALIZE(VEC(3,4))",
272                "SELECT VEC_NORMALIZE(direction)",
273            ],
274        }
275    }
276
277    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
278        self.validate_args(args)?;
279
280        let v = get_vector(&args[0])?;
281        let magnitude = v.iter().map(|x| x * x).sum::<f64>().sqrt();
282
283        if magnitude == 0.0 {
284            return Err(anyhow!("Cannot normalize zero vector"));
285        }
286
287        let normalized: Vec<f64> = v.iter().map(|x| x / magnitude).collect();
288        Ok(DataValue::Vector(normalized))
289    }
290}
291
292/// VEC_DISTANCE(v1, v2) - Compute Euclidean distance between two vectors
293pub struct VecDistanceFunction;
294
295impl SqlFunction for VecDistanceFunction {
296    fn signature(&self) -> FunctionSignature {
297        FunctionSignature {
298            name: "VEC_DISTANCE",
299            category: FunctionCategory::Mathematical,
300            arg_count: ArgCount::Fixed(2),
301            description: "Compute Euclidean distance between two vectors",
302            returns: "Float",
303            examples: vec![
304                "SELECT VEC_DISTANCE(VEC(0,0), VEC(3,4))",
305                "SELECT VEC_DISTANCE(position1, position2)",
306            ],
307        }
308    }
309
310    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
311        self.validate_args(args)?;
312
313        let v1 = get_vector(&args[0])?;
314        let v2 = get_vector(&args[1])?;
315
316        if v1.len() != v2.len() {
317            return Err(anyhow!(
318                "Vector dimension mismatch: {} != {}",
319                v1.len(),
320                v2.len()
321            ));
322        }
323
324        let distance = v1
325            .iter()
326            .zip(v2.iter())
327            .map(|(a, b)| (a - b).powi(2))
328            .sum::<f64>()
329            .sqrt();
330
331        Ok(DataValue::Float(distance))
332    }
333}
334
335/// VEC_CROSS(v1, v2) - Compute cross product (3D only)
336pub struct VecCrossFunction;
337
338impl SqlFunction for VecCrossFunction {
339    fn signature(&self) -> FunctionSignature {
340        FunctionSignature {
341            name: "VEC_CROSS",
342            category: FunctionCategory::Mathematical,
343            arg_count: ArgCount::Fixed(2),
344            description: "Compute cross product of two 3D vectors",
345            returns: "Vector",
346            examples: vec![
347                "SELECT VEC_CROSS(VEC(1,0,0), VEC(0,1,0))",
348                "SELECT VEC_CROSS(velocity, force)",
349            ],
350        }
351    }
352
353    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
354        self.validate_args(args)?;
355
356        let v1 = get_vector(&args[0])?;
357        let v2 = get_vector(&args[1])?;
358
359        if v1.len() != 3 || v2.len() != 3 {
360            return Err(anyhow!(
361                "VEC_CROSS requires 3D vectors, got dimensions {} and {}",
362                v1.len(),
363                v2.len()
364            ));
365        }
366
367        let cross = vec![
368            v1[1] * v2[2] - v1[2] * v2[1],
369            v1[2] * v2[0] - v1[0] * v2[2],
370            v1[0] * v2[1] - v1[1] * v2[0],
371        ];
372
373        Ok(DataValue::Vector(cross))
374    }
375}
376
377/// VEC_ANGLE(v1, v2) - Compute angle between two vectors in radians
378pub struct VecAngleFunction;
379
380impl SqlFunction for VecAngleFunction {
381    fn signature(&self) -> FunctionSignature {
382        FunctionSignature {
383            name: "VEC_ANGLE",
384            category: FunctionCategory::Mathematical,
385            arg_count: ArgCount::Fixed(2),
386            description: "Compute angle between two vectors in radians",
387            returns: "Float",
388            examples: vec![
389                "SELECT VEC_ANGLE(VEC(1,0), VEC(0,1))",
390                "SELECT VEC_ANGLE(direction1, direction2)",
391            ],
392        }
393    }
394
395    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
396        self.validate_args(args)?;
397
398        let v1 = get_vector(&args[0])?;
399        let v2 = get_vector(&args[1])?;
400
401        if v1.len() != v2.len() {
402            return Err(anyhow!(
403                "Vector dimension mismatch: {} != {}",
404                v1.len(),
405                v2.len()
406            ));
407        }
408
409        let dot: f64 = v1.iter().zip(v2.iter()).map(|(a, b)| a * b).sum();
410        let mag1 = v1.iter().map(|x| x * x).sum::<f64>().sqrt();
411        let mag2 = v2.iter().map(|x| x * x).sum::<f64>().sqrt();
412
413        if mag1 == 0.0 || mag2 == 0.0 {
414            return Err(anyhow!("Cannot compute angle with zero vector"));
415        }
416
417        let cos_angle = dot / (mag1 * mag2);
418        // Clamp to [-1, 1] to handle floating point errors
419        let cos_angle = cos_angle.max(-1.0).min(1.0);
420        let angle = cos_angle.acos();
421
422        Ok(DataValue::Float(angle))
423    }
424}
425
426/// LINE_INTERSECT(p1, p2, p3, p4) - Find exact intersection point of two 2D lines
427pub struct LineIntersectFunction;
428
429impl SqlFunction for LineIntersectFunction {
430    fn signature(&self) -> FunctionSignature {
431        FunctionSignature {
432            name: "LINE_INTERSECT",
433            category: FunctionCategory::Mathematical,
434            arg_count: ArgCount::Fixed(4),
435            description: "Find intersection point of two 2D lines (returns NULL if parallel)",
436            returns: "Vector or NULL",
437            examples: vec![
438                "SELECT LINE_INTERSECT(VEC(0,0), VEC(4,4), VEC(0,4), VEC(4,0))",
439                "SELECT LINE_INTERSECT(line1_p1, line1_p2, line2_p1, line2_p2)",
440            ],
441        }
442    }
443
444    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
445        self.validate_args(args)?;
446
447        // Get four 2D points defining two lines
448        let p1 = get_vector(&args[0])?;
449        let p2 = get_vector(&args[1])?;
450        let p3 = get_vector(&args[2])?;
451        let p4 = get_vector(&args[3])?;
452
453        if p1.len() != 2 || p2.len() != 2 || p3.len() != 2 || p4.len() != 2 {
454            return Err(anyhow!("LINE_INTERSECT requires 2D points"));
455        }
456
457        // Line 1: p1 + t * (p2 - p1)
458        // Line 2: p3 + s * (p4 - p3)
459        //
460        // Solving: p1 + t*(p2-p1) = p3 + s*(p4-p3)
461        // This gives us two equations (for x and y)
462        //
463        // Using determinant method:
464        let x1 = p1[0];
465        let y1 = p1[1];
466        let x2 = p2[0];
467        let y2 = p2[1];
468        let x3 = p3[0];
469        let y3 = p3[1];
470        let x4 = p4[0];
471        let y4 = p4[1];
472
473        let denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
474
475        // If denominator is 0, lines are parallel
476        if denom.abs() < 1e-10 {
477            return Ok(DataValue::Null);
478        }
479
480        // Calculate intersection point
481        let t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom;
482
483        let intersect_x = x1 + t * (x2 - x1);
484        let intersect_y = y1 + t * (y2 - y1);
485
486        Ok(DataValue::Vector(vec![intersect_x, intersect_y]))
487    }
488}
489
490/// SEGMENT_INTERSECT(p1, p2, p3, p4) - Check if two line segments intersect
491pub struct SegmentIntersectFunction;
492
493impl SqlFunction for SegmentIntersectFunction {
494    fn signature(&self) -> FunctionSignature {
495        FunctionSignature {
496            name: "SEGMENT_INTERSECT",
497            category: FunctionCategory::Mathematical,
498            arg_count: ArgCount::Fixed(4),
499            description:
500                "Check if two 2D line segments intersect (returns intersection point or NULL)",
501            returns: "Vector or NULL",
502            examples: vec![
503                "SELECT SEGMENT_INTERSECT(VEC(0,0), VEC(2,2), VEC(0,2), VEC(2,0))",
504                "SELECT SEGMENT_INTERSECT(seg1_p1, seg1_p2, seg2_p1, seg2_p2)",
505            ],
506        }
507    }
508
509    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
510        self.validate_args(args)?;
511
512        // Get four 2D points defining two segments
513        let p1 = get_vector(&args[0])?;
514        let p2 = get_vector(&args[1])?;
515        let p3 = get_vector(&args[2])?;
516        let p4 = get_vector(&args[3])?;
517
518        if p1.len() != 2 || p2.len() != 2 || p3.len() != 2 || p4.len() != 2 {
519            return Err(anyhow!("SEGMENT_INTERSECT requires 2D points"));
520        }
521
522        let x1 = p1[0];
523        let y1 = p1[1];
524        let x2 = p2[0];
525        let y2 = p2[1];
526        let x3 = p3[0];
527        let y3 = p3[1];
528        let x4 = p4[0];
529        let y4 = p4[1];
530
531        let denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
532
533        // If denominator is 0, segments are parallel
534        if denom.abs() < 1e-10 {
535            return Ok(DataValue::Null);
536        }
537
538        // Calculate parameters t and s
539        let t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom;
540        let s = ((x1 - x3) * (y1 - y2) - (y1 - y3) * (x1 - x2)) / denom;
541
542        // Check if intersection is within both segments (t and s in [0, 1])
543        if t >= 0.0 && t <= 1.0 && s >= 0.0 && s <= 1.0 {
544            let intersect_x = x1 + t * (x2 - x1);
545            let intersect_y = y1 + t * (y2 - y1);
546            Ok(DataValue::Vector(vec![intersect_x, intersect_y]))
547        } else {
548            // Segments don't intersect (they would if extended to lines)
549            Ok(DataValue::Null)
550        }
551    }
552}
553
554/// CLOSEST_POINT_ON_LINE(point, line_point, line_dir) - Find closest point on line to given point
555pub struct ClosestPointOnLineFunction;
556
557impl SqlFunction for ClosestPointOnLineFunction {
558    fn signature(&self) -> FunctionSignature {
559        FunctionSignature {
560            name: "CLOSEST_POINT_ON_LINE",
561            category: FunctionCategory::Mathematical,
562            arg_count: ArgCount::Fixed(3),
563            description: "Find closest point on a line to a given point (projection)",
564            returns: "Vector",
565            examples: vec![
566                "SELECT CLOSEST_POINT_ON_LINE(VEC(2,2), VEC(0,0), VEC(1,0))",
567                "SELECT CLOSEST_POINT_ON_LINE(point, line_start, line_direction)",
568            ],
569        }
570    }
571
572    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
573        self.validate_args(args)?;
574
575        let point = get_vector(&args[0])?;
576        let line_point = get_vector(&args[1])?;
577        let line_dir = get_vector(&args[2])?;
578
579        if point.len() != line_point.len() || point.len() != line_dir.len() {
580            return Err(anyhow!(
581                "All vectors must have same dimension, got {}, {}, {}",
582                point.len(),
583                line_point.len(),
584                line_dir.len()
585            ));
586        }
587
588        // Vector from line point to target point
589        let to_point: Vec<f64> = point
590            .iter()
591            .zip(line_point.iter())
592            .map(|(p, lp)| p - lp)
593            .collect();
594
595        // Project onto line direction: t = (to_point · line_dir) / (line_dir · line_dir)
596        let dot_product: f64 = to_point
597            .iter()
598            .zip(line_dir.iter())
599            .map(|(a, b)| a * b)
600            .sum();
601        let dir_mag_sq: f64 = line_dir.iter().map(|x| x * x).sum();
602
603        if dir_mag_sq < 1e-10 {
604            return Err(anyhow!("Line direction vector cannot be zero"));
605        }
606
607        let t = dot_product / dir_mag_sq;
608
609        // Closest point = line_point + t * line_dir
610        let closest: Vec<f64> = line_point
611            .iter()
612            .zip(line_dir.iter())
613            .map(|(lp, ld)| lp + t * ld)
614            .collect();
615
616        Ok(DataValue::Vector(closest))
617    }
618}
619
620/// POINT_LINE_DISTANCE(point, line_point, line_dir) - Distance from point to line
621pub struct PointLineDistanceFunction;
622
623impl SqlFunction for PointLineDistanceFunction {
624    fn signature(&self) -> FunctionSignature {
625        FunctionSignature {
626            name: "POINT_LINE_DISTANCE",
627            category: FunctionCategory::Mathematical,
628            arg_count: ArgCount::Fixed(3),
629            description: "Compute perpendicular distance from point to line",
630            returns: "Float",
631            examples: vec![
632                "SELECT POINT_LINE_DISTANCE(VEC(2,2), VEC(0,0), VEC(1,0))",
633                "SELECT POINT_LINE_DISTANCE(point, line_start, line_direction)",
634            ],
635        }
636    }
637
638    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
639        self.validate_args(args)?;
640
641        let point = get_vector(&args[0])?;
642        let line_point = get_vector(&args[1])?;
643        let line_dir = get_vector(&args[2])?;
644
645        // For 2D and 3D, we can use the cross product method
646        if point.len() == 2 {
647            // Extend to 3D for cross product
648            let point_3d = vec![point[0], point[1], 0.0];
649            let line_point_3d = vec![line_point[0], line_point[1], 0.0];
650            let line_dir_3d = vec![line_dir[0], line_dir[1], 0.0];
651
652            let to_point: Vec<f64> = point_3d
653                .iter()
654                .zip(line_point_3d.iter())
655                .map(|(p, lp)| p - lp)
656                .collect();
657
658            // Cross product
659            let cross_x = to_point[1] * line_dir_3d[2] - to_point[2] * line_dir_3d[1];
660            let cross_y = to_point[2] * line_dir_3d[0] - to_point[0] * line_dir_3d[2];
661            let cross_z = to_point[0] * line_dir_3d[1] - to_point[1] * line_dir_3d[0];
662
663            let cross_mag = (cross_x * cross_x + cross_y * cross_y + cross_z * cross_z).sqrt();
664            let dir_mag = (line_dir[0] * line_dir[0] + line_dir[1] * line_dir[1]).sqrt();
665
666            if dir_mag < 1e-10 {
667                return Err(anyhow!("Line direction cannot be zero"));
668            }
669
670            Ok(DataValue::Float(cross_mag / dir_mag))
671        } else if point.len() == 3 {
672            let to_point: Vec<f64> = point
673                .iter()
674                .zip(line_point.iter())
675                .map(|(p, lp)| p - lp)
676                .collect();
677
678            // 3D cross product
679            let cross_x = to_point[1] * line_dir[2] - to_point[2] * line_dir[1];
680            let cross_y = to_point[2] * line_dir[0] - to_point[0] * line_dir[2];
681            let cross_z = to_point[0] * line_dir[1] - to_point[1] * line_dir[0];
682
683            let cross_mag = (cross_x * cross_x + cross_y * cross_y + cross_z * cross_z).sqrt();
684            let dir_mag = line_dir.iter().map(|x| x * x).sum::<f64>().sqrt();
685
686            if dir_mag < 1e-10 {
687                return Err(anyhow!("Line direction cannot be zero"));
688            }
689
690            Ok(DataValue::Float(cross_mag / dir_mag))
691        } else {
692            Err(anyhow!(
693                "POINT_LINE_DISTANCE only supports 2D and 3D, got {}D",
694                point.len()
695            ))
696        }
697    }
698}
699
700/// LINE_REFLECT_POINT(point, line_point, line_dir) - Reflect point across line
701pub struct LineReflectPointFunction;
702
703impl SqlFunction for LineReflectPointFunction {
704    fn signature(&self) -> FunctionSignature {
705        FunctionSignature {
706            name: "LINE_REFLECT_POINT",
707            category: FunctionCategory::Mathematical,
708            arg_count: ArgCount::Fixed(3),
709            description: "Reflect a point across a line",
710            returns: "Vector",
711            examples: vec![
712                "SELECT LINE_REFLECT_POINT(VEC(2,2), VEC(0,0), VEC(1,0))",
713                "SELECT LINE_REFLECT_POINT(point, line_start, line_direction)",
714            ],
715        }
716    }
717
718    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
719        self.validate_args(args)?;
720
721        let point = get_vector(&args[0])?;
722        let line_point = get_vector(&args[1])?;
723        let line_dir = get_vector(&args[2])?;
724
725        if point.len() != line_point.len() || point.len() != line_dir.len() {
726            return Err(anyhow!("All vectors must have same dimension"));
727        }
728
729        // Find closest point on line (projection)
730        let to_point: Vec<f64> = point
731            .iter()
732            .zip(line_point.iter())
733            .map(|(p, lp)| p - lp)
734            .collect();
735
736        let dot_product: f64 = to_point
737            .iter()
738            .zip(line_dir.iter())
739            .map(|(a, b)| a * b)
740            .sum();
741        let dir_mag_sq: f64 = line_dir.iter().map(|x| x * x).sum();
742
743        if dir_mag_sq < 1e-10 {
744            return Err(anyhow!("Line direction vector cannot be zero"));
745        }
746
747        let t = dot_product / dir_mag_sq;
748
749        let closest: Vec<f64> = line_point
750            .iter()
751            .zip(line_dir.iter())
752            .map(|(lp, ld)| lp + t * ld)
753            .collect();
754
755        // Reflection = point + 2 * (closest - point) = 2 * closest - point
756        let reflected: Vec<f64> = closest
757            .iter()
758            .zip(point.iter())
759            .map(|(c, p)| 2.0 * c - p)
760            .collect();
761
762        Ok(DataValue::Vector(reflected))
763    }
764}
765
766#[cfg(test)]
767mod tests {
768    use super::*;
769
770    #[test]
771    fn test_parse_vector_string() {
772        assert_eq!(parse_vector_string("[1,2,3]").unwrap(), vec![1.0, 2.0, 3.0]);
773        assert_eq!(parse_vector_string("1 2 3").unwrap(), vec![1.0, 2.0, 3.0]);
774        assert_eq!(
775            parse_vector_string("1.5, 2.5, 3.5").unwrap(),
776            vec![1.5, 2.5, 3.5]
777        );
778    }
779
780    #[test]
781    fn test_vec_function() {
782        let func = VecFunction;
783        let args = vec![
784            DataValue::Integer(1),
785            DataValue::Integer(2),
786            DataValue::Integer(3),
787        ];
788        let result = func.evaluate(&args).unwrap();
789        assert_eq!(result, DataValue::Vector(vec![1.0, 2.0, 3.0]));
790    }
791
792    #[test]
793    fn test_vec_add() {
794        let func = VecAddFunction;
795        let args = vec![
796            DataValue::Vector(vec![1.0, 2.0, 3.0]),
797            DataValue::Vector(vec![4.0, 5.0, 6.0]),
798        ];
799        let result = func.evaluate(&args).unwrap();
800        assert_eq!(result, DataValue::Vector(vec![5.0, 7.0, 9.0]));
801    }
802
803    #[test]
804    fn test_vec_mag() {
805        let func = VecMagFunction;
806        let args = vec![DataValue::Vector(vec![3.0, 4.0])];
807        let result = func.evaluate(&args).unwrap();
808        assert_eq!(result, DataValue::Float(5.0));
809    }
810
811    #[test]
812    fn test_vec_dot() {
813        let func = VecDotFunction;
814        let args = vec![
815            DataValue::Vector(vec![1.0, 2.0, 3.0]),
816            DataValue::Vector(vec![4.0, 5.0, 6.0]),
817        ];
818        let result = func.evaluate(&args).unwrap();
819        assert_eq!(result, DataValue::Float(32.0)); // 1*4 + 2*5 + 3*6 = 32
820    }
821
822    #[test]
823    fn test_vec_cross() {
824        let func = VecCrossFunction;
825        let args = vec![
826            DataValue::Vector(vec![1.0, 0.0, 0.0]),
827            DataValue::Vector(vec![0.0, 1.0, 0.0]),
828        ];
829        let result = func.evaluate(&args).unwrap();
830        assert_eq!(result, DataValue::Vector(vec![0.0, 0.0, 1.0]));
831    }
832}