willow_data_model/path/
codec.rs

1use ufotofu::{BulkConsumer, BulkProducer};
2
3use ufotofu_codec::{
4    Blame, Decodable, DecodableCanonic, DecodableSync, DecodeError, Encodable, EncodableKnownSize,
5    EncodableSync, RelativeDecodable, RelativeDecodableCanonic, RelativeDecodableSync,
6    RelativeEncodable, RelativeEncodableKnownSize, RelativeEncodableSync,
7};
8
9use compact_u64::*;
10
11use super::*;
12
13/// Essentially how to encode a Path, but working with an arbitrary iterator of components. Path encoding consists of calling this directly, relative path encoding consists of first encoding the length of the greatest common suffix and *then* calling this.
14async fn encode_from_iterator_of_components<'a, const MCL: usize, C, I>(
15    consumer: &mut C,
16    path_length: u64,
17    component_count: u64,
18    components: I,
19) -> Result<(), C::Error>
20where
21    C: BulkConsumer<Item = u8>,
22    I: Iterator<Item = Component<'a, MCL>>,
23{
24    // First byte contains two 4-bit tags of `CompactU64`s:
25    // total path length in bytes: 4 bit tag at offset zero
26    // total number of components: 4 bit tag at offset four
27    let path_length_tag = Tag::min_tag(path_length, TagWidth::four());
28    let component_count_tag = Tag::min_tag(component_count, TagWidth::four());
29
30    let first_byte = path_length_tag.data_at_offset(0) | component_count_tag.data_at_offset(4);
31    consumer.consume(first_byte).await?;
32
33    // Next, encode the total path length in compact bytes.
34    let total_bytes_bytes = CompactU64(path_length);
35    total_bytes_bytes
36        .relative_encode(consumer, &path_length_tag.encoding_width())
37        .await?;
38
39    // Next, encode the  total number of components in compact bytes.
40    let component_count_bytes = CompactU64(component_count);
41    component_count_bytes
42        .relative_encode(consumer, &component_count_tag.encoding_width())
43        .await?;
44
45    // Then, encode the components. Each is prefixed by its lenght as an 8-bit-tag CompactU64, except for the final component.
46    for (i, component) in components.enumerate() {
47        // The length of the final component is omitted (because a decoder can infer it from the total length and all prior components' lengths).
48        if i as u64 + 1 != component_count {
49            CompactU64(component.len() as u64).encode(consumer).await?;
50        }
51
52        // Each component length (if any) is followed by the raw component data itself.
53        consumer
54            .bulk_consume_full_slice(component.as_ref())
55            .await
56            .map_err(|err| err.into_reason())?;
57    }
58
59    Ok(())
60}
61
62impl<const MCL: usize, const MCC: usize, const MPL: usize> Encodable for Path<MCL, MCC, MPL> {
63    async fn encode<C>(&self, consumer: &mut C) -> Result<(), C::Error>
64    where
65        C: BulkConsumer<Item = u8>,
66    {
67        encode_from_iterator_of_components::<MCL, _, _>(
68            consumer,
69            self.path_length() as u64,
70            self.component_count() as u64,
71            self.components(),
72        )
73        .await
74    }
75}
76
77// Decodes the path length and component count as expected at the start of a path encoding, generic over whether the encoding must be canonic or not.
78// Implemented as a dedicated function so that it can be used in both absolute and relative decoding.
79async fn decode_total_length_and_component_count_maybe_canonic<const CANONIC: bool, P>(
80    producer: &mut P,
81) -> Result<(usize, usize), DecodeError<P::Final, P::Error, Blame>>
82where
83    P: BulkProducer<Item = u8>,
84{
85    // Decode the first byte - the two compact width tags for the path length and component count.
86    let first_byte = producer.produce_item().await?;
87    let path_length_tag = Tag::from_raw(first_byte, TagWidth::four(), 0);
88    let component_count_tag = Tag::from_raw(first_byte, TagWidth::four(), 4);
89
90    // Next, decode the total path length and the component count.
91    let total_length = relative_decode_cu64::<CANONIC, _>(producer, &path_length_tag).await?;
92    let component_count =
93        relative_decode_cu64::<CANONIC, _>(producer, &component_count_tag).await?;
94
95    // Convert them from u64 to usize, error if usize cannot represent the number.
96    let total_length = Blame::u64_to_usize(total_length)?;
97    let component_count = Blame::u64_to_usize(component_count)?;
98
99    Ok((total_length, component_count))
100}
101
102// Decodes the components of a path encoding, generic over whether the encoding must be canonic or not. Appends them into a PathBuilder. Needs to know the total length of components that had already been appended to that PathBuilder before.
103// Implemented as a dedicated function so that it can be used in both absolute and relative decoding.
104async fn decode_components_maybe_canonic<
105    const CANONIC: bool,
106    const MCL: usize,
107    const MCC: usize,
108    const MPL: usize,
109    P,
110>(
111    producer: &mut P,
112    mut builder: PathBuilder<MCL, MCC, MPL>,
113    initial_accumulated_component_length: usize,
114    remaining_component_count: usize,
115    expected_total_length: usize,
116) -> Result<Path<MCL, MCC, MPL>, DecodeError<P::Final, P::Error, Blame>>
117where
118    P: BulkProducer<Item = u8>,
119{
120    // Now decode the actual components.
121    // We track the sum of the lengths of all decoded components so far, because we need it to determine the length of the final component.
122    let mut accumulated_component_length = initial_accumulated_component_length;
123
124    // Handle decoding of the empty path with dedicated logic to prevent underflows in loop counters =S
125    if remaining_component_count == 0 {
126        if expected_total_length > accumulated_component_length {
127            // Claimed length is incorrect
128            Err(DecodeError::Other(Blame::TheirFault))
129        } else {
130            // Nothing more to do, decoding an empty path turns out to be simple!
131            Ok(builder.build())
132        }
133    } else {
134        // We have at least one component.
135
136        // Decode all but the final one (because the final one is encoded without its lenght and hence requires dedicated logic to decode).
137        for _ in 1..remaining_component_count {
138            let component_len = Blame::u64_to_usize(decode_cu64::<CANONIC, _>(producer).await?)?;
139
140            if component_len > MCL {
141                // Decoded path must respect the MCL.
142                return Err(DecodeError::Other(Blame::TheirFault));
143            } else {
144                // Increase the accumulated length, accounting for errors.
145                accumulated_component_length = accumulated_component_length
146                    .checked_add(component_len)
147                    .ok_or(DecodeError::Other(Blame::TheirFault))?;
148
149                // Copy the component bytes into the Path.
150                builder
151                    .append_component_from_bulk_producer(component_len, producer)
152                    .await?;
153            }
154        }
155
156        // For the final component, compute its length. If the computation result would be negative, then the encoding was invalid.
157        let final_component_length = expected_total_length
158            .checked_sub(accumulated_component_length)
159            .ok_or(DecodeError::Other(Blame::TheirFault))?;
160
161        if final_component_length > MCL {
162            // Decoded path must respect the MCL.
163            Err(DecodeError::Other(Blame::TheirFault))
164        } else {
165            // Copy the final component bytes into the Path.
166            builder
167                .append_component_from_bulk_producer(final_component_length, producer)
168                .await?;
169
170            // What a journey. We are done!
171            Ok(builder.build())
172        }
173    }
174}
175
176// Decodes a path, generic over whether the encoding must be canonic or not.
177async fn decode_maybe_canonic<
178    const CANONIC: bool,
179    const MCL: usize,
180    const MCC: usize,
181    const MPL: usize,
182    P,
183>(
184    producer: &mut P,
185) -> Result<Path<MCL, MCC, MPL>, DecodeError<P::Final, P::Error, Blame>>
186where
187    P: BulkProducer<Item = u8>,
188{
189    let (total_length, component_count) =
190        decode_total_length_and_component_count_maybe_canonic::<CANONIC, _>(producer).await?;
191
192    // Preallocate all storage for the path. Error if the total length or component_count are greater than MPL and MCC respectively allow.
193    let builder = PathBuilder::new(total_length, component_count)
194        .map_err(|_| DecodeError::Other(Blame::TheirFault))?;
195
196    decode_components_maybe_canonic::<CANONIC, MCL, MCC, MPL, _>(
197        producer,
198        builder,
199        0,
200        component_count,
201        total_length,
202    )
203    .await
204}
205
206impl<const MCL: usize, const MCC: usize, const MPL: usize> Decodable for Path<MCL, MCC, MPL> {
207    type ErrorReason = Blame;
208
209    async fn decode<P>(
210        producer: &mut P,
211    ) -> Result<Self, DecodeError<P::Final, P::Error, Self::ErrorReason>>
212    where
213        P: BulkProducer<Item = u8>,
214        Self: Sized,
215    {
216        decode_maybe_canonic::<false, MCL, MCC, MPL, _>(producer).await
217    }
218}
219
220impl<const MCL: usize, const MCC: usize, const MPL: usize> DecodableCanonic
221    for Path<MCL, MCC, MPL>
222{
223    type ErrorCanonic = Blame;
224
225    async fn decode_canonic<P>(
226        producer: &mut P,
227    ) -> Result<Self, DecodeError<P::Final, P::Error, Self::ErrorCanonic>>
228    where
229        P: BulkProducer<Item = u8>,
230        Self: Sized,
231    {
232        decode_maybe_canonic::<true, MCL, MCC, MPL, _>(producer).await
233    }
234}
235
236// Separate function to allow for reuse in relative encoding.
237fn encoding_len_from_iterator_of_components<'a, const MCL: usize, I>(
238    path_length: u64,
239    component_count: usize,
240    components: I,
241) -> usize
242where
243    I: Iterator<Item = Component<'a, MCL>>,
244{
245    let mut total_enc_len = 1; // First byte for the two four-bit tags at the start of the encoding.
246
247    total_enc_len += EncodingWidth::min_width(path_length, TagWidth::four()).as_usize();
248    total_enc_len += EncodingWidth::min_width(component_count as u64, TagWidth::four()).as_usize();
249
250    for (i, comp) in components.enumerate() {
251        if i + 1 < component_count {
252            total_enc_len += CompactU64(comp.len() as u64).len_of_encoding();
253        }
254
255        total_enc_len += comp.len();
256    }
257
258    total_enc_len
259}
260
261impl<const MCL: usize, const MCC: usize, const MPL: usize> EncodableKnownSize
262    for Path<MCL, MCC, MPL>
263{
264    fn len_of_encoding(&self) -> usize {
265        encoding_len_from_iterator_of_components::<MCL, _>(
266            self.path_length() as u64,
267            self.component_count(),
268            self.components(),
269        )
270    }
271}
272
273impl<const MCL: usize, const MCC: usize, const MPL: usize> EncodableSync for Path<MCL, MCC, MPL> {}
274impl<const MCL: usize, const MCC: usize, const MPL: usize> DecodableSync for Path<MCL, MCC, MPL> {}
275
276// Relative encoding path <> path
277
278impl<const MCL: usize, const MCC: usize, const MPL: usize> RelativeEncodable<Path<MCL, MCC, MPL>>
279    for Path<MCL, MCC, MPL>
280{
281    /// Encodes this [`Path`] relative to a reference [`Path`].
282    ///
283    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#enc_path_relative)
284    async fn relative_encode<Consumer>(
285        &self,
286        consumer: &mut Consumer,
287        reference: &Path<MCL, MCC, MPL>,
288    ) -> Result<(), Consumer::Error>
289    where
290        Consumer: BulkConsumer<Item = u8>,
291    {
292        let lcp = self.longest_common_prefix(reference);
293
294        CompactU64(lcp.component_count() as u64)
295            .encode(consumer)
296            .await?;
297
298        let suffix_length = self.path_length() - lcp.path_length();
299        let suffix_component_count = self.component_count() - lcp.component_count();
300
301        encode_from_iterator_of_components::<MCL, _, _>(
302            consumer,
303            suffix_length as u64,
304            suffix_component_count as u64,
305            self.suffix_components(lcp.component_count()),
306        )
307        .await
308    }
309}
310
311// Decodes a path relative to another path, generic over whether the encoding must be canonic or not.
312async fn relative_decode_maybe_canonic<
313    const CANONIC: bool,
314    const MCL: usize,
315    const MCC: usize,
316    const MPL: usize,
317    P,
318>(
319    producer: &mut P,
320    r: &Path<MCL, MCC, MPL>,
321) -> Result<Path<MCL, MCC, MPL>, DecodeError<P::Final, P::Error, Blame>>
322where
323    P: BulkProducer<Item = u8>,
324{
325    let prefix_component_count = Blame::u64_to_usize(decode_cu64::<CANONIC, _>(producer).await?)?;
326
327    let (suffix_length, suffix_component_count) =
328        decode_total_length_and_component_count_maybe_canonic::<CANONIC, _>(producer).await?;
329
330    if prefix_component_count > r.component_count() {
331        return Err(DecodeError::Other(Blame::TheirFault));
332    }
333
334    let prefix_path_length = r.path_length_of_prefix(prefix_component_count);
335
336    let total_length = prefix_path_length
337        .checked_add(suffix_length)
338        .ok_or(DecodeError::Other(Blame::TheirFault))?;
339    let total_component_count = prefix_component_count
340        .checked_add(suffix_component_count)
341        .ok_or(DecodeError::Other(Blame::TheirFault))?;
342
343    // Preallocate all storage for the path. Error if the total length or component_count are greater than MPL and MCC respectively allow.
344    let builder = PathBuilder::new_from_prefix(
345        total_length,
346        total_component_count,
347        r,
348        prefix_component_count,
349    )
350    .map_err(|_| DecodeError::Other(Blame::TheirFault))?;
351
352    // Decode the remaining components, add them to the builder, then build.
353    let decoded = decode_components_maybe_canonic::<CANONIC, MCL, MCC, MPL, _>(
354        producer,
355        builder,
356        prefix_path_length,
357        suffix_component_count,
358        total_length,
359    )
360    .await?;
361
362    if CANONIC {
363        // Did the encoding use the *longest* common prefix?
364        if prefix_component_count == r.component_count() {
365            // Could not have taken a longer prefix of `r`, i.e., the prefix was maximal.
366            Ok(decoded)
367        } else if prefix_component_count == decoded.component_count() {
368            // The prefix was the full path to decode, so it clearly was chosen maximally.
369            Ok(decoded)
370        } else {
371            // We check whether the next-longer prefix of `r` could have also been used for encoding. If so, error.
372            // To efficiently check, we check whether the next component of `r` is equal to its counterpart in what we decoded.
373            // Both next components exist, otherwise we would have been in an earlier branch of the `if` expression.
374            if r.component(prefix_component_count).unwrap()
375                == decoded.component(prefix_component_count).unwrap()
376            {
377                // Could have used a longer prefix for decoding. Not canonic!
378                Err(DecodeError::Other(Blame::TheirFault))
379            } else {
380                // Encoding was minimal, yay =)
381                Ok(decoded)
382            }
383        }
384    } else {
385        // No additional canonicity checks needed.
386        Ok(decoded)
387    }
388}
389
390impl<const MCL: usize, const MCC: usize, const MPL: usize>
391    RelativeDecodable<Path<MCL, MCC, MPL>, Blame> for Path<MCL, MCC, MPL>
392{
393    /// Decodes a [`Path`] relative to a reference [`Path`].
394    ///
395    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#enc_path_relative)
396    async fn relative_decode<P>(
397        producer: &mut P,
398        r: &Path<MCL, MCC, MPL>,
399    ) -> Result<Self, DecodeError<P::Final, P::Error, Blame>>
400    where
401        P: BulkProducer<Item = u8>,
402        Self: Sized,
403    {
404        relative_decode_maybe_canonic::<false, MCL, MCC, MPL, _>(producer, r).await
405    }
406}
407
408impl<const MCL: usize, const MCC: usize, const MPL: usize>
409    RelativeDecodableCanonic<Path<MCL, MCC, MPL>, Blame, Blame> for Path<MCL, MCC, MPL>
410{
411    async fn relative_decode_canonic<P>(
412        producer: &mut P,
413        r: &Path<MCL, MCC, MPL>,
414    ) -> Result<Self, DecodeError<P::Final, P::Error, Blame>>
415    where
416        P: BulkProducer<Item = u8>,
417        Self: Sized,
418    {
419        relative_decode_maybe_canonic::<true, MCL, MCC, MPL, _>(producer, r).await
420    }
421}
422
423impl<const MCL: usize, const MCC: usize, const MPL: usize>
424    RelativeEncodableKnownSize<Path<MCL, MCC, MPL>> for Path<MCL, MCC, MPL>
425{
426    fn relative_len_of_encoding(&self, r: &Path<MCL, MCC, MPL>) -> usize {
427        let lcp = self.longest_common_prefix(r);
428        let path_len_of_suffix = self.path_length() - lcp.path_length();
429        let component_count_of_suffix = self.component_count() - lcp.component_count();
430
431        let mut total_enc_len = 0;
432
433        // Number of components in the longest common prefix, encoded as a CompactU64 with an 8-bit tag.
434        total_enc_len += CompactU64(lcp.component_count() as u64).len_of_encoding();
435
436        total_enc_len += encoding_len_from_iterator_of_components::<MCL, _>(
437            path_len_of_suffix as u64,
438            component_count_of_suffix,
439            self.suffix_components(lcp.component_count()),
440        );
441
442        total_enc_len
443    }
444}
445
446impl<const MCL: usize, const MCC: usize, const MPL: usize>
447    RelativeEncodableSync<Path<MCL, MCC, MPL>> for Path<MCL, MCC, MPL>
448{
449}
450
451impl<const MCL: usize, const MCC: usize, const MPL: usize>
452    RelativeDecodableSync<Path<MCL, MCC, MPL>, Blame> for Path<MCL, MCC, MPL>
453{
454}
455
456// TODO move stuff below to more appropriate files/crates
457
458/// Decodes a `CompactU64` relative to a `Tag`, generic over whether the encoding must be canonic or not.
459pub async fn relative_decode_cu64<const CANONIC: bool, P>(
460    producer: &mut P,
461    tag: &Tag,
462) -> Result<u64, DecodeError<P::Final, P::Error, Blame>>
463where
464    P: BulkProducer<Item = u8>,
465{
466    if CANONIC {
467        Ok(CompactU64::relative_decode_canonic(producer, tag)
468            .await
469            .map_err(|err| DecodeError::map_other(err, |_| Blame::TheirFault))?
470            .0)
471    } else {
472        Ok(CompactU64::relative_decode(producer, tag)
473            .await
474            .map_err(|err| DecodeError::map_other(err, |_| Blame::TheirFault))?
475            .0)
476    }
477}
478
479/// Decodes a `CompactU64`, generic over whether the encoding must be canonic or not.
480pub async fn decode_cu64<const CANONIC: bool, P>(
481    producer: &mut P,
482) -> Result<u64, DecodeError<P::Final, P::Error, Blame>>
483where
484    P: BulkProducer<Item = u8>,
485{
486    if CANONIC {
487        Ok(CompactU64::decode_canonic(producer)
488            .await
489            .map_err(|err| DecodeError::map_other(err, |_| Blame::TheirFault))?
490            .0)
491    } else {
492        Ok(CompactU64::decode(producer)
493            .await
494            .map_err(|err| DecodeError::map_other(err, |_| Blame::TheirFault))?
495            .0)
496    }
497}
498
499pub async fn encode_path_extends_path<const MCL: usize, const MCC: usize, const MPL: usize, C>(
500    consumer: &mut C,
501    path: &Path<MCL, MCC, MPL>,
502    extends: &Path<MCL, MCC, MPL>,
503) -> Result<(), C::Error>
504where
505    C: BulkConsumer<Item = u8>,
506{
507    // Check path extends extends
508    if !path.is_prefixed_by(extends) {
509        panic!("Tried to encode with PathExtendsPath with a path that does not extend another path")
510    }
511
512    let extends_count = extends.component_count();
513
514    let path_len = path.path_length() - extends.path_length();
515    let diff = path.component_count() - extends_count;
516
517    encode_from_iterator_of_components(
518        consumer,
519        path_len as u64,
520        diff as u64,
521        path.suffix_components(extends_count),
522    )
523    .await?;
524
525    Ok(())
526}
527
528pub async fn decode_path_extends_path<const MCL: usize, const MCC: usize, const MPL: usize, P>(
529    producer: &mut P,
530    prefix: &Path<MCL, MCC, MPL>,
531) -> Result<Path<MCL, MCC, MPL>, DecodeError<P::Final, P::Error, Blame>>
532where
533    P: BulkProducer<Item = u8>,
534{
535    let suffix = Path::<MCL, MCC, MPL>::decode(producer)
536        .await
537        .map_err(DecodeError::map_other_from)?;
538
539    let prefix_count = prefix.component_count();
540
541    let total_length = prefix.path_length() + suffix.path_length();
542    let total_count = prefix_count + suffix.component_count();
543
544    let mut path_builder =
545        PathBuilder::new_from_prefix(total_length, total_count, prefix, prefix_count)
546            .map_err(|_err| DecodeError::Other(Blame::TheirFault))?;
547
548    for component in suffix.components() {
549        path_builder.append_component(component);
550    }
551
552    Ok(path_builder.build())
553}