1use specta::{
2 Types,
3 datatype::{DataType, Fields, NamedReferenceType, Reference},
4};
5
6#[derive(Debug, Clone, Default)]
49pub struct Remapper {
50 rules: Vec<(DataType, DataType)>,
51}
52
53impl Remapper {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn rule(mut self, from: DataType, to: DataType) -> Self {
63 self.rules.push((from, to));
64 self
65 }
66
67 pub fn remap_dt(&self, mut dt: DataType) -> DataType {
69 self.remap_internal(&mut dt);
70 dt
71 }
72
73 pub fn remap_types(&self, types: Types) -> Types {
75 types.map(|mut ndt| {
76 ndt.generics.to_mut().iter_mut().for_each(|generic| {
77 if let Some(dt) = &mut generic.default {
78 self.remap_internal(dt);
79 }
80 });
81 if let Some(dt) = &mut ndt.ty {
82 self.remap_internal(dt);
83 }
84 ndt
85 })
86 }
87
88 fn remap_internal(&self, dt: &mut DataType) {
89 self.remap_rules(dt);
90
91 match dt {
92 DataType::Primitive(_) | DataType::Generic(_) => {}
93 DataType::List(list) => self.remap_internal(&mut list.ty),
94 DataType::Map(map) => {
95 self.remap_internal(map.key_ty_mut());
96 self.remap_internal(map.value_ty_mut());
97 }
98 DataType::Struct(s) => self.remap_fields(&mut s.fields),
99 DataType::Enum(e) => {
100 for (_, variant) in &mut e.variants {
101 self.remap_fields(&mut variant.fields);
102 }
103 }
104 DataType::Tuple(tuple) => {
105 for dt in &mut tuple.elements {
106 self.remap_internal(dt);
107 }
108 }
109 DataType::Nullable(dt) => self.remap_internal(dt),
110 DataType::Intersection(dts) => {
111 for dt in dts {
112 self.remap_internal(dt);
113 }
114 }
115 DataType::Reference(r) => self.remap_reference(r),
116 }
117 }
118
119 fn remap_rules(&self, dt: &mut DataType) {
120 for (from, to) in &self.rules {
121 if *dt == *from {
122 *dt = to.clone();
123 }
124 }
125 }
126
127 fn remap_fields(&self, fields: &mut Fields) {
128 match fields {
129 Fields::Unit => {}
130 Fields::Unnamed(fields) => {
131 for field in &mut fields.fields {
132 if let Some(dt) = &mut field.ty {
133 self.remap_internal(dt);
134 }
135 }
136 }
137 Fields::Named(fields) => {
138 for (_, field) in &mut fields.fields {
139 if let Some(dt) = &mut field.ty {
140 self.remap_internal(dt);
141 }
142 }
143 }
144 }
145 }
146
147 fn remap_reference(&self, reference: &mut Reference) {
148 let Reference::Named(reference) = reference else {
149 return;
150 };
151
152 match &mut reference.inner {
153 NamedReferenceType::Recursive(_) => {}
154 NamedReferenceType::Inline { dt, .. } => self.remap_internal(dt),
155 NamedReferenceType::Reference { generics, .. } => {
156 for (_, dt) in generics {
157 self.remap_internal(dt);
158 }
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use specta::{
167 Types,
168 datatype::{DataType, Field, List, NamedDataType, Primitive, Struct, Tuple},
169 };
170
171 use super::Remapper;
172
173 #[test]
174 fn remaps_multiple_rules_in_one_crawl() {
175 let dt = DataType::Tuple(Tuple::new(vec![
176 Primitive::u32.into(),
177 Primitive::i32.into(),
178 ]));
179
180 let remapped = Remapper::new()
181 .rule(Primitive::u32.into(), Primitive::str.into())
182 .rule(Primitive::i32.into(), Primitive::bool.into())
183 .remap_dt(dt);
184
185 assert_eq!(
186 remapped,
187 DataType::Tuple(Tuple::new(vec![
188 Primitive::str.into(),
189 Primitive::bool.into()
190 ]))
191 );
192 }
193
194 #[test]
195 fn rules_are_piped_in_registration_order() {
196 let remapped = Remapper::new()
197 .rule(Primitive::u32.into(), Primitive::i32.into())
198 .rule(Primitive::i32.into(), Primitive::bool.into())
199 .remap_dt(Primitive::u32.into());
200
201 assert_eq!(remapped, Primitive::bool.into());
202 }
203
204 #[test]
205 fn replacement_is_recrawled() {
206 let remapped = Remapper::new()
207 .rule(
208 Primitive::u32.into(),
209 DataType::List(List::new(Primitive::i32.into())),
210 )
211 .rule(Primitive::i32.into(), Primitive::bool.into())
212 .remap_dt(Primitive::u32.into());
213
214 assert_eq!(remapped, DataType::List(List::new(Primitive::bool.into())));
215 }
216
217 #[test]
218 fn remaps_named_type_bodies() {
219 let mut types = Types::default();
220 NamedDataType::new("User", &mut types, |_, ty| {
221 ty.ty = Some(
222 Struct::named()
223 .field("id", Field::new(Primitive::u32.into()))
224 .field("active", Field::new(Primitive::i32.into()))
225 .build(),
226 );
227 });
228
229 let types = Remapper::new()
230 .rule(Primitive::u32.into(), Primitive::str.into())
231 .rule(Primitive::i32.into(), Primitive::bool.into())
232 .remap_types(types);
233
234 let debug = format!("{types:?}");
235 assert!(debug.contains("Primitive(str)"));
236 assert!(debug.contains("Primitive(bool)"));
237 }
238}