willow_data_model/relative_encodings/
area_area.rs

1use compact_u64::{CompactU64, EncodingWidth, Tag, TagWidth};
2use ufotofu::{BulkConsumer, BulkProducer};
3use ufotofu_codec::{
4    Blame, DecodableCanonic, DecodeError, Encodable, EncodableKnownSize, EncodableSync,
5    RelativeDecodable, RelativeDecodableCanonic, RelativeDecodableSync, RelativeEncodable,
6    RelativeEncodableKnownSize, RelativeEncodableSync,
7};
8use willow_encoding::is_bitflagged;
9
10use crate::{
11    grouping::{Area, AreaSubspace, Range, RangeEnd},
12    Path, SubspaceId,
13};
14
15impl<const MCL: usize, const MCC: usize, const MPL: usize, S>
16    RelativeEncodable<Area<MCL, MCC, MPL, S>> for Area<MCL, MCC, MPL, S>
17where
18    S: SubspaceId + Encodable,
19{
20    /// Encodes this [`Area`] relative to another [`Area`] which [includes](https://willowprotocol.org/specs/grouping-entries/index.html#area_include_area) it.
21    ///
22    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#enc_area_in_area).
23    async fn relative_encode<C>(
24        &self,
25        consumer: &mut C,
26        r: &Area<MCL, MCC, MPL, S>,
27    ) -> Result<(), C::Error>
28    where
29        C: BulkConsumer<Item = u8>,
30    {
31        if !r.includes_area(self) {
32            panic!("Tried to encode an area relative to a area it is not included by")
33        }
34
35        let start_diff = core::cmp::min(
36            self.times().start - r.times().start,
37            u64::from(&r.times().end) - self.times().start,
38        );
39
40        let end_diff = core::cmp::min(
41            u64::from(&self.times().end) - r.times().start,
42            u64::from(&r.times().end) - u64::from(&self.times().end),
43        );
44
45        let mut header = 0;
46
47        if self.subspace() != r.subspace() {
48            header |= 0b1000_0000;
49        }
50
51        if self.times().end == RangeEnd::Open {
52            header |= 0b0100_0000;
53        }
54
55        if start_diff == self.times().start - r.times().start {
56            header |= 0b0010_0000;
57        }
58
59        if self.times().end != RangeEnd::Open
60            && end_diff == u64::from(&self.times().end) - r.times().start
61        {
62            header |= 0b0001_0000;
63        }
64
65        let start_diff_tag = Tag::min_tag(start_diff, TagWidth::two());
66        let end_diff_tag = Tag::min_tag(end_diff, TagWidth::two());
67
68        header |= start_diff_tag.data_at_offset(4);
69        header |= end_diff_tag.data_at_offset(6);
70
71        consumer.consume(header).await?;
72
73        match (&self.subspace(), &r.subspace()) {
74            (AreaSubspace::Any, AreaSubspace::Any) => {} // Same subspace
75            (AreaSubspace::Id(_), AreaSubspace::Id(_)) => {} // Same subspace
76            (AreaSubspace::Id(subspace), AreaSubspace::Any) => {
77                subspace.encode(consumer).await?;
78            }
79            (AreaSubspace::Any, AreaSubspace::Id(_)) => {
80                unreachable!(
81                    "We should have already rejected an area not included by another area!"
82                )
83            }
84        }
85
86        self.path().relative_encode(consumer, r.path()).await?;
87
88        CompactU64(start_diff)
89            .relative_encode(consumer, &start_diff_tag.encoding_width())
90            .await?;
91
92        if self.times().end != RangeEnd::Open {
93            CompactU64(end_diff)
94                .relative_encode(consumer, &end_diff_tag.encoding_width())
95                .await?;
96        }
97
98        Ok(())
99    }
100}
101
102impl<const MCL: usize, const MCC: usize, const MPL: usize, S>
103    RelativeDecodable<Area<MCL, MCC, MPL, S>, Blame> for Area<MCL, MCC, MPL, S>
104where
105    S: SubspaceId + DecodableCanonic,
106    Blame: From<S::ErrorReason> + From<S::ErrorCanonic>,
107{
108    /// Decodes an [`Area`] relative to another [`Area`] which [includes](https://willowprotocol.org/specs/grouping-entries/index.html#area_include_area) it.
109    ///
110    /// Will return an error if the encoding has not been produced by the corresponding encoding function.
111    ///
112    /// [Definition](https://willowprotocol.org/specs/encodings/index.html#enc_area_in_area).
113    async fn relative_decode<P>(
114        producer: &mut P,
115        r: &Area<MCL, MCC, MPL, S>,
116    ) -> Result<Self, DecodeError<P::Final, P::Error, Blame>>
117    where
118        P: BulkProducer<Item = u8>,
119        Self: Sized,
120    {
121        relative_decode_maybe_canonic::<false, MCL, MCC, MPL, S, P>(producer, r).await
122    }
123}
124
125impl<const MCL: usize, const MCC: usize, const MPL: usize, S>
126    RelativeDecodableCanonic<Area<MCL, MCC, MPL, S>, Blame, Blame> for Area<MCL, MCC, MPL, S>
127where
128    S: SubspaceId + DecodableCanonic,
129    Blame: From<S::ErrorReason> + From<S::ErrorCanonic>,
130{
131    async fn relative_decode_canonic<P>(
132        producer: &mut P,
133        r: &Area<MCL, MCC, MPL, S>,
134    ) -> Result<Self, DecodeError<P::Final, P::Error, Blame>>
135    where
136        P: BulkProducer<Item = u8>,
137        Self: Sized,
138    {
139        relative_decode_maybe_canonic::<true, MCL, MCC, MPL, S, P>(producer, r).await
140    }
141}
142
143impl<const MCL: usize, const MCC: usize, const MPL: usize, S>
144    RelativeEncodableKnownSize<Area<MCL, MCC, MPL, S>> for Area<MCL, MCC, MPL, S>
145where
146    S: SubspaceId + EncodableKnownSize,
147{
148    fn relative_len_of_encoding(&self, r: &Area<MCL, MCC, MPL, S>) -> usize {
149        if !r.includes_area(self) {
150            panic!("Tried to encode an area relative to a area it is not included by")
151        }
152
153        let start_diff = core::cmp::min(
154            self.times().start - r.times().start,
155            u64::from(&r.times().end) - self.times().start,
156        );
157
158        let end_diff = core::cmp::min(
159            u64::from(&self.times().end) - r.times().start,
160            u64::from(&r.times().end) - u64::from(&self.times().end),
161        );
162
163        let start_diff_tag = Tag::min_tag(start_diff, TagWidth::two());
164        let end_diff_tag = Tag::min_tag(end_diff, TagWidth::two());
165
166        let subspace_len = match (&self.subspace(), &r.subspace()) {
167            (AreaSubspace::Any, AreaSubspace::Any) => 0, // Same subspace
168            (AreaSubspace::Id(_), AreaSubspace::Id(_)) => 0, // Same subspace
169            (AreaSubspace::Id(subspace), AreaSubspace::Any) => subspace.len_of_encoding(),
170            (AreaSubspace::Any, AreaSubspace::Id(_)) => {
171                unreachable!(
172                    "We should have already rejected an area not included by another area!"
173                )
174            }
175        };
176
177        let path_len = self.path().relative_len_of_encoding(r.path());
178
179        let start_diff_len =
180            CompactU64(start_diff).relative_len_of_encoding(&start_diff_tag.encoding_width());
181
182        let end_diff_len = if self.times().end != RangeEnd::Open {
183            CompactU64(end_diff).relative_len_of_encoding(&end_diff_tag.encoding_width())
184        } else {
185            0
186        };
187
188        1 + subspace_len + path_len + start_diff_len + end_diff_len
189    }
190}
191
192impl<const MCL: usize, const MCC: usize, const MPL: usize, S>
193    RelativeEncodableSync<Area<MCL, MCC, MPL, S>> for Area<MCL, MCC, MPL, S>
194where
195    S: SubspaceId + EncodableSync,
196{
197}
198
199impl<const MCL: usize, const MCC: usize, const MPL: usize, S>
200    RelativeDecodableSync<Area<MCL, MCC, MPL, S>, Blame> for Area<MCL, MCC, MPL, S>
201where
202    S: SubspaceId + DecodableCanonic,
203    Blame: From<S::ErrorReason> + From<S::ErrorCanonic>,
204{
205}
206
207async fn relative_decode_maybe_canonic<
208    const CANONIC: bool,
209    const MCL: usize,
210    const MCC: usize,
211    const MPL: usize,
212    S,
213    P,
214>(
215    producer: &mut P,
216    r: &Area<MCL, MCC, MPL, S>,
217) -> Result<Area<MCL, MCC, MPL, S>, DecodeError<P::Final, P::Error, Blame>>
218where
219    P: BulkProducer<Item = u8>,
220    S: SubspaceId + DecodableCanonic,
221    Blame: From<S::ErrorReason> + From<S::ErrorCanonic>,
222{
223    let header = producer.produce_item().await?;
224
225    // Decode subspace?
226    let is_subspace_encoded = is_bitflagged(header, 0);
227
228    // Decode end value of times?
229    let is_times_end_open = is_bitflagged(header, 1);
230
231    // Add start_diff to out.get_times().start, or subtract from out.get_times().end?
232    let add_start_diff = is_bitflagged(header, 2);
233
234    // Add end_diff to out.get_times().start, or subtract from out.get_times().end?
235    let add_end_diff = is_bitflagged(header, 3);
236
237    // === Necessary to produce canonic encodings. ===
238    // Verify that we don't add_end_diff when open...
239    if CANONIC && add_end_diff && is_times_end_open {
240        return Err(DecodeError::Other(Blame::TheirFault));
241    }
242    // ===============================================
243
244    let start_time_diff_tag = Tag::from_raw(header, TagWidth::two(), 4);
245    let end_time_diff_tag = Tag::from_raw(header, TagWidth::two(), 6);
246
247    // === Necessary to produce canonic encodings. ===
248    // Verify the last two bits are zero if is_times_end_open
249    if CANONIC && is_times_end_open && (end_time_diff_tag.encoding_width() != EncodingWidth::one())
250    {
251        return Err(DecodeError::Other(Blame::TheirFault));
252    }
253    // ===============================================
254
255    let subspace = if is_subspace_encoded {
256        let id = if CANONIC {
257            S::decode_canonic(producer)
258                .await
259                .map_err(DecodeError::map_other_from)?
260        } else {
261            S::decode(producer)
262                .await
263                .map_err(DecodeError::map_other_from)?
264        };
265        let sub = AreaSubspace::Id(id);
266
267        // === Necessary to produce canonic encodings. ===
268        // Verify that subspace wasn't needlessly encoded
269        if CANONIC && &sub == r.subspace() {
270            return Err(DecodeError::Other(Blame::TheirFault));
271        }
272        // ===============================================
273
274        sub
275    } else {
276        r.subspace().clone()
277    };
278
279    // Verify that the decoded subspace is included by the reference subspace
280    match (&r.subspace(), &subspace) {
281        (AreaSubspace::Any, AreaSubspace::Any) => {}
282        (AreaSubspace::Any, AreaSubspace::Id(_)) => {}
283        (AreaSubspace::Id(_), AreaSubspace::Any) => {
284            return Err(DecodeError::Other(Blame::TheirFault));
285        }
286        (AreaSubspace::Id(a), AreaSubspace::Id(b)) => {
287            if a != b {
288                return Err(DecodeError::Other(Blame::TheirFault));
289            }
290        }
291    }
292
293    let path = if CANONIC {
294        Path::relative_decode_canonic(producer, r.path())
295            .await
296            .map_err(DecodeError::map_other_from)?
297    } else {
298        Path::relative_decode(producer, r.path())
299            .await
300            .map_err(DecodeError::map_other_from)?
301    };
302
303    // Verify the decoded path is prefixed by the reference path
304    if !path.is_prefixed_by(r.path()) {
305        return Err(DecodeError::Other(Blame::TheirFault));
306    }
307
308    let start_diff = if CANONIC {
309        CompactU64::relative_decode_canonic(producer, &start_time_diff_tag)
310            .await
311            .map_err(DecodeError::map_other_from)?
312            .0
313    } else {
314        CompactU64::relative_decode(producer, &start_time_diff_tag)
315            .await
316            .map_err(DecodeError::map_other_from)?
317            .0
318    };
319
320    let start = if add_start_diff {
321        r.times().start.checked_add(start_diff)
322    } else {
323        u64::from(&r.times().end).checked_sub(start_diff)
324    }
325    .ok_or(DecodeError::Other(Blame::TheirFault))?;
326
327    // TODO: DOES THE BELOW NEED TO BE PART OF CANONIC CHECK?
328    // Verify they sent correct start diff
329    let expected_start_diff = core::cmp::min(
330        start.checked_sub(r.times().start),
331        u64::from(&r.times().end).checked_sub(start),
332    )
333    .ok_or(DecodeError::Other(Blame::TheirFault))?;
334
335    if expected_start_diff != start_diff {
336        return Err(DecodeError::Other(Blame::TheirFault));
337    }
338
339    if CANONIC {
340        // === Necessary to produce canonic encodings. ===
341        // Verify that bit 2 of the header was set correctly
342        let should_add_start_diff = start_diff
343            == start
344                .checked_sub(r.times().start)
345                .ok_or(DecodeError::Other(Blame::TheirFault))?;
346
347        if add_start_diff != should_add_start_diff {
348            return Err(DecodeError::Other(Blame::TheirFault));
349        }
350        // ===============================================
351    }
352
353    let end = if is_times_end_open {
354        if add_end_diff {
355            return Err(DecodeError::Other(Blame::TheirFault));
356        }
357
358        RangeEnd::Open
359    } else {
360        let end_diff = if CANONIC {
361            CompactU64::relative_decode_canonic(producer, &end_time_diff_tag)
362                .await
363                .map_err(DecodeError::map_other_from)?
364                .0
365        } else {
366            CompactU64::relative_decode(producer, &end_time_diff_tag)
367                .await
368                .map_err(DecodeError::map_other_from)?
369                .0
370        };
371
372        let end = if add_end_diff {
373            r.times().start.checked_add(end_diff)
374        } else {
375            u64::from(&r.times().end).checked_sub(end_diff)
376        }
377        .ok_or(DecodeError::Other(Blame::TheirFault))?;
378
379        // Verify they sent correct end diff
380        let expected_end_diff = core::cmp::min(
381            end.checked_sub(r.times().start),
382            u64::from(&r.times().end).checked_sub(end),
383        )
384        .ok_or(DecodeError::Other(Blame::TheirFault))?;
385
386        if end_diff != expected_end_diff {
387            return Err(DecodeError::Other(Blame::TheirFault));
388        }
389
390        // === Necessary to produce canonic encodings. ===
391        if CANONIC {
392            let should_add_end_diff = end_diff
393                == end
394                    .checked_sub(r.times().start)
395                    .ok_or(DecodeError::Other(Blame::TheirFault))?;
396
397            if add_end_diff != should_add_end_diff {
398                return Err(DecodeError::Other(Blame::TheirFault));
399            }
400        }
401        // ============================================
402
403        RangeEnd::Closed(end)
404    };
405
406    let times = Range { start, end };
407
408    // Verify the decoded time range is included by the reference time range
409    if !r.times().includes_range(&times) {
410        return Err(DecodeError::Other(Blame::TheirFault));
411    }
412
413    Ok(Area::new(subspace, path, times))
414}