Skip to main content

vortex_array/expr/
scalar_fn.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::fmt::Debug;
6use std::fmt::Display;
7use std::fmt::Formatter;
8use std::hash::Hash;
9use std::hash::Hasher;
10use std::ops::Deref;
11
12use vortex_dtype::DType;
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::debug_with::DebugWith;
16
17use crate::ArrayRef;
18use crate::expr::EmptyOptions;
19use crate::expr::ExecutionArgs;
20use crate::expr::ExprId;
21use crate::expr::ExprVTable;
22use crate::expr::Expression;
23use crate::expr::IsNull;
24use crate::expr::Not;
25use crate::expr::ReduceCtx;
26use crate::expr::ReduceNode;
27use crate::expr::ReduceNodeRef;
28use crate::expr::VTable;
29use crate::expr::VTableExt;
30use crate::expr::options::ExpressionOptions;
31use crate::expr::signature::ExpressionSignature;
32
33/// An instance of an expression bound to some invocation options.
34pub struct ScalarFn {
35    vtable: ExprVTable,
36    options: Box<dyn Any + Send + Sync>,
37}
38
39impl ScalarFn {
40    /// Create a new bound expression from raw vtable and options.
41    ///
42    /// # Safety
43    ///
44    /// The caller must ensure that the provided options are compatible with the provided vtable.
45    pub(super) unsafe fn new_unchecked(
46        vtable: ExprVTable,
47        options: Box<dyn Any + Send + Sync>,
48    ) -> Self {
49        Self { vtable, options }
50    }
51
52    /// Create a new bound expression from a vtable.
53    pub fn new<V: VTable>(vtable: V, options: V::Options) -> Self {
54        let vtable = ExprVTable::new::<V>(vtable);
55        let options = Box::new(options);
56        Self { vtable, options }
57    }
58
59    /// Create a new expression from a static vtable.
60    pub fn new_static<V: VTable>(vtable: &'static V, options: V::Options) -> Self {
61        let vtable = ExprVTable::new_static::<V>(vtable);
62        let options = Box::new(options);
63        Self { vtable, options }
64    }
65
66    /// The vtable for this expression.
67    pub fn vtable(&self) -> &ExprVTable {
68        &self.vtable
69    }
70
71    /// Returns the ID of this expression.
72    pub fn id(&self) -> ExprId {
73        self.vtable.id()
74    }
75
76    /// The type-erased options for this expression.
77    pub fn options(&self) -> ExpressionOptions<'_> {
78        ExpressionOptions {
79            vtable: &self.vtable,
80            options: self.options.deref(),
81        }
82    }
83
84    /// Returns whether the scalar function is of the given vtable type.
85    pub fn is<V: VTable>(&self) -> bool {
86        self.vtable.is::<V>()
87    }
88
89    /// Returns the typed options for this `ScalarFn` if it matches the given vtable type.
90    pub fn as_opt<V: VTable>(&self) -> Option<&V::Options> {
91        self.vtable.is::<V>().then(|| {
92            self.options()
93                .as_any()
94                .downcast_ref::<V::Options>()
95                .vortex_expect("Expression options type mismatch")
96        })
97    }
98
99    /// Returns the typed options for this `ScalarFn` if it matches the given vtable type.
100    pub fn as_<V: VTable>(&self) -> &V::Options {
101        self.as_opt::<V>()
102            .vortex_expect("Expression options type mismatch")
103    }
104    /// Signature information for this expression.
105    pub fn signature(&self) -> ExpressionSignature<'_> {
106        ExpressionSignature {
107            vtable: &self.vtable,
108            options: self.options.deref(),
109        }
110    }
111
112    /// Compute the return [`DType`] of this expression given the input argument types.
113    pub fn return_dtype(&self, arg_types: &[DType]) -> VortexResult<DType> {
114        self.vtable
115            .as_dyn()
116            .return_dtype(self.options.deref(), arg_types)
117    }
118
119    /// Transforms the expression into one representing the validity of this expression.
120    pub fn validity(&self, expr: &Expression) -> VortexResult<Expression> {
121        Ok(self.vtable.as_dyn().validity(expr)?.unwrap_or_else(|| {
122            // TODO(ngates): make validity a mandatory method on VTable to avoid this fallback.
123            // TODO(ngates): add an IsNotNull expression.
124            Not.new_expr(
125                EmptyOptions,
126                [IsNull.new_expr(EmptyOptions, [expr.clone()])],
127            )
128        }))
129    }
130
131    /// Execute the expression given the input arguments.
132    pub fn execute(&self, ctx: ExecutionArgs) -> VortexResult<ArrayRef> {
133        self.vtable.as_dyn().execute(self.options.deref(), ctx)
134    }
135
136    /// Perform abstract reduction on this scalar function node.
137    pub fn reduce(
138        &self,
139        node: &dyn ReduceNode,
140        ctx: &dyn ReduceCtx,
141    ) -> VortexResult<Option<ReduceNodeRef>> {
142        self.vtable.as_dyn().reduce(self.options.deref(), node, ctx)
143    }
144}
145
146impl Clone for ScalarFn {
147    fn clone(&self) -> Self {
148        ScalarFn {
149            vtable: self.vtable.clone(),
150            options: self.vtable.as_dyn().options_clone(self.options.deref()),
151        }
152    }
153}
154
155impl Debug for ScalarFn {
156    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157        f.debug_struct("BoundExpression")
158            .field("vtable", &self.vtable)
159            .field(
160                "options",
161                &DebugWith(|fmt| {
162                    self.vtable
163                        .as_dyn()
164                        .options_debug(self.options.deref(), fmt)
165                }),
166            )
167            .finish()
168    }
169}
170
171impl Display for ScalarFn {
172    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
173        write!(f, "{}(", self.vtable.id())?;
174        self.vtable
175            .as_dyn()
176            .options_display(self.options.deref(), f)?;
177        write!(f, ")")
178    }
179}
180
181impl PartialEq for ScalarFn {
182    fn eq(&self, other: &Self) -> bool {
183        self.vtable == other.vtable
184            && self
185                .vtable
186                .as_dyn()
187                .options_eq(self.options.deref(), other.options.deref())
188    }
189}
190impl Eq for ScalarFn {}
191
192impl Hash for ScalarFn {
193    fn hash<H: Hasher>(&self, state: &mut H) {
194        self.vtable.hash(state);
195        self.vtable
196            .as_dyn()
197            .options_hash(self.options.deref(), state);
198    }
199}