1use super::{
2 PyGenericAlias, PyInt, PyIntRef, PySlice, PyTupleRef, PyType, PyTypeRef, builtins_iter,
3 tuple::tuple_hash,
4};
5use crate::common::lock::LazyLock;
6use crate::{
7 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject,
8 VirtualMachine, atomic_func,
9 class::PyClassImpl,
10 common::hash::PyHash,
11 function::{ArgIndex, FuncArgs, OptionalArg, PyComparisonValue},
12 protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
13 types::{
14 AsMapping, AsNumber, AsSequence, Comparable, Hashable, IterNext, Iterable, PyComparisonOp,
15 Representable, SelfIter,
16 },
17};
18use core::cell::Cell;
19use core::cmp::max;
20use core::ptr::NonNull;
21use crossbeam_utils::atomic::AtomicCell;
22use malachite_bigint::{BigInt, Sign};
23use num_integer::Integer;
24use num_traits::{One, Signed, ToPrimitive, Zero};
25
26enum SearchType {
28 Count,
29 Contains,
30 Index,
31}
32
33#[inline]
34fn iter_search(
35 obj: &PyObject,
36 item: &PyObject,
37 flag: SearchType,
38 vm: &VirtualMachine,
39) -> PyResult<usize> {
40 let mut count = 0;
41 let iter = obj.get_iter(vm)?;
42 for element in iter.iter_without_hint::<PyObjectRef>(vm)? {
43 if vm.bool_eq(item, &*element?)? {
44 match flag {
45 SearchType::Index => return Ok(count),
46 SearchType::Contains => return Ok(1),
47 SearchType::Count => count += 1,
48 }
49 }
50 }
51 match flag {
52 SearchType::Count => Ok(count),
53 SearchType::Contains => Ok(0),
54 SearchType::Index => Err(vm.new_value_error(format!(
55 "{} not in range",
56 item.repr(vm)
57 .as_ref()
58 .map_or("value".as_ref(), |s| s.as_wtf8())
59 .to_owned()
60 ))),
61 }
62}
63
64#[pyclass(module = false, name = "range")]
65#[derive(Debug, Clone)]
66pub struct PyRange {
67 pub start: PyIntRef,
68 pub stop: PyIntRef,
69 pub step: PyIntRef,
70}
71
72thread_local! {
74 static RANGE_FREELIST: Cell<crate::object::FreeList<PyRange>> = const { Cell::new(crate::object::FreeList::new()) };
75}
76
77impl PyPayload for PyRange {
78 const MAX_FREELIST: usize = 6;
79 const HAS_FREELIST: bool = true;
80
81 #[inline]
82 fn class(ctx: &Context) -> &'static Py<PyType> {
83 ctx.types.range_type
84 }
85
86 #[inline]
87 unsafe fn freelist_push(obj: *mut PyObject) -> bool {
88 RANGE_FREELIST
89 .try_with(|fl| {
90 let mut list = fl.take();
91 let stored = if list.len() < Self::MAX_FREELIST {
92 list.push(obj);
93 true
94 } else {
95 false
96 };
97 fl.set(list);
98 stored
99 })
100 .unwrap_or(false)
101 }
102
103 #[inline]
104 unsafe fn freelist_pop(_payload: &Self) -> Option<NonNull<PyObject>> {
105 RANGE_FREELIST
106 .try_with(|fl| {
107 let mut list = fl.take();
108 let result = list.pop().map(|p| unsafe { NonNull::new_unchecked(p) });
109 fl.set(list);
110 result
111 })
112 .ok()
113 .flatten()
114 }
115}
116
117impl PyRange {
118 #[inline]
119 fn offset(&self, value: &BigInt) -> Option<BigInt> {
120 let start = self.start.as_bigint();
121 let stop = self.stop.as_bigint();
122 let step = self.step.as_bigint();
123 match step.sign() {
124 Sign::Plus if value >= start && value < stop => Some(value - start),
125 Sign::Minus if value <= self.start.as_bigint() && value > stop => Some(start - value),
126 _ => None,
127 }
128 }
129
130 #[inline]
131 pub fn index_of(&self, value: &BigInt) -> Option<BigInt> {
132 let step = self.step.as_bigint();
133 match self.offset(value) {
134 Some(ref offset) if offset.is_multiple_of(step) => Some((offset / step).abs()),
135 Some(_) | None => None,
136 }
137 }
138
139 #[inline]
140 pub fn is_empty(&self) -> bool {
141 self.compute_length().is_zero()
142 }
143
144 #[inline]
145 pub fn forward(&self) -> bool {
146 self.start.as_bigint() < self.stop.as_bigint()
147 }
148
149 #[inline]
150 pub fn get(&self, index: &BigInt) -> Option<BigInt> {
151 let start = self.start.as_bigint();
152 let step = self.step.as_bigint();
153 let stop = self.stop.as_bigint();
154 if self.is_empty() {
155 return None;
156 }
157
158 if index.is_negative() {
159 let length = self.compute_length();
160 let index: BigInt = &length + index;
161 if index.is_negative() {
162 return None;
163 }
164
165 Some(if step.is_one() {
166 start + index
167 } else {
168 start + step * index
169 })
170 } else {
171 let index = if step.is_one() {
172 start + index
173 } else {
174 start + step * index
175 };
176
177 if (step.is_positive() && stop > &index) || (step.is_negative() && stop < &index) {
178 Some(index)
179 } else {
180 None
181 }
182 }
183 }
184
185 #[inline]
186 fn compute_length(&self) -> BigInt {
187 let start = self.start.as_bigint();
188 let stop = self.stop.as_bigint();
189 let step = self.step.as_bigint();
190
191 match step.sign() {
192 Sign::Plus if start < stop => {
193 if step.is_one() {
194 stop - start
195 } else {
196 (stop - start - 1usize) / step + 1
197 }
198 }
199 Sign::Minus if start > stop => (start - stop - 1usize) / (-step) + 1,
200 Sign::Plus | Sign::Minus => BigInt::zero(),
201 Sign::NoSign => unreachable!(),
202 }
203 }
204}
205
206pub fn init(context: &'static Context) {
211 PyRange::extend_class(context, context.types.range_type);
212 PyLongRangeIterator::extend_class(context, context.types.long_range_iterator_type);
213 PyRangeIterator::extend_class(context, context.types.range_iterator_type);
214}
215
216#[pyclass(
217 with(
218 Py,
219 AsMapping,
220 AsNumber,
221 AsSequence,
222 Hashable,
223 Comparable,
224 Iterable,
225 Representable
226 ),
227 flags(SEQUENCE)
228)]
229impl PyRange {
230 fn new(cls: PyTypeRef, stop: ArgIndex, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
231 Self {
232 start: vm.ctx.new_pyref(0),
233 stop: stop.into(),
234 step: vm.ctx.new_pyref(1),
235 }
236 .into_ref_with_type(vm, cls)
237 }
238
239 fn new_from(
240 cls: PyTypeRef,
241 start: PyObjectRef,
242 stop: PyObjectRef,
243 step: OptionalArg<ArgIndex>,
244 vm: &VirtualMachine,
245 ) -> PyResult<PyRef<Self>> {
246 let step = step.map_or_else(|| vm.ctx.new_int(1), |step| step.into());
247 if step.as_bigint().is_zero() {
248 return Err(vm.new_value_error("range() arg 3 must not be zero"));
249 }
250 Self {
251 start: start.try_index(vm)?,
252 stop: stop.try_index(vm)?,
253 step,
254 }
255 .into_ref_with_type(vm, cls)
256 }
257
258 #[pygetset]
259 fn start(&self) -> PyIntRef {
260 self.start.clone()
261 }
262
263 #[pygetset]
264 fn stop(&self) -> PyIntRef {
265 self.stop.clone()
266 }
267
268 #[pygetset]
269 fn step(&self) -> PyIntRef {
270 self.step.clone()
271 }
272
273 #[pymethod]
274 fn __reversed__(&self, vm: &VirtualMachine) -> PyResult {
275 let start = self.start.as_bigint();
276 let step = self.step.as_bigint();
277
278 let length = self.__len__();
280 let new_stop = start - step;
281 let start = &new_stop + length.clone() * step;
282 let step = -step;
283
284 Ok(
285 if let (Some(start), Some(step), Some(_)) =
286 (start.to_isize(), step.to_isize(), new_stop.to_isize())
287 {
288 PyRangeIterator {
289 index: AtomicCell::new(0),
290 start,
291 step,
292 length: length.to_usize().unwrap_or(0),
295 }
296 .into_pyobject(vm)
297 } else {
298 PyLongRangeIterator {
299 index: AtomicCell::new(0),
300 start,
301 step,
302 length,
303 }
304 .into_pyobject(vm)
305 },
306 )
307 }
308
309 fn __len__(&self) -> BigInt {
310 self.compute_length()
311 }
312
313 #[pymethod]
314 fn __reduce__(&self, vm: &VirtualMachine) -> (PyTypeRef, PyTupleRef) {
315 let range_parameters: Vec<PyObjectRef> = [&self.start, &self.stop, &self.step]
316 .iter()
317 .map(|x| x.as_object().to_owned())
318 .collect();
319 let range_parameters_tuple = vm.ctx.new_tuple(range_parameters);
320 (vm.ctx.types.range_type.to_owned(), range_parameters_tuple)
321 }
322
323 fn __getitem__(&self, subscript: PyObjectRef, vm: &VirtualMachine) -> PyResult {
324 match RangeIndex::try_from_object(vm, subscript)? {
325 RangeIndex::Slice(slice) => {
326 let (mut sub_start, mut sub_stop, mut sub_step) =
327 slice.inner_indices(&self.compute_length(), vm)?;
328 let range_step = &self.step;
329 let range_start = &self.start;
330
331 sub_step *= range_step.as_bigint();
332 sub_start = (sub_start * range_step.as_bigint()) + range_start.as_bigint();
333 sub_stop = (sub_stop * range_step.as_bigint()) + range_start.as_bigint();
334
335 Ok(Self {
336 start: vm.ctx.new_pyref(sub_start),
337 stop: vm.ctx.new_pyref(sub_stop),
338 step: vm.ctx.new_pyref(sub_step),
339 }
340 .into_ref(&vm.ctx)
341 .into())
342 }
343 RangeIndex::Int(index) => match self.get(index.as_bigint()) {
344 Some(value) => Ok(vm.ctx.new_int(value).into()),
345 None => Err(vm.new_index_error("range object index out of range")),
346 },
347 }
348 }
349
350 #[pyslot]
351 fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult {
352 let range = if args.args.len() <= 1 {
353 let stop = args.bind(vm)?;
354 Self::new(cls, stop, vm)
355 } else {
356 let (start, stop, step) = args.bind(vm)?;
357 Self::new_from(cls, start, stop, step, vm)
358 }?;
359
360 Ok(range.into())
361 }
362
363 fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
366 PyGenericAlias::from_args(cls, args, vm)
367 }
368}
369
370#[pyclass]
371impl Py<PyRange> {
372 fn contains_inner(&self, needle: &PyObject, vm: &VirtualMachine) -> bool {
373 if let Some(int) = needle.downcast_ref_if_exact::<PyInt>(vm) {
375 match self.offset(int.as_bigint()) {
376 Some(ref offset) => offset.is_multiple_of(self.step.as_bigint()),
377 None => false,
378 }
379 } else {
380 iter_search(self.as_object(), needle, SearchType::Contains, vm).unwrap_or(0) != 0
381 }
382 }
383
384 fn __contains__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> bool {
385 self.contains_inner(&needle, vm)
386 }
387
388 #[pymethod]
389 fn index(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
390 if let Ok(int) = needle.clone().downcast::<PyInt>() {
391 match self.index_of(int.as_bigint()) {
392 Some(idx) => Ok(idx),
393 None => Err(vm.new_value_error(format!("{int} is not in range"))),
394 }
395 } else {
396 Ok(BigInt::from_bytes_be(
398 Sign::Plus,
399 &iter_search(self.as_object(), &needle, SearchType::Index, vm)?.to_be_bytes(),
400 ))
401 }
402 }
403
404 #[pymethod]
405 fn count(&self, item: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
406 if let Ok(int) = item.clone().downcast::<PyInt>() {
407 let count = if self.index_of(int.as_bigint()).is_some() {
408 1
409 } else {
410 0
411 };
412 Ok(count)
413 } else {
414 iter_search(self.as_object(), &item, SearchType::Count, vm)
417 }
418 }
419}
420
421impl PyRange {
422 fn protocol_length(&self, vm: &VirtualMachine) -> PyResult<usize> {
423 PyInt::from(self.__len__())
424 .try_to_primitive::<isize>(vm)
425 .map(|x| x as usize)
426 }
427}
428
429impl AsMapping for PyRange {
430 fn as_mapping() -> &'static PyMappingMethods {
431 static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
432 length: atomic_func!(
433 |mapping, vm| PyRange::mapping_downcast(mapping).protocol_length(vm)
434 ),
435 subscript: atomic_func!(|mapping, needle, vm| {
436 PyRange::mapping_downcast(mapping).__getitem__(needle.to_owned(), vm)
437 }),
438 ..PyMappingMethods::NOT_IMPLEMENTED
439 });
440 &AS_MAPPING
441 }
442}
443
444impl AsSequence for PyRange {
445 fn as_sequence() -> &'static PySequenceMethods {
446 static AS_SEQUENCE: LazyLock<PySequenceMethods> = LazyLock::new(|| PySequenceMethods {
447 length: atomic_func!(|seq, vm| PyRange::sequence_downcast(seq).protocol_length(vm)),
448 item: atomic_func!(|seq, i, vm| {
449 PyRange::sequence_downcast(seq)
450 .get(&i.into())
451 .map(|x| PyInt::from(x).into_ref(&vm.ctx).into())
452 .ok_or_else(|| vm.new_index_error("index out of range"))
453 }),
454 contains: atomic_func!(|seq, needle, vm| {
455 Ok(PyRange::sequence_downcast(seq).contains_inner(needle, vm))
456 }),
457 ..PySequenceMethods::NOT_IMPLEMENTED
458 });
459 &AS_SEQUENCE
460 }
461}
462
463impl AsNumber for PyRange {
464 fn as_number() -> &'static PyNumberMethods {
465 static AS_NUMBER: PyNumberMethods = PyNumberMethods {
466 boolean: Some(|number, _vm| {
467 let zelf = number.obj.downcast_ref::<PyRange>().unwrap();
468 Ok(!zelf.is_empty())
469 }),
470 ..PyNumberMethods::NOT_IMPLEMENTED
471 };
472 &AS_NUMBER
473 }
474}
475
476impl Hashable for PyRange {
477 fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyHash> {
478 let length = zelf.compute_length();
479 let elements = if length.is_zero() {
480 [vm.ctx.new_int(length).into(), vm.ctx.none(), vm.ctx.none()]
481 } else if length.is_one() {
482 [
483 vm.ctx.new_int(length).into(),
484 zelf.start().into(),
485 vm.ctx.none(),
486 ]
487 } else {
488 [
489 vm.ctx.new_int(length).into(),
490 zelf.start().into(),
491 zelf.step().into(),
492 ]
493 };
494 tuple_hash(&elements, vm)
495 }
496}
497
498impl Comparable for PyRange {
499 fn cmp(
500 zelf: &Py<Self>,
501 other: &PyObject,
502 op: PyComparisonOp,
503 _vm: &VirtualMachine,
504 ) -> PyResult<PyComparisonValue> {
505 op.eq_only(|| {
506 if zelf.is(other) {
507 return Ok(true.into());
508 }
509 let rhs = class_or_notimplemented!(Self, other);
510 let lhs_len = zelf.compute_length();
511 let eq = if lhs_len != rhs.compute_length() {
512 false
513 } else if lhs_len.is_zero() {
514 true
515 } else if zelf.start.as_bigint() != rhs.start.as_bigint() {
516 false
517 } else if lhs_len.is_one() {
518 true
519 } else {
520 zelf.step.as_bigint() == rhs.step.as_bigint()
521 };
522 Ok(eq.into())
523 })
524 }
525}
526
527impl Iterable for PyRange {
528 fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
529 let (start, stop, step, length) = (
530 zelf.start.as_bigint(),
531 zelf.stop.as_bigint(),
532 zelf.step.as_bigint(),
533 zelf.__len__(),
534 );
535 if let (Some(start), Some(step), Some(_), Some(_)) = (
536 start.to_isize(),
537 step.to_isize(),
538 stop.to_isize(),
539 (start + step).to_isize(),
540 ) {
541 Ok(PyRangeIterator {
542 index: AtomicCell::new(0),
543 start,
544 step,
545 length: length.to_usize().unwrap_or(0),
548 }
549 .into_pyobject(vm))
550 } else {
551 Ok(PyLongRangeIterator {
552 index: AtomicCell::new(0),
553 start: start.clone(),
554 step: step.clone(),
555 length,
556 }
557 .into_pyobject(vm))
558 }
559 }
560}
561
562impl Representable for PyRange {
563 #[inline]
564 fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
565 let repr = if zelf.step.as_bigint().is_one() {
566 format!("range({}, {})", zelf.start, zelf.stop)
567 } else {
568 format!("range({}, {}, {})", zelf.start, zelf.stop, zelf.step)
569 };
570 Ok(repr)
571 }
572}
573
574#[pyclass(module = false, name = "longrange_iterator")]
583#[derive(Debug)]
584pub struct PyLongRangeIterator {
585 index: AtomicCell<usize>,
586 start: BigInt,
587 step: BigInt,
588 length: BigInt,
589}
590
591impl PyPayload for PyLongRangeIterator {
592 #[inline]
593 fn class(ctx: &Context) -> &'static Py<PyType> {
594 ctx.types.long_range_iterator_type
595 }
596}
597
598#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))]
599impl PyLongRangeIterator {
600 #[pymethod]
601 fn __length_hint__(&self) -> BigInt {
602 let index = BigInt::from(self.index.load());
603 if index < self.length {
604 self.length.clone() - index
605 } else {
606 BigInt::zero()
607 }
608 }
609
610 #[pymethod]
611 fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
612 self.index.store(range_state(&self.length, state, vm)?);
613 Ok(())
614 }
615
616 #[pymethod]
617 fn __reduce__(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
618 range_iter_reduce(
619 self.start.clone(),
620 self.length.clone(),
621 self.step.clone(),
622 self.index.load(),
623 vm,
624 )
625 }
626}
627
628impl SelfIter for PyLongRangeIterator {}
629impl IterNext for PyLongRangeIterator {
630 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
631 let index = BigInt::from(zelf.index.fetch_add(1));
635 let r = if index < zelf.length {
636 let value = zelf.start.clone() + index * zelf.step.clone();
637 PyIterReturn::Return(vm.ctx.new_int(value).into())
638 } else {
639 PyIterReturn::StopIteration(None)
640 };
641 Ok(r)
642 }
643}
644
645#[pyclass(module = false, name = "range_iterator")]
648#[derive(Debug)]
649pub struct PyRangeIterator {
650 index: AtomicCell<usize>,
651 start: isize,
652 step: isize,
653 length: usize,
654}
655
656impl PyPayload for PyRangeIterator {
657 #[inline]
658 fn class(ctx: &Context) -> &'static Py<PyType> {
659 ctx.types.range_iterator_type
660 }
661}
662
663#[pyclass(flags(DISALLOW_INSTANTIATION), with(IterNext, Iterable))]
664impl PyRangeIterator {
665 #[pymethod]
666 fn __length_hint__(&self) -> usize {
667 let index = self.index.load();
668 self.length.saturating_sub(index)
669 }
670
671 #[pymethod]
672 fn __setstate__(&self, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
673 self.index
674 .store(range_state(&BigInt::from(self.length), state, vm)?);
675 Ok(())
676 }
677
678 #[pymethod]
679 fn __reduce__(&self, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
680 range_iter_reduce(
681 BigInt::from(self.start),
682 BigInt::from(self.length),
683 BigInt::from(self.step),
684 self.index.load(),
685 vm,
686 )
687 }
688}
689
690impl PyRangeIterator {
691 pub(crate) fn fast_next(&self) -> Option<isize> {
694 let index = self.index.fetch_add(1);
695 if index < self.length {
696 Some(self.start + (index as isize) * self.step)
697 } else {
698 None
699 }
700 }
701}
702
703impl SelfIter for PyRangeIterator {}
704impl IterNext for PyRangeIterator {
705 fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
706 let r = match zelf.fast_next() {
707 Some(value) => PyIterReturn::Return(vm.ctx.new_int(value).into()),
708 None => PyIterReturn::StopIteration(None),
709 };
710 Ok(r)
711 }
712}
713
714fn range_iter_reduce(
715 start: BigInt,
716 length: BigInt,
717 step: BigInt,
718 index: usize,
719 vm: &VirtualMachine,
720) -> PyResult<PyTupleRef> {
721 let iter = builtins_iter(vm);
722 let stop = start.clone() + length * step.clone();
723 let range = PyRange {
724 start: PyInt::from(start).into_ref(&vm.ctx),
725 stop: PyInt::from(stop).into_ref(&vm.ctx),
726 step: PyInt::from(step).into_ref(&vm.ctx),
727 };
728 Ok(vm.new_tuple((iter, (range,), index)))
729}
730
731fn range_state(length: &BigInt, state: PyObjectRef, vm: &VirtualMachine) -> PyResult<usize> {
733 if let Some(i) = state.downcast_ref::<PyInt>() {
734 let mut index = i.as_bigint();
735 let max_usize = BigInt::from(usize::MAX);
736 if index > length {
737 index = max(length, &max_usize);
738 }
739 Ok(index.to_usize().unwrap_or(0))
740 } else {
741 Err(vm.new_type_error("an integer is required."))
742 }
743}
744
745pub enum RangeIndex {
746 Int(PyIntRef),
747 Slice(PyRef<PySlice>),
748}
749
750impl TryFromObject for RangeIndex {
751 fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
752 match_class!(match obj {
753 i @ PyInt => Ok(Self::Int(i)),
754 s @ PySlice => Ok(Self::Slice(s)),
755 obj => {
756 let val = obj.try_index(vm).map_err(|_| vm.new_type_error(format!(
757 "sequence indices be integers or slices or classes that override __index__ operator, not '{}'",
758 obj.class().name()
759 )))?;
760 Ok(Self::Int(val))
761 }
762 })
763 }
764}