Skip to main content

rustpython_vm/builtins/
asyncgenerator.rs

1use super::{PyCode, PyGenerator, PyGenericAlias, PyStrRef, PyType, PyTypeRef};
2use crate::{
3    AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
4    builtins::PyBaseExceptionRef,
5    class::PyClassImpl,
6    common::lock::PyMutex,
7    coroutine::{Coro, warn_deprecated_throw_signature},
8    frame::FrameRef,
9    function::OptionalArg,
10    object::{Traverse, TraverseFn},
11    protocol::PyIterReturn,
12    types::{Destructor, IterNext, Iterable, Representable, SelfIter},
13};
14
15use crossbeam_utils::atomic::AtomicCell;
16
17#[pyclass(name = "async_generator", module = false, traverse = "manual")]
18#[derive(Debug)]
19pub struct PyAsyncGen {
20    inner: Coro,
21    running_async: AtomicCell<bool>,
22    // whether hooks have been initialized
23    ag_hooks_inited: AtomicCell<bool>,
24    // ag_origin_or_finalizer - stores the finalizer callback
25    ag_finalizer: PyMutex<Option<PyObjectRef>>,
26}
27
28unsafe impl Traverse for PyAsyncGen {
29    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
30        self.inner.traverse(tracer_fn);
31        self.ag_finalizer.traverse(tracer_fn);
32    }
33}
34type PyAsyncGenRef = PyRef<PyAsyncGen>;
35
36impl PyPayload for PyAsyncGen {
37    #[inline]
38    fn class(ctx: &Context) -> &'static Py<PyType> {
39        ctx.types.async_generator
40    }
41}
42
43#[pyclass(
44    flags(DISALLOW_INSTANTIATION, HAS_WEAKREF),
45    with(PyRef, Representable, Destructor)
46)]
47impl PyAsyncGen {
48    pub const fn as_coro(&self) -> &Coro {
49        &self.inner
50    }
51
52    pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
53        Self {
54            inner: Coro::new(frame, name, qualname),
55            running_async: AtomicCell::new(false),
56            ag_hooks_inited: AtomicCell::new(false),
57            ag_finalizer: PyMutex::new(None),
58        }
59    }
60
61    /// Initialize async generator hooks.
62    /// Returns Ok(()) if successful, Err if firstiter hook raised an exception.
63    fn init_hooks(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<()> {
64        // = async_gen_init_hooks
65        if zelf.ag_hooks_inited.load() {
66            return Ok(());
67        }
68
69        zelf.ag_hooks_inited.store(true);
70
71        // Get and store finalizer from VM
72        let finalizer = vm.async_gen_finalizer.borrow().clone();
73        if let Some(finalizer) = finalizer {
74            *zelf.ag_finalizer.lock() = Some(finalizer);
75        }
76
77        // Call firstiter hook
78        let firstiter = vm.async_gen_firstiter.borrow().clone();
79        if let Some(firstiter) = firstiter {
80            let obj: PyObjectRef = zelf.to_owned().into();
81            firstiter.call((obj,), vm)?;
82        }
83
84        Ok(())
85    }
86
87    /// Call finalizer hook if set.
88    fn call_finalizer(zelf: &Py<Self>, vm: &VirtualMachine) {
89        let finalizer = zelf.ag_finalizer.lock().clone();
90        if let Some(finalizer) = finalizer
91            && !zelf.inner.closed.load()
92        {
93            // Create a strong reference for the finalizer call.
94            // This keeps the object alive during the finalizer execution.
95            let obj: PyObjectRef = zelf.to_owned().into();
96
97            // Call the finalizer. Any exceptions are handled as unraisable.
98            if let Err(e) = finalizer.call((obj,), vm) {
99                vm.run_unraisable(e, Some("async generator finalizer".to_owned()), finalizer);
100            }
101        }
102    }
103
104    #[pygetset]
105    fn __name__(&self) -> PyStrRef {
106        self.inner.name()
107    }
108
109    #[pygetset(setter)]
110    fn set___name__(&self, name: PyStrRef) {
111        self.inner.set_name(name)
112    }
113
114    #[pygetset]
115    fn __qualname__(&self) -> PyStrRef {
116        self.inner.qualname()
117    }
118
119    #[pygetset(setter)]
120    fn set___qualname__(&self, qualname: PyStrRef) {
121        self.inner.set_qualname(qualname)
122    }
123
124    #[pygetset]
125    fn ag_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
126        self.inner.frame().yield_from_target()
127    }
128    #[pygetset]
129    fn ag_frame(&self, _vm: &VirtualMachine) -> Option<FrameRef> {
130        if self.inner.closed() {
131            None
132        } else {
133            Some(self.inner.frame())
134        }
135    }
136    #[pygetset]
137    fn ag_running(&self, _vm: &VirtualMachine) -> bool {
138        self.inner.running()
139    }
140    #[pygetset]
141    fn ag_code(&self, _vm: &VirtualMachine) -> PyRef<PyCode> {
142        self.inner.frame().code.clone()
143    }
144
145    #[pyclassmethod]
146    fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
147        PyGenericAlias::from_args(cls, args, vm)
148    }
149}
150
151#[pyclass]
152impl PyRef<PyAsyncGen> {
153    #[pymethod]
154    const fn __aiter__(self, _vm: &VirtualMachine) -> Self {
155        self
156    }
157
158    #[pymethod]
159    fn __anext__(self, vm: &VirtualMachine) -> PyResult<PyAsyncGenASend> {
160        PyAsyncGen::init_hooks(&self, vm)?;
161        Ok(PyAsyncGenASend {
162            ag: self,
163            state: AtomicCell::new(AwaitableState::Init),
164            value: vm.ctx.none(),
165        })
166    }
167
168    #[pymethod]
169    fn asend(self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyAsyncGenASend> {
170        PyAsyncGen::init_hooks(&self, vm)?;
171        Ok(PyAsyncGenASend {
172            ag: self,
173            state: AtomicCell::new(AwaitableState::Init),
174            value,
175        })
176    }
177
178    #[pymethod]
179    fn athrow(
180        self,
181        exc_type: PyObjectRef,
182        exc_val: OptionalArg,
183        exc_tb: OptionalArg,
184        vm: &VirtualMachine,
185    ) -> PyResult<PyAsyncGenAThrow> {
186        warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
187        PyAsyncGen::init_hooks(&self, vm)?;
188        Ok(PyAsyncGenAThrow {
189            ag: self,
190            aclose: false,
191            state: AtomicCell::new(AwaitableState::Init),
192            value: (
193                exc_type,
194                exc_val.unwrap_or_none(vm),
195                exc_tb.unwrap_or_none(vm),
196            ),
197        })
198    }
199
200    #[pymethod]
201    fn aclose(self, vm: &VirtualMachine) -> PyResult<PyAsyncGenAThrow> {
202        PyAsyncGen::init_hooks(&self, vm)?;
203        Ok(PyAsyncGenAThrow {
204            ag: self,
205            aclose: true,
206            state: AtomicCell::new(AwaitableState::Init),
207            value: (
208                vm.ctx.exceptions.generator_exit.to_owned().into(),
209                vm.ctx.none(),
210                vm.ctx.none(),
211            ),
212        })
213    }
214}
215
216impl Representable for PyAsyncGen {
217    #[inline]
218    fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
219        Ok(zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm))
220    }
221}
222
223#[pyclass(
224    module = false,
225    name = "async_generator_wrapped_value",
226    traverse = "manual"
227)]
228#[derive(Debug)]
229pub(crate) struct PyAsyncGenWrappedValue(pub PyObjectRef);
230
231unsafe impl Traverse for PyAsyncGenWrappedValue {
232    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
233        self.0.traverse(tracer_fn);
234    }
235}
236
237impl PyPayload for PyAsyncGenWrappedValue {
238    #[inline]
239    fn class(ctx: &Context) -> &'static Py<PyType> {
240        ctx.types.async_generator_wrapped_value
241    }
242}
243
244#[pyclass]
245impl PyAsyncGenWrappedValue {}
246
247impl PyAsyncGenWrappedValue {
248    fn unbox(ag: &PyAsyncGen, val: PyResult<PyIterReturn>, vm: &VirtualMachine) -> PyResult {
249        let (closed, async_done) = match &val {
250            Ok(PyIterReturn::StopIteration(_)) => (true, true),
251            Err(e) if e.fast_isinstance(vm.ctx.exceptions.generator_exit) => (true, true),
252            Err(_) => (false, true),
253            _ => (false, false),
254        };
255        if closed {
256            ag.inner.closed.store(true);
257        }
258        if async_done {
259            ag.running_async.store(false);
260        }
261        let val = val?.into_async_pyresult(vm)?;
262        match_class!(match val {
263            val @ Self => {
264                ag.running_async.store(false);
265                Err(vm.new_stop_iteration(Some(val.0.clone())))
266            }
267            val => Ok(val),
268        })
269    }
270}
271
272#[derive(Debug, Clone, Copy)]
273enum AwaitableState {
274    Init,
275    Iter,
276    Closed,
277}
278
279#[pyclass(module = false, name = "async_generator_asend", traverse = "manual")]
280#[derive(Debug)]
281pub(crate) struct PyAsyncGenASend {
282    ag: PyAsyncGenRef,
283    state: AtomicCell<AwaitableState>,
284    value: PyObjectRef,
285}
286
287unsafe impl Traverse for PyAsyncGenASend {
288    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
289        self.ag.traverse(tracer_fn);
290        self.value.traverse(tracer_fn);
291    }
292}
293
294impl PyPayload for PyAsyncGenASend {
295    #[inline]
296    fn class(ctx: &Context) -> &'static Py<PyType> {
297        ctx.types.async_generator_asend
298    }
299}
300
301#[pyclass(with(IterNext, Iterable))]
302impl PyAsyncGenASend {
303    #[pymethod(name = "__await__")]
304    const fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
305        zelf
306    }
307
308    #[pymethod]
309    fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
310        let val = match self.state.load() {
311            AwaitableState::Closed => {
312                return Err(
313                    vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()")
314                );
315            }
316            AwaitableState::Iter => val, // already running, all good
317            AwaitableState::Init => {
318                if self.ag.running_async.load() {
319                    return Err(
320                        vm.new_runtime_error("anext(): asynchronous generator is already running")
321                    );
322                }
323                self.ag.running_async.store(true);
324                self.state.store(AwaitableState::Iter);
325                if vm.is_none(&val) {
326                    self.value.clone()
327                } else {
328                    val
329                }
330            }
331        };
332        let res = self.ag.inner.send(self.ag.as_object(), val, vm);
333        let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
334        if res.is_err() {
335            self.set_closed();
336        }
337        res
338    }
339
340    #[pymethod]
341    fn throw(
342        &self,
343        exc_type: PyObjectRef,
344        exc_val: OptionalArg,
345        exc_tb: OptionalArg,
346        vm: &VirtualMachine,
347    ) -> PyResult {
348        match self.state.load() {
349            AwaitableState::Closed => {
350                return Err(
351                    vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()")
352                );
353            }
354            AwaitableState::Init => {
355                if self.ag.running_async.load() {
356                    self.state.store(AwaitableState::Closed);
357                    return Err(
358                        vm.new_runtime_error("anext(): asynchronous generator is already running")
359                    );
360                }
361                self.ag.running_async.store(true);
362                self.state.store(AwaitableState::Iter);
363            }
364            AwaitableState::Iter => {}
365        }
366
367        warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
368        let res = self.ag.inner.throw(
369            self.ag.as_object(),
370            exc_type,
371            exc_val.unwrap_or_none(vm),
372            exc_tb.unwrap_or_none(vm),
373            vm,
374        );
375        let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
376        if res.is_err() {
377            self.set_closed();
378        }
379        res
380    }
381
382    #[pymethod]
383    fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
384        if matches!(self.state.load(), AwaitableState::Closed) {
385            return Ok(());
386        }
387        let result = self.throw(
388            vm.ctx.exceptions.generator_exit.to_owned().into(),
389            OptionalArg::Missing,
390            OptionalArg::Missing,
391            vm,
392        );
393        match result {
394            Ok(_) => Err(vm.new_runtime_error("coroutine ignored GeneratorExit")),
395            Err(e)
396                if e.fast_isinstance(vm.ctx.exceptions.stop_iteration)
397                    || e.fast_isinstance(vm.ctx.exceptions.stop_async_iteration)
398                    || e.fast_isinstance(vm.ctx.exceptions.generator_exit) =>
399            {
400                Ok(())
401            }
402            Err(e) => Err(e),
403        }
404    }
405
406    fn set_closed(&self) {
407        self.state.store(AwaitableState::Closed);
408    }
409}
410
411impl SelfIter for PyAsyncGenASend {}
412impl IterNext for PyAsyncGenASend {
413    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
414        PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
415    }
416}
417
418#[pyclass(module = false, name = "async_generator_athrow", traverse = "manual")]
419#[derive(Debug)]
420pub(crate) struct PyAsyncGenAThrow {
421    ag: PyAsyncGenRef,
422    aclose: bool,
423    state: AtomicCell<AwaitableState>,
424    value: (PyObjectRef, PyObjectRef, PyObjectRef),
425}
426
427unsafe impl Traverse for PyAsyncGenAThrow {
428    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
429        self.ag.traverse(tracer_fn);
430        self.value.traverse(tracer_fn);
431    }
432}
433
434impl PyPayload for PyAsyncGenAThrow {
435    #[inline]
436    fn class(ctx: &Context) -> &'static Py<PyType> {
437        ctx.types.async_generator_athrow
438    }
439}
440
441#[pyclass(with(IterNext, Iterable))]
442impl PyAsyncGenAThrow {
443    #[pymethod(name = "__await__")]
444    const fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
445        zelf
446    }
447
448    #[pymethod]
449    fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
450        match self.state.load() {
451            AwaitableState::Closed => {
452                Err(vm.new_runtime_error("cannot reuse already awaited aclose()/athrow()"))
453            }
454            AwaitableState::Init => {
455                if self.ag.running_async.load() {
456                    self.state.store(AwaitableState::Closed);
457                    let msg = if self.aclose {
458                        "aclose(): asynchronous generator is already running"
459                    } else {
460                        "athrow(): asynchronous generator is already running"
461                    };
462                    return Err(vm.new_runtime_error(msg.to_owned()));
463                }
464                if self.ag.inner.closed() {
465                    self.state.store(AwaitableState::Closed);
466                    return Err(vm.new_stop_iteration(None));
467                }
468                if !vm.is_none(&val) {
469                    return Err(vm.new_runtime_error(
470                        "can't send non-None value to a just-started async generator",
471                    ));
472                }
473                self.state.store(AwaitableState::Iter);
474                self.ag.running_async.store(true);
475
476                let (ty, val, tb) = self.value.clone();
477                let ret = self.ag.inner.throw(self.ag.as_object(), ty, val, tb, vm);
478                let ret = if self.aclose {
479                    if self.ignored_close(&ret) {
480                        Err(self.yield_close(vm))
481                    } else {
482                        ret.and_then(|o| o.into_async_pyresult(vm))
483                    }
484                } else {
485                    PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
486                };
487                ret.map_err(|e| self.check_error(e, vm))
488            }
489            AwaitableState::Iter => {
490                let ret = self.ag.inner.send(self.ag.as_object(), val, vm);
491                if self.aclose {
492                    match ret {
493                        Ok(PyIterReturn::Return(v))
494                            if v.downcastable::<PyAsyncGenWrappedValue>() =>
495                        {
496                            Err(self.yield_close(vm))
497                        }
498                        other => other
499                            .and_then(|o| o.into_async_pyresult(vm))
500                            .map_err(|e| self.check_error(e, vm)),
501                    }
502                } else {
503                    PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
504                }
505            }
506        }
507    }
508
509    #[pymethod]
510    fn throw(
511        &self,
512        exc_type: PyObjectRef,
513        exc_val: OptionalArg,
514        exc_tb: OptionalArg,
515        vm: &VirtualMachine,
516    ) -> PyResult {
517        match self.state.load() {
518            AwaitableState::Closed => {
519                return Err(vm.new_runtime_error("cannot reuse already awaited aclose()/athrow()"));
520            }
521            AwaitableState::Init => {
522                if self.ag.running_async.load() {
523                    self.state.store(AwaitableState::Closed);
524                    let msg = if self.aclose {
525                        "aclose(): asynchronous generator is already running"
526                    } else {
527                        "athrow(): asynchronous generator is already running"
528                    };
529                    return Err(vm.new_runtime_error(msg.to_owned()));
530                }
531                if self.ag.inner.closed() {
532                    self.state.store(AwaitableState::Closed);
533                    return Err(vm.new_stop_iteration(None));
534                }
535                self.ag.running_async.store(true);
536                self.state.store(AwaitableState::Iter);
537            }
538            AwaitableState::Iter => {}
539        }
540
541        warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
542        let ret = self.ag.inner.throw(
543            self.ag.as_object(),
544            exc_type,
545            exc_val.unwrap_or_none(vm),
546            exc_tb.unwrap_or_none(vm),
547            vm,
548        );
549        let res = if self.aclose {
550            if self.ignored_close(&ret) {
551                Err(self.yield_close(vm))
552            } else {
553                ret.and_then(|o| o.into_async_pyresult(vm))
554            }
555        } else {
556            PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
557        };
558        res.map_err(|e| self.check_error(e, vm))
559    }
560
561    #[pymethod]
562    fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
563        if matches!(self.state.load(), AwaitableState::Closed) {
564            return Ok(());
565        }
566        let result = self.throw(
567            vm.ctx.exceptions.generator_exit.to_owned().into(),
568            OptionalArg::Missing,
569            OptionalArg::Missing,
570            vm,
571        );
572        match result {
573            Ok(_) => Err(vm.new_runtime_error("coroutine ignored GeneratorExit")),
574            Err(e)
575                if e.fast_isinstance(vm.ctx.exceptions.stop_iteration)
576                    || e.fast_isinstance(vm.ctx.exceptions.stop_async_iteration)
577                    || e.fast_isinstance(vm.ctx.exceptions.generator_exit) =>
578            {
579                Ok(())
580            }
581            Err(e) => Err(e),
582        }
583    }
584
585    fn ignored_close(&self, res: &PyResult<PyIterReturn>) -> bool {
586        res.as_ref().is_ok_and(|v| match v {
587            PyIterReturn::Return(obj) => obj.downcastable::<PyAsyncGenWrappedValue>(),
588            PyIterReturn::StopIteration(_) => false,
589        })
590    }
591    fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
592        self.ag.running_async.store(false);
593        self.ag.inner.closed.store(true);
594        self.state.store(AwaitableState::Closed);
595        vm.new_runtime_error("async generator ignored GeneratorExit")
596    }
597    fn check_error(&self, exc: PyBaseExceptionRef, vm: &VirtualMachine) -> PyBaseExceptionRef {
598        self.ag.running_async.store(false);
599        self.ag.inner.closed.store(true);
600        self.state.store(AwaitableState::Closed);
601        if self.aclose
602            && (exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration)
603                || exc.fast_isinstance(vm.ctx.exceptions.generator_exit))
604        {
605            vm.new_stop_iteration(None)
606        } else {
607            exc
608        }
609    }
610}
611
612impl SelfIter for PyAsyncGenAThrow {}
613impl IterNext for PyAsyncGenAThrow {
614    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
615        PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
616    }
617}
618
619/// Awaitable wrapper for anext() builtin with default value.
620/// When StopAsyncIteration is raised, it converts it to StopIteration(default).
621#[pyclass(module = false, name = "anext_awaitable", traverse = "manual")]
622#[derive(Debug)]
623pub struct PyAnextAwaitable {
624    wrapped: PyObjectRef,
625    default_value: PyObjectRef,
626    state: AtomicCell<AwaitableState>,
627}
628
629unsafe impl Traverse for PyAnextAwaitable {
630    fn traverse(&self, tracer_fn: &mut TraverseFn<'_>) {
631        self.wrapped.traverse(tracer_fn);
632        self.default_value.traverse(tracer_fn);
633    }
634}
635
636impl PyPayload for PyAnextAwaitable {
637    #[inline]
638    fn class(ctx: &Context) -> &'static Py<PyType> {
639        ctx.types.anext_awaitable
640    }
641}
642
643#[pyclass(with(IterNext, Iterable))]
644impl PyAnextAwaitable {
645    pub fn new(wrapped: PyObjectRef, default_value: PyObjectRef) -> Self {
646        Self {
647            wrapped,
648            default_value,
649            state: AtomicCell::new(AwaitableState::Init),
650        }
651    }
652
653    #[pymethod(name = "__await__")]
654    fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
655        zelf
656    }
657
658    fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> {
659        if let AwaitableState::Closed = self.state.load() {
660            return Err(vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()"));
661        }
662        Ok(())
663    }
664
665    /// Get the awaitable iterator from wrapped object.
666    // = anextawaitable_getiter.
667    fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult {
668        use crate::builtins::PyCoroutine;
669        use crate::protocol::PyIter;
670
671        let wrapped = &self.wrapped;
672
673        // If wrapped is already an async_generator_asend, it's an iterator
674        if wrapped.class().is(vm.ctx.types.async_generator_asend)
675            || wrapped.class().is(vm.ctx.types.async_generator_athrow)
676        {
677            return Ok(wrapped.clone());
678        }
679
680        // _PyCoro_GetAwaitableIter equivalent
681        let awaitable = if wrapped.class().is(vm.ctx.types.coroutine_type) {
682            // Coroutine - get __await__ later
683            wrapped.clone()
684        } else {
685            // Check for generator with CO_ITERABLE_COROUTINE flag
686            if let Some(generator) = wrapped.downcast_ref::<PyGenerator>()
687                && generator
688                    .as_coro()
689                    .frame()
690                    .code
691                    .flags
692                    .contains(crate::bytecode::CodeFlags::ITERABLE_COROUTINE)
693            {
694                // Return the generator itself as the iterator
695                return Ok(wrapped.clone());
696            }
697            // Try to get __await__ method
698            if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) {
699                await_method?.call((), vm)?
700            } else {
701                return Err(vm.new_type_error(format!(
702                    "'{}' object can't be awaited",
703                    wrapped.class().name()
704                )));
705            }
706        };
707
708        // If awaitable is a coroutine, get its __await__
709        if awaitable.class().is(vm.ctx.types.coroutine_type) {
710            let coro_await = vm.call_method(&awaitable, "__await__", ())?;
711            // Check that __await__ returned an iterator
712            if !PyIter::check(&coro_await) {
713                return Err(vm.new_type_error("__await__ returned a non-iterable"));
714            }
715            return Ok(coro_await);
716        }
717
718        // Check the result is an iterator, not a coroutine
719        if awaitable.downcast_ref::<PyCoroutine>().is_some() {
720            return Err(vm.new_type_error("__await__() returned a coroutine"));
721        }
722
723        // Check that the result is an iterator
724        if !PyIter::check(&awaitable) {
725            return Err(vm.new_type_error(format!(
726                "__await__() returned non-iterator of type '{}'",
727                awaitable.class().name()
728            )));
729        }
730
731        Ok(awaitable)
732    }
733
734    #[pymethod]
735    fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
736        self.check_closed(vm)?;
737        self.state.store(AwaitableState::Iter);
738        let awaitable = self.get_awaitable_iter(vm)?;
739        let result = vm.call_method(&awaitable, "send", (val,));
740        self.handle_result(result, vm)
741    }
742
743    #[pymethod]
744    fn throw(
745        &self,
746        exc_type: PyObjectRef,
747        exc_val: OptionalArg,
748        exc_tb: OptionalArg,
749        vm: &VirtualMachine,
750    ) -> PyResult {
751        self.check_closed(vm)?;
752        warn_deprecated_throw_signature(&exc_val, &exc_tb, vm)?;
753        self.state.store(AwaitableState::Iter);
754        let awaitable = self.get_awaitable_iter(vm)?;
755        let result = vm.call_method(
756            &awaitable,
757            "throw",
758            (
759                exc_type,
760                exc_val.unwrap_or_none(vm),
761                exc_tb.unwrap_or_none(vm),
762            ),
763        );
764        self.handle_result(result, vm)
765    }
766
767    #[pymethod]
768    fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
769        self.state.store(AwaitableState::Closed);
770        if let Ok(awaitable) = self.get_awaitable_iter(vm) {
771            let _ = vm.call_method(&awaitable, "close", ());
772        }
773        Ok(())
774    }
775
776    /// Convert StopAsyncIteration to StopIteration(default_value)
777    fn handle_result(&self, result: PyResult, vm: &VirtualMachine) -> PyResult {
778        match result {
779            Ok(value) => Ok(value),
780            Err(exc) if exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) => {
781                Err(vm.new_stop_iteration(Some(self.default_value.clone())))
782            }
783            Err(exc) => Err(exc),
784        }
785    }
786}
787
788impl SelfIter for PyAnextAwaitable {}
789impl IterNext for PyAnextAwaitable {
790    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
791        PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
792    }
793}
794
795/// _PyGen_Finalize for async generators
796impl Destructor for PyAsyncGen {
797    fn del(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<()> {
798        // Generator is already closed, nothing to do
799        if zelf.inner.closed.load() {
800            return Ok(());
801        }
802
803        // Call the async generator finalizer hook if set.
804        Self::call_finalizer(zelf, vm);
805
806        Ok(())
807    }
808}
809
810impl Drop for PyAsyncGen {
811    fn drop(&mut self) {
812        self.inner.frame().clear_generator();
813    }
814}
815
816pub fn init(ctx: &'static Context) {
817    PyAsyncGen::extend_class(ctx, ctx.types.async_generator);
818    PyAsyncGenASend::extend_class(ctx, ctx.types.async_generator_asend);
819    PyAsyncGenAThrow::extend_class(ctx, ctx.types.async_generator_athrow);
820    PyAnextAwaitable::extend_class(ctx, ctx.types.anext_awaitable);
821}