1use super::{genericalias, type_};
2use crate::{
3 atomic_func,
4 builtins::{PyFrozenSet, PyStr, PyTuple, PyTupleRef, PyType},
5 class::PyClassImpl,
6 common::hash,
7 convert::{ToPyObject, ToPyResult},
8 function::PyComparisonValue,
9 protocol::{PyMappingMethods, PyNumberMethods},
10 types::{AsMapping, AsNumber, Comparable, GetAttr, Hashable, PyComparisonOp, Representable},
11 AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
12};
13use once_cell::sync::Lazy;
14use std::fmt;
15
16const CLS_ATTRS: &[&str] = &["__module__"];
17
18#[pyclass(module = "types", name = "UnionType", traverse)]
19pub struct PyUnion {
20 args: PyTupleRef,
21 parameters: PyTupleRef,
22}
23
24impl fmt::Debug for PyUnion {
25 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26 f.write_str("UnionObject")
27 }
28}
29
30impl PyPayload for PyUnion {
31 fn class(ctx: &Context) -> &'static Py<PyType> {
32 ctx.types.union_type
33 }
34}
35
36impl PyUnion {
37 pub fn new(args: PyTupleRef, vm: &VirtualMachine) -> Self {
38 let parameters = make_parameters(&args, vm);
39 Self { args, parameters }
40 }
41
42 fn repr(&self, vm: &VirtualMachine) -> PyResult<String> {
43 fn repr_item(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
44 if obj.is(vm.ctx.types.none_type) {
45 return Ok("None".to_string());
46 }
47
48 if vm
49 .get_attribute_opt(obj.clone(), identifier!(vm, __origin__))?
50 .is_some()
51 && vm
52 .get_attribute_opt(obj.clone(), identifier!(vm, __args__))?
53 .is_some()
54 {
55 return Ok(obj.repr(vm)?.as_str().to_string());
56 }
57
58 match (
59 vm.get_attribute_opt(obj.clone(), identifier!(vm, __qualname__))?
60 .and_then(|o| o.downcast_ref::<PyStr>().map(|n| n.as_str().to_string())),
61 vm.get_attribute_opt(obj.clone(), identifier!(vm, __module__))?
62 .and_then(|o| o.downcast_ref::<PyStr>().map(|m| m.as_str().to_string())),
63 ) {
64 (None, _) | (_, None) => Ok(obj.repr(vm)?.as_str().to_string()),
65 (Some(qualname), Some(module)) => Ok(if module == "builtins" {
66 qualname
67 } else {
68 format!("{module}.{qualname}")
69 }),
70 }
71 }
72
73 Ok(self
74 .args
75 .iter()
76 .map(|o| repr_item(o.clone(), vm))
77 .collect::<PyResult<Vec<_>>>()?
78 .join(" | "))
79 }
80}
81
82#[pyclass(
83 flags(BASETYPE),
84 with(Hashable, Comparable, AsMapping, AsNumber, Representable)
85)]
86impl PyUnion {
87 #[pygetset(magic)]
88 fn parameters(&self) -> PyObjectRef {
89 self.parameters.clone().into()
90 }
91
92 #[pygetset(magic)]
93 fn args(&self) -> PyObjectRef {
94 self.args.clone().into()
95 }
96
97 #[pymethod(magic)]
98 fn instancecheck(zelf: PyRef<Self>, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
99 if zelf
100 .args
101 .iter()
102 .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
103 {
104 Err(vm.new_type_error(
105 "isinstance() argument 2 cannot be a parameterized generic".to_owned(),
106 ))
107 } else {
108 obj.is_instance(zelf.args().as_object(), vm)
109 }
110 }
111
112 #[pymethod(magic)]
113 fn subclasscheck(zelf: PyRef<Self>, obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
114 if zelf
115 .args
116 .iter()
117 .any(|x| x.class().is(vm.ctx.types.generic_alias_type))
118 {
119 Err(vm.new_type_error(
120 "issubclass() argument 2 cannot be a parameterized generic".to_owned(),
121 ))
122 } else {
123 obj.is_subclass(zelf.args().as_object(), vm)
124 }
125 }
126
127 #[pymethod(name = "__ror__")]
128 #[pymethod(magic)]
129 fn or(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef {
130 type_::or_(zelf, other, vm)
131 }
132}
133
134pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
135 obj.class().is(vm.ctx.types.none_type)
136 || obj.payload_if_subclass::<PyType>(vm).is_some()
137 || obj.class().is(vm.ctx.types.generic_alias_type)
138 || obj.class().is(vm.ctx.types.union_type)
139}
140
141fn make_parameters(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
142 let parameters = genericalias::make_parameters(args, vm);
143 dedup_and_flatten_args(¶meters, vm)
144}
145
146fn flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
147 let mut total_args = 0;
148 for arg in args {
149 if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
150 total_args += pyref.args.len();
151 } else {
152 total_args += 1;
153 };
154 }
155
156 let mut flattened_args = Vec::with_capacity(total_args);
157 for arg in args {
158 if let Some(pyref) = arg.downcast_ref::<PyUnion>() {
159 flattened_args.extend(pyref.args.iter().cloned());
160 } else if vm.is_none(arg) {
161 flattened_args.push(vm.ctx.types.none_type.to_owned().into());
162 } else {
163 flattened_args.push(arg.clone());
164 };
165 }
166
167 PyTuple::new_ref(flattened_args, &vm.ctx)
168}
169
170fn dedup_and_flatten_args(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyTupleRef {
171 let args = flatten_args(args, vm);
172
173 let mut new_args: Vec<PyObjectRef> = Vec::with_capacity(args.len());
174 for arg in &*args {
175 if !new_args.iter().any(|param| {
176 param
177 .rich_compare_bool(arg, PyComparisonOp::Eq, vm)
178 .expect("types are always comparable")
179 }) {
180 new_args.push(arg.clone());
181 }
182 }
183
184 new_args.shrink_to_fit();
185
186 PyTuple::new_ref(new_args, &vm.ctx)
187}
188
189pub fn make_union(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyObjectRef {
190 let args = dedup_and_flatten_args(args, vm);
191 match args.len() {
192 1 => args.fast_getitem(0),
193 _ => PyUnion::new(args, vm).to_pyobject(vm),
194 }
195}
196
197impl PyUnion {
198 fn getitem(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
199 let new_args = genericalias::subs_parameters(
200 |vm| self.repr(vm),
201 self.args.clone(),
202 self.parameters.clone(),
203 needle,
204 vm,
205 )?;
206 let mut res;
207 if new_args.len() == 0 {
208 res = make_union(&new_args, vm);
209 } else {
210 res = new_args.fast_getitem(0);
211 for arg in new_args.iter().skip(1) {
212 res = vm._or(&res, arg)?;
213 }
214 }
215
216 Ok(res)
217 }
218}
219
220impl AsMapping for PyUnion {
221 fn as_mapping() -> &'static PyMappingMethods {
222 static AS_MAPPING: Lazy<PyMappingMethods> = Lazy::new(|| PyMappingMethods {
223 subscript: atomic_func!(|mapping, needle, vm| {
224 PyUnion::mapping_downcast(mapping).getitem(needle.to_owned(), vm)
225 }),
226 ..PyMappingMethods::NOT_IMPLEMENTED
227 });
228 &AS_MAPPING
229 }
230}
231
232impl AsNumber for PyUnion {
233 fn as_number() -> &'static PyNumberMethods {
234 static AS_NUMBER: PyNumberMethods = PyNumberMethods {
235 or: Some(|a, b, vm| PyUnion::or(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
236 ..PyNumberMethods::NOT_IMPLEMENTED
237 };
238 &AS_NUMBER
239 }
240}
241
242impl Comparable for PyUnion {
243 fn cmp(
244 zelf: &Py<Self>,
245 other: &PyObject,
246 op: PyComparisonOp,
247 vm: &VirtualMachine,
248 ) -> PyResult<PyComparisonValue> {
249 op.eq_only(|| {
250 let other = class_or_notimplemented!(Self, other);
251 let a = PyFrozenSet::from_iter(vm, zelf.args.into_iter().cloned())?;
252 let b = PyFrozenSet::from_iter(vm, other.args.into_iter().cloned())?;
253 Ok(PyComparisonValue::Implemented(
254 a.into_pyobject(vm).as_object().rich_compare_bool(
255 b.into_pyobject(vm).as_object(),
256 PyComparisonOp::Eq,
257 vm,
258 )?,
259 ))
260 })
261 }
262}
263
264impl Hashable for PyUnion {
265 #[inline]
266 fn hash(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<hash::PyHash> {
267 let set = PyFrozenSet::from_iter(vm, zelf.args.into_iter().cloned())?;
268 PyFrozenSet::hash(&set.into_ref(&vm.ctx), vm)
269 }
270}
271
272impl GetAttr for PyUnion {
273 fn getattro(zelf: &Py<Self>, attr: &Py<PyStr>, vm: &VirtualMachine) -> PyResult {
274 for &exc in CLS_ATTRS {
275 if *exc == attr.to_string() {
276 return zelf.as_object().generic_getattr(attr, vm);
277 }
278 }
279 zelf.as_object().get_attr(attr, vm)
280 }
281}
282
283impl Representable for PyUnion {
284 #[inline]
285 fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
286 zelf.repr(vm)
287 }
288}
289
290pub fn init(context: &Context) {
291 let union_type = &context.types.union_type;
292 PyUnion::extend_class(context, union_type);
293}