1use syn::Error;
2
3fn path_to_single_string(path: &syn::Path) -> Option<String> {
4 path.segments
5 .iter()
6 .next()
7 .map(|segment| segment.ident.to_string())
8}
9
10#[derive(Debug)]
11struct Group {
12 variant: GroupVariant,
13 n_init_impls: usize,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum Dependency {
18 Generic,
19 Set,
20 Unset,
21}
22
23impl std::fmt::Display for Dependency {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Dependency::Generic => write!(f, "_"),
27 Dependency::Set => write!(f, "S"),
28 Dependency::Unset => write!(f, "U"),
29 }
30 }
31}
32
33pub type Implementation = Vec<Dependency>;
34
35impl Group {
36 fn new(variant: GroupVariant) -> Self {
37 let n_init_impls = variant.n_init_impls();
38
39 Self {
40 variant,
41 n_init_impls,
42 }
43 }
44}
45
46#[derive(Debug)]
47enum GroupVariant {
48 Any(Vec<Group>),
49 All(Vec<Group>),
50 Not(Box<Group>),
51 Field(String),
52}
53
54impl GroupVariant {
55 fn n_init_impls(&self) -> usize {
56 match self {
57 GroupVariant::Any(args) => args
59 .iter()
60 .fold(0, |acc, arg| acc + arg.variant.n_init_impls()),
61 GroupVariant::All(args) => args
63 .iter()
64 .fold(1, |acc, arg| acc * arg.variant.n_init_impls()),
65 GroupVariant::Not(arg) => arg.n_init_impls,
66 GroupVariant::Field(_) => 1,
67 }
68 }
69}
70
71impl TryFrom<&syn::Expr> for Group {
72 type Error = Error;
73
74 fn try_from(expr: &syn::Expr) -> Result<Self, Self::Error> {
75 match expr {
76 syn::Expr::Call(syn::ExprCall { func, args, .. }) => match func.as_ref() {
77 syn::Expr::Path(path) => {
78 let Some(path) = path_to_single_string(&path.path)
79 else {
80 return Err(Error::new_spanned(func, "Expected `any`, `all` or `not`"))
81 };
82
83 match path.as_str() {
84 "any" => Ok(Self::new(GroupVariant::Any(
85 args.into_iter()
86 .map(Self::try_from)
87 .collect::<Result<Vec<_>, Error>>()?,
88 ))),
89 "all" => Ok(Self::new(GroupVariant::All(
90 args.into_iter()
91 .map(Self::try_from)
92 .collect::<Result<Vec<_>, Error>>()?,
93 ))),
94 "not" => {
95 let mut args_iter = args.iter();
96
97 let Some(arg) = args_iter.next()
98 else {
99 return Err(Error::new_spanned(args, "Expected exactly one argument for `not`"));
100 };
101
102 if args_iter.next().is_some() {
103 return Err(Error::new_spanned(
104 args,
105 "Expected only one argument for `not`",
106 ));
107 }
108
109 Ok(Self::new(GroupVariant::Not(Box::new(Self::try_from(arg)?))))
110 }
111 _ => Err(Error::new_spanned(func, "Expected `any`, `all` or `not`")),
112 }
113 }
114 _ => Err(Error::new_spanned(func, "Expected `any`, `all` or `not`")),
115 },
116 syn::Expr::Path(path) => {
117 let Some(field_name) = path_to_single_string(&path.path)
118 else {
119 return Err(Error::new_spanned(path, "Expected a field name"));
120 };
121
122 Ok(Self::new(GroupVariant::Field(field_name)))
123 }
124 _ => Err(Error::new_spanned(
125 expr,
126 "Expected `any(…)`, `all(…)`, `not(…)` or a field",
127 )),
128 }
129 }
130}
131
132impl Group {
133 fn set_init_impls<T>(
135 &self,
136 impls: &mut [Implementation],
137 field_names: &[T],
138 invert: bool,
139 ) -> bool
140 where
141 T: PartialEq<String>,
142 {
143 let mut contains_not = false;
144
145 match &self.variant {
146 GroupVariant::Any(args) => {
147 let mut ind = 0;
148
149 let mult = impls.len() / args.iter().fold(0, |acc, arg| acc + arg.n_init_impls);
150
151 for arg in args {
152 let new_ind = ind + mult * arg.n_init_impls;
153 let new_contains_not =
154 arg.set_init_impls(&mut impls[ind..new_ind], field_names, invert);
155 ind = new_ind;
156
157 contains_not = contains_not || new_contains_not;
158 }
159 }
160 GroupVariant::All(args) => {
161 let mut remaining_n_init_impls = impls.len();
162
163 for arg in args {
164 for split_ind in 0..impls.len() / remaining_n_init_impls {
165 let init = split_ind * remaining_n_init_impls;
166 let new_contains_not = arg.set_init_impls(
167 &mut impls[init..init + remaining_n_init_impls],
168 field_names,
169 invert,
170 );
171
172 contains_not = contains_not || new_contains_not;
173 }
174
175 remaining_n_init_impls /= arg.n_init_impls;
176 }
177 }
178 GroupVariant::Not(arg) => {
179 arg.set_init_impls(impls, field_names, !invert);
180
181 contains_not = true;
182 }
183 GroupVariant::Field(field_name) => {
184 let Some(ind) = field_names.iter().position(|x| x == field_name)
185 else {
186 panic!("{field_name} not found as struct field!");
187 };
188
189 for implementation in impls {
190 implementation[ind] = if invert {
191 Dependency::Unset
192 } else {
193 Dependency::Set
194 };
195 }
196 }
197 };
198
199 contains_not
200 }
201
202 fn no_conflict(first_impl: &[Dependency], second_impl: &[Dependency]) -> bool {
203 let mut generic_non_generic_found = false;
204 let mut non_generic_generic_found = false;
205 let mut non_generic_diff = false;
206
207 for (first_dep, second_dep) in second_impl.iter().zip(first_impl.iter()) {
208 match (first_dep, second_dep) {
209 (Dependency::Generic, Dependency::Set)
210 | (Dependency::Generic, Dependency::Unset) => {
211 generic_non_generic_found = true;
212
213 if non_generic_generic_found {
214 break;
215 }
216 }
217 (Dependency::Set, Dependency::Generic)
218 | (Dependency::Unset, Dependency::Generic) => {
219 non_generic_generic_found = true;
220
221 if generic_non_generic_found {
222 break;
223 }
224 }
225 (Dependency::Set, Dependency::Unset) | (Dependency::Unset, Dependency::Set) => {
226 non_generic_diff = true;
227 break;
228 }
229 _ => (),
230 }
231 }
232
233 non_generic_diff || (generic_non_generic_found && non_generic_generic_found)
234 }
235
236 fn check_input_conflicts(impls: &[Implementation]) -> Result<(), &'static str> {
237 for (upper_impl_ind, upper_impl) in impls.iter().enumerate().take(impls.len() - 1) {
238 let no_conflict = impls[upper_impl_ind + 1..]
239 .iter()
240 .all(|lower_impl| Self::no_conflict(lower_impl, upper_impl));
241
242 if !no_conflict {
243 return Err("A logical conflict was detected in the input. Something like `any(a, all(a, b))` which must be only `a`");
244 }
245 }
246
247 Ok(())
248 }
249
250 fn fix_conflicts(
251 init_impls: &[Implementation],
252 impls: &mut [Implementation],
253 additional_impls: &mut Vec<Implementation>,
254 new_additional_impls: &mut Vec<Implementation>,
255 focus_dep: Dependency,
256 inverse_focus_dep: Dependency,
257 ) {
258 for (impl_ind, init_impl) in init_impls.iter().enumerate() {
259 let focus_deps = init_impl
260 .iter()
261 .copied()
262 .enumerate()
263 .filter_map(|(ind, dep)| if dep == focus_dep { Some(ind) } else { None })
264 .collect::<Vec<_>>();
265
266 if focus_deps.is_empty() {
268 continue;
269 }
270
271 for lower_impl in impls
272 .iter_mut()
273 .skip(impl_ind + 1)
274 .chain(additional_impls.iter_mut())
275 {
276 let Some(shadowed_generics) = focus_deps
277 .iter()
278 .copied()
279 .filter_map(|ind| {
280 let dep = lower_impl[ind];
281
282 if dep == focus_dep {
283 None
285 } else if dep == inverse_focus_dep {
286 Some(None)
288 } else {
289 Some(Some(ind))
290 }
291 })
292 .collect::<Option<Vec<_>>>()
293 else {
294 continue;
296 };
297
298 if let Some(first_shadowed_generic) = shadowed_generics.first() {
299 lower_impl[*first_shadowed_generic] = inverse_focus_dep;
302
303 for (shadowed_generic_ind, shadowed_generic) in
308 shadowed_generics.iter().skip(1).copied().enumerate()
309 {
310 let mut additional_impl = lower_impl.clone();
311
312 for left_shadowed_generic in shadowed_generics
314 .iter()
315 .take(shadowed_generic_ind + 1)
316 .copied()
317 {
318 additional_impl[left_shadowed_generic] = focus_dep;
319 }
320 additional_impl[shadowed_generic] = inverse_focus_dep;
322
323 new_additional_impls.push(additional_impl);
324 }
325 }
326 }
327
328 additional_impls.extend_from_slice(new_additional_impls);
329 new_additional_impls.clear();
330 }
331 }
332
333 fn impls<T>(self, field_names: &[T]) -> Result<Vec<Implementation>, &'static str>
334 where
335 T: PartialEq<String>,
336 {
337 let n_fields = field_names.len();
338
339 let mut impls = vec![vec![Dependency::Generic; n_fields]; self.n_init_impls];
340
341 let contains_not = self.set_init_impls(&mut impls, field_names, false);
342
343 Self::check_input_conflicts(&impls)?;
344
345 impls.sort_by_cached_key(|implementation| {
347 implementation.iter().fold(0, |acc, dep| match dep {
348 Dependency::Set | Dependency::Unset => acc + 1,
349 Dependency::Generic => acc,
350 })
351 });
352
353 let mut init_impls = impls.clone();
354 let mut additional_impls = Vec::new();
355 let mut new_additional_impls = Vec::new();
356
357 Self::fix_conflicts(
359 &init_impls,
360 &mut impls,
361 &mut additional_impls,
362 &mut new_additional_impls,
363 Dependency::Set,
364 Dependency::Unset,
365 );
366
367 if contains_not {
368 Self::fix_conflicts(
370 &init_impls,
371 &mut impls,
372 &mut additional_impls,
373 &mut new_additional_impls,
374 Dependency::Unset,
375 Dependency::Set,
376 );
377
378 init_impls.clear();
382 let mut cleaned_up_init_impls = init_impls;
383
384 impls.sort_by_cached_key(|implementation| {
386 implementation
387 .iter()
388 .filter(|dep| matches!(dep, Dependency::Set | Dependency::Unset))
389 .count()
390 });
391
392 for (impl_ind, lower_impl) in impls.iter().rev().enumerate() {
394 let no_conflict = impls
395 .iter()
396 .rev()
397 .skip(impl_ind + 1)
398 .all(|upper_impl| Self::no_conflict(upper_impl, lower_impl));
399
400 if no_conflict {
403 cleaned_up_init_impls.push(lower_impl.clone());
404 }
405 }
406
407 impls = cleaned_up_init_impls;
408 }
409
410 new_additional_impls.clear();
412 let mut cleaned_up_additional_impls = new_additional_impls;
413
414 for (additional_impl_ind, additional_impl) in additional_impls.iter().enumerate() {
416 let no_conflict = impls
417 .iter()
418 .chain(additional_impls[..additional_impl_ind].iter())
419 .all(|implementation| Self::no_conflict(implementation, additional_impl));
420
421 if no_conflict {
422 cleaned_up_additional_impls.push(additional_impl.clone());
423 }
424 }
425
426 impls.extend_from_slice(&cleaned_up_additional_impls);
427
428 Ok(impls)
429 }
430}
431
432pub fn impls<T>(expr: &syn::Expr, field_names: &[T]) -> Result<Vec<Implementation>, Error>
433where
434 T: PartialEq<String>,
435{
436 let group = Group::try_from(expr)?;
437
438 group
439 .impls(field_names)
440 .map_err(|e| Error::new_spanned(expr, e))
441}