1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
use crate::algebraic_value::de::{ValueDeserializeError, ValueDeserializer};
use crate::algebraic_value::ser::value_serialize;
use crate::meta_type::MetaType;
use crate::{de::Deserialize, ser::Serialize};
use crate::{AlgebraicType, AlgebraicValue, SumTypeVariant};
/// A structural sum type.
///
/// Unlike most languages, sums in SATS are *[structural]* and not nominal.
/// When checking whether two nominal types are the same,
/// their names and/or declaration sites (e.g., module / namespace) are considered.
/// Meanwhile, a structural type system would only check the structure of the type itself,
/// e.g., the names of its variants and their inner data types in the case of a sum.
///
/// This is also known as a discriminated union (implementation) or disjoint union.
/// Another name is [coproduct (category theory)](https://ncatlab.org/nlab/show/coproduct).
///
/// These structures are known as sum types because the number of possible values a sum
/// ```ignore
/// { N_0(T_0), N_1(T_1), ..., N_n(T_n) }
/// ```
/// is:
/// ```ignore
/// Σ (i ∈ 0..n). values(T_i)
/// ```
/// so for example, `values({ A(U64), B(Bool) }) = values(U64) + values(Bool)`.
///
/// See also: https://ncatlab.org/nlab/show/sum+type.
///
/// [structural]: https://en.wikipedia.org/wiki/Structural_type_system
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
#[sats(crate = crate)]
pub struct SumType {
/// The possible variants of the sum type.
///
/// The order is relevant as it defines the tags of the variants at runtime.
pub variants: Box<[SumTypeVariant]>,
}
impl SumType {
/// Returns a sum type with these possible `variants`.
pub const fn new(variants: Box<[SumTypeVariant]>) -> Self {
Self { variants }
}
/// Returns a sum type of unnamed variants taken from `types`.
pub fn new_unnamed(types: Box<[AlgebraicType]>) -> Self {
let variants = Vec::from(types).into_iter().map(|ty| ty.into()).collect();
Self { variants }
}
/// Returns whether this sum type looks like an option type.
///
/// An option type has `some(T)` as its first variant and `none` as its second.
/// That is, `{ some(T), none }` or `some: T | none` depending on your notation.
pub fn as_option(&self) -> Option<&AlgebraicType> {
match &*self.variants {
[first, second]
if second.is_unit() // Done first to avoid pointer indirection when it doesn't matter.
&& first.has_name("some")
&& second.has_name("none") =>
{
Some(&first.algebraic_type)
}
_ => None,
}
}
/// Returns whether this sum type is like on in C without data attached to the variants.
pub fn is_simple_enum(&self) -> bool {
self.variants.iter().all(SumTypeVariant::is_unit)
}
/// Returns the sum type variant using `tag_name` with their tag position.
pub fn get_variant(&self, tag_name: &str) -> Option<(u8, &SumTypeVariant)> {
self.variants.iter().enumerate().find_map(|(pos, x)| {
if x.name.as_deref() == Some(tag_name) {
Some((pos as u8, x))
} else {
None
}
})
}
/// Returns the sum type variant using `tag_name` with their tag position, if this is a [Self::is_simple_enum]
pub fn get_variant_simple(&self, tag_name: &str) -> Option<(u8, &SumTypeVariant)> {
if self.is_simple_enum() {
self.get_variant(tag_name)
} else {
None
}
}
}
impl From<Box<[SumTypeVariant]>> for SumType {
fn from(fields: Box<[SumTypeVariant]>) -> Self {
SumType::new(fields)
}
}
impl<const N: usize> From<[SumTypeVariant; N]> for SumType {
fn from(fields: [SumTypeVariant; N]) -> Self {
SumType::new(fields.into())
}
}
impl<const N: usize> From<[(Option<&str>, AlgebraicType); N]> for SumType {
fn from(fields: [(Option<&str>, AlgebraicType); N]) -> Self {
fields.map(|(s, t)| SumTypeVariant::new(t, s.map(<_>::into))).into()
}
}
impl<const N: usize> From<[(&str, AlgebraicType); N]> for SumType {
fn from(fields: [(&str, AlgebraicType); N]) -> Self {
fields.map(|(s, t)| SumTypeVariant::new_named(t, s)).into()
}
}
impl<const N: usize> From<[AlgebraicType; N]> for SumType {
fn from(fields: [AlgebraicType; N]) -> Self {
fields.map(SumTypeVariant::from).into()
}
}
impl MetaType for SumType {
fn meta_type() -> AlgebraicType {
AlgebraicType::product([("variants", AlgebraicType::array(SumTypeVariant::meta_type()))])
}
}
impl SumType {
pub fn as_value(&self) -> AlgebraicValue {
value_serialize(self)
}
pub fn from_value(value: &AlgebraicValue) -> Result<SumType, ValueDeserializeError> {
Self::deserialize(ValueDeserializer::from_ref(value))
}
}