use std::fmt::{self, Debug};
use std::sync::Arc;
use ecow::eco_vec;
use super::{Context, Error, TryFromValue, Type, Value};
type FuncImpl<T> = Arc<dyn Fn(&Context<T>, &[Value<T>]) -> Result<Value<T>, Error>>;
#[derive(Clone)]
pub struct Func<T>(FuncImpl<T>);
impl<T> Func<T> {
pub fn new<F>(f: F) -> Self
where
F: Fn(&Context<T>, &[Value<T>]) -> Result<Value<T>, Error> + 'static,
{
Self(Arc::new(f) as _)
}
pub fn call(&self, ctx: &Context<T>, args: &[Value<T>]) -> Result<Value<T>, Error> {
(self.0)(ctx, args)
}
}
impl<T> Debug for Func<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Func").field(&..).finish()
}
}
impl<T> Func<T> {
pub fn expect_no_args(id: &str, _ctx: &Context<T>, args: &[Value<T>]) -> Result<(), Error> {
if args.is_empty() {
Ok(())
} else {
Err(Error::InvalidArgumentCount {
func: id.into(),
expected: 0,
is_min: false,
found: args.len(),
})
}
}
pub fn expect_args_exact<V: TryFromValue<T> + Debug, const N: usize>(
func: &str,
_ctx: &Context<T>,
args: &[Value<T>],
) -> Result<[V; N], Error>
where
T: Clone,
{
if args.len() < N {
return Err(Error::InvalidArgumentCount {
func: func.into(),
expected: N,
is_min: false,
found: args.len(),
});
}
Ok(args
.iter()
.take(N)
.cloned()
.map(V::try_from_value)
.collect::<Result<Vec<_>, _>>()?
.try_into()
.expect("we checked both min and max of the args"))
}
pub fn expect_args_min<V: TryFromValue<T> + Debug, const N: usize>(
func: &str,
_ctx: &Context<T>,
args: &[Value<T>],
) -> Result<([V; N], Vec<V>), Error>
where
T: Clone,
{
if args.len() < N {
return Err(Error::InvalidArgumentCount {
func: func.into(),
expected: N,
is_min: true,
found: args.len(),
});
}
let min = args
.iter()
.take(N)
.cloned()
.map(V::try_from_value)
.collect::<Result<Vec<_>, _>>()?
.try_into()
.expect("we checked both min and max of the args");
Ok((
min,
args[N..]
.iter()
.cloned()
.map(V::try_from_value)
.collect::<Result<_, _>>()?,
))
}
}
impl<T> TryFromValue<T> for Func<T> {
fn try_from_value(value: Value<T>) -> Result<Self, Error> {
Ok(match value {
Value::Func(set) => set,
_ => {
return Err(Error::TypeMismatch {
expected: eco_vec![Type::Func],
found: value.as_type(),
})
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ast::Num;
const NUM: Num = Num(0);
const VAL: Value<()> = Value::Num(NUM);
#[test]
fn test_expect_args_variadic_min_length() {
let ctx = Context::new();
assert_eq!(
Func::expect_args_min::<Num, 0>("f", &ctx, &[]).unwrap(),
([], vec![]),
);
assert_eq!(
Func::expect_args_min("f", &ctx, &[VAL]).unwrap(),
([], vec![NUM]),
);
assert_eq!(
Func::expect_args_min("f", &ctx, &[VAL, VAL]).unwrap(),
([], vec![NUM, NUM]),
);
assert!(Func::expect_args_min::<Num, 1>("f", &ctx, &[]).is_err());
assert_eq!(
Func::expect_args_min("f", &ctx, &[VAL]).unwrap(),
([NUM], vec![]),
);
assert_eq!(
Func::expect_args_min("f", &ctx, &[VAL, VAL]).unwrap(),
([NUM], vec![NUM]),
);
assert!(Func::expect_args_min::<Num, 2>("f", &ctx, &[]).is_err());
assert!(Func::expect_args_min::<Num, 2>("f", &ctx, &[VAL]).is_err(),);
assert_eq!(
Func::expect_args_min("f", &ctx, &[VAL, VAL]).unwrap(),
([NUM, NUM], vec![]),
);
}
}