1use serde::{Deserialize, Serialize};
4use std::fmt;
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
8pub enum Type {
9 Number,
11 String,
13 Bool,
15 Null,
17 Unit,
19 Color,
21 Timestamp,
23 Timeframe,
25 TimeRef,
27 Duration,
29 Pattern,
31 Object(Vec<(std::string::String, Type)>),
33 Array(Box<Type>),
35 Matrix(Box<Type>),
37 Column(Box<Type>),
39 Function {
41 params: Vec<Type>,
42 returns: Box<Type>,
43 },
44 Module,
46 Range(Box<Type>),
48 Result(Box<Type>),
51 Unknown,
53 Error,
55}
56
57impl Type {
58 pub fn can_coerce_to(&self, other: &Type) -> bool {
60 match (self, other) {
61 (a, b) if a == b => true,
63
64 (Type::Unknown, _) => true,
66 (_, Type::Unknown) => true,
67
68 (Type::Error, _) | (_, Type::Error) => true,
70
71 (Type::Number, Type::String) => true, (Type::Array(a), Type::Array(b)) => a.can_coerce_to(b),
76 (Type::Matrix(a), Type::Matrix(b)) => a.can_coerce_to(b),
77 (Type::Array(rows), Type::Matrix(elem)) => match rows.as_ref() {
79 Type::Array(inner) => inner.can_coerce_to(elem),
80 _ => false,
81 },
82 (Type::Matrix(elem), Type::Array(rows)) => match rows.as_ref() {
83 Type::Array(inner) => elem.can_coerce_to(inner),
84 _ => false,
85 },
86 (Type::Column(a), Type::Column(b)) => a.can_coerce_to(b),
87 (Type::Range(a), Type::Range(b)) => a.can_coerce_to(b),
88 (Type::Result(a), Type::Result(b)) => a.can_coerce_to(b),
89
90 (Type::Object(a_fields), Type::Object(b_fields)) => {
92 if b_fields.is_empty() {
94 return true;
95 }
96 b_fields.iter().all(|(b_name, b_type)| {
98 a_fields
99 .iter()
100 .find(|(a_name, _)| a_name == b_name)
101 .map(|(_, a_type)| a_type.can_coerce_to(b_type))
102 .unwrap_or(false)
103 })
104 }
105
106 _ => false,
107 }
108 }
109
110 pub fn binary_op_result(&self, op: &shape_ast::ast::BinaryOp, rhs: &Type) -> Type {
112 use shape_ast::ast::BinaryOp;
113
114 match op {
115 BinaryOp::Add
117 | BinaryOp::Sub
118 | BinaryOp::Mul
119 | BinaryOp::Div
120 | BinaryOp::Mod
121 | BinaryOp::Pow => {
122 match (self, rhs) {
123 (Type::Number, Type::Number) => Type::Number,
124 (Type::String, Type::String) if matches!(op, BinaryOp::Add) => Type::String,
125 (Type::Timestamp, Type::Timestamp) if matches!(op, BinaryOp::Sub) => {
127 Type::Duration
128 }
129 (Type::Timestamp, Type::Duration)
130 if matches!(op, BinaryOp::Add | BinaryOp::Sub) =>
131 {
132 Type::Timestamp
133 }
134 (Type::Duration, Type::Timestamp) if matches!(op, BinaryOp::Add) => {
135 Type::Timestamp
136 }
137 (Type::Duration, Type::Duration)
139 if matches!(op, BinaryOp::Add | BinaryOp::Sub) =>
140 {
141 Type::Duration
142 }
143 (Type::Duration, Type::Number)
144 if matches!(op, BinaryOp::Mul | BinaryOp::Div) =>
145 {
146 Type::Duration
147 }
148 (Type::Number, Type::Duration) if matches!(op, BinaryOp::Mul) => Type::Duration,
149 (Type::Column(a), Type::Column(b)) if a == b => Type::Column(a.clone()),
151 (Type::Column(elem), Type::Number) | (Type::Number, Type::Column(elem))
152 if **elem == Type::Number =>
153 {
154 Type::Column(elem.clone())
155 }
156 (Type::Matrix(elem), Type::Array(vec_elem))
158 if matches!(op, BinaryOp::Mul)
159 && **elem == Type::Number
160 && **vec_elem == Type::Number =>
161 {
162 Type::Array(Box::new(Type::Number))
163 }
164 (Type::Matrix(left_elem), Type::Matrix(right_elem))
165 if matches!(op, BinaryOp::Mul)
166 && **left_elem == Type::Number
167 && **right_elem == Type::Number =>
168 {
169 Type::Matrix(Box::new(Type::Number))
170 }
171 (Type::Unknown, Type::Number) | (Type::Number, Type::Unknown) => Type::Number,
173 (Type::Unknown, Type::Unknown) => Type::Unknown,
174 _ => Type::Error,
175 }
176 }
177
178 BinaryOp::Greater
180 | BinaryOp::Less
181 | BinaryOp::GreaterEq
182 | BinaryOp::LessEq
183 | BinaryOp::Equal
184 | BinaryOp::NotEqual => {
185 match (self, rhs) {
186 (Type::Number, Type::Number) => Type::Bool,
187 (Type::String, Type::String) => Type::Bool,
188 (Type::Bool, Type::Bool) => Type::Bool,
189 (Type::Color, Type::Color) => Type::Bool,
190 (Type::Timestamp, Type::Timestamp) => Type::Bool,
192 (Type::Duration, Type::Duration) => Type::Bool,
193 (Type::Column(a), Type::Column(b)) if a == b => {
195 Type::Column(Box::new(Type::Bool))
196 }
197 (Type::Column(elem), Type::Number) | (Type::Number, Type::Column(elem))
198 if **elem == Type::Number =>
199 {
200 Type::Column(Box::new(Type::Bool))
201 }
202 (Type::Unknown, Type::Number) | (Type::Number, Type::Unknown) => Type::Bool,
204 (Type::Unknown, Type::Unknown) => Type::Bool,
205 _ => Type::Error,
206 }
207 }
208
209 BinaryOp::FuzzyEqual | BinaryOp::FuzzyGreater | BinaryOp::FuzzyLess => {
211 match (self, rhs) {
212 (Type::Number, Type::Number) => Type::Bool,
213 _ => Type::Error,
214 }
215 }
216
217 BinaryOp::BitAnd
219 | BinaryOp::BitOr
220 | BinaryOp::BitXor
221 | BinaryOp::BitShl
222 | BinaryOp::BitShr => match (self, rhs) {
223 (Type::Number, Type::Number) => Type::Number,
224 _ => Type::Error,
225 },
226
227 BinaryOp::And | BinaryOp::Or => match (self, rhs) {
229 (Type::Bool, Type::Bool) => Type::Bool,
230 _ => Type::Error,
231 },
232
233 BinaryOp::NullCoalesce => {
235 match (self, rhs) {
237 (Type::Null, right) => right.clone(),
238 (left, _) => left.clone(),
239 }
240 }
241
242 BinaryOp::ErrorContext => match self {
243 Type::Result(inner) => Type::Result(inner.clone()),
244 Type::Null => Type::Result(Box::new(Type::Unknown)),
245 other => Type::Result(Box::new(other.clone())),
246 },
247
248 BinaryOp::Pipe => Type::Unknown,
250 }
251 }
252
253 pub fn unary_op_result(&self, op: &shape_ast::ast::UnaryOp) -> Type {
255 use shape_ast::ast::UnaryOp;
256
257 match op {
258 UnaryOp::Not => match self {
259 Type::Bool => Type::Bool,
260 _ => Type::Error,
261 },
262 UnaryOp::Neg => match self {
263 Type::Number => Type::Number,
264 _ => Type::Error,
265 },
266 UnaryOp::BitNot => match self {
267 Type::Number => Type::Number,
268 _ => Type::Error,
269 },
270 }
271 }
272
273 pub fn property_type(&self, property: &str) -> Type {
275 match self {
276 Type::Object(fields) => {
277 if fields.is_empty() {
279 Type::Unknown
281 } else {
282 fields
283 .iter()
284 .find(|(name, _)| name == property)
285 .map(|(_, ty)| ty.clone())
286 .unwrap_or(Type::Error)
287 }
288 }
289 _ => Type::Error,
290 }
291 }
292
293 pub fn to_inference_type(&self) -> crate::type_system::Type {
298 use crate::type_system::{BuiltinTypes, Type as InferenceType};
299 use shape_ast::ast::TypeAnnotation;
300
301 match self {
302 Type::Number => BuiltinTypes::number(),
303 Type::String => BuiltinTypes::string(),
304 Type::Bool => BuiltinTypes::boolean(),
305 Type::Null => BuiltinTypes::null(),
306 Type::Unit => BuiltinTypes::void(),
307 Type::Color => InferenceType::Concrete(TypeAnnotation::Basic("color".to_string())),
308 Type::Timestamp => {
309 InferenceType::Concrete(TypeAnnotation::Basic("timestamp".to_string()))
310 }
311 Type::Timeframe => {
312 InferenceType::Concrete(TypeAnnotation::Basic("timeframe".to_string()))
313 }
314 Type::TimeRef => InferenceType::Concrete(TypeAnnotation::Basic("timeref".to_string())),
315 Type::Duration => {
316 InferenceType::Concrete(TypeAnnotation::Basic("duration".to_string()))
317 }
318 Type::Pattern => BuiltinTypes::pattern(),
319 Type::Module => InferenceType::Concrete(TypeAnnotation::Basic("module".to_string())),
320 Type::Unknown => BuiltinTypes::any(),
321 Type::Error => InferenceType::Concrete(TypeAnnotation::Never),
322
323 Type::Array(elem) => BuiltinTypes::array(elem.to_inference_type()),
324 Type::Matrix(elem) => InferenceType::Generic {
325 base: Box::new(InferenceType::Concrete(TypeAnnotation::Reference(
326 "Mat".to_string(),
327 ))),
328 args: vec![elem.to_inference_type()],
329 },
330
331 Type::Column(elem) => InferenceType::Generic {
332 base: Box::new(InferenceType::Concrete(TypeAnnotation::Reference(
333 "Column".to_string(),
334 ))),
335 args: vec![elem.to_inference_type()],
336 },
337
338 Type::Range(elem) => InferenceType::Generic {
339 base: Box::new(InferenceType::Concrete(TypeAnnotation::Reference(
340 "Range".to_string(),
341 ))),
342 args: vec![elem.to_inference_type()],
343 },
344
345 Type::Result(ok_type) => InferenceType::Generic {
346 base: Box::new(InferenceType::Concrete(TypeAnnotation::Reference(
347 "Result".to_string(),
348 ))),
349 args: vec![ok_type.to_inference_type()],
350 },
351
352 Type::Object(fields) => {
353 let obj_fields: Vec<_> = fields
354 .iter()
355 .map(|(name, ty)| shape_ast::ast::ObjectTypeField {
356 name: name.clone(),
357 optional: false,
358 type_annotation: ty.to_type_annotation(),
359 annotations: vec![],
360 })
361 .collect();
362 InferenceType::Concrete(TypeAnnotation::Object(obj_fields))
363 }
364
365 Type::Function { params, returns } => {
366 let param_annotations: Vec<_> = params
367 .iter()
368 .map(|p| shape_ast::ast::FunctionParam {
369 name: None,
370 optional: false,
371 type_annotation: p.to_type_annotation(),
372 })
373 .collect();
374 InferenceType::Concrete(TypeAnnotation::Function {
375 params: param_annotations,
376 returns: Box::new(returns.to_type_annotation()),
377 })
378 }
379 }
380 }
381
382 pub fn from_inference_type(ty: &crate::type_system::Type) -> Self {
387 use crate::type_system::Type as InferenceType;
388 use shape_ast::ast::TypeAnnotation;
389
390 match ty {
391 InferenceType::Concrete(ann) => Self::from_type_annotation(ann),
392
393 InferenceType::Variable(_) => Type::Unknown,
394
395 InferenceType::Generic { base, args } => {
396 if let InferenceType::Concrete(TypeAnnotation::Reference(name)) = base.as_ref() {
398 match name.as_str() {
399 "Vec" if !args.is_empty() => {
400 Type::Array(Box::new(Self::from_inference_type(&args[0])))
401 }
402 "Mat" if !args.is_empty() => {
403 Type::Matrix(Box::new(Self::from_inference_type(&args[0])))
404 }
405 "Column" if !args.is_empty() => {
406 Type::Column(Box::new(Self::from_inference_type(&args[0])))
407 }
408 "Range" if !args.is_empty() => {
409 Type::Range(Box::new(Self::from_inference_type(&args[0])))
410 }
411 "Result" if !args.is_empty() => {
412 Type::Result(Box::new(Self::from_inference_type(&args[0])))
413 }
414 _ => Type::Unknown,
415 }
416 } else {
417 Type::Unknown
418 }
419 }
420
421 InferenceType::Constrained { .. } => Type::Unknown,
422 InferenceType::Function { params, returns } => {
423 let param_types: Vec<_> = params.iter().map(Self::from_inference_type).collect();
424 let return_type = Self::from_inference_type(returns);
425 Type::Function {
426 params: param_types,
427 returns: Box::new(return_type),
428 }
429 }
430 }
431 }
432
433 fn to_type_annotation(&self) -> shape_ast::ast::TypeAnnotation {
435 use shape_ast::ast::TypeAnnotation;
436
437 match self {
438 Type::Number => TypeAnnotation::Basic("number".to_string()),
439 Type::String => TypeAnnotation::Basic("string".to_string()),
440 Type::Bool => TypeAnnotation::Basic("bool".to_string()),
441 Type::Null => TypeAnnotation::Null,
442 Type::Unit => TypeAnnotation::Void,
443 Type::Color => TypeAnnotation::Basic("color".to_string()),
444 Type::Timestamp => TypeAnnotation::Basic("timestamp".to_string()),
445 Type::Timeframe => TypeAnnotation::Basic("timeframe".to_string()),
446 Type::TimeRef => TypeAnnotation::Basic("timeref".to_string()),
447 Type::Duration => TypeAnnotation::Basic("duration".to_string()),
448 Type::Pattern => TypeAnnotation::Basic("pattern".to_string()),
449 Type::Module => TypeAnnotation::Basic("module".to_string()),
450 Type::Unknown => TypeAnnotation::Any,
451 Type::Error => TypeAnnotation::Never,
452
453 Type::Array(elem) => TypeAnnotation::Array(Box::new(elem.to_type_annotation())),
454 Type::Matrix(elem) => TypeAnnotation::Generic {
455 name: "Mat".to_string(),
456 args: vec![elem.to_type_annotation()],
457 },
458
459 Type::Column(elem) => TypeAnnotation::Generic {
460 name: "Column".to_string(),
461 args: vec![elem.to_type_annotation()],
462 },
463
464 Type::Range(elem) => TypeAnnotation::Generic {
465 name: "Range".to_string(),
466 args: vec![elem.to_type_annotation()],
467 },
468
469 Type::Result(ok_type) => TypeAnnotation::Generic {
470 name: "Result".to_string(),
471 args: vec![ok_type.to_type_annotation()],
472 },
473
474 Type::Object(fields) => {
475 let obj_fields: Vec<_> = fields
476 .iter()
477 .map(|(name, ty)| shape_ast::ast::ObjectTypeField {
478 name: name.clone(),
479 optional: false,
480 type_annotation: ty.to_type_annotation(),
481 annotations: vec![],
482 })
483 .collect();
484 TypeAnnotation::Object(obj_fields)
485 }
486
487 Type::Function { params, returns } => {
488 let param_annotations: Vec<_> = params
489 .iter()
490 .map(|p| shape_ast::ast::FunctionParam {
491 name: None,
492 optional: false,
493 type_annotation: p.to_type_annotation(),
494 })
495 .collect();
496 TypeAnnotation::Function {
497 params: param_annotations,
498 returns: Box::new(returns.to_type_annotation()),
499 }
500 }
501 }
502 }
503
504 fn from_type_annotation(ann: &shape_ast::ast::TypeAnnotation) -> Self {
506 use shape_ast::ast::TypeAnnotation;
507
508 match ann {
509 TypeAnnotation::Basic(name) => match name.as_str() {
510 "number" | "Number" | "f64" | "float" => Type::Number,
511 "string" | "String" => Type::String,
512 "bool" | "boolean" | "Boolean" => Type::Bool,
513 "null" | "Null" => Type::Null,
514 "color" | "Color" => Type::Color,
515 "timestamp" | "Timestamp" => Type::Timestamp,
516 "timeframe" | "Timeframe" => Type::Timeframe,
517 "timeref" | "TimeRef" => Type::TimeRef,
518 "duration" | "Duration" => Type::Duration,
519 "pattern" | "Pattern" => Type::Pattern,
520 "module" | "Module" => Type::Module,
521 "object" => Type::Object(vec![]),
522 _ => Type::Unknown,
523 },
524
525 TypeAnnotation::Array(elem) => Type::Array(Box::new(Self::from_type_annotation(elem))),
526
527 TypeAnnotation::Generic { name, args } => match name.as_str() {
528 "Column" if !args.is_empty() => {
529 Type::Column(Box::new(Self::from_type_annotation(&args[0])))
530 }
531 "Vec" if !args.is_empty() => {
532 Type::Array(Box::new(Self::from_type_annotation(&args[0])))
533 }
534 "Mat" if !args.is_empty() => {
535 Type::Matrix(Box::new(Self::from_type_annotation(&args[0])))
536 }
537 "Range" if !args.is_empty() => {
538 Type::Range(Box::new(Self::from_type_annotation(&args[0])))
539 }
540 "Result" if !args.is_empty() => {
541 Type::Result(Box::new(Self::from_type_annotation(&args[0])))
542 }
543 _ => Type::Unknown,
544 },
545
546 TypeAnnotation::Reference(name) => match name.as_str() {
547 "number" | "Number" => Type::Number,
548 "string" | "String" => Type::String,
549 "bool" | "Bool" => Type::Bool,
550 _ => Type::Unknown,
551 },
552
553 TypeAnnotation::Object(fields) => {
554 let type_fields: Vec<_> = fields
555 .iter()
556 .map(|f| {
557 (
558 f.name.clone(),
559 Self::from_type_annotation(&f.type_annotation),
560 )
561 })
562 .collect();
563 Type::Object(type_fields)
564 }
565
566 TypeAnnotation::Function { params, returns } => {
567 let param_types: Vec<_> = params
568 .iter()
569 .map(|p| Self::from_type_annotation(&p.type_annotation))
570 .collect();
571 Type::Function {
572 params: param_types,
573 returns: Box::new(Self::from_type_annotation(returns)),
574 }
575 }
576
577 TypeAnnotation::Optional(inner) => {
578 Self::from_type_annotation(inner)
580 }
581
582 TypeAnnotation::Void => Type::Unit,
583 TypeAnnotation::Any => Type::Unknown,
584 TypeAnnotation::Never => Type::Error,
585 TypeAnnotation::Null | TypeAnnotation::Undefined => Type::Null,
586
587 TypeAnnotation::Tuple(_)
588 | TypeAnnotation::Union(_)
589 | TypeAnnotation::Intersection(_)
590 | TypeAnnotation::Dyn(_) => Type::Unknown,
591 }
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
600 fn test_to_inference_type_primitives() {
601 use crate::type_system::BuiltinTypes;
602
603 assert_eq!(Type::Number.to_inference_type(), BuiltinTypes::number());
604 assert_eq!(Type::String.to_inference_type(), BuiltinTypes::string());
605 assert_eq!(Type::Bool.to_inference_type(), BuiltinTypes::boolean());
606 }
607
608 #[test]
609 fn test_to_inference_type_array() {
610 let arr = Type::Array(Box::new(Type::Number));
611 let inference_type = arr.to_inference_type();
612
613 let back = Type::from_inference_type(&inference_type);
615 assert_eq!(back, Type::Array(Box::new(Type::Number)));
616 }
617
618 #[test]
619 fn test_to_inference_type_matrix() {
620 let mat = Type::Matrix(Box::new(Type::Number));
621 let inference_type = mat.to_inference_type();
622
623 let back = Type::from_inference_type(&inference_type);
625 assert_eq!(back, Type::Matrix(Box::new(Type::Number)));
626 }
627
628 #[test]
629 fn test_to_inference_type_column() {
630 let column = Type::Column(Box::new(Type::Number));
631 let inference_type = column.to_inference_type();
632
633 let back = Type::from_inference_type(&inference_type);
635 assert_eq!(back, Type::Column(Box::new(Type::Number)));
636 }
637
638 #[test]
639 fn test_binary_op_result_matrix_mul() {
640 use shape_ast::ast::BinaryOp;
641
642 let mat = Type::Matrix(Box::new(Type::Number));
643 let vec = Type::Array(Box::new(Type::Number));
644 let mat_rhs = Type::Matrix(Box::new(Type::Number));
645
646 assert_eq!(
647 mat.binary_op_result(&BinaryOp::Mul, &vec),
648 Type::Array(Box::new(Type::Number))
649 );
650 assert_eq!(
651 mat.binary_op_result(&BinaryOp::Mul, &mat_rhs),
652 Type::Matrix(Box::new(Type::Number))
653 );
654 }
655
656 #[test]
657 fn test_from_inference_type_primitives() {
658 use crate::type_system::BuiltinTypes;
659
660 assert_eq!(
661 Type::from_inference_type(&BuiltinTypes::number()),
662 Type::Number
663 );
664 assert_eq!(
665 Type::from_inference_type(&BuiltinTypes::string()),
666 Type::String
667 );
668 assert_eq!(
669 Type::from_inference_type(&BuiltinTypes::boolean()),
670 Type::Bool
671 );
672 }
673
674 #[test]
675 fn test_roundtrip_object_type() {
676 let obj = Type::Object(vec![
677 ("x".to_string(), Type::Number),
678 ("y".to_string(), Type::String),
679 ]);
680 let inference = obj.to_inference_type();
681 let back = Type::from_inference_type(&inference);
682
683 assert_eq!(back, obj);
684 }
685
686 #[test]
687 fn test_roundtrip_function_type() {
688 let func = Type::Function {
689 params: vec![Type::Number, Type::String],
690 returns: Box::new(Type::Bool),
691 };
692 let inference = func.to_inference_type();
693 let back = Type::from_inference_type(&inference);
694
695 assert_eq!(back, func);
696 }
697}
698
699impl fmt::Display for Type {
700 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701 match self {
702 Type::Number => write!(f, "Number"),
703 Type::String => write!(f, "String"),
704 Type::Bool => write!(f, "Bool"),
705 Type::Null => write!(f, "Null"),
706 Type::Unit => write!(f, "Unit"),
707 Type::Color => write!(f, "Color"),
708 Type::Timestamp => write!(f, "Timestamp"),
709 Type::Timeframe => write!(f, "Timeframe"),
710 Type::TimeRef => write!(f, "TimeRef"),
711 Type::Duration => write!(f, "Duration"),
712 Type::Pattern => write!(f, "Pattern"),
713 Type::Object(fields) => {
714 write!(f, "{{")?;
715 for (i, (name, ty)) in fields.iter().enumerate() {
716 if i > 0 {
717 write!(f, ", ")?;
718 }
719 write!(f, "{}: {}", name, ty)?;
720 }
721 write!(f, "}}")
722 }
723 Type::Array(elem) => write!(f, "Vec<{}>", elem),
724 Type::Matrix(elem) => write!(f, "Mat<{}>", elem),
725 Type::Column(elem) => write!(f, "Column<{}>", elem),
726 Type::Function { params, returns } => {
727 write!(f, "(")?;
728 for (i, param) in params.iter().enumerate() {
729 if i > 0 {
730 write!(f, ", ")?;
731 }
732 write!(f, "{}", param)?;
733 }
734 write!(f, ") -> {}", returns)
735 }
736 Type::Module => write!(f, "Module"),
737 Type::Range(elem) => write!(f, "Range<{}>", elem),
738 Type::Result(ok_type) => write!(f, "Result<{}>", ok_type),
739 Type::Unknown => write!(f, "?"),
740 Type::Error => write!(f, "Error"),
741 }
742 }
743}