Skip to main content

rustpython_vm/builtins/
coroutine.rs

1use super::{PyCode, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
2use crate::{
3    AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
4    class::PyClassImpl,
5    coroutine::{Coro, warn_deprecated_throw_signature},
6    frame::FrameRef,
7    function::OptionalArg,
8    object::{Traverse, TraverseFn},
9    protocol::PyIterReturn,
10    types::{Destructor, IterNext, Iterable, Representable, SelfIter},
11};
12use crossbeam_utils::atomic::AtomicCell;
13
14#[pyclass(module = false, name = "coroutine", traverse = "manual")]
15#[derive(Debug)]
16// PyCoro_Type in CPython
17pub struct PyCoroutine {
18    inner: Coro,
19}
20
21unsafe impl Traverse for PyCoroutine {
22    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
23        self.inner.traverse(tracer_fn);
24    }
25}
26
27impl PyPayload for PyCoroutine {
28    #[inline]
29    fn class(ctx: &Context) -> &'static Py<PyType> {
30        ctx.types.coroutine_type
31    }
32}
33
34#[pyclass(
35    flags(DISALLOW_INSTANTIATION, HAS_WEAKREF),
36    with(Py, IterNext, Representable, Destructor)
37)]
38impl PyCoroutine {
39    pub const fn as_coro(&self) -> &Coro {
40        &self.inner
41    }
42
43    pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
44        Self {
45            inner: Coro::new(frame, name, qualname),
46        }
47    }
48
49    #[pygetset]
50    fn __name__(&self) -> PyStrRef {
51        self.inner.name()
52    }
53
54    #[pygetset(setter)]
55    fn set___name__(&self, name: PyStrRef) {
56        self.inner.set_name(name)
57    }
58
59    #[pygetset]
60    fn __qualname__(&self) -> PyStrRef {
61        self.inner.qualname()
62    }
63
64    #[pygetset(setter)]
65    fn set___qualname__(&self, qualname: PyStrRef) {
66        self.inner.set_qualname(qualname)
67    }
68
69    #[pymethod(name = "__await__")]
70    fn r#await(zelf: PyRef<Self>) -> PyCoroutineWrapper {
71        PyCoroutineWrapper {
72            coro: zelf,
73            closed: AtomicCell::new(false),
74        }
75    }
76
77    #[pygetset]
78    fn cr_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
79        self.inner.frame().yield_from_target()
80    }
81    #[pygetset]
82    fn cr_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
83        if self.inner.closed() {
84            None
85        } else {
86            Some(self.inner.frame())
87        }
88    }
89    #[pygetset]
90    fn cr_running(&self, _vm: &VirtualMachine) -> bool {
91        self.inner.running()
92    }
93    #[pygetset]
94    fn cr_code(&self, _vm: &VirtualMachine) -> PyRef<PyCode> {
95        self.inner.frame().code.clone()
96    }
97    // TODO: coroutine origin tracking:
98    // https://docs.python.org/3/library/sys.html#sys.set_coroutine_origin_tracking_depth
99    #[pygetset]
100    const fn cr_origin(&self, _vm: &VirtualMachine) -> Option<(PyStrRef, usize, PyStrRef)> {
101        None
102    }
103
104    #[pyclassmethod]
105    fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
106        PyGenericAlias::from_args(cls, args, vm)
107    }
108}
109
110#[pyclass]
111impl Py<PyCoroutine> {
112    #[pymethod]
113    fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
114        self.inner.send(self.as_object(), value, vm)
115    }
116
117    #[pymethod]
118    fn throw(
119        &self,
120        exc_type: PyObjectRef,
121        exc_val: OptionalArg,
122        exc_tb: OptionalArg,
123        vm: &VirtualMachine,
124    ) -> PyResult<PyIterReturn> {
125        warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
126        self.inner.throw(
127            self.as_object(),
128            exc_type,
129            exc_val.unwrap_or_none(vm),
130            exc_tb.unwrap_or_none(vm),
131            vm,
132        )
133    }
134
135    #[pymethod]
136    fn close(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
137        self.inner.close(self.as_object(), vm)
138    }
139}
140
141impl Representable for PyCoroutine {
142    #[inline]
143    fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
144        Ok(zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm))
145    }
146}
147
148impl SelfIter for PyCoroutine {}
149impl IterNext for PyCoroutine {
150    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
151        zelf.send(vm.ctx.none(), vm)
152    }
153}
154
155impl Destructor for PyCoroutine {
156    fn del(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<()> {
157        if zelf.inner.closed() || zelf.inner.running() {
158            return Ok(());
159        }
160        if zelf.inner.frame().lasti() == 0 {
161            zelf.inner.closed.store(true);
162            return Ok(());
163        }
164        if let Err(e) = zelf.inner.close(zelf.as_object(), vm) {
165            vm.run_unraisable(e, None, zelf.as_object().to_owned());
166        }
167        Ok(())
168    }
169}
170
171#[pyclass(module = false, name = "coroutine_wrapper", traverse = "manual")]
172#[derive(Debug)]
173// PyCoroWrapper_Type in CPython
174pub struct PyCoroutineWrapper {
175    coro: PyRef<PyCoroutine>,
176    closed: AtomicCell<bool>,
177}
178
179unsafe impl Traverse for PyCoroutineWrapper {
180    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
181        self.coro.traverse(tracer_fn);
182    }
183}
184
185impl PyPayload for PyCoroutineWrapper {
186    #[inline]
187    fn class(ctx: &Context) -> &'static Py<PyType> {
188        ctx.types.coroutine_wrapper_type
189    }
190}
191
192#[pyclass(with(IterNext, Iterable))]
193impl PyCoroutineWrapper {
194    fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> {
195        if self.closed.load() {
196            return Err(vm.new_runtime_error("cannot reuse already awaited coroutine"));
197        }
198        Ok(())
199    }
200
201    #[pymethod]
202    fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
203        self.check_closed(vm)?;
204        let result = self.coro.send(val, vm);
205        // Mark as closed if exhausted
206        if let Ok(PyIterReturn::StopIteration(_)) = &result {
207            self.closed.store(true);
208        }
209        result
210    }
211
212    #[pymethod]
213    fn throw(
214        &self,
215        exc_type: PyObjectRef,
216        exc_val: OptionalArg,
217        exc_tb: OptionalArg,
218        vm: &VirtualMachine,
219    ) -> PyResult<PyIterReturn> {
220        self.check_closed(vm)?;
221        warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
222        let result = self.coro.throw(exc_type, exc_val, exc_tb, vm);
223        // Mark as closed if exhausted
224        if let Ok(PyIterReturn::StopIteration(_)) = &result {
225            self.closed.store(true);
226        }
227        result
228    }
229
230    #[pymethod]
231    fn close(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
232        self.closed.store(true);
233        self.coro.close(vm)
234    }
235}
236
237impl SelfIter for PyCoroutineWrapper {}
238impl IterNext for PyCoroutineWrapper {
239    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
240        Self::send(zelf, vm.ctx.none(), vm)
241    }
242}
243
244impl Drop for PyCoroutine {
245    fn drop(&mut self) {
246        self.inner.frame().clear_generator();
247    }
248}
249
250pub fn init(ctx: &'static Context) {
251    PyCoroutine::extend_class(ctx, ctx.types.coroutine_type);
252    PyCoroutineWrapper::extend_class(ctx, ctx.types.coroutine_wrapper_type);
253}