1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
use substrait::proto::{
    expression::{field_reference::ReferenceType, Literal, RexType},
    Expression, Type,
};

use crate::{
    error::{Result, SubstraitExprError},
    util::HasRequiredPropertiesRef,
};

use super::{
    literals::{LiteralExt, LiteralInference},
    schema::SchemaInfo,
};

/// Extends the protobuf Expression object with useful helper methods
pub trait ExpressionExt {
    /// The rex_type is a required property
    ///
    /// This converts an expression's Option<&RexType> into Result<&RexType>
    /// that fails with a "required property missing" error
    ///
    /// TODO: Should this be public?
    fn try_rex_type(&self) -> Result<&RexType>;
    /// Tries to decode the expression as a rust literal of the given type
    fn try_as_rust_literal<T: LiteralInference>(&self) -> Result<T>;
    /// Tries to decode the expression as a Substrait literal
    fn try_as_literal(&self) -> Result<&Literal>;
    /// Determines the output type of the expression
    ///
    /// TODO: Explain this more
    fn output_type(&self, schema: &SchemaInfo) -> Result<Type>;
}

impl ExpressionExt for Expression {
    fn try_rex_type(&self) -> Result<&RexType> {
        self.rex_type.as_ref().ok_or_else(|| {
            SubstraitExprError::invalid_substrait(
                "The required property rex_type was missing from an expression",
            )
        })
    }

    fn try_as_rust_literal<T: LiteralInference>(&self) -> Result<T> {
        let literal = self.try_as_literal()?;
        T::try_from_substrait(literal.literal_type.as_ref().ok_or_else(|| {
            SubstraitExprError::invalid_substrait(
                "The required property literal_type was missing from a literal",
            )
        })?)
    }

    fn try_as_literal(&self) -> Result<&Literal> {
        match self.try_rex_type()? {
            RexType::Literal(literal) => Ok(literal),
            _ => Err(SubstraitExprError::invalid_substrait(
                "Expected a literal but received something else",
            )),
        }
    }

    fn output_type(&self, schema: &SchemaInfo) -> Result<Type> {
        match self.try_rex_type()? {
            RexType::Literal(literal) => literal.data_type(),
            RexType::ScalarFunction(func) => func.output_type.required("output_type").cloned(),
            RexType::Selection(selection) => {
                match selection.root_type.as_ref().required("root_type")? {
                    substrait::proto::expression::field_reference::RootType::Expression(_) => {
                        todo!()
                    }
                    substrait::proto::expression::field_reference::RootType::RootReference(_) => {
                        match selection.reference_type.as_ref().required("reference_type")? {
                            ReferenceType::DirectReference(root_segment) => {
                                schema.resolve_type(root_segment)
                            },
                            ReferenceType::MaskedReference(_) => {
                                Err(SubstraitExprError::invalid_substrait("A root reference did not have a reference type of direct reference"))
                            }
                        }
                    }
                    substrait::proto::expression::field_reference::RootType::OuterReference(_) => {
                        todo!()
                    }
                }
            }
            _ => todo!(),
        }
    }
}