strict_encoding/
types.rs

1// Strict encoding library for deterministic binary serialization.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Written in 2019-2024 by
6//     Dr. Maxim Orlovsky <orlovsky@ubideco.org>
7//
8// Copyright 2022-2024 UBIDECO Labs
9//
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13//
14//     http://www.apache.org/licenses/LICENSE-2.0
15//
16// Unless required by applicable law or agreed to in writing, software
17// distributed under the License is distributed on an "AS IS" BASIS,
18// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19// See the License for the specific language governing permissions and
20// limitations under the License.
21
22use std::any;
23use std::collections::BTreeSet;
24use std::fmt::{Debug, Display};
25use std::marker::PhantomData;
26
27use crate::{LibName, TypeName, VariantName};
28
29pub fn type_name<T>() -> String {
30    fn get_ident(path: &str) -> &str {
31        path.rsplit_once("::").map(|(_, n)| n.trim()).unwrap_or(path)
32    }
33
34    let name = any::type_name::<T>().replace('&', "");
35    let mut ident = vec![];
36    for mut arg in name.split([',', '<', '>', '(', ')']) {
37        arg = arg.trim();
38        if arg.is_empty() {
39            continue;
40        }
41        ident.push(get_ident(arg));
42    }
43    ident.join("")
44}
45
46#[derive(Clone, Eq, PartialEq, Debug, Display, Error)]
47#[display("unexpected variant {1} for enum or union {0:?}")]
48pub struct VariantError<V: Debug + Display>(pub Option<String>, pub V);
49
50impl<V: Debug + Display> VariantError<V> {
51    pub fn with<T>(val: V) -> Self { VariantError(Some(type_name::<T>()), val) }
52    pub fn typed(name: impl Into<String>, val: V) -> Self { VariantError(Some(name.into()), val) }
53    pub fn untyped(val: V) -> Self { VariantError(None, val) }
54}
55
56pub trait StrictDumb: Sized {
57    fn strict_dumb() -> Self;
58}
59
60impl<T> StrictDumb for T
61where T: StrictType + Default
62{
63    fn strict_dumb() -> T { T::default() }
64}
65
66pub trait StrictType: Sized {
67    const STRICT_LIB_NAME: &'static str;
68    fn strict_name() -> Option<TypeName> { Some(tn!(type_name::<Self>())) }
69}
70
71impl<T: StrictType> StrictType for &T {
72    const STRICT_LIB_NAME: &'static str = T::STRICT_LIB_NAME;
73}
74
75impl<T> StrictType for PhantomData<T> {
76    const STRICT_LIB_NAME: &'static str = "";
77}
78
79pub trait StrictProduct: StrictType + StrictDumb {}
80
81pub trait StrictTuple: StrictProduct {
82    const FIELD_COUNT: u8;
83    fn strict_check_fields() {
84        let name = Self::strict_name().unwrap_or_else(|| tn!("__unnamed"));
85        assert_ne!(
86            Self::FIELD_COUNT,
87            0,
88            "tuple type {} does not contain a single field defined",
89            name
90        );
91    }
92
93    fn strict_type_info() -> TypeInfo<Self> {
94        Self::strict_check_fields();
95        TypeInfo {
96            lib: libname!(Self::STRICT_LIB_NAME),
97            name: Self::strict_name().map(|name| tn!(name)),
98            cls: TypeClass::Tuple(Self::FIELD_COUNT),
99            dumb: Self::strict_dumb(),
100        }
101    }
102}
103
104pub trait StrictStruct: StrictProduct {
105    const ALL_FIELDS: &'static [&'static str];
106
107    fn strict_check_fields() {
108        let name = Self::strict_name().unwrap_or_else(|| tn!("__unnamed"));
109        assert!(
110            !Self::ALL_FIELDS.is_empty(),
111            "struct type {} does not contain a single field defined",
112            name
113        );
114        let names: BTreeSet<_> = Self::ALL_FIELDS.iter().copied().collect();
115        assert_eq!(
116            names.len(),
117            Self::ALL_FIELDS.len(),
118            "struct type {} contains repeated field names",
119            name
120        );
121    }
122
123    fn strict_type_info() -> TypeInfo<Self> {
124        Self::strict_check_fields();
125        TypeInfo {
126            lib: libname!(Self::STRICT_LIB_NAME),
127            name: Self::strict_name().map(|name| tn!(name)),
128            cls: TypeClass::Struct(Self::ALL_FIELDS),
129            dumb: Self::strict_dumb(),
130        }
131    }
132}
133
134pub trait StrictSum: StrictType {
135    const ALL_VARIANTS: &'static [(u8, &'static str)];
136
137    fn strict_check_variants() {
138        let name = Self::strict_name().unwrap_or_else(|| tn!("__unnamed"));
139        assert!(
140            !Self::ALL_VARIANTS.is_empty(),
141            "type {} does not contain a single variant defined",
142            name
143        );
144        let (ords, names): (BTreeSet<_>, BTreeSet<_>) = Self::ALL_VARIANTS.iter().copied().unzip();
145        assert_eq!(
146            ords.len(),
147            Self::ALL_VARIANTS.len(),
148            "type {} contains repeated variant ids",
149            name
150        );
151        assert_eq!(
152            names.len(),
153            Self::ALL_VARIANTS.len(),
154            "type {} contains repeated variant names",
155            name
156        );
157    }
158
159    fn variant_name_by_tag(tag: u8) -> Option<VariantName> {
160        Self::ALL_VARIANTS
161            .iter()
162            .find(|(n, _)| *n == tag)
163            .map(|(_, variant_name)| vname!(*variant_name))
164    }
165
166    fn variant_ord(&self) -> u8 {
167        let variant = self.variant_name();
168        for (tag, name) in Self::ALL_VARIANTS {
169            if *name == variant {
170                return *tag;
171            }
172        }
173        unreachable!(
174            "not all variants are enumerated for {} enum in StrictUnion::all_variants \
175             implementation",
176            type_name::<Self>()
177        )
178    }
179    fn variant_name(&self) -> &'static str;
180}
181
182pub trait StrictUnion: StrictSum + StrictDumb {
183    fn strict_type_info() -> TypeInfo<Self> {
184        Self::strict_check_variants();
185        TypeInfo {
186            lib: libname!(Self::STRICT_LIB_NAME),
187            name: Self::strict_name().map(|name| tn!(name)),
188            cls: TypeClass::Union(Self::ALL_VARIANTS),
189            dumb: Self::strict_dumb(),
190        }
191    }
192}
193
194pub trait StrictEnum
195where
196    Self: StrictSum + Copy + TryFrom<u8, Error = VariantError<u8>>,
197    u8: From<Self>,
198{
199    fn from_variant_name(name: &VariantName) -> Result<Self, VariantError<&VariantName>> {
200        for (tag, n) in Self::ALL_VARIANTS {
201            if *n == name.as_str() {
202                return Self::try_from(*tag).map_err(|_| VariantError::with::<Self>(name));
203            }
204        }
205        Err(VariantError::with::<Self>(name))
206    }
207
208    fn strict_type_info() -> TypeInfo<Self> {
209        Self::strict_check_variants();
210        TypeInfo {
211            lib: libname!(Self::STRICT_LIB_NAME),
212            name: Self::strict_name().map(|name| tn!(name)),
213            cls: TypeClass::Enum(Self::ALL_VARIANTS),
214            dumb: Self::try_from(Self::ALL_VARIANTS[0].0)
215                .expect("first variant contains invalid value"),
216        }
217    }
218}
219
220pub enum TypeClass {
221    Embedded,
222    Enum(&'static [(u8, &'static str)]),
223    Union(&'static [(u8, &'static str)]),
224    Tuple(u8),
225    Struct(&'static [&'static str]),
226}
227
228pub struct TypeInfo<T: StrictType> {
229    pub lib: LibName,
230    pub name: Option<TypeName>,
231    pub cls: TypeClass,
232    pub dumb: T,
233}
234
235#[cfg(test)]
236mod test {
237    use amplify::confinement::TinyVec;
238
239    use super::*;
240
241    #[test]
242    fn name_derivation() { assert_eq!(Option::<TinyVec<u8>>::strict_name(), None) }
243}