strict_encoding/
traits.rs

1// Strict encoding library for deterministic binary serialization.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Written in 2019-2024 by
6//     Dr. Maxim Orlovsky <orlovsky@ubideco.org>
7//
8// Copyright 2022-2024 UBIDECO Labs
9//
10// Licensed under the Apache License, Version 2.0 (the "License");
11// you may not use this file except in compliance with the License.
12// You may obtain a copy of the License at
13//
14//     http://www.apache.org/licenses/LICENSE-2.0
15//
16// Unless required by applicable law or agreed to in writing, software
17// distributed under the License is distributed on an "AS IS" BASIS,
18// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19// See the License for the specific language governing permissions and
20// limitations under the License.
21
22use std::io::{BufRead, Seek};
23use std::marker::PhantomData;
24use std::{fs, io};
25
26use amplify::confinement::{Collection, Confined};
27use amplify::num::u24;
28use amplify::Wrapper;
29
30use super::{DecodeError, DecodeRawLe, VariantName};
31use crate::reader::StreamReader;
32use crate::writer::StreamWriter;
33use crate::{
34    DeserializeError, FieldName, Primitive, SerializeError, Sizing, StrictDumb, StrictEnum,
35    StrictReader, StrictStruct, StrictSum, StrictTuple, StrictType, StrictUnion, StrictWriter,
36};
37
38pub trait TypedParent: Sized {}
39
40pub trait WriteRaw {
41    fn write_raw<const MAX_LEN: usize>(&mut self, bytes: impl AsRef<[u8]>) -> io::Result<()>;
42    fn write_raw_array<const LEN: usize>(&mut self, raw: [u8; LEN]) -> io::Result<()> {
43        self.write_raw::<LEN>(raw)
44    }
45    fn write_raw_len<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<()> {
46        match MAX_LEN {
47            tiny if tiny <= u8::MAX as usize => self.write_raw_array((len as u8).to_le_bytes()),
48            small if small <= u16::MAX as usize => self.write_raw_array((len as u16).to_le_bytes()),
49            medium if medium <= u24::MAX.into_usize() => {
50                self.write_raw_array((u24::with(len as u32)).to_le_bytes())
51            }
52            large if large <= u32::MAX as usize => self.write_raw_array((len as u32).to_le_bytes()),
53            huge if huge <= u64::MAX as usize => self.write_raw_array((len as u64).to_le_bytes()),
54            _ => unreachable!("confined collections larger than u64::MAX must not exist"),
55        }
56    }
57}
58
59impl<T: WriteRaw> WriteRaw for &mut T {
60    fn write_raw<const MAX_LEN: usize>(&mut self, bytes: impl AsRef<[u8]>) -> io::Result<()> {
61        (*self).write_raw::<MAX_LEN>(bytes)
62    }
63}
64
65#[allow(unused_variables)]
66pub trait TypedWrite: Sized {
67    type TupleWriter: WriteTuple<Parent = Self>;
68    type StructWriter: WriteStruct<Parent = Self>;
69    type UnionDefiner: DefineUnion<Parent = Self>;
70    type RawWriter: WriteRaw;
71
72    #[doc(hidden)]
73    unsafe fn raw_writer(&mut self) -> &mut Self::RawWriter;
74
75    fn write_union<T: StrictUnion>(
76        self,
77        inner: impl FnOnce(Self::UnionDefiner) -> io::Result<Self>,
78    ) -> io::Result<Self>;
79    fn write_enum<T: StrictEnum>(self, value: T) -> io::Result<Self>
80    where u8: From<T>;
81    fn write_tuple<T: StrictTuple>(
82        self,
83        inner: impl FnOnce(Self::TupleWriter) -> io::Result<Self>,
84    ) -> io::Result<Self>;
85    fn write_struct<T: StrictStruct>(
86        self,
87        inner: impl FnOnce(Self::StructWriter) -> io::Result<Self>,
88    ) -> io::Result<Self>;
89    fn write_newtype<T: StrictTuple>(self, value: &impl StrictEncode) -> io::Result<Self> {
90        self.write_tuple::<T>(|writer| Ok(writer.write_field(value)?.complete()))
91    }
92
93    #[doc(hidden)]
94    unsafe fn register_primitive(self, prim: Primitive) -> Self { self }
95    #[doc(hidden)]
96    unsafe fn register_array(self, ty: &impl StrictEncode, len: u16) -> Self { self }
97    #[doc(hidden)]
98    unsafe fn register_unicode(self, sizing: Sizing) -> Self { self }
99    #[doc(hidden)]
100    #[deprecated(since = "2.3.1", note = "use register_rstring")]
101    unsafe fn register_ascii(self, sizing: Sizing) -> Self {
102        panic!("TypedWrite::register_ascii must not be called; pls see compilation warnings")
103    }
104    #[doc(hidden)]
105    unsafe fn register_rstring(
106        self,
107        c: &impl StrictEncode,
108        c1: &impl StrictEncode,
109        sizing: Sizing,
110    ) -> Self {
111        self
112    }
113    #[doc(hidden)]
114    unsafe fn register_list(self, ty: &impl StrictEncode, sizing: Sizing) -> Self { self }
115    #[doc(hidden)]
116    unsafe fn register_set(self, ty: &impl StrictEncode, sizing: Sizing) -> Self { self }
117    #[doc(hidden)]
118    unsafe fn register_map(
119        self,
120        ket: &impl StrictEncode,
121        ty: &impl StrictEncode,
122        sizing: Sizing,
123    ) -> Self {
124        self
125    }
126
127    /// Used by unicode strings, ASCII strings and restricted char set strings.
128    #[doc(hidden)]
129    unsafe fn write_string<const MAX_LEN: usize>(
130        mut self,
131        bytes: impl AsRef<[u8]>,
132    ) -> io::Result<Self> {
133        self.raw_writer().write_raw_len::<MAX_LEN>(bytes.as_ref().len())?;
134        self.raw_writer().write_raw::<MAX_LEN>(bytes)?;
135        Ok(self)
136    }
137
138    /// Vec and sets - excluding strings, written by [`Self::write_string`].
139    #[doc(hidden)]
140    unsafe fn write_collection<C: Collection, const MIN_LEN: usize, const MAX_LEN: usize>(
141        mut self,
142        col: &Confined<C, MIN_LEN, MAX_LEN>,
143    ) -> io::Result<Self>
144    where
145        for<'a> &'a C: IntoIterator,
146        for<'a> <&'a C as IntoIterator>::Item: StrictEncode,
147    {
148        self.raw_writer().write_raw_len::<MAX_LEN>(col.len())?;
149        for item in col {
150            self = item.strict_encode(self)?;
151        }
152        Ok(self)
153    }
154
155    // TODO: Do `write_keyed_collection`
156}
157
158pub trait ReadRaw {
159    fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>>;
160
161    fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]>;
162
163    fn read_raw_len<const MAX_LEN: usize>(&mut self) -> Result<usize, DecodeError> {
164        Ok(match MAX_LEN {
165            tiny if tiny <= u8::MAX as usize => u8::decode_raw_le(self)? as usize,
166            small if small <= u16::MAX as usize => u16::decode_raw_le(self)? as usize,
167            medium if medium <= u24::MAX.into_usize() => u24::decode_raw_le(self)?.into_usize(),
168            large if large <= u32::MAX as usize => u32::decode_raw_le(self)? as usize,
169            huge if huge <= u64::MAX as usize => u64::decode_raw_le(self)? as usize,
170            _ => unreachable!("confined collections larger than u64::MAX must not exist"),
171        })
172    }
173}
174
175impl<T: ReadRaw> ReadRaw for &mut T {
176    fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
177        (*self).read_raw::<MAX_LEN>(len)
178    }
179
180    fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
181        (*self).read_raw_array::<LEN>()
182    }
183}
184
185pub trait TypedRead {
186    type TupleReader<'parent>: ReadTuple
187    where Self: 'parent;
188    type StructReader<'parent>: ReadStruct
189    where Self: 'parent;
190    type UnionReader: ReadUnion;
191    type RawReader: ReadRaw;
192
193    #[doc(hidden)]
194    unsafe fn raw_reader(&mut self) -> &mut Self::RawReader;
195
196    fn read_union<T: StrictUnion>(
197        &mut self,
198        inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
199    ) -> Result<T, DecodeError>;
200
201    fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
202    where u8: From<T>;
203
204    fn read_tuple<'parent, 'me, T: StrictTuple>(
205        &'me mut self,
206        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
207    ) -> Result<T, DecodeError>
208    where
209        Self: 'parent,
210        'me: 'parent;
211
212    fn read_struct<'parent, 'me, T: StrictStruct>(
213        &'me mut self,
214        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
215    ) -> Result<T, DecodeError>
216    where
217        Self: 'parent,
218        'me: 'parent;
219
220    fn read_newtype<T: StrictTuple + Wrapper>(&mut self) -> Result<T, DecodeError>
221    where T::Inner: StrictDecode {
222        self.read_tuple(|reader| reader.read_field().map(T::from_inner))
223    }
224
225    #[doc(hidden)]
226    unsafe fn read_string<const MAX_LEN: usize>(&mut self) -> Result<Vec<u8>, DecodeError> {
227        let len = self.raw_reader().read_raw_len::<MAX_LEN>()?;
228        self.raw_reader().read_raw::<MAX_LEN>(len).map_err(DecodeError::from)
229    }
230}
231
232pub trait DefineTuple: Sized {
233    type Parent: TypedParent;
234    fn define_field<T: StrictEncode + StrictDumb>(self) -> Self;
235    fn complete(self) -> Self::Parent;
236}
237
238pub trait WriteTuple: Sized {
239    type Parent: TypedParent;
240    fn write_field(self, value: &impl StrictEncode) -> io::Result<Self>;
241    fn complete(self) -> Self::Parent;
242}
243
244pub trait ReadTuple {
245    fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError>;
246}
247
248pub trait DefineStruct: Sized {
249    type Parent: TypedParent;
250    fn define_field<T: StrictEncode + StrictDumb>(self, name: FieldName) -> Self;
251    fn complete(self) -> Self::Parent;
252}
253
254pub trait WriteStruct: Sized {
255    type Parent: TypedParent;
256    fn write_field(self, name: FieldName, value: &impl StrictEncode) -> io::Result<Self>;
257    fn complete(self) -> Self::Parent;
258}
259
260pub trait ReadStruct {
261    fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError>;
262}
263
264pub trait DefineEnum: Sized {
265    type Parent: TypedWrite;
266    type EnumWriter: WriteEnum<Parent = Self::Parent>;
267    fn define_variant(self, name: VariantName) -> Self;
268    fn complete(self) -> Self::EnumWriter;
269}
270
271pub trait WriteEnum: Sized {
272    type Parent: TypedWrite;
273    fn write_variant(self, name: VariantName) -> io::Result<Self>;
274    fn complete(self) -> Self::Parent;
275}
276
277pub trait DefineUnion: Sized {
278    type Parent: TypedWrite;
279    type TupleDefiner: DefineTuple<Parent = Self>;
280    type StructDefiner: DefineStruct<Parent = Self>;
281    type UnionWriter: WriteUnion<Parent = Self::Parent>;
282
283    fn define_unit(self, name: VariantName) -> Self;
284    fn define_newtype<T: StrictEncode + StrictDumb>(self, name: VariantName) -> Self {
285        self.define_tuple(name, |definer| definer.define_field::<T>().complete())
286    }
287    fn define_tuple(
288        self,
289        name: VariantName,
290        inner: impl FnOnce(Self::TupleDefiner) -> Self,
291    ) -> Self;
292    fn define_struct(
293        self,
294        name: VariantName,
295        inner: impl FnOnce(Self::StructDefiner) -> Self,
296    ) -> Self;
297
298    fn complete(self) -> Self::UnionWriter;
299}
300
301pub trait WriteUnion: Sized {
302    type Parent: TypedWrite;
303    type TupleWriter: WriteTuple<Parent = Self>;
304    type StructWriter: WriteStruct<Parent = Self>;
305
306    fn write_unit(self, name: VariantName) -> io::Result<Self>;
307    fn write_newtype(self, name: VariantName, value: &impl StrictEncode) -> io::Result<Self> {
308        self.write_tuple(name, |writer| Ok(writer.write_field(value)?.complete()))
309    }
310    fn write_tuple(
311        self,
312        name: VariantName,
313        inner: impl FnOnce(Self::TupleWriter) -> io::Result<Self>,
314    ) -> io::Result<Self>;
315    fn write_struct(
316        self,
317        name: VariantName,
318        inner: impl FnOnce(Self::StructWriter) -> io::Result<Self>,
319    ) -> io::Result<Self>;
320
321    fn complete(self) -> Self::Parent;
322}
323
324pub trait ReadUnion: Sized {
325    type TupleReader<'parent>: ReadTuple
326    where Self: 'parent;
327    type StructReader<'parent>: ReadStruct
328    where Self: 'parent;
329
330    fn read_tuple<'parent, 'me, T: StrictSum>(
331        &'me mut self,
332        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
333    ) -> Result<T, DecodeError>
334    where
335        Self: 'parent,
336        'me: 'parent;
337
338    fn read_struct<'parent, 'me, T: StrictSum>(
339        &'me mut self,
340        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
341    ) -> Result<T, DecodeError>
342    where
343        Self: 'parent,
344        'me: 'parent;
345
346    fn read_newtype<T: StrictSum + From<I>, I: StrictDecode>(&mut self) -> Result<T, DecodeError> {
347        self.read_tuple(|reader| reader.read_field::<I>().map(T::from))
348    }
349}
350
351pub trait StrictEncode: StrictType {
352    fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W>;
353    fn strict_write(&self, writer: impl WriteRaw) -> io::Result<()> {
354        let w = StrictWriter::with(writer);
355        self.strict_encode(w)?;
356        Ok(())
357    }
358}
359
360pub trait StrictDecode: StrictType {
361    fn strict_decode(reader: &mut impl TypedRead) -> Result<Self, DecodeError>;
362    fn strict_read(reader: impl ReadRaw) -> Result<Self, DecodeError> {
363        let mut r = StrictReader::with(reader);
364        Self::strict_decode(&mut r)
365    }
366}
367
368impl<T: StrictEncode> StrictEncode for &T {
369    fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> {
370        (*self).strict_encode(writer)
371    }
372}
373
374impl<T> StrictEncode for PhantomData<T> {
375    fn strict_encode<W: TypedWrite>(&self, writer: W) -> io::Result<W> { Ok(writer) }
376}
377
378impl<T> StrictDecode for PhantomData<T> {
379    fn strict_decode(_reader: &mut impl TypedRead) -> Result<Self, DecodeError> { Ok(default!()) }
380}
381
382pub trait StrictSerialize: StrictEncode {
383    fn strict_serialized_len<const MAX: usize>(&self) -> io::Result<usize> {
384        let counter = StrictWriter::counter::<MAX>();
385        Ok(self.strict_encode(counter)?.unbox().unconfine().count)
386    }
387
388    fn to_strict_serialized<const MAX: usize>(
389        &self,
390    ) -> Result<Confined<Vec<u8>, 0, MAX>, SerializeError> {
391        let ast_data = StrictWriter::in_memory::<MAX>();
392        let data = self.strict_encode(ast_data)?.unbox().unconfine();
393        Confined::<Vec<u8>, 0, MAX>::try_from(data).map_err(SerializeError::from)
394    }
395
396    fn strict_serialize_to_file<const MAX: usize>(
397        &self,
398        path: impl AsRef<std::path::Path>,
399    ) -> Result<(), SerializeError> {
400        let file = fs::File::create(path)?;
401        // TODO: Do FileReader
402        let file = StrictWriter::with(StreamWriter::new::<MAX>(file));
403        self.strict_encode(file)?;
404        Ok(())
405    }
406}
407
408pub trait StrictDeserialize: StrictDecode {
409    fn from_strict_serialized<const MAX: usize>(
410        ast_data: Confined<Vec<u8>, 0, MAX>,
411    ) -> Result<Self, DeserializeError> {
412        let mut reader = StrictReader::in_memory::<MAX>(ast_data);
413        let me = Self::strict_decode(&mut reader)?;
414        let mut cursor = reader.into_cursor();
415        if !cursor.fill_buf()?.is_empty() {
416            return Err(DeserializeError::DataNotEntirelyConsumed);
417        }
418        Ok(me)
419    }
420
421    fn strict_deserialize_from_file<const MAX: usize>(
422        path: impl AsRef<std::path::Path>,
423    ) -> Result<Self, DeserializeError> {
424        let file = fs::File::open(path)?;
425        // TODO: Do FileReader
426        let mut reader = StrictReader::with(StreamReader::new::<MAX>(file));
427        let me = Self::strict_decode(&mut reader)?;
428        let mut file = reader.unbox().unconfine();
429        if file.stream_position()? != file.seek(io::SeekFrom::End(0))? {
430            return Err(DeserializeError::DataNotEntirelyConsumed);
431        }
432        Ok(me)
433    }
434}