1use std::{fmt, panic::Location, str::FromStr, vec};
2
3mod macros;
4
5use lazy_static::lazy_static;
6use regex::Regex;
7use syn::{GenericArgument, PathArguments, Type};
8
9lazy_static! {
10 static ref OPTION_REGEX: Regex = Regex::new(r"^Option<(.+)>$").unwrap();
12 static ref VEC_REGEX: Regex = Regex::new(r"^Vec<(.+)>$").unwrap();
14}
15
16#[derive(Clone, Eq, PartialEq, Hash)]
22pub enum TsType {
23 Base(String),
25 Array(Box<TsType>),
26 Paren(Box<TsType>),
27 Tuple(Vec<TsType>),
28 Union(Vec<TsType>),
29 Intersection(Vec<TsType>),
30 Generic(Box<TsType>, Vec<TsType>),
32 IndexedAccess(Box<TsType>, Box<TsType>),
34}
35
36impl Default for TsType {
37 fn default() -> Self {
38 TsType::Base("any".to_string())
39 }
40}
41
42pub struct TsTypeError {
45 pub message: String,
46 pub location: String,
47}
48
49impl fmt::Display for TsTypeError {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(
52 f,
53 "TypeError: {}\n Location: {}",
54 self.message, self.location
55 )
56 }
57}
58
59impl fmt::Debug for TsTypeError {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 write!(
62 f,
63 "TypeError: {}\n Location: {}",
64 self.message, self.location
65 )
66 }
67}
68
69impl TsType {
72 pub fn is_union_with(&self, other: &Self) -> bool {
74 match self {
75 Self::Union(types) => types.iter().any(|ty| ty == other),
76 _ => false,
77 }
78 }
79
80 pub fn contains(&self, other: &Self) -> bool {
83 match self {
84 Self::Base(name) => match other {
85 Self::Base(other_name) => name == other_name,
87 _ => false,
89 },
90 Self::Array(inner) => inner.contains(other),
91 Self::Paren(inner) => inner.contains(other),
92 Self::IndexedAccess(base, key) => base.contains(other) || key.contains(other),
93 Self::Generic(base, args) => {
94 if base.contains(other) {
95 return true;
96 }
97 for arg in args {
98 if arg.contains(other) {
99 return true;
100 }
101 }
102 false
103 }
104 Self::Union(types) => types.iter().any(|ty| ty.contains(other)),
105 Self::Intersection(types) => types.iter().any(|ty| ty.contains(other)),
106 Self::Tuple(types) => types.iter().any(|ty| ty.contains(other)),
107 }
108 }
109
110 pub fn as_generic(self, args: Vec<Self>) -> Self {
112 match self {
113 Self::Base(_) => Self::Generic(Box::new(self), args),
114 Self::IndexedAccess(_, _) => Self::Generic(Box::new(self), args),
115 _ => panic!("Type can't be generic."),
116 }
117 }
118
119 pub fn property(self, key: Self) -> Self {
121 Self::IndexedAccess(Box::new(self), Box::new(key))
122 }
123
124 pub fn in_array(self) -> Self {
126 Self::Array(Box::new(self))
127 }
128
129 pub fn in_parens(self) -> Self {
131 match self {
132 Self::Intersection(_) => Self::Paren(Box::new(self)),
133 Self::Union(_) => Self::Paren(Box::new(self)),
134 _ => self,
136 }
137 }
138
139 pub fn or(self, other: Self) -> Self {
141 match self {
142 Self::Union(mut types) => match other {
143 Self::Union(mut other_types) => {
145 types.append(&mut other_types);
146 Self::Union(types)
147 }
148 _ => {
150 types.push(other);
151 Self::Union(types)
152 }
153 },
154 _ => match other {
155 Self::Union(mut other_types) => {
157 other_types.insert(0, self);
158 Self::Union(other_types)
159 }
160 _ => Self::Union(vec![self, other]),
162 },
163 }
164 }
165
166 pub fn and(self, other: Self) -> Self {
168 match self {
169 Self::Intersection(mut types) => match other {
170 Self::Intersection(mut other_types) => {
172 types.append(&mut other_types);
173 Self::Intersection(types)
174 }
175 _ => {
177 types.push(other);
178 Self::Intersection(types)
179 }
180 },
181 _ => match other {
182 Self::Intersection(mut types) => {
184 types.insert(0, self);
185 Self::Intersection(types)
186 }
187 _ => Self::Intersection(vec![self, other]),
189 },
190 }
191 }
192
193 pub fn join(self, other: Self) -> Result<Self, &'static str> {
196 match self {
197 Self::Base(mut str) => match other {
198 Self::Base(other_str) => {
200 str.push_str(&other_str);
201 Ok(Self::Base(str))
202 }
203 _ => other.join(Self::Base(str)),
205 },
206 Self::Generic(ty, mut args) => {
208 args.push(other);
209 Ok(Self::Generic(ty, args))
210 }
211 Self::IndexedAccess(object, key) => {
213 let key_inner = *key;
214 Ok(Self::IndexedAccess(
215 object,
216 Box::new(key_inner.join(other)?),
217 ))
218 }
219 Self::Union(mut types) => {
220 match other {
221 Self::Union(mut other_types) => {
223 types.append(&mut other_types);
224 Ok(Self::Union(types))
225 }
226 _ => {
228 types.push(other);
229 Ok(Self::Union(types))
230 }
231 }
232 }
233 Self::Intersection(mut types) => {
234 match other {
235 Self::Intersection(mut other_types) => {
237 types.append(&mut other_types);
238 Ok(Self::Intersection(types))
239 }
240 Self::Union(mut union_types) => {
242 let first_member = union_types.remove(0);
243 let intersection = Self::Intersection(types);
244 let intersected_member = intersection.and(first_member);
245 union_types.insert(0, intersected_member);
246 Ok(Self::Union(union_types))
247 }
248 _ => {
250 types.push(other);
251 Ok(Self::Intersection(types))
252 }
253 }
254 }
255 Self::Tuple(mut types) => {
257 types.push(other);
258 Ok(Self::Tuple(types))
259 }
260 Self::Paren(inner) => inner.join(other),
262 _ => Err("Type does not support joining."),
263 }
264 }
265
266 #[track_caller]
283 pub fn from_ts_str(str: &str) -> Result<Self, TsTypeError> {
284 let location = Location::caller();
285
286 if str.is_empty() {
287 return Err(type_error_at!(location, "Empty string."));
288 }
289
290 let mut stacks: Vec<Vec<Self>> = vec![];
291 let mut pending_stack: Vec<Self> = vec![];
292 let mut pending_type: Option<Self> = None;
293 let mut ambiguous_bracket = false;
294 let chars = str.trim().chars();
295
296 for char in chars {
297 if ambiguous_bracket && char != ']' {
298 pending_stack.push(pending_type.unwrap().property(TsType::Base("".to_string())));
299 pending_type = None;
300 ambiguous_bracket = false;
301 }
302 match char {
303 ' ' => continue,
304 '|' => {
305 let member = match pending_type {
308 Some(ty) => vec![ty],
309 None => vec![],
310 };
311 let mut _union = Self::Union(member);
312 pending_stack.push(_union);
313 pending_type = None;
314 }
315 '&' => {
316 let member = match pending_type {
319 Some(ty) => vec![ty],
320 None => vec![],
321 };
322 let intersection = Self::Intersection(member);
323 pending_stack.push(intersection);
324 pending_type = None;
325 }
326 '<' => {
327 if pending_type.is_none() {
329 return Err(type_error_at!(location, "Unexpected `<` found."));
330 }
331 let inner = pending_type.unwrap();
332 let generic = inner.as_generic(vec![]);
333 pending_stack.push(generic);
334 pending_type = None;
335 }
336 ',' => {
337 if pending_type.is_none() {
340 return Err(type_error_at!(location, "Unexpected `,` found."));
341 }
342 let mut inner = pending_type.unwrap();
343
344 loop {
345 let top = pending_stack.pop().unwrap();
346 inner = top.join(inner).unwrap();
347
348 match inner {
349 Self::Generic(_, _) => break,
350 Self::IndexedAccess(_, _) => break,
351 Self::Tuple(_) => break,
352 _ => {}
353 }
354
355 if pending_stack.is_empty() {
356 return Err(type_error_at!(location, "Unexpected `,` found."));
357 }
358 }
359 pending_stack.push(inner);
360 pending_type = None;
361 }
362 '>' => {
363 if pending_type.is_none() {
366 return Err(type_error_at!(location, "Unexpected `>` found."));
367 };
368 let mut ty = pending_type.unwrap();
369 loop {
370 let top = pending_stack.pop().unwrap();
371 ty = top.join(ty).unwrap();
372
373 if let Self::Generic(_, _) = ty {
374 break;
375 }
376
377 if pending_stack.is_empty() {
378 return Err(type_error_at!(location, "Unexpected `,` found."));
379 }
380 }
381 pending_type = Some(ty);
382 }
383 '[' => {
384 if pending_type.is_none() {
385 let tuple = Self::Tuple(vec![]);
387 pending_stack.push(tuple);
388 } else {
389 ambiguous_bracket = true;
392 }
393 }
394 ']' => {
395 if pending_type.is_none() {
396 return Err(type_error_at!(location, "Unexpected `]` found."));
397 };
398 let mut ty = pending_type.unwrap();
399
400 if ambiguous_bracket {
403 pending_type = Some(ty.in_array());
404 ambiguous_bracket = false;
405 } else {
406 loop {
407 let top = pending_stack.pop().unwrap();
408 ty = top.join(ty).unwrap();
409
410 match ty {
411 Self::IndexedAccess(_, _) => break,
412 Self::Tuple(_) => break,
413 _ => {}
414 }
415
416 if pending_stack.is_empty() {
417 return Err(type_error_at!(location, "Unexpected `]` found."));
418 }
419 }
420 pending_type = Some(ty);
421 }
422 }
423 '(' => {
424 if pending_type.is_some() {
425 return Err(type_error_at!(location, "Unexpected `(` found."));
426 };
427
428 stacks.push(pending_stack);
430 pending_stack = vec![];
431 }
432 ')' => {
433 if pending_type.is_none() {
436 return Err(type_error_at!(location, "Unexpected `)` found."));
437 };
438 let mut inner = pending_type.unwrap();
439
440 for _ in 0..pending_stack.len() {
441 let top = pending_stack.pop().unwrap();
442 inner = top.join(inner).unwrap();
443 }
444
445 pending_type = Some(inner.in_parens());
446 pending_stack = stacks.pop().unwrap();
447 }
448 part => {
449 match pending_type {
450 Some(pending) => {
451 let next = Self::Base(part.to_string());
453 pending_type = Some(pending.join(next).unwrap());
454 }
455 None => {
456 pending_type = Some(Self::Base(part.to_string()));
458 }
459 }
460 }
461 }
462 }
463
464 let mut final_ty = pending_type.unwrap_or_else(|| pending_stack.pop().unwrap());
465
466 while let Some(top) = pending_stack.pop() {
467 final_ty = top.join(final_ty).unwrap();
468 }
469
470 Ok(final_ty)
471 }
472}
473
474impl fmt::Display for TsType {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 match self {
479 TsType::Base(name) => write!(f, "{}", name.trim()),
480 TsType::Array(ty) => match ty.as_ref() {
481 TsType::Union(_) => write!(f, "({})[]", ty),
484 TsType::Intersection(_) => write!(f, "({})[]", ty),
485 _ => write!(f, "{}[]", ty.to_string()),
486 },
487 TsType::Paren(ty) => write!(f, "({})", ty.to_string()),
488 TsType::IndexedAccess(ty, key_ty) => {
489 write!(f, "{}[{}]", ty.to_string(), key_ty.to_string())
490 }
491 TsType::Generic(name, args) => {
492 let args = args
493 .iter()
494 .map(|ty| ty.to_string())
495 .collect::<Vec<_>>()
496 .join(", ");
497 write!(f, "{}<{}>", name, args)
498 }
499 TsType::Union(types) => {
500 let types = types
501 .iter()
502 .map(|ty| match ty {
503 TsType::Intersection(_) => format!("({})", ty),
506 _ => ty.to_string(),
507 })
508 .collect::<Vec<_>>()
509 .join(" | ");
510 write!(f, "{}", types)
511 }
512 TsType::Intersection(types) => {
513 let types = types
514 .iter()
515 .map(|ty| ty.to_string())
516 .collect::<Vec<_>>()
517 .join(" & ");
518 write!(f, "{}", types)
519 }
520 TsType::Tuple(types) => {
521 let types = types
522 .iter()
523 .map(|ty| ty.to_string())
524 .collect::<Vec<_>>()
525 .join(", ");
526 write!(f, "[{}]", types)
527 }
528 }
529 }
530}
531
532impl fmt::Debug for TsType {
533 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
534 match self {
535 TsType::Base(name) => write!(f, "Base({})", name.trim()),
536 TsType::Array(ty) => write!(f, "Array({:?})", ty),
537 TsType::Paren(ty) => write!(f, "Paren({:?})", ty),
538 TsType::IndexedAccess(ty, key_ty) => {
539 write!(f, "IndexedAccess({:?}[{:?}])", ty, key_ty)
540 }
541 TsType::Generic(name, args) => {
542 write!(f, "Generic(")?;
543 write!(f, "{:?}<", name)?;
544
545 for (i, arg) in args.iter().enumerate() {
546 write!(f, "{:?}", arg)?;
547 if i < args.len() - 1 {
548 write!(f, ", ")?;
549 }
550 }
551
552 write!(f, ">)")
553 }
554 TsType::Union(types) => {
555 write!(f, "Union(")?;
556
557 for (i, ty) in types.iter().enumerate() {
558 write!(f, "{:?}", ty)?;
559 if i < types.len() - 1 {
560 write!(f, " | ")?;
561 }
562 }
563
564 write!(f, ")")
565 }
566 TsType::Intersection(types) => {
567 write!(f, "Intersection(")?;
568
569 for (i, ty) in types.iter().enumerate() {
570 write!(f, "{:?}", ty)?;
571 if i < types.len() - 1 {
572 write!(f, " & ")?;
573 }
574 }
575
576 write!(f, ")")
577 }
578 TsType::Tuple(types) => {
579 write!(f, "Tuple([")?;
580
581 for (i, ty) in types.iter().enumerate() {
582 write!(f, "{:?}", ty)?;
583 if i < types.len() - 1 {
584 write!(f, ", ")?;
585 }
586 }
587
588 write!(f, "])")
589 }
590 }
591 }
592}
593
594impl TryFrom<&Type> for TsType {
597 type Error = TsTypeError;
598
599 #[track_caller]
600 fn try_from(ty: &Type) -> Result<Self, Self::Error> {
601 let rust_type_str = strip_type(ty)?;
602
603 if let Some(ts_type) = match_simple_type(&rust_type_str) {
605 return Ok(ts_type);
606 }
607
608 if let Some(captures) = OPTION_REGEX.captures(&rust_type_str) {
610 let inner_rust_type_str = &captures[1];
611 let ts_type = match_simple_type(inner_rust_type_str)
612 .unwrap_or(TsType::from_str(inner_rust_type_str)?);
613 return Ok(ts_type.or(ts_type!(undefined)));
614 }
615
616 if let Some(captures) = VEC_REGEX.captures(&rust_type_str) {
618 let inner_rust_type_str = &captures[1];
619 let ts_type = match_simple_type(inner_rust_type_str)
620 .unwrap_or(TsType::from_str(inner_rust_type_str)?);
621 return Ok(ts_type.in_array());
622 }
623
624 TsType::from_ts_str(&rust_type_str)
627 }
628}
629
630impl FromStr for TsType {
631 type Err = TsTypeError;
632
633 #[track_caller]
634 fn from_str(s: &str) -> Result<Self, Self::Err> {
635 TsType::from_ts_str(s)
636 }
637}
638
639pub trait ToTsType {
640 fn to_ts_type(&self) -> Result<TsType, TsTypeError>;
641}
642
643impl ToTsType for Type {
644 #[track_caller]
645 fn to_ts_type(&self) -> Result<TsType, TsTypeError> {
646 TsType::try_from(self)
647 }
648}
649
650impl ToTsType for &Type {
651 #[track_caller]
652 fn to_ts_type(&self) -> Result<TsType, TsTypeError> {
653 TsType::try_from(*self)
654 }
655}
656
657impl ToTsType for &str {
658 #[track_caller]
659 fn to_ts_type(&self) -> Result<TsType, TsTypeError> {
660 TsType::from_str(self)
661 }
662}
663
664impl ToTsType for String {
665 #[track_caller]
666 fn to_ts_type(&self) -> Result<TsType, TsTypeError> {
667 TsType::from_str(self.as_str())
668 }
669}
670
671#[track_caller]
675fn strip_type(ty: &Type) -> Result<String, TsTypeError> {
676 let location = Location::caller();
677 match ty {
678 Type::Group(group) => strip_type(&group.elem),
679 Type::Paren(paren) => strip_type(&paren.elem),
680 Type::Ptr(ptr) => strip_type(&ptr.elem),
681 Type::Reference(reference) => strip_type(&reference.elem),
682 Type::Slice(type_slice) => Ok(format!("[{}]", strip_type(&type_slice.elem)?)),
683 Type::Array(type_array) => Ok(format!("[{}; _]", strip_type(&type_array.elem)?)),
684 Type::Tuple(tuple) => {
685 if tuple.elems.is_empty() {
686 Ok("()".to_string())
687 } else {
688 let types = tuple
689 .elems
690 .iter()
691 .map(|elem| strip_type(elem))
692 .collect::<Result<Vec<_>, _>>()?
693 .join(", ");
694 Ok(format!("({})", types))
695 }
696 }
697 Type::Path(path) => {
698 let last_segment = path
699 .path
700 .segments
701 .last()
702 .ok_or_else(|| type_error_at!(location, "No segments found"))?;
703 let outer_type = last_segment.ident.to_string();
704
705 if last_segment.arguments.is_empty() {
706 Ok(outer_type)
707 } else {
708 let arguments = match &last_segment.arguments {
709 PathArguments::AngleBracketed(angle) => {
710 let args = angle
711 .args
712 .iter()
713 .map(|arg| match arg {
714 GenericArgument::Type(ty) => strip_type(ty),
715 _ => Err(type_error_at!(location, "Unsupported type argument.",)),
716 })
717 .collect::<Result<Vec<_>, _>>()?
718 .join(", ");
719 format!("<{}>", args)
720 }
721 PathArguments::Parenthesized(paren) => {
722 let inputs = paren
723 .inputs
724 .iter()
725 .map(strip_type)
726 .collect::<Result<Vec<_>, _>>()?
727 .join(", ");
728 format!("({})", inputs)
729 }
730 _ => String::new(),
731 };
732
733 Ok(format!("{}{}", last_segment.ident, arguments))
734 }
735 }
736
737 _ => Err(type_error_at!(location, "Unsupported type.")),
738 }
739}
740
741fn match_simple_type(rust_type: &str) -> Option<TsType> {
749 let simple_match = match rust_type {
750 "bool" => ts_type!(boolean),
752 "String" | "str" | "char" => ts_type!(string),
753 "u8" | "i8" | "u16" | "i16" | "u32" | "i32" | "f32" | "f64" => ts_type!(number),
755 "u64" | "i64" | "u128" | "i128" => ts_type!(bigint),
756
757 "U256" | "I256" => ts_type!(bigint),
759 "Address" => TsType::from_ts_str("`0x${string}`").unwrap(),
760
761 "BigInt" => ts_type!(bigint),
763 "Boolean" => ts_type!(boolean),
764 "JsString" => ts_type!(string),
765 "Number" => ts_type!(number),
766 "Object" => ts_type!(object),
767
768 "FixedPoint" => ts_type!(bigint),
770
771 _ => return None,
773 };
774 Some(simple_match)
775}
776
777#[cfg(test)]
778mod tests {
779 use super::*;
780
781 #[test]
782 fn test_formatting() {
783 let base = ts_type!(string);
784 assert_eq!(base.to_string(), "string");
785
786 #[rustfmt::skip]
787 let array = ts_type!(string [ ]);
788 assert_eq!(array.to_string(), "string[]");
789
790 #[rustfmt::skip]
791 let paren = ts_type!(( string | number ));
792 assert_eq!(paren.to_string(), "(string | number)");
793
794 #[rustfmt::skip]
795 let generic = ts_type!(Set< string, number >);
796 assert_eq!(generic.to_string(), "Set<string, number>");
797
798 #[rustfmt::skip]
799 let _union = ts_type!(
800 | string
801 | number
802 | boolean
803 );
804 assert_eq!(_union.to_string(), "string | number | boolean");
805
806 #[rustfmt::skip]
807 let intersection = ts_type!( Foo & Bar & Baz);
808 assert_eq!(intersection.to_string(), "Foo & Bar & Baz");
809
810 #[rustfmt::skip]
811 let wrapped_intersection = ts_type!(Foo | Bar & Baz);
812 assert_eq!(wrapped_intersection.to_string(), "Foo | (Bar & Baz)");
813
814 let template_string = TsType::from_ts_str("`0x${string}`");
815 assert_eq!(template_string.unwrap().to_string(), "`0x${string}`");
816 }
817
818 #[test]
819 fn test_variable_parsing() {
820 let base = ts_type!(string);
821 let generic = ts_type!(Set<string>);
822 let group = ts_type!((string | number));
823 let intersection = ts_type!(string & number);
824 let _union = ts_type!(string | number);
825
826 let single = ts_type!((#base));
829 assert_eq!(single.to_string(), "string",);
830
831 let single_generic = ts_type!((#generic));
832 assert_eq!(single_generic.to_string(), "Set<string>");
833
834 let single_group = ts_type!((#group));
835 assert_eq!(single_group.to_string(), "(string | number)");
836
837 let single_intersection = ts_type!((#intersection));
838 assert_eq!(single_intersection.to_string(), "string & number");
839
840 let single_union = ts_type!((#_union));
841 assert_eq!(single_union.to_string(), "string | number");
842
843 let generic = ts_type!(Set<(#base)>);
846 assert_eq!(generic.to_string(), "Set<string>");
847
848 let generic_two = ts_type!(Set<(#base), (#_union)>);
849 assert_eq!(generic_two.to_string(), "Set<string, string | number>");
850
851 let start_union = ts_type!((#base) | true | false);
854 assert_eq!(start_union.to_string(), "string | true | false");
855
856 let mid_union = ts_type!(true | (#base) | false);
857 assert_eq!(mid_union.to_string(), "true | string | false");
858
859 let end_union = ts_type!(true | false | (#base));
860 assert_eq!(end_union.to_string(), "true | false | string");
861
862 let start_union_pair = ts_type!((#base) | true);
863 assert_eq!(start_union_pair.to_string(), "string | true");
864
865 let end_union_pair = ts_type!(true | (#base));
866 assert_eq!(end_union_pair.to_string(), "true | string");
867
868 let var_union = ts_type!((#base) | (#generic) | (#group));
869 assert_eq!(
870 var_union.to_string(),
871 "string | Set<string> | (string | number)"
872 );
873
874 let var_union_pair = ts_type!((#base) | (#generic));
875 assert_eq!(var_union_pair.to_string(), "string | Set<string>");
876
877 let start_intersection = ts_type!((#base) & true & false);
880 assert_eq!(start_intersection.to_string(), "string & true & false");
881
882 let mid_intersection = ts_type!(true & (#base) & false);
883 assert_eq!(mid_intersection.to_string(), "true & string & false");
884
885 let end_intersection = ts_type!(true & false & (#base));
886 assert_eq!(end_intersection.to_string(), "true & false & string");
887
888 let start_intersection_pair = ts_type!((#base) & true);
889 assert_eq!(start_intersection_pair.to_string(), "string & true");
890
891 let end_intersection_pair = ts_type!(true & (#base));
892 assert_eq!(end_intersection_pair.to_string(), "true & string");
893
894 let var_intersection = ts_type!((#base) & (#generic) & (#group));
895 assert_eq!(
896 var_intersection.to_string(),
897 "string & Set<string> & (string | number)"
898 );
899
900 let var_intersection_pair = ts_type!((#base) & (#generic));
901 assert_eq!(var_intersection_pair.to_string(), "string & Set<string>");
902 }
903}