Skip to main content

rustpython_vm/builtins/
memory.rs

1use super::{
2    PositionIterInternal, PyBytes, PyBytesRef, PyGenericAlias, PyInt, PyListRef, PySlice, PyStr,
3    PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef, iter::builtins_iter,
4};
5use crate::common::lock::LazyLock;
6use crate::{
7    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
8    TryFromBorrowedObject, TryFromObject, VirtualMachine, atomic_func,
9    buffer::FormatSpec,
10    bytes_inner::bytes_to_hex,
11    class::PyClassImpl,
12    common::{
13        borrow::{BorrowedValue, BorrowedValueMut},
14        hash::PyHash,
15        lock::OnceCell,
16    },
17    convert::ToPyObject,
18    function::Either,
19    function::{FuncArgs, OptionalArg, PyComparisonValue},
20    protocol::{
21        BufferDescriptor, BufferMethods, PyBuffer, PyIterReturn, PyMappingMethods,
22        PySequenceMethods, VecBuffer,
23    },
24    sliceable::SequenceIndexOp,
25    types::{
26        AsBuffer, AsMapping, AsSequence, Comparable, Constructor, Hashable, IterNext, Iterable,
27        PyComparisonOp, Representable, SelfIter,
28    },
29};
30use core::{cmp::Ordering, fmt::Debug, mem::ManuallyDrop, ops::Range};
31use crossbeam_utils::atomic::AtomicCell;
32use itertools::Itertools;
33use rustpython_common::lock::PyMutex;
34
35#[derive(FromArgs)]
36pub struct PyMemoryViewNewArgs {
37    object: PyObjectRef,
38}
39
40#[pyclass(module = false, name = "memoryview")]
41#[derive(Debug)]
42pub struct PyMemoryView {
43    // avoid double release when memoryview had released the buffer before drop
44    buffer: ManuallyDrop<PyBuffer>,
45    // the released memoryview does not mean the buffer is destroyed
46    // because the possible another memoryview is viewing from it
47    released: AtomicCell<bool>,
48    // start does NOT mean the bytes before start will not be visited,
49    // it means the point we starting to get the absolute position via
50    // the needle
51    start: usize,
52    format_spec: FormatSpec,
53    // memoryview's options could be different from buffer's options
54    desc: BufferDescriptor,
55    hash: OnceCell<PyHash>,
56    // exports
57    // memoryview has no exports count by itself
58    // instead it relay on the buffer it viewing to maintain the count
59}
60
61impl Constructor for PyMemoryView {
62    type Args = PyMemoryViewNewArgs;
63
64    fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
65        Self::from_object(&args.object, vm)
66    }
67}
68
69impl PyMemoryView {
70    fn parse_format(format: &str, vm: &VirtualMachine) -> PyResult<FormatSpec> {
71        FormatSpec::parse(format.as_bytes(), vm)
72    }
73
74    /// this should be the main entrance to create the memoryview
75    /// to avoid the chained memoryview
76    pub fn from_object(obj: &PyObject, vm: &VirtualMachine) -> PyResult<Self> {
77        if let Some(other) = obj.downcast_ref::<Self>() {
78            Ok(other.new_view())
79        } else {
80            let buffer = PyBuffer::try_from_borrowed_object(vm, obj)?;
81            Self::from_buffer(buffer, vm)
82        }
83    }
84
85    /// don't use this function to create the memoryview if the buffer is exporting
86    /// via another memoryview, use PyMemoryView::new_view() or PyMemoryView::from_object
87    /// to reduce the chain
88    pub fn from_buffer(buffer: PyBuffer, vm: &VirtualMachine) -> PyResult<Self> {
89        // when we get a buffer means the buffered object is size locked
90        // so we can assume the buffer's options will never change as long
91        // as memoryview is still alive
92        let format_spec = Self::parse_format(&buffer.desc.format, vm)?;
93        let desc = buffer.desc.clone();
94
95        Ok(Self {
96            buffer: ManuallyDrop::new(buffer),
97            released: AtomicCell::new(false),
98            start: 0,
99            format_spec,
100            desc,
101            hash: OnceCell::new(),
102        })
103    }
104
105    /// don't use this function to create the memoryview if the buffer is exporting
106    /// via another memoryview, use PyMemoryView::new_view() or PyMemoryView::from_object
107    /// to reduce the chain
108    pub fn from_buffer_range(
109        buffer: PyBuffer,
110        range: Range<usize>,
111        vm: &VirtualMachine,
112    ) -> PyResult<Self> {
113        let mut zelf = Self::from_buffer(buffer, vm)?;
114
115        zelf.init_range(range, 0);
116        zelf.init_len();
117        Ok(zelf)
118    }
119
120    /// this should be the only way to create a memoryview from another memoryview
121    pub fn new_view(&self) -> Self {
122        let zelf = Self {
123            buffer: self.buffer.clone(),
124            released: AtomicCell::new(false),
125            start: self.start,
126            format_spec: self.format_spec.clone(),
127            desc: self.desc.clone(),
128            hash: OnceCell::new(),
129        };
130        zelf.buffer.retain();
131        zelf
132    }
133
134    fn try_not_released(&self, vm: &VirtualMachine) -> PyResult<()> {
135        if self.released.load() {
136            Err(vm.new_value_error("operation forbidden on released memoryview object"))
137        } else {
138            Ok(())
139        }
140    }
141
142    fn getitem_by_idx(&self, i: isize, vm: &VirtualMachine) -> PyResult {
143        if self.desc.ndim() != 1 {
144            return Err(
145                vm.new_not_implemented_error("multi-dimensional sub-views are not implemented")
146            );
147        }
148        let (shape, stride, suboffset) = self.desc.dim_desc[0];
149        let index = i
150            .wrapped_at(shape)
151            .ok_or_else(|| vm.new_index_error("index out of range"))?;
152        let index = index as isize * stride + suboffset;
153        let pos = (index + self.start as isize) as usize;
154        self.unpack_single(pos, vm)
155    }
156
157    fn getitem_by_slice(&self, slice: &PySlice, vm: &VirtualMachine) -> PyResult {
158        let mut other = self.new_view();
159        other.init_slice(slice, 0, vm)?;
160        other.init_len();
161
162        Ok(other.into_ref(&vm.ctx).into())
163    }
164
165    fn getitem_by_multi_idx(&self, indexes: &[isize], vm: &VirtualMachine) -> PyResult {
166        let pos = self.pos_from_multi_index(indexes, vm)?;
167        let bytes = self.buffer.obj_bytes();
168        format_unpack(&self.format_spec, &bytes[pos..pos + self.desc.itemsize], vm)
169    }
170
171    fn setitem_by_idx(&self, i: isize, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
172        if self.desc.ndim() != 1 {
173            return Err(vm.new_not_implemented_error("sub-views are not implemented"));
174        }
175        let (shape, stride, suboffset) = self.desc.dim_desc[0];
176        let index = i
177            .wrapped_at(shape)
178            .ok_or_else(|| vm.new_index_error("index out of range"))?;
179        let index = index as isize * stride + suboffset;
180        let pos = (index + self.start as isize) as usize;
181        self.pack_single(pos, value, vm)
182    }
183
184    fn setitem_by_multi_idx(
185        &self,
186        indexes: &[isize],
187        value: PyObjectRef,
188        vm: &VirtualMachine,
189    ) -> PyResult<()> {
190        let pos = self.pos_from_multi_index(indexes, vm)?;
191        self.pack_single(pos, value, vm)
192    }
193
194    fn pack_single(&self, pos: usize, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
195        let mut bytes = self.buffer.obj_bytes_mut();
196        // TODO: Optimize
197        let data = self.format_spec.pack(vec![value], vm).map_err(|_| {
198            vm.new_type_error(format!(
199                "memoryview: invalid type for format '{}'",
200                &self.desc.format
201            ))
202        })?;
203        bytes[pos..pos + self.desc.itemsize].copy_from_slice(&data);
204        Ok(())
205    }
206
207    fn unpack_single(&self, pos: usize, vm: &VirtualMachine) -> PyResult {
208        let bytes = self.buffer.obj_bytes();
209        // TODO: Optimize
210        self.format_spec
211            .unpack(&bytes[pos..pos + self.desc.itemsize], vm)
212            .map(|x| {
213                if x.len() == 1 {
214                    x[0].to_owned()
215                } else {
216                    x.into()
217                }
218            })
219    }
220
221    fn pos_from_multi_index(&self, indexes: &[isize], vm: &VirtualMachine) -> PyResult<usize> {
222        match indexes.len().cmp(&self.desc.ndim()) {
223            Ordering::Less => {
224                return Err(vm.new_not_implemented_error("sub-views are not implemented"));
225            }
226            Ordering::Greater => {
227                return Err(vm.new_type_error(format!(
228                    "cannot index {}-dimension view with {}-element tuple",
229                    self.desc.ndim(),
230                    indexes.len()
231                )));
232            }
233            Ordering::Equal => (),
234        }
235
236        let pos = self.desc.position(indexes, vm)?;
237        let pos = (pos + self.start as isize) as usize;
238        Ok(pos)
239    }
240
241    fn init_len(&mut self) {
242        let product: usize = self.desc.dim_desc.iter().map(|x| x.0).product();
243        self.desc.len = product * self.desc.itemsize;
244    }
245
246    fn init_range(&mut self, range: Range<usize>, dim: usize) {
247        let (shape, stride, _) = self.desc.dim_desc[dim];
248        debug_assert!(shape >= range.len());
249
250        let mut is_adjusted = false;
251        for (_, _, suboffset) in self.desc.dim_desc.iter_mut().rev() {
252            if *suboffset != 0 {
253                *suboffset += stride * range.start as isize;
254                is_adjusted = true;
255                break;
256            }
257        }
258        if !is_adjusted {
259            // no suboffset set, stride must be positive
260            self.start += stride as usize * range.start;
261        }
262        let new_len = range.len();
263        self.desc.dim_desc[dim].0 = new_len;
264    }
265
266    fn init_slice(&mut self, slice: &PySlice, dim: usize, vm: &VirtualMachine) -> PyResult<()> {
267        let (shape, stride, _) = self.desc.dim_desc[dim];
268        let slice = slice.to_saturated(vm)?;
269        let (range, step, slice_len) = slice.adjust_indices(shape);
270
271        let mut is_adjusted_suboffset = false;
272        for (_, _, suboffset) in self.desc.dim_desc.iter_mut().rev() {
273            if *suboffset != 0 {
274                *suboffset += stride * range.start as isize;
275                is_adjusted_suboffset = true;
276                break;
277            }
278        }
279        if !is_adjusted_suboffset {
280            // no suboffset set, stride must be positive
281            self.start += stride as usize
282                * if step.is_negative() {
283                    range.end - 1
284                } else {
285                    range.start
286                };
287        }
288        self.desc.dim_desc[dim].0 = slice_len;
289        self.desc.dim_desc[dim].1 *= step;
290
291        Ok(())
292    }
293
294    fn _to_list(
295        &self,
296        bytes: &[u8],
297        mut index: isize,
298        dim: usize,
299        vm: &VirtualMachine,
300    ) -> PyResult<PyListRef> {
301        let (shape, stride, suboffset) = self.desc.dim_desc[dim];
302        if dim + 1 == self.desc.ndim() {
303            let mut v = Vec::with_capacity(shape);
304            for _ in 0..shape {
305                let pos = index + suboffset;
306                let pos = (pos + self.start as isize) as usize;
307                let obj =
308                    format_unpack(&self.format_spec, &bytes[pos..pos + self.desc.itemsize], vm)?;
309                v.push(obj);
310                index += stride;
311            }
312            return Ok(vm.ctx.new_list(v));
313        }
314
315        let mut v = Vec::with_capacity(shape);
316        for _ in 0..shape {
317            let obj = self._to_list(bytes, index + suboffset, dim + 1, vm)?.into();
318            v.push(obj);
319            index += stride;
320        }
321        Ok(vm.ctx.new_list(v))
322    }
323
324    fn eq(zelf: &Py<Self>, other: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
325        if zelf.is(other) {
326            return Ok(true);
327        }
328        if zelf.released.load() {
329            return Ok(false);
330        }
331
332        if let Some(other) = other.downcast_ref::<Self>()
333            && other.released.load()
334        {
335            return Ok(false);
336        }
337
338        let other = match PyBuffer::try_from_borrowed_object(vm, other) {
339            Ok(buf) => buf,
340            Err(_) => return Ok(false),
341        };
342
343        if !is_equiv_shape(&zelf.desc, &other.desc) {
344            return Ok(false);
345        }
346
347        let a_itemsize = zelf.desc.itemsize;
348        let b_itemsize = other.desc.itemsize;
349        let a_format_spec = &zelf.format_spec;
350        let b_format_spec = &Self::parse_format(&other.desc.format, vm)?;
351
352        if zelf.desc.ndim() == 0 {
353            let a_val = format_unpack(a_format_spec, &zelf.buffer.obj_bytes()[..a_itemsize], vm)?;
354            let b_val = format_unpack(b_format_spec, &other.obj_bytes()[..b_itemsize], vm)?;
355            return vm.bool_eq(&a_val, &b_val);
356        }
357
358        // TODO: optimize cmp by format
359        let mut ret = Ok(true);
360        let a_bytes = zelf.buffer.obj_bytes();
361        let b_bytes = other.obj_bytes();
362        zelf.desc.zip_eq(&other.desc, false, |a_range, b_range| {
363            let a_range = (a_range.start + zelf.start as isize) as usize
364                ..(a_range.end + zelf.start as isize) as usize;
365            let b_range = b_range.start as usize..b_range.end as usize;
366            let a_val = match format_unpack(a_format_spec, &a_bytes[a_range], vm) {
367                Ok(val) => val,
368                Err(e) => {
369                    ret = Err(e);
370                    return true;
371                }
372            };
373            let b_val = match format_unpack(b_format_spec, &b_bytes[b_range], vm) {
374                Ok(val) => val,
375                Err(e) => {
376                    ret = Err(e);
377                    return true;
378                }
379            };
380            ret = vm.bool_eq(&a_val, &b_val);
381            if let Ok(b) = ret { !b } else { true }
382        });
383        ret
384    }
385
386    fn obj_bytes(&self) -> BorrowedValue<'_, [u8]> {
387        if self.desc.is_contiguous() {
388            BorrowedValue::map(self.buffer.obj_bytes(), |x| {
389                &x[self.start..self.start + self.desc.len]
390            })
391        } else {
392            BorrowedValue::map(self.buffer.obj_bytes(), |x| &x[self.start..])
393        }
394    }
395
396    fn obj_bytes_mut(&self) -> BorrowedValueMut<'_, [u8]> {
397        if self.desc.is_contiguous() {
398            BorrowedValueMut::map(self.buffer.obj_bytes_mut(), |x| {
399                &mut x[self.start..self.start + self.desc.len]
400            })
401        } else {
402            BorrowedValueMut::map(self.buffer.obj_bytes_mut(), |x| &mut x[self.start..])
403        }
404    }
405
406    fn as_contiguous(&self) -> Option<BorrowedValue<'_, [u8]>> {
407        self.desc.is_contiguous().then(|| {
408            BorrowedValue::map(self.buffer.obj_bytes(), |x| {
409                &x[self.start..self.start + self.desc.len]
410            })
411        })
412    }
413
414    fn _as_contiguous_mut(&self) -> Option<BorrowedValueMut<'_, [u8]>> {
415        self.desc.is_contiguous().then(|| {
416            BorrowedValueMut::map(self.buffer.obj_bytes_mut(), |x| {
417                &mut x[self.start..self.start + self.desc.len]
418            })
419        })
420    }
421
422    fn append_to(&self, buf: &mut Vec<u8>) {
423        if let Some(bytes) = self.as_contiguous() {
424            buf.extend_from_slice(&bytes);
425        } else {
426            buf.reserve(self.desc.len);
427            let bytes = &*self.buffer.obj_bytes();
428            self.desc.for_each_segment(true, |range| {
429                let start = (range.start + self.start as isize) as usize;
430                let end = (range.end + self.start as isize) as usize;
431                buf.extend_from_slice(&bytes[start..end]);
432            })
433        }
434    }
435
436    fn contiguous_or_collect<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
437        let borrowed;
438        let mut collected;
439        let v = if let Some(bytes) = self.as_contiguous() {
440            borrowed = bytes;
441            &*borrowed
442        } else {
443            collected = vec![];
444            self.append_to(&mut collected);
445            &collected
446        };
447        f(v)
448    }
449
450    /// clone data from memoryview
451    /// keep the shape, convert to contiguous
452    pub fn to_contiguous(&self, vm: &VirtualMachine) -> PyBuffer {
453        let mut data = vec![];
454        self.append_to(&mut data);
455
456        if self.desc.ndim() == 0 {
457            return VecBuffer::from(data)
458                .into_ref(&vm.ctx)
459                .into_pybuffer_with_descriptor(self.desc.clone());
460        }
461
462        let mut dim_desc = self.desc.dim_desc.clone();
463        dim_desc.last_mut().unwrap().1 = self.desc.itemsize as isize;
464        dim_desc.last_mut().unwrap().2 = 0;
465        for i in (0..dim_desc.len() - 1).rev() {
466            dim_desc[i].1 = dim_desc[i + 1].1 * dim_desc[i + 1].0 as isize;
467            dim_desc[i].2 = 0;
468        }
469
470        let desc = BufferDescriptor {
471            len: self.desc.len,
472            readonly: self.desc.readonly,
473            itemsize: self.desc.itemsize,
474            format: self.desc.format.clone(),
475            dim_desc,
476        };
477
478        VecBuffer::from(data)
479            .into_ref(&vm.ctx)
480            .into_pybuffer_with_descriptor(desc)
481    }
482}
483
484impl Py<PyMemoryView> {
485    fn setitem_by_slice(
486        &self,
487        slice: &PySlice,
488        src: PyObjectRef,
489        vm: &VirtualMachine,
490    ) -> PyResult<()> {
491        if self.desc.ndim() != 1 {
492            return Err(vm.new_not_implemented_error("sub-view are not implemented"));
493        }
494
495        let mut dest = self.new_view();
496        dest.init_slice(slice, 0, vm)?;
497        dest.init_len();
498
499        if self.is(&src) {
500            return if !is_equiv_structure(&self.desc, &dest.desc) {
501                Err(vm.new_value_error(
502                    "memoryview assignment: lvalue and rvalue have different structures",
503                ))
504            } else {
505                // assign self[:] to self
506                Ok(())
507            };
508        };
509
510        let src = if let Some(src) = src.downcast_ref::<PyMemoryView>() {
511            if self.buffer.obj.is(&src.buffer.obj) {
512                src.to_contiguous(vm)
513            } else {
514                AsBuffer::as_buffer(src, vm)?
515            }
516        } else {
517            PyBuffer::try_from_object(vm, src)?
518        };
519
520        if !is_equiv_structure(&src.desc, &dest.desc) {
521            return Err(vm.new_value_error(
522                "memoryview assignment: lvalue and rvalue have different structures",
523            ));
524        }
525
526        let mut bytes_mut = dest.buffer.obj_bytes_mut();
527        let src_bytes = src.obj_bytes();
528        dest.desc.zip_eq(&src.desc, true, |a_range, b_range| {
529            let a_range = (a_range.start + dest.start as isize) as usize
530                ..(a_range.end + dest.start as isize) as usize;
531            let b_range = b_range.start as usize..b_range.end as usize;
532            bytes_mut[a_range].copy_from_slice(&src_bytes[b_range]);
533            false
534        });
535
536        Ok(())
537    }
538}
539
540#[pyclass(
541    with(
542        Py,
543        Hashable,
544        Comparable,
545        AsBuffer,
546        AsMapping,
547        AsSequence,
548        Constructor,
549        Iterable,
550        Representable
551    ),
552    flags(SEQUENCE, HAS_WEAKREF)
553)]
554impl PyMemoryView {
555    #[pyclassmethod]
556    fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
557        PyGenericAlias::from_args(cls, args, vm)
558    }
559
560    #[pymethod]
561    pub fn release(&self) {
562        if self.released.compare_exchange(false, true).is_ok() {
563            self.buffer.release();
564        }
565    }
566
567    #[pygetset]
568    fn obj(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
569        self.try_not_released(vm).map(|_| self.buffer.obj.clone())
570    }
571
572    #[pygetset]
573    fn nbytes(&self, vm: &VirtualMachine) -> PyResult<usize> {
574        self.try_not_released(vm).map(|_| self.desc.len)
575    }
576
577    #[pygetset]
578    fn readonly(&self, vm: &VirtualMachine) -> PyResult<bool> {
579        self.try_not_released(vm).map(|_| self.desc.readonly)
580    }
581
582    #[pygetset]
583    fn itemsize(&self, vm: &VirtualMachine) -> PyResult<usize> {
584        self.try_not_released(vm).map(|_| self.desc.itemsize)
585    }
586
587    #[pygetset]
588    fn ndim(&self, vm: &VirtualMachine) -> PyResult<usize> {
589        self.try_not_released(vm).map(|_| self.desc.ndim())
590    }
591
592    #[pygetset]
593    fn shape(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
594        self.try_not_released(vm)?;
595        Ok(vm.ctx.new_tuple(
596            self.desc
597                .dim_desc
598                .iter()
599                .map(|(shape, _, _)| shape.to_pyobject(vm))
600                .collect(),
601        ))
602    }
603
604    #[pygetset]
605    fn strides(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
606        self.try_not_released(vm)?;
607        Ok(vm.ctx.new_tuple(
608            self.desc
609                .dim_desc
610                .iter()
611                .map(|(_, stride, _)| stride.to_pyobject(vm))
612                .collect(),
613        ))
614    }
615
616    #[pygetset]
617    fn suboffsets(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
618        self.try_not_released(vm)?;
619        let has_suboffsets = self
620            .desc
621            .dim_desc
622            .iter()
623            .any(|(_, _, suboffset)| *suboffset != 0);
624        if has_suboffsets {
625            Ok(vm.ctx.new_tuple(
626                self.desc
627                    .dim_desc
628                    .iter()
629                    .map(|(_, _, suboffset)| suboffset.to_pyobject(vm))
630                    .collect(),
631            ))
632        } else {
633            Ok(vm.ctx.empty_tuple.clone())
634        }
635    }
636
637    #[pygetset]
638    fn format(&self, vm: &VirtualMachine) -> PyResult<PyStr> {
639        self.try_not_released(vm)
640            .map(|_| PyStr::from(self.desc.format.clone()))
641    }
642
643    #[pygetset]
644    fn contiguous(&self, vm: &VirtualMachine) -> PyResult<bool> {
645        self.try_not_released(vm).map(|_| self.desc.is_contiguous())
646    }
647
648    #[pygetset]
649    fn c_contiguous(&self, vm: &VirtualMachine) -> PyResult<bool> {
650        self.try_not_released(vm).map(|_| self.desc.is_contiguous())
651    }
652
653    #[pygetset]
654    fn f_contiguous(&self, vm: &VirtualMachine) -> PyResult<bool> {
655        // TODO: column-major order
656        self.try_not_released(vm)
657            .map(|_| self.desc.ndim() <= 1 && self.desc.is_contiguous())
658    }
659
660    #[pymethod]
661    fn __enter__(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
662        zelf.try_not_released(vm).map(|_| zelf)
663    }
664
665    #[pymethod]
666    fn __exit__(&self, _args: FuncArgs) {
667        self.release();
668    }
669
670    fn __getitem__(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
671        zelf.try_not_released(vm)?;
672        if zelf.desc.ndim() == 0 {
673            // 0-d memoryview can be referenced using mv[...] or mv[()] only
674            if needle.is(&vm.ctx.ellipsis) {
675                return Ok(zelf.into());
676            }
677            if let Some(tuple) = needle.downcast_ref::<PyTuple>()
678                && tuple.is_empty()
679            {
680                return zelf.unpack_single(0, vm);
681            }
682            return Err(vm.new_type_error("invalid indexing of 0-dim memory"));
683        }
684
685        match SubscriptNeedle::try_from_object(vm, needle)? {
686            SubscriptNeedle::Index(i) => zelf.getitem_by_idx(i, vm),
687            SubscriptNeedle::Slice(slice) => zelf.getitem_by_slice(&slice, vm),
688            SubscriptNeedle::MultiIndex(indices) => zelf.getitem_by_multi_idx(&indices, vm),
689        }
690    }
691
692    fn __delitem__(&self, _needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
693        if self.desc.readonly {
694            return Err(vm.new_type_error("cannot modify read-only memory"));
695        }
696        Err(vm.new_type_error("cannot delete memory"))
697    }
698
699    fn __len__(&self, vm: &VirtualMachine) -> PyResult<usize> {
700        self.try_not_released(vm)?;
701        if self.desc.ndim() == 0 {
702            // 0-dimensional memoryview has no length
703            Err(vm.new_type_error("0-dim memory has no length"))
704        } else {
705            // shape for dim[0]
706            Ok(self.desc.dim_desc[0].0)
707        }
708    }
709
710    #[pymethod]
711    fn tobytes(&self, vm: &VirtualMachine) -> PyResult<PyBytesRef> {
712        self.try_not_released(vm)?;
713        let mut v = vec![];
714        self.append_to(&mut v);
715        Ok(PyBytes::from(v).into_ref(&vm.ctx))
716    }
717
718    #[pymethod]
719    fn tolist(&self, vm: &VirtualMachine) -> PyResult<PyListRef> {
720        self.try_not_released(vm)?;
721        let bytes = self.buffer.obj_bytes();
722        if self.desc.ndim() == 0 {
723            return Ok(vm.ctx.new_list(vec![format_unpack(
724                &self.format_spec,
725                &bytes[..self.desc.itemsize],
726                vm,
727            )?]));
728        }
729        self._to_list(&bytes, 0, 0, vm)
730    }
731
732    #[pymethod]
733    fn toreadonly(&self, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
734        self.try_not_released(vm)?;
735        let mut other = self.new_view();
736        other.desc.readonly = true;
737        Ok(other.into_ref(&vm.ctx))
738    }
739
740    #[pymethod]
741    fn hex(
742        &self,
743        sep: OptionalArg<Either<PyStrRef, PyBytesRef>>,
744        bytes_per_sep: OptionalArg<isize>,
745        vm: &VirtualMachine,
746    ) -> PyResult<String> {
747        self.try_not_released(vm)?;
748        self.contiguous_or_collect(|x| bytes_to_hex(x, sep, bytes_per_sep, vm))
749    }
750
751    #[pymethod]
752    fn count(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
753        self.try_not_released(vm)?;
754        if self.desc.ndim() != 1 {
755            return Err(
756                vm.new_not_implemented_error("multi-dimensional sub-views are not implemented")
757            );
758        }
759        let len = self.desc.dim_desc[0].0;
760        let mut count = 0;
761        for i in 0..len {
762            let item = self.getitem_by_idx(i as isize, vm)?;
763            if vm.bool_eq(&item, &value)? {
764                count += 1;
765            }
766        }
767        Ok(count)
768    }
769
770    #[pymethod]
771    fn index(
772        &self,
773        value: PyObjectRef,
774        start: OptionalArg<isize>,
775        stop: OptionalArg<isize>,
776        vm: &VirtualMachine,
777    ) -> PyResult<usize> {
778        self.try_not_released(vm)?;
779        if self.desc.ndim() != 1 {
780            return Err(
781                vm.new_not_implemented_error("multi-dimensional sub-views are not implemented")
782            );
783        }
784        let len = self.desc.dim_desc[0].0;
785        let start = start.unwrap_or(0);
786        let stop = stop.unwrap_or(len as isize);
787
788        let start = if start < 0 {
789            (start + len as isize).max(0) as usize
790        } else {
791            (start as usize).min(len)
792        };
793        let stop = if stop < 0 {
794            (stop + len as isize).max(0) as usize
795        } else {
796            (stop as usize).min(len)
797        };
798
799        for i in start..stop {
800            let item = self.getitem_by_idx(i as isize, vm)?;
801            if vm.bool_eq(&item, &value)? {
802                return Ok(i);
803            }
804        }
805        Err(vm.new_value_error("memoryview.index(x): x not in memoryview"))
806    }
807
808    fn cast_to_1d(&self, format: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult<Self> {
809        let format_str = format.as_str();
810        let format_spec = Self::parse_format(format_str, vm)?;
811        let itemsize = format_spec.size();
812        if !self.desc.len.is_multiple_of(itemsize) {
813            return Err(vm.new_type_error("memoryview: length is not a multiple of itemsize"));
814        }
815
816        Ok(Self {
817            buffer: self.buffer.clone(),
818            released: AtomicCell::new(false),
819            start: self.start,
820            format_spec,
821            desc: BufferDescriptor {
822                len: self.desc.len,
823                readonly: self.desc.readonly,
824                itemsize,
825                format: format_str.to_owned().into(),
826                dim_desc: vec![(self.desc.len / itemsize, itemsize as isize, 0)],
827            },
828            hash: OnceCell::new(),
829        })
830    }
831
832    #[pymethod]
833    fn cast(&self, args: CastArgs, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
834        self.try_not_released(vm)?;
835        if !self.desc.is_contiguous() {
836            return Err(vm.new_type_error("memoryview: casts are restricted to C-contiguous views"));
837        }
838
839        let CastArgs { format, shape } = args;
840
841        if let OptionalArg::Present(shape) = shape {
842            if self.desc.is_zero_in_shape() {
843                return Err(vm.new_type_error(
844                    "memoryview: cannot cast view with zeros in shape or strides",
845                ));
846            }
847
848            let tup;
849            let list;
850            let list_borrow;
851            let shape = match shape {
852                Either::A(shape) => {
853                    tup = shape;
854                    tup.as_slice()
855                }
856                Either::B(shape) => {
857                    list = shape;
858                    list_borrow = list.borrow_vec();
859                    &list_borrow
860                }
861            };
862
863            let shape_ndim = shape.len();
864            // TODO: MAX_NDIM
865            if self.desc.ndim() != 1 && shape_ndim != 1 {
866                return Err(vm.new_type_error("memoryview: cast must be 1D -> ND or ND -> 1D"));
867            }
868
869            let mut other = self.cast_to_1d(format, vm)?;
870            let itemsize = other.desc.itemsize;
871
872            // 0 ndim is single item
873            if shape_ndim == 0 {
874                other.desc.dim_desc = vec![];
875                other.desc.len = itemsize;
876                return Ok(other.into_ref(&vm.ctx));
877            }
878
879            let mut product_shape = itemsize;
880            let mut dim_descriptor = Vec::with_capacity(shape_ndim);
881
882            for x in shape {
883                let x = usize::try_from_borrowed_object(vm, x)?;
884
885                if x > isize::MAX as usize / product_shape {
886                    return Err(vm.new_value_error("memoryview.cast(): product(shape) > SSIZE_MAX"));
887                }
888                product_shape *= x;
889                dim_descriptor.push((x, 0, 0));
890            }
891
892            dim_descriptor.last_mut().unwrap().1 = itemsize as isize;
893            for i in (0..dim_descriptor.len() - 1).rev() {
894                dim_descriptor[i].1 = dim_descriptor[i + 1].1 * dim_descriptor[i + 1].0 as isize;
895            }
896
897            if product_shape != other.desc.len {
898                return Err(
899                    vm.new_type_error("memoryview: product(shape) * itemsize != buffer size")
900                );
901            }
902
903            other.desc.dim_desc = dim_descriptor;
904
905            Ok(other.into_ref(&vm.ctx))
906        } else {
907            Ok(self.cast_to_1d(format, vm)?.into_ref(&vm.ctx))
908        }
909    }
910}
911
912#[pyclass]
913impl Py<PyMemoryView> {
914    fn __setitem__(
915        &self,
916        needle: PyObjectRef,
917        value: PyObjectRef,
918        vm: &VirtualMachine,
919    ) -> PyResult<()> {
920        self.try_not_released(vm)?;
921        if self.desc.readonly {
922            return Err(vm.new_type_error("cannot modify read-only memory"));
923        }
924        if value.is(&vm.ctx.none) {
925            return Err(vm.new_type_error("cannot delete memory"));
926        }
927
928        if self.desc.ndim() == 0 {
929            // TODO: merge branches when we got conditional if let
930            if needle.is(&vm.ctx.ellipsis) {
931                return self.pack_single(0, value, vm);
932            } else if let Some(tuple) = needle.downcast_ref::<PyTuple>()
933                && tuple.is_empty()
934            {
935                return self.pack_single(0, value, vm);
936            }
937            return Err(vm.new_type_error("invalid indexing of 0-dim memory"));
938        }
939        match SubscriptNeedle::try_from_object(vm, needle)? {
940            SubscriptNeedle::Index(i) => self.setitem_by_idx(i, value, vm),
941            SubscriptNeedle::Slice(slice) => self.setitem_by_slice(&slice, value, vm),
942            SubscriptNeedle::MultiIndex(indices) => self.setitem_by_multi_idx(&indices, value, vm),
943        }
944    }
945
946    #[pymethod]
947    fn __reduce_ex__(&self, _proto: usize, vm: &VirtualMachine) -> PyResult {
948        self.__reduce__(vm)
949    }
950
951    #[pymethod]
952    fn __reduce__(&self, vm: &VirtualMachine) -> PyResult {
953        Err(vm.new_type_error("cannot pickle 'memoryview' object"))
954    }
955}
956
957#[derive(FromArgs)]
958struct CastArgs {
959    #[pyarg(any)]
960    format: PyUtf8StrRef,
961    #[pyarg(any, optional)]
962    shape: OptionalArg<Either<PyTupleRef, PyListRef>>,
963}
964
965enum SubscriptNeedle {
966    Index(isize),
967    Slice(PyRef<PySlice>),
968    MultiIndex(Vec<isize>),
969    // MultiSlice(Vec<PySliceRef>),
970}
971
972impl TryFromObject for SubscriptNeedle {
973    fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
974        // TODO: number protocol
975        if let Some(i) = obj.downcast_ref::<PyInt>() {
976            Ok(Self::Index(i.try_to_primitive(vm)?))
977        } else if obj.downcastable::<PySlice>() {
978            Ok(Self::Slice(unsafe { obj.downcast_unchecked::<PySlice>() }))
979        } else if let Ok(i) = obj.try_index(vm) {
980            Ok(Self::Index(i.try_to_primitive(vm)?))
981        } else {
982            if let Some(tuple) = obj.downcast_ref::<PyTuple>() {
983                if tuple.iter().all(|x| x.downcastable::<PyInt>()) {
984                    let v = tuple
985                        .iter()
986                        .map(|x| {
987                            unsafe { x.downcast_unchecked_ref::<PyInt>() }
988                                .try_to_primitive::<isize>(vm)
989                        })
990                        .try_collect()?;
991                    return Ok(Self::MultiIndex(v));
992                } else if tuple.iter().all(|x| x.downcastable::<PySlice>()) {
993                    return Err(vm.new_not_implemented_error(
994                        "multi-dimensional slicing is not implemented",
995                    ));
996                }
997            }
998            Err(vm.new_type_error("memoryview: invalid slice key"))
999        }
1000    }
1001}
1002
1003static BUFFER_METHODS: BufferMethods = BufferMethods {
1004    obj_bytes: |buffer| buffer.obj_as::<PyMemoryView>().obj_bytes(),
1005    obj_bytes_mut: |buffer| buffer.obj_as::<PyMemoryView>().obj_bytes_mut(),
1006    release: |buffer| buffer.obj_as::<PyMemoryView>().buffer.release(),
1007    retain: |buffer| buffer.obj_as::<PyMemoryView>().buffer.retain(),
1008};
1009
1010impl AsBuffer for PyMemoryView {
1011    fn as_buffer(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyBuffer> {
1012        if zelf.released.load() {
1013            Err(vm.new_value_error("operation forbidden on released memoryview object"))
1014        } else {
1015            Ok(PyBuffer::new(
1016                zelf.to_owned().into(),
1017                zelf.desc.clone(),
1018                &BUFFER_METHODS,
1019            ))
1020        }
1021    }
1022}
1023
1024impl Drop for PyMemoryView {
1025    fn drop(&mut self) {
1026        if self.released.load() {
1027            unsafe { self.buffer.drop_without_release() };
1028        } else {
1029            unsafe { ManuallyDrop::drop(&mut self.buffer) };
1030        }
1031    }
1032}
1033
1034impl AsMapping for PyMemoryView {
1035    fn as_mapping() -> &'static PyMappingMethods {
1036        static AS_MAPPING: PyMappingMethods = PyMappingMethods {
1037            length: atomic_func!(|mapping, vm| PyMemoryView::mapping_downcast(mapping).__len__(vm)),
1038            subscript: atomic_func!(|mapping, needle, vm| {
1039                let zelf = PyMemoryView::mapping_downcast(mapping);
1040                PyMemoryView::__getitem__(zelf.to_owned(), needle.to_owned(), vm)
1041            }),
1042            ass_subscript: atomic_func!(|mapping, needle, value, vm| {
1043                let zelf = PyMemoryView::mapping_downcast(mapping);
1044                if let Some(value) = value {
1045                    zelf.__setitem__(needle.to_owned(), value, vm)
1046                } else {
1047                    Err(vm.new_type_error("cannot delete memory".to_owned()))
1048                }
1049            }),
1050        };
1051        &AS_MAPPING
1052    }
1053}
1054
1055impl AsSequence for PyMemoryView {
1056    fn as_sequence() -> &'static PySequenceMethods {
1057        static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
1058            length: atomic_func!(|seq, vm| {
1059                let zelf = PyMemoryView::sequence_downcast(seq);
1060                zelf.try_not_released(vm)?;
1061                zelf.__len__(vm)
1062            }),
1063            item: atomic_func!(|seq, i, vm| {
1064                let zelf = PyMemoryView::sequence_downcast(seq);
1065                zelf.try_not_released(vm)?;
1066                zelf.getitem_by_idx(i, vm)
1067            }),
1068            ..PySequenceMethods::NOT_IMPLEMENTED
1069        });
1070        &AS_SEQUENCE
1071    }
1072}
1073
1074impl Comparable for PyMemoryView {
1075    fn cmp(
1076        zelf: &Py<Self>,
1077        other: &PyObject,
1078        op: PyComparisonOp,
1079        vm: &VirtualMachine,
1080    ) -> PyResult<PyComparisonValue> {
1081        match op {
1082            PyComparisonOp::Ne => {
1083                Self::eq(zelf, other, vm).map(|x| PyComparisonValue::Implemented(!x))
1084            }
1085            PyComparisonOp::Eq => Self::eq(zelf, other, vm).map(PyComparisonValue::Implemented),
1086            _ => Err(vm.new_type_error(format!(
1087                "'{}' not supported between instances of '{}' and '{}'",
1088                op.operator_token(),
1089                zelf.class().name(),
1090                other.class().name()
1091            ))),
1092        }
1093    }
1094}
1095
1096impl Hashable for PyMemoryView {
1097    fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
1098        if let Some(val) = zelf.hash.get() {
1099            return Ok(*val);
1100        }
1101        zelf.try_not_released(vm)?;
1102        if !zelf.desc.readonly {
1103            return Err(vm.new_value_error("cannot hash writable memoryview object"));
1104        }
1105        let val = zelf.contiguous_or_collect(|bytes| vm.state.hash_secret.hash_bytes(bytes));
1106        let _ = zelf.hash.set(val);
1107        Ok(*zelf.hash.get().unwrap())
1108    }
1109}
1110
1111impl PyPayload for PyMemoryView {
1112    #[inline]
1113    fn class(ctx: &Context) -> &'static Py<PyType> {
1114        ctx.types.memoryview_type
1115    }
1116}
1117
1118impl Representable for PyMemoryView {
1119    #[inline]
1120    fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
1121        let repr = if zelf.released.load() {
1122            format!("<released memory at {:#x}>", zelf.get_id())
1123        } else {
1124            format!("<memory at {:#x}>", zelf.get_id())
1125        };
1126        Ok(repr)
1127    }
1128}
1129
1130pub(crate) fn init(ctx: &'static Context) {
1131    PyMemoryView::extend_class(ctx, ctx.types.memoryview_type);
1132    PyMemoryViewIterator::extend_class(ctx, ctx.types.memoryviewiterator_type);
1133}
1134
1135fn format_unpack(
1136    format_spec: &FormatSpec,
1137    bytes: &[u8],
1138    vm: &VirtualMachine,
1139) -> PyResult<PyObjectRef> {
1140    format_spec.unpack(bytes, vm).map(|x| {
1141        if x.len() == 1 {
1142            x[0].to_owned()
1143        } else {
1144            x.into()
1145        }
1146    })
1147}
1148
1149fn is_equiv_shape(a: &BufferDescriptor, b: &BufferDescriptor) -> bool {
1150    if a.ndim() != b.ndim() {
1151        return false;
1152    }
1153
1154    let a_iter = a.dim_desc.iter().map(|x| x.0);
1155    let b_iter = b.dim_desc.iter().map(|x| x.0);
1156    for (a_shape, b_shape) in a_iter.zip(b_iter) {
1157        if a_shape != b_shape {
1158            return false;
1159        }
1160        // if both shape is 0, ignore the rest
1161        if a_shape == 0 {
1162            break;
1163        }
1164    }
1165    true
1166}
1167
1168fn is_equiv_format(a: &BufferDescriptor, b: &BufferDescriptor) -> bool {
1169    // TODO: skip @
1170    a.itemsize == b.itemsize && a.format == b.format
1171}
1172
1173fn is_equiv_structure(a: &BufferDescriptor, b: &BufferDescriptor) -> bool {
1174    is_equiv_format(a, b) && is_equiv_shape(a, b)
1175}
1176
1177impl Iterable for PyMemoryView {
1178    fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
1179        Ok(PyMemoryViewIterator {
1180            internal: PyMutex::new(PositionIterInternal::new(zelf, 0)),
1181        }
1182        .into_pyobject(vm))
1183    }
1184}
1185
1186#[pyclass(module = false, name = "memory_iterator")]
1187#[derive(Debug, Traverse)]
1188pub struct PyMemoryViewIterator {
1189    internal: PyMutex<PositionIterInternal<PyRef<PyMemoryView>>>,
1190}
1191
1192impl PyPayload for PyMemoryViewIterator {
1193    fn class(ctx: &Context) -> &'static Py<PyType> {
1194        ctx.types.memoryviewiterator_type
1195    }
1196}
1197
1198#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))]
1199impl PyMemoryViewIterator {
1200    #[pymethod]
1201    fn __reduce__(&self, vm: &VirtualMachine) -> PyTupleRef {
1202        let func = builtins_iter(vm);
1203        self.internal.lock().reduce(
1204            func,
1205            |x| x.clone().into(),
1206            |vm| vm.ctx.empty_tuple.clone().into(),
1207            vm,
1208        )
1209    }
1210}
1211
1212impl SelfIter for PyMemoryViewIterator {}
1213impl IterNext for PyMemoryViewIterator {
1214    fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
1215        zelf.internal.lock().next(|mv, pos| {
1216            let len = mv.__len__(vm)?;
1217            Ok(if pos >= len {
1218                PyIterReturn::StopIteration(None)
1219            } else {
1220                PyIterReturn::Return(mv.getitem_by_idx(pos.try_into().unwrap(), vm)?)
1221            })
1222        })
1223    }
1224}