sus_impls/
lib.rs

1use syn::Error;
2
3fn path_to_single_string(path: &syn::Path) -> Option<String> {
4    path.segments
5        .iter()
6        .next()
7        .map(|segment| segment.ident.to_string())
8}
9
10#[derive(Debug)]
11struct Group {
12    variant: GroupVariant,
13    n_init_impls: usize,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum Dependency {
18    Generic,
19    Set,
20    Unset,
21}
22
23impl std::fmt::Display for Dependency {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Dependency::Generic => write!(f, "_"),
27            Dependency::Set => write!(f, "S"),
28            Dependency::Unset => write!(f, "U"),
29        }
30    }
31}
32
33pub type Implementation = Vec<Dependency>;
34
35impl Group {
36    fn new(variant: GroupVariant) -> Self {
37        let n_init_impls = variant.n_init_impls();
38
39        Self {
40            variant,
41            n_init_impls,
42        }
43    }
44}
45
46#[derive(Debug)]
47enum GroupVariant {
48    Any(Vec<Group>),
49    All(Vec<Group>),
50    Not(Box<Group>),
51    Field(String),
52}
53
54impl GroupVariant {
55    fn n_init_impls(&self) -> usize {
56        match self {
57            // Sum!
58            GroupVariant::Any(args) => args
59                .iter()
60                .fold(0, |acc, arg| acc + arg.variant.n_init_impls()),
61            // Product!
62            GroupVariant::All(args) => args
63                .iter()
64                .fold(1, |acc, arg| acc * arg.variant.n_init_impls()),
65            GroupVariant::Not(arg) => arg.n_init_impls,
66            GroupVariant::Field(_) => 1,
67        }
68    }
69}
70
71impl TryFrom<&syn::Expr> for Group {
72    type Error = Error;
73
74    fn try_from(expr: &syn::Expr) -> Result<Self, Self::Error> {
75        match expr {
76            syn::Expr::Call(syn::ExprCall { func, args, .. }) => match func.as_ref() {
77                syn::Expr::Path(path) => {
78                    let Some(path) = path_to_single_string(&path.path)
79                    else {
80                        return Err(Error::new_spanned(func, "Expected `any`, `all` or `not`"))
81                    };
82
83                    match path.as_str() {
84                        "any" => Ok(Self::new(GroupVariant::Any(
85                            args.into_iter()
86                                .map(Self::try_from)
87                                .collect::<Result<Vec<_>, Error>>()?,
88                        ))),
89                        "all" => Ok(Self::new(GroupVariant::All(
90                            args.into_iter()
91                                .map(Self::try_from)
92                                .collect::<Result<Vec<_>, Error>>()?,
93                        ))),
94                        "not" => {
95                            let mut args_iter = args.iter();
96
97                            let Some(arg) = args_iter.next()
98                            else {
99                                return Err(Error::new_spanned(args, "Expected exactly one argument for `not`"));
100                            };
101
102                            if args_iter.next().is_some() {
103                                return Err(Error::new_spanned(
104                                    args,
105                                    "Expected only one argument for `not`",
106                                ));
107                            }
108
109                            Ok(Self::new(GroupVariant::Not(Box::new(Self::try_from(arg)?))))
110                        }
111                        _ => Err(Error::new_spanned(func, "Expected `any`, `all` or `not`")),
112                    }
113                }
114                _ => Err(Error::new_spanned(func, "Expected `any`, `all` or `not`")),
115            },
116            syn::Expr::Path(path) => {
117                let Some(field_name) = path_to_single_string(&path.path)
118                else {
119                    return Err(Error::new_spanned(path, "Expected a field name"));
120                };
121
122                Ok(Self::new(GroupVariant::Field(field_name)))
123            }
124            _ => Err(Error::new_spanned(
125                expr,
126                "Expected `any(…)`, `all(…)`, `not(…)` or a field",
127            )),
128        }
129    }
130}
131
132impl Group {
133    // Returns a bool about whether the group contains a `not`.
134    fn set_init_impls<T>(
135        &self,
136        impls: &mut [Implementation],
137        field_names: &[T],
138        invert: bool,
139    ) -> bool
140    where
141        T: PartialEq<String>,
142    {
143        let mut contains_not = false;
144
145        match &self.variant {
146            GroupVariant::Any(args) => {
147                let mut ind = 0;
148
149                let mult = impls.len() / args.iter().fold(0, |acc, arg| acc + arg.n_init_impls);
150
151                for arg in args {
152                    let new_ind = ind + mult * arg.n_init_impls;
153                    let new_contains_not =
154                        arg.set_init_impls(&mut impls[ind..new_ind], field_names, invert);
155                    ind = new_ind;
156
157                    contains_not = contains_not || new_contains_not;
158                }
159            }
160            GroupVariant::All(args) => {
161                let mut remaining_n_init_impls = impls.len();
162
163                for arg in args {
164                    for split_ind in 0..impls.len() / remaining_n_init_impls {
165                        let init = split_ind * remaining_n_init_impls;
166                        let new_contains_not = arg.set_init_impls(
167                            &mut impls[init..init + remaining_n_init_impls],
168                            field_names,
169                            invert,
170                        );
171
172                        contains_not = contains_not || new_contains_not;
173                    }
174
175                    remaining_n_init_impls /= arg.n_init_impls;
176                }
177            }
178            GroupVariant::Not(arg) => {
179                arg.set_init_impls(impls, field_names, !invert);
180
181                contains_not = true;
182            }
183            GroupVariant::Field(field_name) => {
184                let Some(ind) = field_names.iter().position(|x| x == field_name)
185                else {
186                    panic!("{field_name} not found as struct field!");
187                };
188
189                for implementation in impls {
190                    implementation[ind] = if invert {
191                        Dependency::Unset
192                    } else {
193                        Dependency::Set
194                    };
195                }
196            }
197        };
198
199        contains_not
200    }
201
202    fn no_conflict(first_impl: &[Dependency], second_impl: &[Dependency]) -> bool {
203        let mut generic_non_generic_found = false;
204        let mut non_generic_generic_found = false;
205        let mut non_generic_diff = false;
206
207        for (first_dep, second_dep) in second_impl.iter().zip(first_impl.iter()) {
208            match (first_dep, second_dep) {
209                (Dependency::Generic, Dependency::Set)
210                | (Dependency::Generic, Dependency::Unset) => {
211                    generic_non_generic_found = true;
212
213                    if non_generic_generic_found {
214                        break;
215                    }
216                }
217                (Dependency::Set, Dependency::Generic)
218                | (Dependency::Unset, Dependency::Generic) => {
219                    non_generic_generic_found = true;
220
221                    if generic_non_generic_found {
222                        break;
223                    }
224                }
225                (Dependency::Set, Dependency::Unset) | (Dependency::Unset, Dependency::Set) => {
226                    non_generic_diff = true;
227                    break;
228                }
229                _ => (),
230            }
231        }
232
233        non_generic_diff || (generic_non_generic_found && non_generic_generic_found)
234    }
235
236    fn check_input_conflicts(impls: &[Implementation]) -> Result<(), &'static str> {
237        for (upper_impl_ind, upper_impl) in impls.iter().enumerate().take(impls.len() - 1) {
238            let no_conflict = impls[upper_impl_ind + 1..]
239                .iter()
240                .all(|lower_impl| Self::no_conflict(lower_impl, upper_impl));
241
242            if !no_conflict {
243                return Err("A logical conflict was detected in the input. Something like `any(a, all(a, b))` which must be only `a`");
244            }
245        }
246
247        Ok(())
248    }
249
250    fn fix_conflicts(
251        init_impls: &[Implementation],
252        impls: &mut [Implementation],
253        additional_impls: &mut Vec<Implementation>,
254        new_additional_impls: &mut Vec<Implementation>,
255        focus_dep: Dependency,
256        inverse_focus_dep: Dependency,
257    ) {
258        for (impl_ind, init_impl) in init_impls.iter().enumerate() {
259            let focus_deps = init_impl
260                .iter()
261                .copied()
262                .enumerate()
263                .filter_map(|(ind, dep)| if dep == focus_dep { Some(ind) } else { None })
264                .collect::<Vec<_>>();
265
266            // Could happen if there is a `not`.
267            if focus_deps.is_empty() {
268                continue;
269            }
270
271            for lower_impl in impls
272                .iter_mut()
273                .skip(impl_ind + 1)
274                .chain(additional_impls.iter_mut())
275            {
276                let Some(shadowed_generics) = focus_deps
277                    .iter()
278                    .copied()
279                    .filter_map(|ind| {
280                        let dep = lower_impl[ind];
281
282                        if dep == focus_dep {
283                            // Ignore only this dep.
284                            None
285                        } else if dep == inverse_focus_dep {
286                            // Ignore the whole lower implementation and go to the next one.
287                            Some(None)
288                        } else {
289                            Some(Some(ind))
290                        }
291                    })
292                    .collect::<Option<Vec<_>>>()
293                else {
294                    // dep == inverse_focus_dep.
295                    continue;
296                };
297
298                if let Some(first_shadowed_generic) = shadowed_generics.first() {
299                    // Set the first inverse by mutating the initial implementation.
300                    // U _ _
301                    lower_impl[*first_shadowed_generic] = inverse_focus_dep;
302
303                    // Set the possible rest of the diagonal shadow with additional implementations.
304                    // S U _
305                    // S S U
306                    // etc.
307                    for (shadowed_generic_ind, shadowed_generic) in
308                        shadowed_generics.iter().skip(1).copied().enumerate()
309                    {
310                        let mut additional_impl = lower_impl.clone();
311
312                        // Set the focus deps.
313                        for left_shadowed_generic in shadowed_generics
314                            .iter()
315                            .take(shadowed_generic_ind + 1)
316                            .copied()
317                        {
318                            additional_impl[left_shadowed_generic] = focus_dep;
319                        }
320                        // Set the inverse.
321                        additional_impl[shadowed_generic] = inverse_focus_dep;
322
323                        new_additional_impls.push(additional_impl);
324                    }
325                }
326            }
327
328            additional_impls.extend_from_slice(new_additional_impls);
329            new_additional_impls.clear();
330        }
331    }
332
333    fn impls<T>(self, field_names: &[T]) -> Result<Vec<Implementation>, &'static str>
334    where
335        T: PartialEq<String>,
336    {
337        let n_fields = field_names.len();
338
339        let mut impls = vec![vec![Dependency::Generic; n_fields]; self.n_init_impls];
340
341        let contains_not = self.set_init_impls(&mut impls, field_names, false);
342
343        Self::check_input_conflicts(&impls)?;
344
345        // Sort implementations on the number of non-generic dependencies.
346        impls.sort_by_cached_key(|implementation| {
347            implementation.iter().fold(0, |acc, dep| match dep {
348                Dependency::Set | Dependency::Unset => acc + 1,
349                Dependency::Generic => acc,
350            })
351        });
352
353        let mut init_impls = impls.clone();
354        let mut additional_impls = Vec::new();
355        let mut new_additional_impls = Vec::new();
356
357        // Fix conflicts related to fields with `S` in the initial implementations.
358        Self::fix_conflicts(
359            &init_impls,
360            &mut impls,
361            &mut additional_impls,
362            &mut new_additional_impls,
363            Dependency::Set,
364            Dependency::Unset,
365        );
366
367        if contains_not {
368            // Fix conflicts related to `not` fields with `U` in the initial implementations.
369            Self::fix_conflicts(
370                &init_impls,
371                &mut impls,
372                &mut additional_impls,
373                &mut new_additional_impls,
374                Dependency::Unset,
375                Dependency::Set,
376            );
377
378            // Remove possible less generic conflicting implementations from the mutated initial implementations.
379
380            // Recycling allocated vector.
381            init_impls.clear();
382            let mut cleaned_up_init_impls = init_impls;
383
384            // Sort implementations to have the one with more generics higher.
385            impls.sort_by_cached_key(|implementation| {
386                implementation
387                    .iter()
388                    .filter(|dep| matches!(dep, Dependency::Set | Dependency::Unset))
389                    .count()
390            });
391
392            // Iterate starting from the bottom where the less generic implementations are.
393            for (impl_ind, lower_impl) in impls.iter().rev().enumerate() {
394                let no_conflict = impls
395                    .iter()
396                    .rev()
397                    .skip(impl_ind + 1)
398                    .all(|upper_impl| Self::no_conflict(upper_impl, lower_impl));
399
400                // If there is a conflict, the less generic implementation is removed
401                // (not added to the cleaned up ones).
402                if no_conflict {
403                    cleaned_up_init_impls.push(lower_impl.clone());
404                }
405            }
406
407            impls = cleaned_up_init_impls;
408        }
409
410        // Recycling allocated vector.
411        new_additional_impls.clear();
412        let mut cleaned_up_additional_impls = new_additional_impls;
413
414        // Remove conflicting additional implementations.
415        for (additional_impl_ind, additional_impl) in additional_impls.iter().enumerate() {
416            let no_conflict = impls
417                .iter()
418                .chain(additional_impls[..additional_impl_ind].iter())
419                .all(|implementation| Self::no_conflict(implementation, additional_impl));
420
421            if no_conflict {
422                cleaned_up_additional_impls.push(additional_impl.clone());
423            }
424        }
425
426        impls.extend_from_slice(&cleaned_up_additional_impls);
427
428        Ok(impls)
429    }
430}
431
432pub fn impls<T>(expr: &syn::Expr, field_names: &[T]) -> Result<Vec<Implementation>, Error>
433where
434    T: PartialEq<String>,
435{
436    let group = Group::try_from(expr)?;
437
438    group
439        .impls(field_names)
440        .map_err(|e| Error::new_spanned(expr, e))
441}