use allocative::Allocative;
use dupe::Dupe;
use starlark_derive::starlark_value;
use starlark_derive::NoSerialize;
use starlark_derive::ProvidesStaticType;
use crate as starlark;
use crate::typing::Ty;
use crate::values::layout::avalue::alloc_static;
use crate::values::layout::avalue::AValueImpl;
use crate::values::layout::avalue::Basic;
use crate::values::layout::heap::repr::AValueRepr;
use crate::values::type_repr::StarlarkTypeRepr;
use crate::values::AllocFrozenValue;
use crate::values::FrozenHeap;
use crate::values::FrozenValue;
use crate::values::StarlarkValue;
use crate::values::UnpackValue;
use crate::values::Value;
#[derive(
Debug,
derive_more::Display,
Allocative,
ProvidesStaticType,
NoSerialize
)]
#[display(fmt = "{}", Self::TYPE)]
pub(crate) struct TypingCallable;
#[starlark_value(type = "typing.Callable")]
impl<'v> StarlarkValue<'v> for TypingCallable {
fn eval_type(&self) -> Option<Ty> {
Some(StarlarkCallable::starlark_type_repr())
}
}
impl AllocFrozenValue for TypingCallable {
fn alloc_frozen_value(self, _heap: &FrozenHeap) -> FrozenValue {
static CALLABLE: AValueRepr<AValueImpl<Basic, TypingCallable>> =
alloc_static(Basic, TypingCallable);
FrozenValue::new_repr(&CALLABLE)
}
}
#[derive(Debug, Copy, Clone, Dupe)]
pub struct StarlarkCallable<'v>(pub Value<'v>);
impl<'v> StarlarkTypeRepr for StarlarkCallable<'v> {
fn starlark_type_repr() -> Ty {
Ty::any_function()
}
}
impl<'v> UnpackValue<'v> for StarlarkCallable<'v> {
#[inline]
fn unpack_value(value: Value<'v>) -> Option<Self> {
if value.vtable().starlark_value.HAS_invoke {
Some(StarlarkCallable(value))
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use crate::assert;
#[test]
fn test_callable_runtime() {
assert::is_true("isinstance(lambda: None, typing.Callable)");
assert::is_true("isinstance(len, typing.Callable)");
assert::is_true("Rec = record(); isinstance(Rec, typing.Callable)");
assert::is_false("isinstance(37, typing.Callable)");
}
#[test]
fn test_callable_pass_compile_time() {
assert::pass(
r#"
Rec = record()
def foo(x: typing.Callable):
pass
def bar():
foo(len)
foo(lambda x: 1)
foo(Rec)
"#,
);
}
#[test]
fn test_callable_fail_compile_time() {
assert::fail(
r#"
def foo(x: typing.Callable):
pass
def bar():
foo(1)
"#,
"Expected type",
);
}
}