vortex_array/expr/
scalar_fn.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::hash::Hasher;
10use std::ops::Deref;
11
12use vortex_dtype::DType;
13use vortex_error::VortexResult;
14use vortex_utils::debug_with::DebugWith;
15use vortex_vector::Datum;
16
17use crate::ArrayRef;
18use crate::expr::ExecutionArgs;
19use crate::expr::ExprId;
20use crate::expr::ExprVTable;
21use crate::expr::Expression;
22use crate::expr::ReduceCtx;
23use crate::expr::ReduceNode;
24use crate::expr::ReduceNodeRef;
25use crate::expr::VTable;
26use crate::expr::options::ExpressionOptions;
27use crate::expr::signature::ExpressionSignature;
28
29/// An instance of an expression bound to some invocation options.
30pub struct ScalarFn {
31    vtable: ExprVTable,
32    options: Box<dyn Any + Send + Sync>,
33}
34
35impl ScalarFn {
36    /// Create a new bound expression from raw vtable and options.
37    ///
38    /// # Safety
39    ///
40    /// The caller must ensure that the provided options are compatible with the provided vtable.
41    pub(super) unsafe fn new_unchecked(
42        vtable: ExprVTable,
43        options: Box<dyn Any + Send + Sync>,
44    ) -> Self {
45        Self { vtable, options }
46    }
47
48    /// Create a new bound expression from a vtable.
49    pub fn new<V: VTable>(vtable: V, options: V::Options) -> Self {
50        let vtable = ExprVTable::new::<V>(vtable);
51        let options = Box::new(options);
52        Self { vtable, options }
53    }
54
55    /// Create a new expression from a static vtable.
56    pub fn new_static<V: VTable>(vtable: &'static V, options: V::Options) -> Self {
57        let vtable = ExprVTable::new_static::<V>(vtable);
58        let options = Box::new(options);
59        Self { vtable, options }
60    }
61
62    /// The vtable for this expression.
63    pub fn vtable(&self) -> &ExprVTable {
64        &self.vtable
65    }
66
67    /// Returns the ID of this expression.
68    pub fn id(&self) -> ExprId {
69        self.vtable.id()
70    }
71
72    /// The type-erased options for this expression.
73    pub fn options(&self) -> ExpressionOptions<'_> {
74        ExpressionOptions {
75            vtable: &self.vtable,
76            options: self.options.deref(),
77        }
78    }
79
80    /// Returns whether the scalar function is of the given vtable type.
81    pub fn is<V: VTable>(&self) -> bool {
82        self.vtable.is::<V>()
83    }
84
85    /// Returns the typed options for this expression if it matches the given vtable type.
86    pub fn as_opt<V: VTable>(&self) -> Option<&V::Options> {
87        self.options().as_any().downcast_ref::<V::Options>()
88    }
89
90    /// Signature information for this expression.
91    pub fn signature(&self) -> ExpressionSignature<'_> {
92        ExpressionSignature {
93            vtable: &self.vtable,
94            options: self.options.deref(),
95        }
96    }
97
98    /// Compute the return [`DType`] of this expression given the input argument types.
99    pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
100        self.vtable
101            .as_dyn()
102            .return_dtype(self.options.deref(), arg_types)
103    }
104
105    /// Evaluate the expression, returning an ArrayRef.
106    ///
107    /// NOTE: this function will soon be deprecated as all expressions will evaluate trivially
108    ///  into an ExprArray.
109    pub fn evaluate(&self, expr: &Expression, scope: &ArrayRef) -> VortexResult<ArrayRef> {
110        self.vtable.as_dyn().evaluate(expr, scope)
111    }
112
113    /// Execute the expression given the input arguments.
114    pub fn execute(&self, ctx: ExecutionArgs) -> VortexResult<Datum> {
115        self.vtable.as_dyn().execute(self.options.deref(), ctx)
116    }
117
118    /// Perform abstract reduction on this scalar function node.
119    pub fn reduce(
120        &self,
121        node: &dyn ReduceNode,
122        ctx: &dyn ReduceCtx,
123    ) -> VortexResult<Option<ReduceNodeRef>> {
124        self.vtable.as_dyn().reduce(self.options.deref(), node, ctx)
125    }
126}
127
128impl Clone for ScalarFn {
129    fn clone(&self) -> Self {
130        ScalarFn {
131            vtable: self.vtable.clone(),
132            options: self.vtable.as_dyn().options_clone(self.options.deref()),
133        }
134    }
135}
136
137impl Debug for ScalarFn {
138    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
139        f.debug_struct("BoundExpression")
140            .field("vtable", &self.vtable)
141            .field(
142                "options",
143                &DebugWith(|fmt| {
144                    self.vtable
145                        .as_dyn()
146                        .options_debug(self.options.deref(), fmt)
147                }),
148            )
149            .finish()
150    }
151}
152
153impl Display for ScalarFn {
154    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
155        write!(f, "{}(", self.vtable.id())?;
156        self.vtable
157            .as_dyn()
158            .options_display(self.options.deref(), f)?;
159        write!(f, ")")
160    }
161}
162
163impl PartialEq for ScalarFn {
164    fn eq(&self, other: &Self) -> bool {
165        self.vtable == other.vtable
166            && self
167                .vtable
168                .as_dyn()
169                .options_eq(self.options.deref(), other.options.deref())
170    }
171}
172impl Eq for ScalarFn {}
173
174impl Hash for ScalarFn {
175    fn hash<H: Hasher>(&self, state: &mut H) {
176        self.vtable.hash(state);
177        self.vtable
178            .as_dyn()
179            .options_hash(self.options.deref(), state);
180    }
181}