use super::{PyInt, PyTupleRef, PyType};
use crate::{
class::PyClassImpl,
function::ArgCallable,
object::{Traverse, TraverseFn},
protocol::{PyIterReturn, PySequence, PySequenceMethods},
types::{IterNext, Iterable, SelfIter},
Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
};
use rustpython_common::{
lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard},
static_cell,
};
#[derive(Debug, Clone)]
pub enum IterStatus<T> {
Active(T),
Exhausted,
}
unsafe impl<T: Traverse> Traverse for IterStatus<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
match self {
IterStatus::Active(ref r) => r.traverse(tracer_fn),
IterStatus::Exhausted => (),
}
}
}
#[derive(Debug)]
pub struct PositionIterInternal<T> {
pub status: IterStatus<T>,
pub position: usize,
}
unsafe impl<T: Traverse> Traverse for PositionIterInternal<T> {
fn traverse(&self, tracer_fn: &mut TraverseFn) {
self.status.traverse(tracer_fn)
}
}
impl<T> PositionIterInternal<T> {
pub fn new(obj: T, position: usize) -> Self {
Self {
status: IterStatus::Active(obj),
position,
}
}
pub fn set_state<F>(&mut self, state: PyObjectRef, f: F, vm: &VirtualMachine) -> PyResult<()>
where
F: FnOnce(&T, usize) -> usize,
{
if let IterStatus::Active(obj) = &self.status {
if let Some(i) = state.payload::<PyInt>() {
let i = i.try_to_primitive(vm).unwrap_or(0);
self.position = f(obj, i);
Ok(())
} else {
Err(vm.new_type_error("an integer is required.".to_owned()))
}
} else {
Ok(())
}
}
fn _reduce<F>(&self, func: PyObjectRef, f: F, vm: &VirtualMachine) -> PyTupleRef
where
F: FnOnce(&T) -> PyObjectRef,
{
if let IterStatus::Active(obj) = &self.status {
vm.new_tuple((func, (f(obj),), self.position))
} else {
vm.new_tuple((func, (vm.ctx.new_list(Vec::new()),)))
}
}
pub fn builtins_iter_reduce<F>(&self, f: F, vm: &VirtualMachine) -> PyTupleRef
where
F: FnOnce(&T) -> PyObjectRef,
{
let iter = builtins_iter(vm).to_owned();
self._reduce(iter, f, vm)
}
pub fn builtins_reversed_reduce<F>(&self, f: F, vm: &VirtualMachine) -> PyTupleRef
where
F: FnOnce(&T) -> PyObjectRef,
{
let reversed = builtins_reversed(vm).to_owned();
self._reduce(reversed, f, vm)
}
fn _next<F, OP>(&mut self, f: F, op: OP) -> PyResult<PyIterReturn>
where
F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
OP: FnOnce(&mut Self),
{
if let IterStatus::Active(obj) = &self.status {
let ret = f(obj, self.position);
if let Ok(PyIterReturn::Return(_)) = ret {
op(self);
} else {
self.status = IterStatus::Exhausted;
}
ret
} else {
Ok(PyIterReturn::StopIteration(None))
}
}
pub fn next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
where
F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
{
self._next(f, |zelf| zelf.position += 1)
}
pub fn rev_next<F>(&mut self, f: F) -> PyResult<PyIterReturn>
where
F: FnOnce(&T, usize) -> PyResult<PyIterReturn>,
{
self._next(f, |zelf| {
if zelf.position == 0 {
zelf.status = IterStatus::Exhausted;
} else {
zelf.position -= 1;
}
})
}
pub fn length_hint<F>(&self, f: F) -> usize
where
F: FnOnce(&T) -> usize,
{
if let IterStatus::Active(obj) = &self.status {
f(obj).saturating_sub(self.position)
} else {
0
}
}
pub fn rev_length_hint<F>(&self, f: F) -> usize
where
F: FnOnce(&T) -> usize,
{
if let IterStatus::Active(obj) = &self.status {
if self.position <= f(obj) {
return self.position + 1;
}
}
0
}
}
pub fn builtins_iter(vm: &VirtualMachine) -> &PyObject {
static_cell! {
static INSTANCE: PyObjectRef;
}
INSTANCE.get_or_init(|| vm.builtins.get_attr("iter", vm).unwrap())
}
pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObject {
static_cell! {
static INSTANCE: PyObjectRef;
}
INSTANCE.get_or_init(|| vm.builtins.get_attr("reversed", vm).unwrap())
}
#[pyclass(module = false, name = "iterator", traverse)]
#[derive(Debug)]
pub struct PySequenceIterator {
#[pytraverse(skip)]
seq_methods: &'static PySequenceMethods,
internal: PyMutex<PositionIterInternal<PyObjectRef>>,
}
impl PyPayload for PySequenceIterator {
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.iter_type
}
}
#[pyclass(with(IterNext, Iterable))]
impl PySequenceIterator {
pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
let seq = PySequence::try_protocol(&obj, vm)?;
Ok(Self {
seq_methods: seq.methods,
internal: PyMutex::new(PositionIterInternal::new(obj, 0)),
})
}
#[pymethod(magic)]
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
let internal = self.internal.lock();
if let IterStatus::Active(obj) = &internal.status {
let seq = PySequence {
obj,
methods: self.seq_methods,
};
seq.length(vm)
.map(|x| PyInt::from(x).into_pyobject(vm))
.unwrap_or_else(|_| vm.ctx.not_implemented())
} else {
PyInt::from(0).into_pyobject(vm)
}
}
#[pymethod(magic)]
fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
self.internal.lock().builtins_iter_reduce(|x| x.clone(), vm)
}
#[pymethod(magic)]
fn setstate(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.internal.lock().set_state(state, |_, pos| pos, vm)
}
}
impl SelfIter for PySequenceIterator {}
impl IterNext for PySequenceIterator {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
zelf.internal.lock().next(|obj, pos| {
let seq = PySequence {
obj,
methods: zelf.seq_methods,
};
PyIterReturn::from_getitem_result(seq.get_item(pos as isize, vm), vm)
})
}
}
#[pyclass(module = false, name = "callable_iterator", traverse)]
#[derive(Debug)]
pub struct PyCallableIterator {
sentinel: PyObjectRef,
status: PyRwLock<IterStatus<ArgCallable>>,
}
impl PyPayload for PyCallableIterator {
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.callable_iterator
}
}
#[pyclass(with(IterNext, Iterable))]
impl PyCallableIterator {
pub fn new(callable: ArgCallable, sentinel: PyObjectRef) -> Self {
Self {
sentinel,
status: PyRwLock::new(IterStatus::Active(callable)),
}
}
}
impl SelfIter for PyCallableIterator {}
impl IterNext for PyCallableIterator {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
let status = zelf.status.upgradable_read();
let next = if let IterStatus::Active(callable) = &*status {
let ret = callable.invoke((), vm)?;
if vm.bool_eq(&ret, &zelf.sentinel)? {
*PyRwLockUpgradableReadGuard::upgrade(status) = IterStatus::Exhausted;
PyIterReturn::StopIteration(None)
} else {
PyIterReturn::Return(ret)
}
} else {
PyIterReturn::StopIteration(None)
};
Ok(next)
}
}
pub fn init(context: &Context) {
PySequenceIterator::extend_class(context, context.types.iter_type);
PyCallableIterator::extend_class(context, context.types.callable_iterator);
}