Skip to main content

vortex_array/scalar/extension/
erased.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::cmp::Ordering;
6use std::fmt;
7use std::hash::Hash;
8use std::hash::Hasher;
9use std::sync::Arc;
10
11use vortex_error::VortexExpect;
12use vortex_error::vortex_err;
13
14use crate::dtype::extension::ExtId;
15use crate::dtype::extension::ExtVTable;
16use crate::scalar::ScalarValue;
17use crate::scalar::extension::ExtScalarValue;
18use crate::scalar::extension::typed::DynExtScalarValue;
19use crate::scalar::extension::typed::ExtScalarValueInner;
20
21/// A type-erased extension scalar value.
22///
23/// This is the extension scalar analog of [`ExtDTypeRef`]: it stores an [`ExtVTable`]
24/// and a storage [`ScalarValue`] behind a trait object, allowing heterogeneous storage inside
25/// `ScalarValue::Extension` (so that we do not need a generic parameter).
26///
27/// You can use [`try_downcast()`] or [`downcast()`] to recover the concrete vtable type as an
28/// [`ExtScalarValue<V>`].
29///
30/// [`ExtDTypeRef`]: crate::dtype::extension::ExtDTypeRef
31/// [`try_downcast()`]: ExtScalarValueRef::try_downcast
32/// [`downcast()`]: ExtScalarValueRef::downcast
33#[derive(Clone)]
34pub struct ExtScalarValueRef(pub(super) Arc<dyn DynExtScalarValue>);
35
36// NB: If you need access to the vtable, you probably want to add a method and implementation to
37// `ExtScalarValueInnerImpl` and `ExtScalarValueInner`.
38/// Methods for downcasting type-erased extension scalars.
39impl ExtScalarValueRef {
40    /// Returns the [`ExtId`] identifying this extension scalar's type.
41    pub fn id(&self) -> ExtId {
42        self.0.id()
43    }
44
45    /// Returns a reference to the underlying storage [`ScalarValue`].
46    pub fn storage_value(&self) -> &ScalarValue {
47        self.0.storage_value()
48    }
49
50    /// Attempts to downcast to a concrete [`ExtScalarValue<V>`].
51    ///
52    /// # Errors
53    ///
54    /// Returns `Err(self)` if the underlying vtable type does not match `V`.
55    pub fn try_downcast<V: ExtVTable>(self) -> Result<ExtScalarValue<V>, ExtScalarValueRef> {
56        // `ExtScalarValueInner<V>` is the only implementor of `ExtScalarValueInnerImpl` (due to
57        // the sealed implementation below), so if the vtable is correct, we know the type can be
58        // downcasted and reinterpreted safely.
59        if !self.0.as_any().is::<ExtScalarValueInner<V>>() {
60            return Err(self);
61        }
62
63        let ptr = Arc::into_raw(self.0) as *const ExtScalarValueInner<V>;
64        // SAFETY: We verified the type matches above, so the size and alignment are correct.
65        let inner = unsafe { Arc::from_raw(ptr) };
66
67        Ok(ExtScalarValue(inner))
68    }
69
70    /// Downcasts to a concrete [`ExtScalarValue<V>`].
71    ///
72    /// # Panics
73    ///
74    /// Panics if the underlying vtable type does not match `V`.
75    pub fn downcast<V: ExtVTable>(self) -> ExtScalarValue<V> {
76        self.try_downcast::<V>()
77            .map_err(|this| {
78                vortex_err!(
79                    "Failed to downcast ExtScalar {} to {}",
80                    this.0.id(),
81                    type_name::<V>(),
82                )
83            })
84            .vortex_expect("Failed to downcast ExtScalar")
85    }
86}
87
88impl fmt::Display for ExtScalarValueRef {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        write!(f, "{}({})", self.0.id(), self.0.storage_value())
91    }
92}
93
94impl fmt::Debug for ExtScalarValueRef {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        f.debug_struct("ExtScalar")
97            .field("id", &self.0.id())
98            .field("storage_value", self.0.storage_value())
99            .finish()
100    }
101}
102
103// TODO(connor): In the future we may want to allow implementors to customize this behavior.
104
105impl PartialEq for ExtScalarValueRef {
106    fn eq(&self, other: &Self) -> bool {
107        self.0.id() == other.0.id() && self.0.storage_value() == other.0.storage_value()
108    }
109}
110impl Eq for ExtScalarValueRef {}
111
112impl PartialOrd for ExtScalarValueRef {
113    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
114        // TODO(connor): Should this check if the IDs are equal before ordering?
115        self.0.storage_value().partial_cmp(other.0.storage_value())
116    }
117}
118
119impl Hash for ExtScalarValueRef {
120    fn hash<H: Hasher>(&self, state: &mut H) {
121        self.0.id().hash(state);
122        self.0.storage_value().hash(state);
123    }
124}