use syn::Error;
fn path_to_single_string(path: &syn::Path) -> Option<String> {
path.segments
.iter()
.next()
.map(|segment| segment.ident.to_string())
}
#[derive(Debug)]
struct Group {
variant: GroupVariant,
n_init_impls: usize,
}
impl Group {
fn new(variant: GroupVariant) -> Self {
let n_init_impls = variant.n_init_impls();
Self {
variant,
n_init_impls,
}
}
}
#[derive(Debug)]
enum GroupVariant {
Any(Vec<Group>),
All(Vec<Group>),
Not(Box<Group>),
Field(String),
}
impl GroupVariant {
fn n_init_impls(&self) -> usize {
match self {
GroupVariant::Any(args) => args
.iter()
.fold(0, |acc, arg| acc + arg.variant.n_init_impls()),
GroupVariant::All(args) => args
.iter()
.fold(1, |acc, arg| acc * arg.variant.n_init_impls()),
GroupVariant::Not(arg) => arg.n_init_impls,
GroupVariant::Field(_) => 1,
}
}
}
impl TryFrom<&syn::Expr> for Group {
type Error = Error;
fn try_from(expr: &syn::Expr) -> Result<Self, Self::Error> {
match expr {
syn::Expr::Call(syn::ExprCall { func, args, .. }) => match func.as_ref() {
syn::Expr::Path(path) => {
let Some(path) = path_to_single_string(&path.path)
else {
return Err(Error::new_spanned(func, "Expected `any`, `all` or `not`"))
};
match path.as_str() {
"any" => Ok(Self::new(GroupVariant::Any(
args.into_iter()
.map(Self::try_from)
.collect::<Result<Vec<_>, Error>>()?,
))),
"all" => Ok(Self::new(GroupVariant::All(
args.into_iter()
.map(Self::try_from)
.collect::<Result<Vec<_>, Error>>()?,
))),
"not" => {
let mut args_iter = args.iter();
let Some(arg) = args_iter.next()
else {
return Err(Error::new_spanned(args, "Expected exactly one argument for `not`"));
};
if args_iter.next().is_some() {
return Err(Error::new_spanned(
args,
"Expected only one argument for `not`",
));
}
Ok(Self::new(GroupVariant::Not(Box::new(Self::try_from(arg)?))))
}
_ => Err(Error::new_spanned(func, "Expected `any`, `all` or `not`")),
}
}
_ => Err(Error::new_spanned(func, "Expected `any`, `all` or `not`")),
},
syn::Expr::Path(path) => {
let Some(field_name) = path_to_single_string(&path.path)
else {
return Err(Error::new_spanned(path, "Expected a field name"));
};
Ok(Self::new(GroupVariant::Field(field_name)))
}
_ => Err(Error::new_spanned(
expr,
"Expected `any(…)`, `all(…)`, `not(…)` or a field",
)),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Dependency {
Generic,
Set,
Unset,
}
impl std::fmt::Display for Dependency {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Dependency::Generic => write!(f, "_"),
Dependency::Set => write!(f, "S"),
Dependency::Unset => write!(f, "U"),
}
}
}
impl Group {
fn set_init_impls<T>(
&self,
impls: &mut [Vec<Dependency>],
field_names: &[T],
invert: bool,
) -> bool
where
T: PartialEq<String>,
{
let mut contains_not = false;
match &self.variant {
GroupVariant::Any(args) => {
let mut ind = 0;
let mult = impls.len() / args.iter().fold(0, |acc, arg| acc + arg.n_init_impls);
for arg in args {
let new_ind = ind + mult * arg.n_init_impls;
let new_contains_not =
arg.set_init_impls(&mut impls[ind..new_ind], field_names, invert);
ind = new_ind;
contains_not = contains_not || new_contains_not;
}
}
GroupVariant::All(args) => {
let mut remaining_n_init_impls = impls.len();
for arg in args {
for split_ind in 0..impls.len() / remaining_n_init_impls {
let init = split_ind * remaining_n_init_impls;
let new_contains_not = arg.set_init_impls(
&mut impls[init..init + remaining_n_init_impls],
field_names,
invert,
);
contains_not = contains_not || new_contains_not;
}
remaining_n_init_impls /= arg.n_init_impls;
}
}
GroupVariant::Not(arg) => {
arg.set_init_impls(impls, field_names, !invert);
contains_not = true;
}
GroupVariant::Field(field_name) => {
let Some(ind) = field_names.iter().position(|x| x == field_name)
else {
panic!("{field_name} not found as struct field!");
};
for implementation in impls {
implementation[ind] = if invert {
Dependency::Unset
} else {
Dependency::Set
};
}
}
};
contains_not
}
fn no_conflict(first_impl: &[Dependency], second_impl: &[Dependency]) -> bool {
let mut generic_non_generic_found = false;
let mut non_generic_generic_found = false;
let mut non_generic_diff = false;
for (first_dep, second_dep) in second_impl.iter().zip(first_impl.iter()) {
match (first_dep, second_dep) {
(Dependency::Generic, Dependency::Set)
| (Dependency::Generic, Dependency::Unset) => {
generic_non_generic_found = true;
if non_generic_generic_found {
break;
}
}
(Dependency::Set, Dependency::Generic)
| (Dependency::Unset, Dependency::Generic) => {
non_generic_generic_found = true;
if generic_non_generic_found {
break;
}
}
(Dependency::Set, Dependency::Unset) | (Dependency::Unset, Dependency::Set) => {
non_generic_diff = true;
break;
}
_ => (),
}
}
non_generic_diff || (generic_non_generic_found && non_generic_generic_found)
}
fn check_input_conflicts(impls: &[Vec<Dependency>]) -> Result<(), &'static str> {
for (upper_impl_ind, upper_impl) in impls.iter().enumerate().take(impls.len() - 1) {
let no_conflict = impls[upper_impl_ind + 1..]
.iter()
.all(|lower_impl| Self::no_conflict(lower_impl, upper_impl));
if !no_conflict {
return Err("A logical conflict was detected in the input. Something like `any(a, all(a, b))` which must be only `a`");
}
}
Ok(())
}
fn fix_conflicts(
init_impls: &[Vec<Dependency>],
impls: &mut [Vec<Dependency>],
additional_impls: &mut Vec<Vec<Dependency>>,
new_additional_impls: &mut Vec<Vec<Dependency>>,
focus_dep: Dependency,
inverse_focus_dep: Dependency,
) {
for (impl_ind, init_impl) in init_impls.iter().enumerate() {
let focus_deps = init_impl
.iter()
.copied()
.enumerate()
.filter_map(|(ind, dep)| if dep == focus_dep { Some(ind) } else { None })
.collect::<Vec<_>>();
if focus_deps.is_empty() {
continue;
}
for lower_impl in impls
.iter_mut()
.skip(impl_ind + 1)
.chain(additional_impls.iter_mut())
{
let Some(shadowed_generics) = focus_deps
.iter()
.copied()
.filter_map(|ind| {
let dep = lower_impl[ind];
if dep == focus_dep {
None
} else if dep == inverse_focus_dep {
Some(None)
} else {
Some(Some(ind))
}
})
.collect::<Option<Vec<_>>>()
else {
continue;
};
if let Some(first_shadowed_generic) = shadowed_generics.first() {
lower_impl[*first_shadowed_generic] = inverse_focus_dep;
for (shadowed_generic_ind, shadowed_generic) in
shadowed_generics.iter().skip(1).copied().enumerate()
{
let mut additional_impl = lower_impl.clone();
for left_shadowed_generic in shadowed_generics
.iter()
.take(shadowed_generic_ind + 1)
.copied()
{
additional_impl[left_shadowed_generic] = focus_dep;
}
additional_impl[shadowed_generic] = inverse_focus_dep;
new_additional_impls.push(additional_impl);
}
}
}
additional_impls.extend_from_slice(new_additional_impls);
new_additional_impls.clear();
}
}
fn impls<T>(self, field_names: &[T]) -> Result<Vec<Vec<Dependency>>, &'static str>
where
T: PartialEq<String>,
{
let n_fields = field_names.len();
let mut impls = vec![vec![Dependency::Generic; n_fields]; self.n_init_impls];
let contains_not = self.set_init_impls(&mut impls, field_names, false);
Self::check_input_conflicts(&impls)?;
impls.sort_by_cached_key(|implementation| {
implementation.iter().fold(0, |acc, dep| match dep {
Dependency::Set | Dependency::Unset => acc + 1,
Dependency::Generic => acc,
})
});
let mut init_impls = impls.clone();
let mut additional_impls = Vec::new();
let mut new_additional_impls = Vec::new();
Self::fix_conflicts(
&init_impls,
&mut impls,
&mut additional_impls,
&mut new_additional_impls,
Dependency::Set,
Dependency::Unset,
);
if contains_not {
Self::fix_conflicts(
&init_impls,
&mut impls,
&mut additional_impls,
&mut new_additional_impls,
Dependency::Unset,
Dependency::Set,
);
init_impls.clear();
let mut cleaned_up_init_impls = init_impls;
impls.sort_by_cached_key(|implementation| {
implementation
.iter()
.filter(|dep| matches!(dep, Dependency::Set | Dependency::Unset))
.count()
});
for (impl_ind, lower_impl) in impls.iter().rev().enumerate() {
let no_conflict = impls
.iter()
.rev()
.skip(impl_ind + 1)
.all(|upper_impl| Self::no_conflict(upper_impl, lower_impl));
if no_conflict {
cleaned_up_init_impls.push(lower_impl.clone());
}
}
impls = cleaned_up_init_impls;
}
new_additional_impls.clear();
let mut cleaned_up_additional_impls = new_additional_impls;
for (additional_impl_ind, additional_impl) in additional_impls.iter().enumerate() {
let no_conflict = impls
.iter()
.chain(additional_impls[..additional_impl_ind].iter())
.all(|implementation| Self::no_conflict(implementation, additional_impl));
if no_conflict {
cleaned_up_additional_impls.push(additional_impl.clone());
}
}
impls.extend_from_slice(&cleaned_up_additional_impls);
Ok(impls)
}
}
pub fn impls<T>(expr: &syn::Expr, field_names: &[T]) -> Result<Vec<Vec<Dependency>>, Error>
where
T: PartialEq<String>,
{
let group = Group::try_from(expr)?;
group
.impls(field_names)
.map_err(|e| Error::new_spanned(expr, e))
}