1use super::{PyCode, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
2use crate::{
3 AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
4 class::PyClassImpl,
5 coroutine::{Coro, warn_deprecated_throw_signature},
6 frame::FrameRef,
7 function::OptionalArg,
8 object::{Traverse, TraverseFn},
9 protocol::PyIterReturn,
10 types::{Destructor, IterNext, Iterable, Representable, SelfIter},
11};
12use crossbeam_utils::atomic::AtomicCell;
13
14#[pyclass(module = false, name = "coroutine", traverse = "manual")]
15#[derive(Debug)]
16pub struct PyCoroutine {
18 inner: Coro,
19}
20
21unsafe impl Traverse for PyCoroutine {
22 fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
23 self.inner.traverse(tracer_fn);
24 }
25}
26
27impl PyPayload for PyCoroutine {
28 #[inline]
29 fn class(ctx: &Context) -> &'static Py<PyType> {
30 ctx.types.coroutine_type
31 }
32}
33
34#[pyclass(
35 flags(DISALLOW_INSTANTIATION, HAS_WEAKREF),
36 with(Py, IterNext, Representable, Destructor)
37)]
38impl PyCoroutine {
39 pub const fn as_coro(&self) -> &Coro {
40 &self.inner
41 }
42
43 pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
44 Self {
45 inner: Coro::new(frame, name, qualname),
46 }
47 }
48
49 #[pygetset]
50 fn __name__(&self) -> PyStrRef {
51 self.inner.name()
52 }
53
54 #[pygetset(setter)]
55 fn set___name__(&self, name: PyStrRef) {
56 self.inner.set_name(name)
57 }
58
59 #[pygetset]
60 fn __qualname__(&self) -> PyStrRef {
61 self.inner.qualname()
62 }
63
64 #[pygetset(setter)]
65 fn set___qualname__(&self, qualname: PyStrRef) {
66 self.inner.set_qualname(qualname)
67 }
68
69 #[pymethod(name = "__await__")]
70 fn r#await(zelf: PyRef<Self>) -> PyCoroutineWrapper {
71 PyCoroutineWrapper {
72 coro: zelf,
73 closed: AtomicCell::new(false),
74 }
75 }
76
77 #[pygetset]
78 fn cr_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
79 self.inner.frame().yield_from_target()
80 }
81 #[pygetset]
82 fn cr_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
83 if self.inner.closed() {
84 None
85 } else {
86 Some(self.inner.frame())
87 }
88 }
89 #[pygetset]
90 fn cr_running(&self, _vm: &VirtualMachine) -> bool {
91 self.inner.running()
92 }
93 #[pygetset]
94 fn cr_code(&self, _vm: &VirtualMachine) -> PyRef<PyCode> {
95 self.inner.frame().code.clone()
96 }
97 #[pygetset]
100 const fn cr_origin(&self, _vm: &VirtualMachine) -> Option<(PyStrRef, usize, PyStrRef)> {
101 None
102 }
103
104 #[pyclassmethod]
105 fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
106 PyGenericAlias::from_args(cls, args, vm)
107 }
108}
109
110#[pyclass]
111impl Py<PyCoroutine> {
112 #[pymethod]
113 fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
114 self.inner.send(self.as_object(), value, vm)
115 }
116
117 #[pymethod]
118 fn throw(
119 &self,
120 exc_type: PyObjectRef,
121 exc_val: OptionalArg,
122 exc_tb: OptionalArg,
123 vm: &VirtualMachine,
124 ) -> PyResult<PyIterReturn> {
125 warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
126 self.inner.throw(
127 self.as_object(),
128 exc_type,
129 exc_val.unwrap_or_none(vm),
130 exc_tb.unwrap_or_none(vm),
131 vm,
132 )
133 }
134
135 #[pymethod]
136 fn close(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
137 self.inner.close(self.as_object(), vm)
138 }
139}
140
141impl Representable for PyCoroutine {
142 #[inline]
143 fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
144 Ok(zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm))
145 }
146}
147
148impl SelfIter for PyCoroutine {}
149impl IterNext for PyCoroutine {
150 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
151 zelf.send(vm.ctx.none(), vm)
152 }
153}
154
155impl Destructor for PyCoroutine {
156 fn del(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<()> {
157 if zelf.inner.closed() || zelf.inner.running() {
158 return Ok(());
159 }
160 if zelf.inner.frame().lasti() == 0 {
161 zelf.inner.closed.store(true);
162 return Ok(());
163 }
164 if let Err(e) = zelf.inner.close(zelf.as_object(), vm) {
165 vm.run_unraisable(e, None, zelf.as_object().to_owned());
166 }
167 Ok(())
168 }
169}
170
171#[pyclass(module = false, name = "coroutine_wrapper", traverse = "manual")]
172#[derive(Debug)]
173pub struct PyCoroutineWrapper {
175 coro: PyRef<PyCoroutine>,
176 closed: AtomicCell<bool>,
177}
178
179unsafe impl Traverse for PyCoroutineWrapper {
180 fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
181 self.coro.traverse(tracer_fn);
182 }
183}
184
185impl PyPayload for PyCoroutineWrapper {
186 #[inline]
187 fn class(ctx: &Context) -> &'static Py<PyType> {
188 ctx.types.coroutine_wrapper_type
189 }
190}
191
192#[pyclass(with(IterNext, Iterable))]
193impl PyCoroutineWrapper {
194 fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> {
195 if self.closed.load() {
196 return Err(vm.new_runtime_error("cannot reuse already awaited coroutine"));
197 }
198 Ok(())
199 }
200
201 #[pymethod]
202 fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
203 self.check_closed(vm)?;
204 let result = self.coro.send(val, vm);
205 if let Ok(PyIterReturn::StopIteration(_)) = &result {
207 self.closed.store(true);
208 }
209 result
210 }
211
212 #[pymethod]
213 fn throw(
214 &self,
215 exc_type: PyObjectRef,
216 exc_val: OptionalArg,
217 exc_tb: OptionalArg,
218 vm: &VirtualMachine,
219 ) -> PyResult<PyIterReturn> {
220 self.check_closed(vm)?;
221 warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
222 let result = self.coro.throw(exc_type, exc_val, exc_tb, vm);
223 if let Ok(PyIterReturn::StopIteration(_)) = &result {
225 self.closed.store(true);
226 }
227 result
228 }
229
230 #[pymethod]
231 fn close(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
232 self.closed.store(true);
233 self.coro.close(vm)
234 }
235}
236
237impl SelfIter for PyCoroutineWrapper {}
238impl IterNext for PyCoroutineWrapper {
239 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
240 Self::send(zelf, vm.ctx.none(), vm)
241 }
242}
243
244impl Drop for PyCoroutine {
245 fn drop(&mut self) {
246 self.inner.frame().clear_generator();
247 }
248}
249
250pub fn init(ctx: &'static Context) {
251 PyCoroutine::extend_class(ctx, ctx.types.coroutine_type);
252 PyCoroutineWrapper::extend_class(ctx, ctx.types.coroutine_wrapper_type);
253}