1use std::any::{Any, TypeId};
2use std::fmt::{Debug, Display};
3use std::str::FromStr;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, FieldPathSet};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_utils::aliases::hash_map::HashMap;
11
12use crate::scope_vars::{ScopeVar, ScopeVars};
13
14type ExprScope<T> = HashMap<Identifier, T>;
15
16#[derive(Clone, Debug, Eq, PartialEq, Hash)]
17pub enum Identifier {
18 Identity,
19 Other(Arc<str>),
20}
21
22impl FromStr for Identifier {
23 type Err = VortexError;
24
25 fn from_str(s: &str) -> Result<Self, Self::Err> {
26 if s.is_empty() {
27 vortex_bail!("Empty strings aren't allowed in identifiers")
28 } else {
29 Ok(Identifier::Other(s.into()))
30 }
31 }
32}
33
34impl PartialEq<str> for Identifier {
35 fn eq(&self, other: &str) -> bool {
36 match self {
37 Identifier::Identity => other.is_empty(),
38 Identifier::Other(str) => str.as_ref() == other,
39 }
40 }
41}
42
43impl From<&str> for Identifier {
44 fn from(value: &str) -> Self {
45 if value.is_empty() {
46 Identifier::Identity
47 } else {
48 Identifier::Other(Arc::from(value))
49 }
50 }
51}
52
53impl Identifier {
54 pub fn is_identity(&self) -> bool {
55 matches!(self, Self::Identity)
56 }
57}
58
59impl Display for Identifier {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Identifier::Identity => write!(f, ""),
63 Identifier::Other(v) => write!(f, "{v}"),
64 }
65 }
66}
67
68#[derive(Clone, Default)]
85pub struct Scope {
86 array_len: usize,
87 root_scope: Option<ArrayRef>,
88 arrays: ExprScope<ArrayRef>,
90 vars: ExprScope<Arc<dyn Any + Send + Sync>>,
93 scope_vars: ScopeVars,
95}
96
97pub type ScopeElement = (Identifier, ArrayRef);
98
99impl Scope {
100 pub fn new(arr: ArrayRef) -> Self {
101 Self {
102 array_len: arr.len(),
103 root_scope: Some(arr),
104 ..Default::default()
105 }
106 }
107
108 pub fn empty(len: usize) -> Self {
109 Self {
110 array_len: len,
111 ..Default::default()
112 }
113 }
114
115 pub fn array(&self, id: &Identifier) -> Option<&ArrayRef> {
117 if id.is_identity() {
118 return self.root_scope.as_ref();
119 }
120 self.arrays.get(id)
121 }
122
123 pub fn vars(&self, id: Identifier) -> VortexResult<&Arc<dyn Any + Send + Sync>> {
124 self.vars
125 .get(&id)
126 .ok_or_else(|| vortex_err!("cannot find {} in var scope", id))
127 }
128
129 pub fn is_empty(&self) -> bool {
130 self.array_len == 0
131 }
132
133 pub fn len(&self) -> usize {
134 self.array_len
135 }
136
137 pub fn copy_with_array(&self, ident: Identifier, value: ArrayRef) -> Self {
138 self.clone().with_array(ident, value)
139 }
140
141 pub fn with_array(mut self, ident: Identifier, value: ArrayRef) -> Self {
143 assert_eq!(value.len(), self.len());
144
145 if ident.is_identity() {
146 self.root_scope = Some(value);
147 } else {
148 self.arrays.insert(ident, value);
149 }
150 self
151 }
152
153 pub fn with_array_pair(self, (ident, value): ScopeElement) -> Self {
155 self.with_array(ident, value)
156 }
157
158 pub fn with_var(mut self, ident: Identifier, var: Arc<dyn Any + Send + Sync>) -> Self {
159 self.vars.insert(ident, var);
160 self
161 }
162
163 pub fn with_scope_var<V: ScopeVar>(mut self, var: V) -> Self {
165 self.scope_vars.insert(TypeId::of::<V>(), Box::new(var));
166 self
167 }
168
169 pub fn scope_var<V: ScopeVar>(&self) -> Option<&V> {
171 self.scope_vars
172 .get(&TypeId::of::<V>())
173 .and_then(|boxed| (**boxed).as_any().downcast_ref::<V>())
174 }
175
176 pub fn scope_var_mut<V: ScopeVar>(&mut self) -> Option<&mut V> {
178 self.scope_vars
179 .get_mut(&TypeId::of::<V>())
180 .and_then(|boxed| (**boxed).as_any_mut().downcast_mut::<V>())
181 }
182
183 pub fn iter(&self) -> impl Iterator<Item = (&Identifier, &ArrayRef)> {
184 let values = self.arrays.iter();
185
186 self.root_scope
187 .iter()
188 .map(|s| (&Identifier::Identity, s))
189 .chain(values)
190 }
191}
192
193impl From<ArrayRef> for Scope {
194 fn from(value: ArrayRef) -> Self {
195 Self::new(value)
196 }
197}
198
199#[derive(Clone, Default, Debug)]
200pub struct ScopeDType {
201 root: Option<DType>,
202 types: ExprScope<DType>,
203}
204
205impl Display for ScopeDType {
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 if let Some(root) = self.root.as_ref() {
208 write!(f, "$: {}", root)?;
209 }
210 if !self.types.is_empty() {
211 write!(f, ". ")?;
212 write!(
213 f,
214 "{}",
215 self.types
216 .iter()
217 .format_with(",", |x, f| f(&format_args!("{}: {}", x.0, x.1)))
218 )?;
219 }
220 Ok(())
221 }
222}
223
224pub type ScopeDTypeElement = (Identifier, DType);
225
226impl From<&Scope> for ScopeDType {
227 fn from(ctx: &Scope) -> Self {
228 Self {
229 root: ctx.root_scope.as_ref().map(|s| s.dtype().clone()),
230 types: HashMap::from_iter(
231 ctx.arrays
232 .iter()
233 .map(|(k, v)| (k.clone(), v.dtype().clone())),
234 ),
235 }
236 }
237}
238
239impl ScopeDType {
240 pub fn new(dtype: DType) -> Self {
241 Self {
242 root: Some(dtype),
243 ..Default::default()
244 }
245 }
246
247 pub fn dtype(&self, id: &Identifier) -> Option<&DType> {
248 if id.is_identity() {
249 return self.root.as_ref();
250 }
251 self.types.get(id)
252 }
253
254 pub fn copy_with_dtype(&self, ident: Identifier, dtype: DType) -> Self {
255 self.clone().with_dtype(ident, dtype)
256 }
257
258 pub fn with_dtype(mut self, ident: Identifier, dtype: DType) -> Self {
259 if ident.is_identity() {
260 self.root = Some(dtype);
261 } else {
262 self.types.insert(ident, dtype);
263 }
264 self
265 }
266
267 pub fn with_dtype_element(self, (ident, dtype): ScopeDTypeElement) -> Self {
268 self.with_dtype(ident, dtype)
269 }
270}
271
272#[derive(Default, Clone, Debug)]
273pub struct ScopeFieldPathSet {
274 root: Option<FieldPathSet>,
275 sets: ExprScope<FieldPathSet>,
276}
277
278pub type ScopeFieldPathSetElement = (Identifier, FieldPathSet);
279
280impl ScopeFieldPathSet {
281 pub fn new(path_set: FieldPathSet) -> Self {
282 Self {
283 root: Some(path_set),
284 ..Default::default()
285 }
286 }
287
288 pub fn set(&self, id: &Identifier) -> Option<&FieldPathSet> {
289 if id.is_identity() {
290 return self.root.as_ref();
291 }
292 self.sets.get(id)
293 }
294
295 pub fn copy_with_set(&self, ident: Identifier, set: FieldPathSet) -> Self {
296 self.clone().with_set(ident, set)
297 }
298
299 pub fn with_set(mut self, ident: Identifier, set: FieldPathSet) -> Self {
300 if ident.is_identity() {
301 self.root = Some(set);
302 } else {
303 self.sets.insert(ident, set);
304 }
305 self
306 }
307
308 pub fn with_set_element(self, (ident, set): ScopeFieldPathSetElement) -> Self {
309 self.with_set(ident, set)
310 }
311}
312
313#[cfg(test)]
314mod test {
315 #[test]
316 fn test_scope_var() {
317 use super::*;
318
319 #[derive(Clone, PartialEq, Eq, Debug)]
320 struct TestVar {
321 value: i32,
322 }
323
324 let scope = Scope::empty(100);
325 assert!(scope.scope_var::<TestVar>().is_none());
326
327 let var = TestVar { value: 42 };
328 let mut scope = scope.with_scope_var(var.clone());
329 assert_eq!(scope.scope_var::<TestVar>(), Some(&var));
330
331 scope.scope_var_mut::<TestVar>().unwrap().value = 43;
332 assert_eq!(scope.scope_var::<TestVar>(), Some(&TestVar { value: 43 }));
333 }
334}