Skip to main content

vantage_expressions/
any_expression.rs

1//! Type-erased expression wrapper
2//!
3//! `AnyExpression` provides a way to store expressions of different types uniformly
4//! while preserving the ability to recover the concrete type through downcasting.
5
6use std::any::{Any, TypeId};
7
8/// Trait for expression types that can be type-erased
9///
10/// This trait is automatically implemented for any type that is Clone + Send + Sync + 'static,
11/// which matches the requirements for TableSource::Expr
12pub trait ExpressionLike: Send + Sync {
13    /// Clone this expression into a Box
14    fn clone_box(&self) -> Box<dyn ExpressionLike>;
15
16    /// Convert to Any for downcasting
17    fn as_any(&self) -> &dyn Any;
18
19    /// Get the TypeId of the concrete type
20    fn type_id(&self) -> TypeId;
21
22    /// Get the type name for debugging
23    fn type_name(&self) -> &'static str;
24}
25
26impl<T> ExpressionLike for T
27where
28    T: Clone + Send + Sync + 'static,
29{
30    fn clone_box(&self) -> Box<dyn ExpressionLike> {
31        Box::new(self.clone())
32    }
33
34    fn as_any(&self) -> &dyn Any {
35        self
36    }
37
38    fn type_id(&self) -> TypeId {
39        TypeId::of::<T>()
40    }
41
42    fn type_name(&self) -> &'static str {
43        std::any::type_name::<T>()
44    }
45}
46
47/// Type-erased expression that can be downcast to concrete expression types
48pub struct AnyExpression {
49    inner: Box<dyn ExpressionLike>,
50    type_id: TypeId,
51    type_name: &'static str,
52}
53
54impl AnyExpression {
55    /// Create a new AnyExpression from a concrete expression type
56    pub fn new<T: Clone + Send + Sync + 'static>(expr: T) -> Self {
57        Self {
58            inner: Box::new(expr),
59            type_id: TypeId::of::<T>(),
60            type_name: std::any::type_name::<T>(),
61        }
62    }
63
64    /// Attempt to downcast to a concrete expression type
65    pub fn downcast<T: Clone + 'static>(self) -> Result<T, Self> {
66        if self.type_id != TypeId::of::<T>() {
67            return Err(self);
68        }
69
70        let any = self.inner.as_any();
71        match any.downcast_ref::<T>() {
72            Some(expr) => Ok(expr.clone()),
73            None => Err(self),
74        }
75    }
76
77    /// Get a reference to the expression as a specific type
78    pub fn downcast_ref<T: 'static>(&self) -> Option<&T> {
79        if self.type_id != TypeId::of::<T>() {
80            return None;
81        }
82        self.inner.as_any().downcast_ref::<T>()
83    }
84
85    /// Check if this expression matches the given type
86    pub fn is_type<T: 'static>(&self) -> bool {
87        self.type_id == TypeId::of::<T>()
88    }
89
90    /// Get the expression type name for debugging
91    pub fn type_name(&self) -> &str {
92        self.type_name
93    }
94
95    /// Get the TypeId
96    pub fn type_id(&self) -> TypeId {
97        self.type_id
98    }
99
100    /// Convert to a boxed Any for advanced use cases
101    pub fn into_any(self) -> Box<dyn Any> {
102        // We need to extract the inner value through Any
103        // This is a bit tricky since we can't directly convert ExpressionLike to Any
104        Box::new(self)
105    }
106}
107
108impl Clone for AnyExpression {
109    fn clone(&self) -> Self {
110        Self {
111            inner: self.inner.clone_box(),
112            type_id: self.type_id,
113            type_name: self.type_name,
114        }
115    }
116}
117
118impl std::fmt::Debug for AnyExpression {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("AnyExpression")
121            .field("type_name", &self.type_name())
122            .finish()
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[derive(Debug, Clone, PartialEq)]
131    struct TestExpr {
132        value: i32,
133    }
134
135    #[derive(Debug, Clone, PartialEq)]
136    struct OtherExpr {
137        text: String,
138    }
139
140    #[test]
141    fn test_any_expression_creation_and_downcast() {
142        let expr = TestExpr { value: 42 };
143        let any = AnyExpression::new(expr.clone());
144
145        assert_eq!(any.type_name(), std::any::type_name::<TestExpr>());
146        assert!(any.is_type::<TestExpr>());
147
148        // Successful downcast
149        let recovered = any.downcast::<TestExpr>().unwrap();
150        assert_eq!(recovered, expr);
151    }
152
153    #[test]
154    fn test_any_expression_downcast_ref() {
155        let expr = TestExpr { value: 42 };
156        let any = AnyExpression::new(expr.clone());
157
158        // Successful downcast_ref
159        let expr_ref = any.downcast_ref::<TestExpr>().unwrap();
160        assert_eq!(expr_ref, &expr);
161
162        // Can still use any after downcast_ref
163        assert!(any.is_type::<TestExpr>());
164    }
165
166    #[test]
167    fn test_any_expression_downcast_wrong_type() {
168        let expr = TestExpr { value: 42 };
169        let any = AnyExpression::new(expr);
170
171        // Try to downcast to wrong type
172        let result = any.downcast::<OtherExpr>();
173        assert!(result.is_err());
174    }
175
176    #[test]
177    fn test_any_expression_is_type() {
178        let expr = TestExpr { value: 42 };
179        let any = AnyExpression::new(expr);
180
181        assert!(any.is_type::<TestExpr>());
182        assert!(!any.is_type::<OtherExpr>());
183    }
184
185    #[test]
186    fn test_any_expression_clone() {
187        let expr = TestExpr { value: 42 };
188        let any = AnyExpression::new(expr.clone());
189        let cloned = any.clone();
190
191        assert_eq!(cloned.type_name(), any.type_name());
192        assert_eq!(cloned.type_id(), any.type_id());
193
194        // Both should downcast successfully
195        let recovered1 = any.downcast::<TestExpr>().unwrap();
196        let recovered2 = cloned.downcast::<TestExpr>().unwrap();
197        assert_eq!(recovered1, recovered2);
198    }
199
200    #[test]
201    fn test_any_expression_debug() {
202        let expr = TestExpr { value: 42 };
203        let any = AnyExpression::new(expr);
204
205        let debug_str = format!("{:?}", any);
206        assert!(debug_str.contains("AnyExpression"));
207        assert!(debug_str.contains("type_name"));
208    }
209}