Skip to main content

rustpython_vm/builtins/
complex.rs

1use super::{PyStr, PyType, PyTypeRef, float};
2use crate::{
3    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
4    builtins::PyUtf8StrRef,
5    class::PyClassImpl,
6    common::{format::FormatSpec, wtf8::Wtf8Buf},
7    convert::{IntoPyException, ToPyObject, ToPyResult},
8    function::{FuncArgs, OptionalArg, PyComparisonValue},
9    protocol::PyNumberMethods,
10    stdlib::_warnings,
11    types::{AsNumber, Callable, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
12};
13use core::cell::Cell;
14use core::num::Wrapping;
15use core::ptr::NonNull;
16use num_complex::Complex64;
17use num_traits::Zero;
18use rustpython_common::hash;
19
20/// Create a complex number from a real part and an optional imaginary part.
21///
22/// This is equivalent to (real + imag*1j) where imag defaults to 0.
23#[pyclass(module = false, name = "complex")]
24#[derive(Debug, Copy, Clone, PartialEq)]
25pub struct PyComplex {
26    value: Complex64,
27}
28
29// spell-checker:ignore MAXFREELIST
30thread_local! {
31    static COMPLEX_FREELIST: Cell<crate::object::FreeList<PyComplex>> = const { Cell::new(crate::object::FreeList::new()) };
32}
33
34impl PyPayload for PyComplex {
35    const MAX_FREELIST: usize = 100;
36    const HAS_FREELIST: bool = true;
37
38    #[inline]
39    fn class(ctx: &Context) -> &'static Py<PyType> {
40        ctx.types.complex_type
41    }
42
43    #[inline]
44    unsafe fn freelist_push(obj: *mut PyObject) -> bool {
45        COMPLEX_FREELIST
46            .try_with(|fl| {
47                let mut list = fl.take();
48                let stored = if list.len() < Self::MAX_FREELIST {
49                    list.push(obj);
50                    true
51                } else {
52                    false
53                };
54                fl.set(list);
55                stored
56            })
57            .unwrap_or(false)
58    }
59
60    #[inline]
61    unsafe fn freelist_pop(_payload: &Self) -> Option<NonNull<PyObject>> {
62        COMPLEX_FREELIST
63            .try_with(|fl| {
64                let mut list = fl.take();
65                let result = list.pop().map(|p| unsafe { NonNull::new_unchecked(p) });
66                fl.set(list);
67                result
68            })
69            .ok()
70            .flatten()
71    }
72}
73
74impl ToPyObject for Complex64 {
75    fn to_pyobject(self, vm: &VirtualMachine) -> PyObjectRef {
76        PyComplex::from(self).to_pyobject(vm)
77    }
78}
79
80impl From<Complex64> for PyComplex {
81    fn from(value: Complex64) -> Self {
82        Self { value }
83    }
84}
85
86impl PyObjectRef {
87    /// Tries converting a python object into a complex, returns an option of whether the complex
88    /// and whether the  object was a complex originally or coerced into one
89    pub fn try_complex(&self, vm: &VirtualMachine) -> PyResult<Option<(Complex64, bool)>> {
90        if let Some(complex) = self.downcast_ref_if_exact::<PyComplex>(vm) {
91            return Ok(Some((complex.value, true)));
92        }
93        if let Some(method) = vm.get_method(self.clone(), identifier!(vm, __complex__)) {
94            let result = method?.call((), vm)?;
95
96            let ret_class = result.class().to_owned();
97            if let Some(ret) = result.downcast_ref::<PyComplex>() {
98                _warnings::warn(
99                    vm.ctx.exceptions.deprecation_warning,
100                    format!(
101                        "__complex__ returned non-complex (type {ret_class}).  \
102                    The ability to return an instance of a strict subclass of complex \
103                    is deprecated, and may be removed in a future version of Python."
104                    ),
105                    1,
106                    vm,
107                )?;
108
109                return Ok(Some((ret.value, true)));
110            } else {
111                return match result.downcast_ref::<PyComplex>() {
112                    Some(complex_obj) => Ok(Some((complex_obj.value, true))),
113                    None => Err(vm.new_type_error(format!(
114                        "__complex__ returned non-complex (type '{}')",
115                        result.class().name()
116                    ))),
117                };
118            }
119        }
120        // `complex` does not have a `__complex__` by default, so subclasses might not either,
121        // use the actual stored value in this case
122        if let Some(complex) = self.downcast_ref::<PyComplex>() {
123            return Ok(Some((complex.value, true)));
124        }
125        if let Some(float) = self.try_float_opt(vm) {
126            return Ok(Some((Complex64::new(float?.to_f64(), 0.0), false)));
127        }
128        Ok(None)
129    }
130}
131
132pub fn init(context: &'static Context) {
133    PyComplex::extend_class(context, context.types.complex_type);
134}
135
136fn to_op_complex(value: &PyObject, vm: &VirtualMachine) -> PyResult<Option<Complex64>> {
137    let r = if let Some(complex) = value.downcast_ref::<PyComplex>() {
138        Some(complex.value)
139    } else {
140        float::to_op_float(value, vm)?.map(|float| Complex64::new(float, 0.0))
141    };
142    Ok(r)
143}
144
145fn inner_div(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
146    if v2.is_zero() {
147        return Err(vm.new_zero_division_error("complex division by zero"));
148    }
149
150    Ok(v1.fdiv(v2))
151}
152
153fn inner_pow(v1: Complex64, v2: Complex64, vm: &VirtualMachine) -> PyResult<Complex64> {
154    if v1.is_zero() {
155        return if v2.re < 0.0 || v2.im != 0.0 {
156            let msg = format!("{v1} cannot be raised to a negative or complex power");
157            Err(vm.new_zero_division_error(msg))
158        } else if v2.is_zero() {
159            Ok(Complex64::new(1.0, 0.0))
160        } else {
161            Ok(Complex64::new(0.0, 0.0))
162        };
163    }
164
165    let ans = powc(v1, v2);
166    if ans.is_infinite() && !(v1.is_infinite() || v2.is_infinite()) {
167        Err(vm.new_overflow_error("complex exponentiation overflow"))
168    } else {
169        Ok(ans)
170    }
171}
172
173// num-complex changed their powc() implementation in 0.4.4, making it incompatible
174// with what the regression tests expect. this is that old formula.
175fn powc(a: Complex64, exp: Complex64) -> Complex64 {
176    let (r, theta) = a.to_polar();
177    if r.is_zero() {
178        return Complex64::new(r, r);
179    }
180    Complex64::from_polar(
181        r.powf(exp.re) * (-exp.im * theta).exp(),
182        exp.re * theta + exp.im * r.ln(),
183    )
184}
185
186impl Constructor for PyComplex {
187    type Args = ComplexArgs;
188
189    fn slot_new(cls: PyTypeRef, func_args: FuncArgs, vm: &VirtualMachine) -> PyResult {
190        // Optimization: return exact complex as-is (only when imag is not provided)
191        if cls.is(vm.ctx.types.complex_type)
192            && func_args.args.len() == 1
193            && func_args.kwargs.is_empty()
194            && func_args.args[0].class().is(vm.ctx.types.complex_type)
195        {
196            return Ok(func_args.args[0].clone());
197        }
198
199        let args: Self::Args = func_args.bind(vm)?;
200        let payload = Self::py_new(&cls, args, vm)?;
201        payload.into_ref_with_type(vm, cls).map(Into::into)
202    }
203
204    fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
205        let imag_missing = args.imag.is_missing();
206        let (real, real_was_complex) = match args.real {
207            OptionalArg::Missing => (Complex64::new(0.0, 0.0), false),
208            OptionalArg::Present(val) => {
209                if let Some(c) = val.try_complex(vm)? {
210                    c
211                } else if let Some(s) = val.downcast_ref::<PyStr>() {
212                    if args.imag.is_present() {
213                        return Err(vm.new_type_error(
214                            "complex() can't take second arg if first is a string",
215                        ));
216                    }
217                    let (re, im) = s
218                        .to_str()
219                        .and_then(rustpython_literal::complex::parse_str)
220                        .ok_or_else(|| vm.new_value_error("complex() arg is a malformed string"))?;
221                    return Ok(Self::from(Complex64 { re, im }));
222                } else {
223                    return Err(vm.new_type_error(format!(
224                        "complex() first argument must be a string or a number, not '{}'",
225                        val.class().name()
226                    )));
227                }
228            }
229        };
230
231        let (imag, imag_was_complex) = match args.imag {
232            // Copy the imaginary from the real to the real of the imaginary
233            // if an  imaginary argument is not passed in
234            OptionalArg::Missing => (Complex64::new(real.im, 0.0), false),
235            OptionalArg::Present(obj) => {
236                if let Some(c) = obj.try_complex(vm)? {
237                    c
238                } else if obj.class().fast_issubclass(vm.ctx.types.str_type) {
239                    return Err(vm.new_type_error("complex() second arg can't be a string"));
240                } else {
241                    return Err(vm.new_type_error(format!(
242                        "complex() second argument must be a number, not '{}'",
243                        obj.class().name()
244                    )));
245                }
246            }
247        };
248
249        let final_real = if imag_was_complex {
250            real.re - imag.im
251        } else {
252            real.re
253        };
254
255        let final_imag = if real_was_complex && !imag_missing {
256            imag.re + real.im
257        } else {
258            imag.re
259        };
260        let value = Complex64::new(final_real, final_imag);
261        Ok(Self::from(value))
262    }
263}
264
265impl PyComplex {
266    #[deprecated(note = "use PyComplex::from(...).into_ref() instead")]
267    pub fn new_ref(value: Complex64, ctx: &Context) -> PyRef<Self> {
268        Self::from(value).into_ref(ctx)
269    }
270
271    pub const fn to_complex64(self) -> Complex64 {
272        self.value
273    }
274
275    pub const fn to_complex(&self) -> Complex64 {
276        self.value
277    }
278
279    fn number_op<F, R>(a: &PyObject, b: &PyObject, op: F, vm: &VirtualMachine) -> PyResult
280    where
281        F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R,
282        R: ToPyResult,
283    {
284        if let (Some(a), Some(b)) = (to_op_complex(a, vm)?, to_op_complex(b, vm)?) {
285            op(a, b, vm).to_pyresult(vm)
286        } else {
287            Ok(vm.ctx.not_implemented())
288        }
289    }
290
291    fn complex_real_binop<CCF, RCF, CRF, R>(
292        a: &PyObject,
293        b: &PyObject,
294        cc_op: CCF,
295        cr_op: CRF,
296        rc_op: RCF,
297        vm: &VirtualMachine,
298    ) -> PyResult
299    where
300        CCF: FnOnce(Complex64, Complex64) -> R,
301        CRF: FnOnce(Complex64, f64) -> R,
302        RCF: FnOnce(f64, Complex64) -> R,
303        R: ToPyResult,
304    {
305        let value = match (a.downcast_ref::<PyComplex>(), b.downcast_ref::<PyComplex>()) {
306            // complex + complex
307            (Some(a_complex), Some(b_complex)) => cc_op(a_complex.value, b_complex.value),
308            (Some(a_complex), None) => {
309                let Some(b_real) = float::to_op_float(b, vm)? else {
310                    return Ok(vm.ctx.not_implemented());
311                };
312
313                // complex + real
314                cr_op(a_complex.value, b_real)
315            }
316            (None, Some(b_complex)) => {
317                let Some(a_real) = float::to_op_float(a, vm)? else {
318                    return Ok(vm.ctx.not_implemented());
319                };
320
321                // real + complex
322                rc_op(a_real, b_complex.value)
323            }
324            (None, None) => return Ok(vm.ctx.not_implemented()),
325        };
326        value.to_pyresult(vm)
327    }
328}
329
330#[pyclass(
331    flags(BASETYPE),
332    with(PyRef, Comparable, Hashable, Constructor, AsNumber, Representable)
333)]
334impl PyComplex {
335    #[pygetset]
336    const fn real(&self) -> f64 {
337        self.value.re
338    }
339
340    #[pygetset]
341    const fn imag(&self) -> f64 {
342        self.value.im
343    }
344
345    #[pymethod]
346    fn conjugate(&self) -> Complex64 {
347        self.value.conj()
348    }
349
350    #[pymethod]
351    const fn __getnewargs__(&self) -> (f64, f64) {
352        let Complex64 { re, im } = self.value;
353        (re, im)
354    }
355
356    #[pymethod]
357    fn __format__(zelf: &Py<Self>, spec: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult<Wtf8Buf> {
358        // Empty format spec: equivalent to str(self)
359        if spec.is_empty() {
360            return Ok(zelf.as_object().str(vm)?.as_wtf8().to_owned());
361        }
362        let format_spec =
363            FormatSpec::parse(spec.as_str()).map_err(|err| err.into_pyexception(vm))?;
364        let result = if format_spec.has_locale_format() {
365            let locale = crate::format::get_locale_info();
366            format_spec.format_complex_locale(&zelf.value, &locale)
367        } else {
368            format_spec.format_complex(&zelf.value)
369        };
370        result
371            .map(Wtf8Buf::from_string)
372            .map_err(|err| err.into_pyexception(vm))
373    }
374
375    #[pyclassmethod]
376    fn from_number(cls: PyTypeRef, number: PyObjectRef, vm: &VirtualMachine) -> PyResult {
377        if number.class().is(vm.ctx.types.complex_type) && cls.is(vm.ctx.types.complex_type) {
378            return Ok(number);
379        }
380        let value = number
381            .try_complex(vm)?
382            .ok_or_else(|| {
383                vm.new_type_error(format!(
384                    "must be real number, not {}",
385                    number.class().name()
386                ))
387            })?
388            .0;
389        let result = vm.ctx.new_complex(value);
390        if cls.is(vm.ctx.types.complex_type) {
391            Ok(result.into())
392        } else {
393            PyType::call(&cls, vec![result.into()].into(), vm)
394        }
395    }
396}
397
398#[pyclass]
399impl PyRef<PyComplex> {
400    #[pymethod]
401    fn __complex__(self, vm: &VirtualMachine) -> Self {
402        if self.is(vm.ctx.types.complex_type) {
403            self
404        } else {
405            PyComplex::from(self.value).into_ref(&vm.ctx)
406        }
407    }
408}
409
410impl Comparable for PyComplex {
411    fn cmp(
412        zelf: &Py<Self>,
413        other: &PyObject,
414        op: PyComparisonOp,
415        vm: &VirtualMachine,
416    ) -> PyResult<PyComparisonValue> {
417        op.eq_only(|| {
418            let result = if let Some(other) = other.downcast_ref::<Self>() {
419                zelf.value == other.value
420            } else {
421                match float::to_op_float(other, vm) {
422                    Ok(Some(other)) => zelf.value == other.into(),
423                    Err(_) => false,
424                    Ok(None) => return Ok(PyComparisonValue::NotImplemented),
425                }
426            };
427            Ok(PyComparisonValue::Implemented(result))
428        })
429    }
430}
431
432impl Hashable for PyComplex {
433    #[inline]
434    fn hash(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<hash::PyHash> {
435        let value = zelf.value;
436
437        let re_hash =
438            hash::hash_float(value.re).unwrap_or_else(|| hash::hash_object_id(zelf.get_id()));
439
440        let im_hash =
441            hash::hash_float(value.im).unwrap_or_else(|| hash::hash_object_id(zelf.get_id()));
442
443        let Wrapping(ret) = Wrapping(re_hash) + Wrapping(im_hash) * Wrapping(hash::IMAG);
444        Ok(hash::fix_sentinel(ret))
445    }
446}
447
448impl AsNumber for PyComplex {
449    fn as_number() -> &'static PyNumberMethods {
450        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
451            add: Some(|a, b, vm| {
452                PyComplex::complex_real_binop(
453                    a,
454                    b,
455                    |a, b| a + b,
456                    |a_complex, b_real| Complex64::new(a_complex.re + b_real, a_complex.im),
457                    |a_real, b_complex| Complex64::new(a_real + b_complex.re, b_complex.im),
458                    vm,
459                )
460            }),
461            subtract: Some(|a, b, vm| {
462                PyComplex::complex_real_binop(
463                    a,
464                    b,
465                    |a, b| a - b,
466                    |a_complex, b_real| Complex64::new(a_complex.re - b_real, a_complex.im),
467                    |a_real, b_complex| Complex64::new(a_real - b_complex.re, -b_complex.im),
468                    vm,
469                )
470            }),
471            multiply: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a * b, vm)),
472            power: Some(|a, b, c, vm| {
473                if vm.is_none(c) {
474                    PyComplex::number_op(a, b, inner_pow, vm)
475                } else {
476                    Err(vm.new_value_error(String::from("complex modulo")))
477                }
478            }),
479            negative: Some(|number, vm| {
480                let value = PyComplex::number_downcast(number).value;
481                (-value).to_pyresult(vm)
482            }),
483            positive: Some(|number, vm| {
484                PyComplex::number_downcast_exact(number, vm).to_pyresult(vm)
485            }),
486            absolute: Some(|number, vm| {
487                let value = PyComplex::number_downcast(number).value;
488                let result = value.norm();
489                // Check for overflow: hypot returns inf for finite inputs that overflow
490                if result.is_infinite() && value.re.is_finite() && value.im.is_finite() {
491                    return Err(vm.new_overflow_error("absolute value too large"));
492                }
493                result.to_pyresult(vm)
494            }),
495            boolean: Some(|number, _vm| Ok(!PyComplex::number_downcast(number).value.is_zero())),
496            true_divide: Some(|a, b, vm| PyComplex::number_op(a, b, inner_div, vm)),
497            ..PyNumberMethods::NOT_IMPLEMENTED
498        };
499        &AS_NUMBER
500    }
501
502    fn clone_exact(zelf: &Py<Self>, vm: &VirtualMachine) -> PyRef<Self> {
503        vm.ctx.new_complex(zelf.value)
504    }
505}
506
507impl Representable for PyComplex {
508    #[inline]
509    fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
510        // TODO: when you fix this, move it to rustpython_common::complex::repr and update
511        //       ast/src/unparse.rs + impl Display for Constant in ast/src/constant.rs
512        let Complex64 { re, im } = zelf.value;
513        Ok(rustpython_literal::complex::to_string(re, im))
514    }
515}
516
517#[derive(FromArgs)]
518pub struct ComplexArgs {
519    #[pyarg(any, optional)]
520    real: OptionalArg<PyObjectRef>,
521    #[pyarg(any, optional)]
522    imag: OptionalArg<PyObjectRef>,
523}