1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
use crate::algebraic_value::de::{ValueDeserializeError, ValueDeserializer};
use crate::algebraic_value::ser::value_serialize;
use crate::meta_type::MetaType;
use crate::{de::Deserialize, ser::Serialize};
use crate::{AlgebraicType, AlgebraicValue, SumTypeVariant};

/// A structural sum type.
///
/// Unlike most languages, sums in SATS are *[structural]* and not nominal.
/// When checking whether two nominal types are the same,
/// their names and/or declaration sites (e.g., module / namespace) are considered.
/// Meanwhile, a structural type system would only check the structure of the type itself,
/// e.g., the names of its variants and their inner data types in the case of a sum.
///
/// This is also known as a discriminated union (implementation) or disjoint union.
/// Another name is [coproduct (category theory)](https://ncatlab.org/nlab/show/coproduct).
///
/// These structures are known as sum types because the number of possible values a sum
/// ```ignore
/// { N_0(T_0), N_1(T_1), ..., N_n(T_n) }
/// ```
/// is:
/// ```ignore
/// Σ (i ∈ 0..n). values(T_i)
/// ```
/// so for example, `values({ A(U64), B(Bool) }) = values(U64) + values(Bool)`.
///
/// See also: https://ncatlab.org/nlab/show/sum+type.
///
/// [structural]: https://en.wikipedia.org/wiki/Structural_type_system
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
#[sats(crate = crate)]
pub struct SumType {
    /// The possible variants of the sum type.
    ///
    /// The order is relevant as it defines the tags of the variants at runtime.
    pub variants: Box<[SumTypeVariant]>,
}

impl SumType {
    /// Returns a sum type with these possible `variants`.
    pub const fn new(variants: Box<[SumTypeVariant]>) -> Self {
        Self { variants }
    }

    /// Returns a sum type of unnamed variants taken from `types`.
    pub fn new_unnamed(types: Box<[AlgebraicType]>) -> Self {
        let variants = Vec::from(types).into_iter().map(|ty| ty.into()).collect();
        Self { variants }
    }

    /// Returns whether this sum type looks like an option type.
    ///
    /// An option type has `some(T)` as its first variant and `none` as its second.
    /// That is, `{ some(T), none }` or `some: T | none` depending on your notation.
    pub fn as_option(&self) -> Option<&AlgebraicType> {
        match &*self.variants {
            [first, second]
                if second.is_unit() // Done first to avoid pointer indirection when it doesn't matter.
                    && first.has_name("some")
                    && second.has_name("none") =>
            {
                Some(&first.algebraic_type)
            }
            _ => None,
        }
    }

    /// Returns whether this sum type is like on in C without data attached to the variants.
    pub fn is_simple_enum(&self) -> bool {
        self.variants.iter().all(SumTypeVariant::is_unit)
    }

    /// Returns the sum type variant using `tag_name` with their tag position.
    pub fn get_variant(&self, tag_name: &str) -> Option<(u8, &SumTypeVariant)> {
        self.variants.iter().enumerate().find_map(|(pos, x)| {
            if x.name.as_deref() == Some(tag_name) {
                Some((pos as u8, x))
            } else {
                None
            }
        })
    }

    /// Returns the sum type variant using `tag_name` with their tag position, if this is a [Self::is_simple_enum]
    pub fn get_variant_simple(&self, tag_name: &str) -> Option<(u8, &SumTypeVariant)> {
        if self.is_simple_enum() {
            self.get_variant(tag_name)
        } else {
            None
        }
    }
}

impl From<Box<[SumTypeVariant]>> for SumType {
    fn from(fields: Box<[SumTypeVariant]>) -> Self {
        SumType::new(fields)
    }
}
impl<const N: usize> From<[SumTypeVariant; N]> for SumType {
    fn from(fields: [SumTypeVariant; N]) -> Self {
        SumType::new(fields.into())
    }
}
impl<const N: usize> From<[(Option<&str>, AlgebraicType); N]> for SumType {
    fn from(fields: [(Option<&str>, AlgebraicType); N]) -> Self {
        fields.map(|(s, t)| SumTypeVariant::new(t, s.map(<_>::into))).into()
    }
}
impl<const N: usize> From<[(&str, AlgebraicType); N]> for SumType {
    fn from(fields: [(&str, AlgebraicType); N]) -> Self {
        fields.map(|(s, t)| SumTypeVariant::new_named(t, s)).into()
    }
}
impl<const N: usize> From<[AlgebraicType; N]> for SumType {
    fn from(fields: [AlgebraicType; N]) -> Self {
        fields.map(SumTypeVariant::from).into()
    }
}

impl MetaType for SumType {
    fn meta_type() -> AlgebraicType {
        AlgebraicType::product([("variants", AlgebraicType::array(SumTypeVariant::meta_type()))])
    }
}

impl SumType {
    pub fn as_value(&self) -> AlgebraicValue {
        value_serialize(self)
    }

    pub fn from_value(value: &AlgebraicValue) -> Result<SumType, ValueDeserializeError> {
        Self::deserialize(ValueDeserializer::from_ref(value))
    }
}