rustpython_vm/builtins/
union.rs

1use super::{genericalias, type_};
2use crate::{
3    atomic_func,
4    builtins::{PyFrozenSet, PyStr, PyTuple, PyTupleRef, PyType},
5    class::PyClassImpl,
6    common::hash,
7    convert::{ToPyObject, ToPyResult},
8    function::PyComparisonValue,
9    protocol::{PyMappingMethods, PyNumberMethods},
10    types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable},
11    AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
12};
13use once_cell::sync::Lazy;
14use std::fmt;
15
16const CLS_ATTRS: &[&str] = &["__module__"];
17
18#[pyclass(module = "types", name = "UnionType", traverse)]
19pub struct PyUnion {
20    args: PyTupleRef,
21    parameters: PyTupleRef,
22}
23
24impl fmt::Debug for PyUnion {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        f.write_str("UnionObject")
27    }
28}
29
30impl PyPayload for PyUnion {
31    fn class(ctx: &Context) -> &'static Py<PyType> {
32        ctx.types.union_type
33    }
34}
35
36impl PyUnion {
37    pub fn new(args: PyTupleRef, vm: &VirtualMachine) -> Self {
38        let parameters = make_parameters(&args, vm);
39        Self { args, parameters }
40    }
41
42    fn repr(&self, vm: &VirtualMachine) -> PyResult<String> {
43        fn repr_item(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
44            if obj.is(vm.ctx.types.none_type) {
45                return Ok("None".to_string());
46            }
47
48            if vm
49                .get_attribute_opt(obj.clone(), identifier!(vm, __origin__))?
50                .is_some()
51                && vm
52                    .get_attribute_opt(obj.clone(), identifier!(vm, __args__))?
53                    .is_some()
54            {
55                return Ok(obj.repr(vm)?.as_str().to_string());
56            }
57
58            match (
59                vm.get_attribute_opt(obj.clone(), identifier!(vm, __qualname__))?
60                    .and_then(|o| o.downcast_ref::<PyStr>().map(|n| n.as_str().to_string())),
61                vm.get_attribute_opt(obj.clone(), identifier!(vm, __module__))?
62                    .and_then(|o| o.downcast_ref::<PyStr>().map(|m| m.as_str().to_string())),
63            ) {
64                (None, _) | (_, None) => Ok(obj.repr(vm)?.as_str().to_string()),
65                (Some(qualname), Some(module)) => Ok(if module == "builtins" {
66                    qualname
67                } else {
68                    format!("{module}.{qualname}")
69                }),
70            }
71        }
72
73        Ok(self
74            .args
75            .iter()
76            .map(|o| repr_item(o.clone(), vm))
77            .collect::<PyResult<Vec<_>>>()?
78            .join(" | "))
79    }
80}
81
82#[pyclass(
83    flags(BASETYPE),
84    with(Hashable, Comparable, AsMapping, AsNumber, Representable)
85)]
86impl PyUnion {
87    #[pygetset(magic)]
88    fn parameters(&self) -> PyObjectRef {
89        self.parameters.clone().into()
90    }
91
92    #[pygetset(magic)]
93    fn args(&self) -> PyObjectRef {
94        self.args.clone().into()
95    }
96
97    #[pymethod(magic)]
98    fn instancecheck(zelf: PyRef<Self>, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
99        if zelf
100            .args
101            .iter()
102            .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
103        {
104            Err(vm.new_type_error(
105                "isinstance() argument 2 cannot be a parameterized generic".to_owned(),
106            ))
107        } else {
108            obj.is_instance(zelf.args().as_object(), vm)
109        }
110    }
111
112    #[pymethod(magic)]
113    fn subclasscheck(zelf: PyRef<Self>, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
114        if zelf
115            .args
116            .iter()
117            .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
118        {
119            Err(vm.new_type_error(
120                "issubclass() argument 2 cannot be a parameterized generic".to_owned(),
121            ))
122        } else {
123            obj.is_subclass(zelf.args().as_object(), vm)
124        }
125    }
126
127    #[pymethod(name = "__ror__")]
128    #[pymethod(magic)]
129    fn or(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
130        type_::or_(zelf, other, vm)
131    }
132}
133
134pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
135    obj.class().is(vm.ctx.types.none_type)
136        || obj.payload_if_subclass::<PyType>(vm).is_some()
137        || obj.class().is(vm.ctx.types.generic_alias_type)
138        || obj.class().is(vm.ctx.types.union_type)
139}
140
141fn make_parameters(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
142    let parameters = genericalias::make_parameters(args, vm);
143    dedup_and_flatten_args(&parameters, vm)
144}
145
146fn flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
147    let mut total_args = 0;
148    for arg in args {
149        if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
150            total_args += pyref.args.len();
151        } else {
152            total_args += 1;
153        };
154    }
155
156    let mut flattened_args = Vec::with_capacity(total_args);
157    for arg in args {
158        if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
159            flattened_args.extend(pyref.args.iter().cloned());
160        } else if vm.is_none(arg) {
161            flattened_args.push(vm.ctx.types.none_type.to_owned().into());
162        } else {
163            flattened_args.push(arg.clone());
164        };
165    }
166
167    PyTuple::new_ref(flattened_args, &vm.ctx)
168}
169
170fn dedup_and_flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
171    let args = flatten_args(args, vm);
172
173    let mut new_args: Vec<PyObjectRef> = Vec::with_capacity(args.len());
174    for arg in &*args {
175        if !new_args.iter().any(|param| {
176            param
177                .rich_compare_bool(arg, PyComparisonOp::Eq, vm)
178                .expect("types are always comparable")
179        }) {
180            new_args.push(arg.clone());
181        }
182    }
183
184    new_args.shrink_to_fit();
185
186    PyTuple::new_ref(new_args, &vm.ctx)
187}
188
189pub fn make_union(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyObjectRef {
190    let args = dedup_and_flatten_args(args, vm);
191    match args.len() {
192        1 => args.fast_getitem(0),
193        _ => PyUnion::new(args, vm).to_pyobject(vm),
194    }
195}
196
197impl PyUnion {
198    fn getitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
199        let new_args = genericalias::subs_parameters(
200            |vm| self.repr(vm),
201            self.args.clone(),
202            self.parameters.clone(),
203            needle,
204            vm,
205        )?;
206        let mut res;
207        if new_args.len() == 0 {
208            res = make_union(&new_args, vm);
209        } else {
210            res = new_args.fast_getitem(0);
211            for arg in new_args.iter().skip(1) {
212                res = vm._or(&res, arg)?;
213            }
214        }
215
216        Ok(res)
217    }
218}
219
220impl AsMapping for PyUnion {
221    fn as_mapping() -> &'static PyMappingMethods {
222        static AS_MAPPING: Lazy<PyMappingMethods> = Lazy::new(|| PyMappingMethods {
223            subscript: atomic_func!(|mapping, needle, vm| {
224                PyUnion::mapping_downcast(mapping).getitem(needle.to_owned(), vm)
225            }),
226            ..PyMappingMethods::NOT_IMPLEMENTED
227        });
228        &AS_MAPPING
229    }
230}
231
232impl AsNumber for PyUnion {
233    fn as_number() -> &'static PyNumberMethods {
234        static AS_NUMBER: PyNumberMethods = PyNumberMethods {
235            or: Some(|a, b, vm| PyUnion::or(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
236            ..PyNumberMethods::NOT_IMPLEMENTED
237        };
238        &AS_NUMBER
239    }
240}
241
242impl Comparable for PyUnion {
243    fn cmp(
244        zelf: &Py<Self>,
245        other: &PyObject,
246        op: PyComparisonOp,
247        vm: &VirtualMachine,
248    ) -> PyResult<PyComparisonValue> {
249        op.eq_only(|| {
250            let other = class_or_notimplemented!(Self, other);
251            let a = PyFrozenSet::from_iter(vm, zelf.args.into_iter().cloned())?;
252            let b = PyFrozenSet::from_iter(vm, other.args.into_iter().cloned())?;
253            Ok(PyComparisonValue::Implemented(
254                a.into_pyobject(vm).as_object().rich_compare_bool(
255                    b.into_pyobject(vm).as_object(),
256                    PyComparisonOp::Eq,
257                    vm,
258                )?,
259            ))
260        })
261    }
262}
263
264impl Hashable for PyUnion {
265    #[inline]
266    fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<hash::PyHash> {
267        let set = PyFrozenSet::from_iter(vm, zelf.args.into_iter().cloned())?;
268        PyFrozenSet::hash(&set.into_ref(&vm.ctx), vm)
269    }
270}
271
272impl GetAttr for PyUnion {
273    fn getattro(zelf: &Py<Self>, attr: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
274        for &exc in CLS_ATTRS {
275            if *exc == attr.to_string() {
276                return zelf.as_object().generic_getattr(attr, vm);
277            }
278        }
279        zelf.as_object().get_attr(attr, vm)
280    }
281}
282
283impl Representable for PyUnion {
284    #[inline]
285    fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
286        zelf.repr(vm)
287    }
288}
289
290pub fn init(context: &Context) {
291    let union_type = &context.types.union_type;
292    PyUnion::extend_class(context, union_type);
293}