vortex_expr/
scope.rs

1use std::any::{Any, TypeId};
2use std::fmt::{Debug, Display};
3use std::str::FromStr;
4use std::sync::Arc;
5
6use itertools::Itertools;
7use vortex_array::{Array, ArrayRef};
8use vortex_dtype::{DType, FieldPathSet};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_utils::aliases::hash_map::HashMap;
11
12use crate::scope_vars::{ScopeVar, ScopeVars};
13
14type ExprScope<T> = HashMap<Identifier, T>;
15
16#[derive(Clone, Debug, Eq, PartialEq, Hash)]
17pub enum Identifier {
18    Identity,
19    Other(Arc<str>),
20}
21
22impl FromStr for Identifier {
23    type Err = VortexError;
24
25    fn from_str(s: &str) -> Result<Self, Self::Err> {
26        if s.is_empty() {
27            vortex_bail!("Empty strings aren't allowed in identifiers")
28        } else {
29            Ok(Identifier::Other(s.into()))
30        }
31    }
32}
33
34impl PartialEq<str> for Identifier {
35    fn eq(&self, other: &str) -> bool {
36        match self {
37            Identifier::Identity => other.is_empty(),
38            Identifier::Other(str) => str.as_ref() == other,
39        }
40    }
41}
42
43impl From<&str> for Identifier {
44    fn from(value: &str) -> Self {
45        if value.is_empty() {
46            Identifier::Identity
47        } else {
48            Identifier::Other(Arc::from(value))
49        }
50    }
51}
52
53impl Identifier {
54    pub fn is_identity(&self) -> bool {
55        matches!(self, Self::Identity)
56    }
57}
58
59impl Display for Identifier {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Identifier::Identity => write!(f, ""),
63            Identifier::Other(v) => write!(f, "{v}"),
64        }
65    }
66}
67
68/// Scope define the evaluation context/scope that an expression uses when being evaluated.
69/// There is a special `Identifier` (`Identity`) which is used to bind the initial array being evaluated
70///
71/// Other identifier can be bound with variables either before execution or while executing (see `Let`).
72/// Values can be extracted from the scope using the `Var` expression.
73///
74/// ```code
75/// <let x = lit(1) in var(Identifier::Identity) + var(x), { Identity -> Primitive[1,2,3]> ->
76/// <var(Identifier::Identity) + var(x), { Identity -> Primitive[1,2,3], x -> ConstantArray(1)> ->
77/// <Primitive[1,2,3] + var(x), { Identity -> Primitive[1,2,3], x -> ConstantArray(1)> ->
78/// <Primitive[1,2,3] + ConstantArray(1), { Identity -> Primitive[1,2,3], x -> ConstantArray(1)> ->
79/// <Primitive[2,3,4], { Identity -> Primitive[1,2,3], x -> ConstantArray(1)>
80/// ```
81///
82/// Other values can be bound before execution e.g.
83///  `<var("x") + var("y") + var("z"), x -> ..., y -> ..., z -> ...>`
84#[derive(Clone, Default)]
85pub struct Scope {
86    array_len: usize,
87    root_scope: Option<ArrayRef>,
88    /// A map from identifiers to arrays
89    arrays: ExprScope<ArrayRef>,
90    /// A map identifiers to opaque values used by expressions, but
91    /// cannot affect the result type/shape.
92    vars: ExprScope<Arc<dyn Any + Send + Sync>>,
93    /// Variables that can be set on the scope during expression evaluation.
94    scope_vars: ScopeVars,
95}
96
97pub type ScopeElement = (Identifier, ArrayRef);
98
99impl Scope {
100    pub fn new(arr: ArrayRef) -> Self {
101        Self {
102            array_len: arr.len(),
103            root_scope: Some(arr),
104            ..Default::default()
105        }
106    }
107
108    pub fn empty(len: usize) -> Self {
109        Self {
110            array_len: len,
111            ..Default::default()
112        }
113    }
114
115    /// Get a value out of the scope by its [`Identifier`]
116    pub fn array(&self, id: &Identifier) -> Option<&ArrayRef> {
117        if id.is_identity() {
118            return self.root_scope.as_ref();
119        }
120        self.arrays.get(id)
121    }
122
123    pub fn vars(&self, id: Identifier) -> VortexResult<&Arc<dyn Any + Send + Sync>> {
124        self.vars
125            .get(&id)
126            .ok_or_else(|| vortex_err!("cannot find {} in var scope", id))
127    }
128
129    pub fn is_empty(&self) -> bool {
130        self.array_len == 0
131    }
132
133    pub fn len(&self) -> usize {
134        self.array_len
135    }
136
137    pub fn copy_with_array(&self, ident: Identifier, value: ArrayRef) -> Self {
138        self.clone().with_array(ident, value)
139    }
140
141    /// Register an array with an identifier in the scope, overriding any existing value stored in it.
142    pub fn with_array(mut self, ident: Identifier, value: ArrayRef) -> Self {
143        assert_eq!(value.len(), self.len());
144
145        if ident.is_identity() {
146            self.root_scope = Some(value);
147        } else {
148            self.arrays.insert(ident, value);
149        }
150        self
151    }
152
153    /// Register an array with an identifier in the scope, overriding any existing value stored in it.
154    pub fn with_array_pair(self, (ident, value): ScopeElement) -> Self {
155        self.with_array(ident, value)
156    }
157
158    pub fn with_var(mut self, ident: Identifier, var: Arc<dyn Any + Send + Sync>) -> Self {
159        self.vars.insert(ident, var);
160        self
161    }
162
163    /// Returns a new evaluation scope with the given variable applied.
164    pub fn with_scope_var<V: ScopeVar>(mut self, var: V) -> Self {
165        self.scope_vars.insert(TypeId::of::<V>(), Box::new(var));
166        self
167    }
168
169    /// Returns the scope variable of type `V` if it exists.
170    pub fn scope_var<V: ScopeVar>(&self) -> Option<&V> {
171        self.scope_vars
172            .get(&TypeId::of::<V>())
173            .and_then(|boxed| (**boxed).as_any().downcast_ref::<V>())
174    }
175
176    /// Returns the mutable scope variable of type `V` if it exists.
177    pub fn scope_var_mut<V: ScopeVar>(&mut self) -> Option<&mut V> {
178        self.scope_vars
179            .get_mut(&TypeId::of::<V>())
180            .and_then(|boxed| (**boxed).as_any_mut().downcast_mut::<V>())
181    }
182
183    pub fn iter(&self) -> impl Iterator<Item = (&Identifier, &ArrayRef)> {
184        let values = self.arrays.iter();
185
186        self.root_scope
187            .iter()
188            .map(|s| (&Identifier::Identity, s))
189            .chain(values)
190    }
191}
192
193impl From<ArrayRef> for Scope {
194    fn from(value: ArrayRef) -> Self {
195        Self::new(value)
196    }
197}
198
199#[derive(Clone, Default, Debug)]
200pub struct ScopeDType {
201    root: Option<DType>,
202    types: ExprScope<DType>,
203}
204
205impl Display for ScopeDType {
206    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207        if let Some(root) = self.root.as_ref() {
208            write!(f, "$: {}", root)?;
209        }
210        if !self.types.is_empty() {
211            write!(f, ". ")?;
212            write!(
213                f,
214                "{}",
215                self.types
216                    .iter()
217                    .format_with(",", |x, f| f(&format_args!("{}: {}", x.0, x.1)))
218            )?;
219        }
220        Ok(())
221    }
222}
223
224pub type ScopeDTypeElement = (Identifier, DType);
225
226impl From<&Scope> for ScopeDType {
227    fn from(ctx: &Scope) -> Self {
228        Self {
229            root: ctx.root_scope.as_ref().map(|s| s.dtype().clone()),
230            types: HashMap::from_iter(
231                ctx.arrays
232                    .iter()
233                    .map(|(k, v)| (k.clone(), v.dtype().clone())),
234            ),
235        }
236    }
237}
238
239impl ScopeDType {
240    pub fn new(dtype: DType) -> Self {
241        Self {
242            root: Some(dtype),
243            ..Default::default()
244        }
245    }
246
247    pub fn dtype(&self, id: &Identifier) -> Option<&DType> {
248        if id.is_identity() {
249            return self.root.as_ref();
250        }
251        self.types.get(id)
252    }
253
254    pub fn copy_with_dtype(&self, ident: Identifier, dtype: DType) -> Self {
255        self.clone().with_dtype(ident, dtype)
256    }
257
258    pub fn with_dtype(mut self, ident: Identifier, dtype: DType) -> Self {
259        if ident.is_identity() {
260            self.root = Some(dtype);
261        } else {
262            self.types.insert(ident, dtype);
263        }
264        self
265    }
266
267    pub fn with_dtype_element(self, (ident, dtype): ScopeDTypeElement) -> Self {
268        self.with_dtype(ident, dtype)
269    }
270}
271
272#[derive(Default, Clone, Debug)]
273pub struct ScopeFieldPathSet {
274    root: Option<FieldPathSet>,
275    sets: ExprScope<FieldPathSet>,
276}
277
278pub type ScopeFieldPathSetElement = (Identifier, FieldPathSet);
279
280impl ScopeFieldPathSet {
281    pub fn new(path_set: FieldPathSet) -> Self {
282        Self {
283            root: Some(path_set),
284            ..Default::default()
285        }
286    }
287
288    pub fn set(&self, id: &Identifier) -> Option<&FieldPathSet> {
289        if id.is_identity() {
290            return self.root.as_ref();
291        }
292        self.sets.get(id)
293    }
294
295    pub fn copy_with_set(&self, ident: Identifier, set: FieldPathSet) -> Self {
296        self.clone().with_set(ident, set)
297    }
298
299    pub fn with_set(mut self, ident: Identifier, set: FieldPathSet) -> Self {
300        if ident.is_identity() {
301            self.root = Some(set);
302        } else {
303            self.sets.insert(ident, set);
304        }
305        self
306    }
307
308    pub fn with_set_element(self, (ident, set): ScopeFieldPathSetElement) -> Self {
309        self.with_set(ident, set)
310    }
311}
312
313#[cfg(test)]
314mod test {
315    #[test]
316    fn test_scope_var() {
317        use super::*;
318
319        #[derive(Clone, PartialEq, Eq, Debug)]
320        struct TestVar {
321            value: i32,
322        }
323
324        let scope = Scope::empty(100);
325        assert!(scope.scope_var::<TestVar>().is_none());
326
327        let var = TestVar { value: 42 };
328        let mut scope = scope.with_scope_var(var.clone());
329        assert_eq!(scope.scope_var::<TestVar>(), Some(&var));
330
331        scope.scope_var_mut::<TestVar>().unwrap().value = 43;
332        assert_eq!(scope.scope_var::<TestVar>(), Some(&TestVar { value: 43 }));
333    }
334}