Skip to main content

rustpython_vm/stdlib/
_typing.rs

1// spell-checker:ignore typevarobject funcobj typevartuples
2use crate::{
3    Context, PyResult, VirtualMachine, builtins::pystr::AsPyStr, class::PyClassImpl,
4    function::IntoFuncArgs,
5};
6
7pub use crate::stdlib::typevar::{
8    Generic, ParamSpec, ParamSpecArgs, ParamSpecKwargs, TypeVar, TypeVarTuple,
9    set_typeparam_default,
10};
11pub(crate) use decl::module_def;
12pub use decl::*;
13
14/// Initialize typing types (call extend_class)
15pub fn init(ctx: &'static Context) {
16    NoDefault::extend_class(ctx, ctx.types.typing_no_default_type);
17}
18
19pub fn call_typing_func_object<'a>(
20    vm: &VirtualMachine,
21    func_name: impl AsPyStr<'a>,
22    args: impl IntoFuncArgs,
23) -> PyResult {
24    let module = vm.import("typing", 0)?;
25    let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?;
26    func.call(args, vm)
27}
28
29#[pymodule(name = "_typing", with(super::typevar::typevar))]
30pub(crate) mod decl {
31    use crate::common::lock::LazyLock;
32    use crate::{
33        AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, atomic_func,
34        builtins::{PyGenericAlias, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, type_},
35        common::wtf8::Wtf8Buf,
36        function::FuncArgs,
37        protocol::{PyMappingMethods, PyNumberMethods},
38        types::{AsMapping, AsNumber, Callable, Constructor, Iterable, Representable},
39    };
40
41    #[pyfunction]
42    pub(crate) fn _idfunc(args: FuncArgs, _vm: &VirtualMachine) -> PyObjectRef {
43        args.args[0].clone()
44    }
45
46    #[pyfunction(name = "override")]
47    pub(crate) fn r#override(func: PyObjectRef, vm: &VirtualMachine) -> PyResult {
48        // Set __override__ attribute to True
49        // Skip the attribute silently if it is not writable.
50        // AttributeError happens if the object has __slots__ or a
51        // read-only property, TypeError if it's a builtin class.
52        let _ = func.set_attr("__override__", vm.ctx.true_value.clone(), vm);
53        Ok(func)
54    }
55
56    #[pyclass(no_attr, name = "NoDefaultType", module = "typing")]
57    #[derive(Debug, PyPayload)]
58    pub struct NoDefault;
59
60    #[pyclass(with(Constructor, Representable), flags(IMMUTABLETYPE))]
61    impl NoDefault {
62        #[pymethod]
63        fn __reduce__(&self, _vm: &VirtualMachine) -> String {
64            "NoDefault".to_owned()
65        }
66    }
67
68    impl Constructor for NoDefault {
69        type Args = ();
70
71        fn slot_new(_cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
72            let _: () = args.bind(vm)?;
73            Ok(vm.ctx.typing_no_default.clone().into())
74        }
75
76        fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
77            unreachable!("NoDefault is a singleton, use slot_new")
78        }
79    }
80
81    impl Representable for NoDefault {
82        #[inline(always)]
83        fn repr_str(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
84            Ok("typing.NoDefault".to_owned())
85        }
86    }
87
88    #[pyattr]
89    #[pyclass(name = "_ConstEvaluator", module = "_typing")]
90    #[derive(Debug, PyPayload)]
91    pub(crate) struct ConstEvaluator {
92        value: PyObjectRef,
93    }
94
95    #[pyclass(with(Constructor, Callable, Representable), flags(IMMUTABLETYPE))]
96    impl ConstEvaluator {}
97
98    impl Constructor for ConstEvaluator {
99        type Args = FuncArgs;
100
101        fn slot_new(_cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult {
102            Err(vm.new_type_error("cannot create '_typing._ConstEvaluator' instances"))
103        }
104
105        fn py_new(_cls: &Py<PyType>, _args: Self::Args, _vm: &VirtualMachine) -> PyResult<Self> {
106            unreachable!("ConstEvaluator cannot be instantiated from Python")
107        }
108    }
109
110    /// annotationlib.Format.STRING = 4
111    const ANNOTATE_FORMAT_STRING: i32 = 4;
112
113    impl Callable for ConstEvaluator {
114        type Args = FuncArgs;
115
116        fn call(zelf: &Py<Self>, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
117            let (format,): (i32,) = args.bind(vm)?;
118            let value = &zelf.value;
119            if format == ANNOTATE_FORMAT_STRING {
120                return typing_type_repr_value(value, vm);
121            }
122            Ok(value.clone())
123        }
124    }
125
126    /// String representation of a type for annotation purposes.
127    /// Equivalent of _Py_typing_type_repr.
128    fn typing_type_repr(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
129        // Ellipsis
130        if obj.is(&vm.ctx.ellipsis) {
131            return Ok("...".to_owned());
132        }
133        // NoneType -> "None"
134        if obj.is(&vm.ctx.types.none_type.as_object()) {
135            return Ok("None".to_owned());
136        }
137        // Generic aliases (has __origin__ and __args__) -> repr
138        let has_origin = obj.get_attr("__origin__", vm).is_ok();
139        let has_args = obj.get_attr("__args__", vm).is_ok();
140        if has_origin && has_args {
141            return Ok(obj.repr(vm)?.to_string());
142        }
143        // Has __qualname__ and __module__
144        if let Ok(qualname) = obj.get_attr("__qualname__", vm)
145            && let Ok(module) = obj.get_attr("__module__", vm)
146            && !vm.is_none(&module)
147            && let Some(module_str) = module.downcast_ref::<crate::builtins::PyStr>()
148        {
149            if module_str.as_bytes() == b"builtins" {
150                return Ok(qualname.str_utf8(vm)?.as_str().to_owned());
151            }
152            return Ok(format!(
153                "{}.{}",
154                module_str.as_wtf8(),
155                qualname.str_utf8(vm)?.as_str()
156            ));
157        }
158        // Fallback to repr
159        Ok(obj.repr(vm)?.to_string())
160    }
161
162    /// Format a value as a string for ANNOTATE_FORMAT_STRING.
163    /// Handles tuples specially by wrapping in parentheses.
164    fn typing_type_repr_value(value: &PyObjectRef, vm: &VirtualMachine) -> PyResult {
165        if let Ok(tuple) = value.try_to_ref::<PyTuple>(vm) {
166            let mut parts = Vec::with_capacity(tuple.len());
167            for item in tuple.iter() {
168                parts.push(typing_type_repr(item, vm)?);
169            }
170            let inner = if parts.len() == 1 {
171                format!("{},", parts[0])
172            } else {
173                parts.join(", ")
174            };
175            Ok(vm.ctx.new_str(format!("({})", inner)).into())
176        } else {
177            Ok(vm.ctx.new_str(typing_type_repr(value, vm)?).into())
178        }
179    }
180
181    impl Representable for ConstEvaluator {
182        fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
183            let value_repr = zelf.value.repr(vm)?;
184            Ok(format!("<constevaluator {}>", value_repr))
185        }
186    }
187
188    pub(crate) fn const_evaluator_alloc(value: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
189        ConstEvaluator { value }.into_ref(&vm.ctx).into()
190    }
191
192    #[pyattr]
193    #[pyclass(name, module = "typing")]
194    #[derive(Debug, PyPayload)]
195    pub(crate) struct TypeAliasType {
196        name: PyStrRef,
197        type_params: PyTupleRef,
198        compute_value: PyObjectRef,
199        cached_value: crate::common::lock::PyMutex<Option<PyObjectRef>>,
200        module: Option<PyObjectRef>,
201        is_lazy: bool,
202    }
203    #[pyclass(
204        with(Constructor, Representable, AsMapping, AsNumber, Iterable),
205        flags(IMMUTABLETYPE)
206    )]
207    impl TypeAliasType {
208        /// Create from intrinsic: compute_value is a callable that returns the value
209        pub fn new(name: PyStrRef, type_params: PyTupleRef, compute_value: PyObjectRef) -> Self {
210            Self {
211                name,
212                type_params,
213                compute_value,
214                cached_value: crate::common::lock::PyMutex::new(None),
215                module: None,
216                is_lazy: true,
217            }
218        }
219
220        /// Create with an eagerly evaluated value (used by constructor)
221        fn new_eager(
222            name: PyStrRef,
223            type_params: PyTupleRef,
224            value: PyObjectRef,
225            module: Option<PyObjectRef>,
226        ) -> Self {
227            Self {
228                name,
229                type_params,
230                compute_value: value.clone(),
231                cached_value: crate::common::lock::PyMutex::new(Some(value)),
232                module,
233                is_lazy: false,
234            }
235        }
236
237        #[pygetset]
238        fn __name__(&self) -> PyObjectRef {
239            self.name.clone().into()
240        }
241
242        #[pygetset]
243        fn __value__(&self, vm: &VirtualMachine) -> PyResult {
244            let cached = self.cached_value.lock().clone();
245            if let Some(value) = cached {
246                return Ok(value);
247            }
248            // Call evaluator with format=1 (FORMAT_VALUE)
249            let value = self.compute_value.call((1i32,), vm)?;
250            *self.cached_value.lock() = Some(value.clone());
251            Ok(value)
252        }
253
254        #[pygetset]
255        fn __type_params__(&self) -> PyTupleRef {
256            self.type_params.clone()
257        }
258
259        #[pygetset]
260        fn __parameters__(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
261            // TypeVarTuples must be unpacked in __parameters__
262            unpack_typevartuples(&self.type_params, vm).map(|t| t.into())
263        }
264
265        #[pygetset]
266        fn __module__(&self, vm: &VirtualMachine) -> PyObjectRef {
267            if let Some(ref module) = self.module {
268                return module.clone();
269            }
270            // Fall back to compute_value's __module__ (like PyFunction_GetModule)
271            if let Ok(module) = self.compute_value.get_attr("__module__", vm) {
272                return module;
273            }
274            vm.ctx.none()
275        }
276
277        fn __getitem__(zelf: PyRef<Self>, args: PyObjectRef, vm: &VirtualMachine) -> PyResult {
278            if zelf.type_params.is_empty() {
279                return Err(vm.new_type_error("Only generic type aliases are subscriptable"));
280            }
281            let args_tuple = if let Ok(tuple) = args.try_to_ref::<PyTuple>(vm) {
282                tuple.to_owned()
283            } else {
284                PyTuple::new_ref(vec![args], &vm.ctx)
285            };
286            let origin: PyObjectRef = zelf.as_object().to_owned();
287            Ok(PyGenericAlias::new(origin, args_tuple, false, vm).into_pyobject(vm))
288        }
289
290        #[pymethod]
291        fn __reduce__(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyObjectRef {
292            zelf.name.clone().into()
293        }
294
295        #[pymethod]
296        fn __typing_unpacked_tuple_args__(&self, vm: &VirtualMachine) -> PyObjectRef {
297            vm.ctx.none()
298        }
299
300        #[pygetset]
301        fn evaluate_value(&self, vm: &VirtualMachine) -> PyResult {
302            if self.is_lazy {
303                return Ok(self.compute_value.clone());
304            }
305            Ok(const_evaluator_alloc(self.compute_value.clone(), vm))
306        }
307
308        /// Check type_params ordering: non-default params must precede default params.
309        /// Uses __default__ attribute to check if a type param has a default value,
310        /// comparing against typing.NoDefault sentinel (like get_type_param_default).
311        fn check_type_params(
312            type_params: &PyTupleRef,
313            vm: &VirtualMachine,
314        ) -> PyResult<Option<PyTupleRef>> {
315            if type_params.is_empty() {
316                return Ok(None);
317            }
318            let no_default = &vm.ctx.typing_no_default;
319            let mut default_seen = false;
320            for param in type_params.iter() {
321                let dflt = param.get_attr("__default__", vm).map_err(|_| {
322                    vm.new_type_error(format!(
323                        "Expected a type param, got {}",
324                        param
325                            .repr(vm)
326                            .map(|s| s.to_string())
327                            .unwrap_or_else(|_| "?".to_owned())
328                    ))
329                })?;
330                let is_no_default = dflt.is(no_default);
331                if is_no_default {
332                    if default_seen {
333                        return Err(vm.new_type_error(format!(
334                            "non-default type parameter '{}' follows default type parameter",
335                            param.repr(vm)?
336                        )));
337                    }
338                } else {
339                    default_seen = true;
340                }
341            }
342            Ok(Some(type_params.clone()))
343        }
344    }
345
346    impl Constructor for TypeAliasType {
347        type Args = FuncArgs;
348
349        fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
350            // typealias(name, value, *, type_params=())
351            // name and value are positional-or-keyword; type_params is keyword-only.
352
353            // Reject unexpected keyword arguments
354            for key in args.kwargs.keys() {
355                if key != "name" && key != "value" && key != "type_params" {
356                    return Err(vm.new_type_error(format!(
357                        "typealias() got an unexpected keyword argument '{key}'"
358                    )));
359                }
360            }
361
362            // Reject too many positional arguments
363            if args.args.len() > 2 {
364                return Err(vm.new_type_error(format!(
365                    "typealias() takes exactly 2 positional arguments ({} given)",
366                    args.args.len()
367                )));
368            }
369
370            // Resolve name: positional[0] or kwarg
371            let name = if !args.args.is_empty() {
372                if args.kwargs.contains_key("name") {
373                    return Err(vm.new_type_error(
374                        "argument for typealias() given by name ('name') and position (1)",
375                    ));
376                }
377                args.args[0].clone()
378            } else {
379                args.kwargs.get("name").cloned().ok_or_else(|| {
380                    vm.new_type_error("typealias() missing required argument 'name' (pos 1)")
381                })?
382            };
383
384            // Resolve value: positional[1] or kwarg
385            let value = if args.args.len() >= 2 {
386                if args.kwargs.contains_key("value") {
387                    return Err(vm.new_type_error(
388                        "argument for typealias() given by name ('value') and position (2)",
389                    ));
390                }
391                args.args[1].clone()
392            } else {
393                args.kwargs.get("value").cloned().ok_or_else(|| {
394                    vm.new_type_error("typealias() missing required argument 'value' (pos 2)")
395                })?
396            };
397
398            let name = name.downcast::<crate::builtins::PyStr>().map_err(|obj| {
399                vm.new_type_error(format!(
400                    "typealias() argument 'name' must be str, not {}",
401                    obj.class().name()
402                ))
403            })?;
404
405            let type_params = if let Some(tp) = args.kwargs.get("type_params") {
406                let tp = tp
407                    .clone()
408                    .downcast::<crate::builtins::PyTuple>()
409                    .map_err(|_| vm.new_type_error("type_params must be a tuple"))?;
410                Self::check_type_params(&tp, vm)?;
411                tp
412            } else {
413                vm.ctx.empty_tuple.clone()
414            };
415
416            // Get caller's module name from frame globals, like typevar.rs caller()
417            let module = vm
418                .current_frame()
419                .and_then(|f| f.globals.get_item("__name__", vm).ok());
420
421            Ok(Self::new_eager(name, type_params, value, module))
422        }
423    }
424
425    impl Representable for TypeAliasType {
426        fn repr_wtf8(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
427            Ok(zelf.name.as_wtf8().to_owned())
428        }
429    }
430
431    impl AsMapping for TypeAliasType {
432        fn as_mapping() -> &'static PyMappingMethods {
433            static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
434                subscript: atomic_func!(|mapping, needle, vm| {
435                    let zelf = TypeAliasType::mapping_downcast(mapping);
436                    TypeAliasType::__getitem__(zelf.to_owned(), needle.to_owned(), vm)
437                }),
438                ..PyMappingMethods::NOT_IMPLEMENTED
439            });
440            &AS_MAPPING
441        }
442    }
443
444    impl AsNumber for TypeAliasType {
445        fn as_number() -> &'static PyNumberMethods {
446            static AS_NUMBER: PyNumberMethods = PyNumberMethods {
447                or: Some(|a, b, vm| type_::or_(a.to_owned(), b.to_owned(), vm)),
448                ..PyNumberMethods::NOT_IMPLEMENTED
449            };
450            &AS_NUMBER
451        }
452    }
453
454    impl Iterable for TypeAliasType {
455        fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
456            // Import typing.Unpack and return iter((Unpack[self],))
457            let typing = vm.import("typing", 0)?;
458            let unpack = typing.get_attr("Unpack", vm)?;
459            let zelf_obj: PyObjectRef = zelf.into();
460            let unpacked = vm.call_method(&unpack, "__getitem__", (zelf_obj,))?;
461            let tuple = PyTuple::new_ref(vec![unpacked], &vm.ctx);
462            Ok(tuple.as_object().get_iter(vm)?.into())
463        }
464    }
465
466    /// Wrap TypeVarTuples in Unpack[], matching unpack_typevartuples()
467    pub(crate) fn unpack_typevartuples(
468        type_params: &PyTupleRef,
469        vm: &VirtualMachine,
470    ) -> PyResult<PyTupleRef> {
471        let has_tvt = type_params
472            .iter()
473            .any(|p| p.downcastable::<crate::stdlib::typevar::TypeVarTuple>());
474        if !has_tvt {
475            return Ok(type_params.clone());
476        }
477        let typing = vm.import("typing", 0)?;
478        let unpack_cls = typing.get_attr("Unpack", vm)?;
479        let new_params: Vec<PyObjectRef> = type_params
480            .iter()
481            .map(|p| {
482                if p.downcastable::<crate::stdlib::typevar::TypeVarTuple>() {
483                    vm.call_method(&unpack_cls, "__getitem__", (p.clone(),))
484                } else {
485                    Ok(p.clone())
486                }
487            })
488            .collect::<PyResult<_>>()?;
489        Ok(PyTuple::new_ref(new_params, &vm.ctx))
490    }
491
492    pub(crate) fn module_exec(
493        vm: &VirtualMachine,
494        module: &Py<crate::builtins::PyModule>,
495    ) -> PyResult<()> {
496        __module_exec(vm, module);
497
498        extend_module!(vm, module, {
499            "NoDefault" => vm.ctx.typing_no_default.clone(),
500            "Union" => vm.ctx.types.union_type.to_owned(),
501        });
502
503        Ok(())
504    }
505}