Skip to main content

vernier_partial/
envelope.rs

1//! Wire envelope: magic, version, header, CRC framing.
2//!
3//! Two-archive layout, little-endian:
4//!
5//! ```text
6//! [ 4 bytes : MAGIC = b"VRPS"        ]
7//! [ 1 byte  : FORMAT_VERSION         ]
8//! [ 4 bytes : header_archive_len u32 ]
9//! [ N bytes : header rkyv archive    ]
10//! [ M bytes : body rkyv archive      ]
11//! [ 4 bytes : CRC32 over the preceding (9 + N + M) bytes ]
12//! ```
13//!
14//! The header archive carries paradigm-agnostic metadata
15//! ([`WireEnvelopeHeader`]). The body archive belongs to each
16//! paradigm crate — `vernier-partial` ships and validates body bytes
17//! as opaque `&[u8]` and never touches paradigm-specific rkyv types.
18//! This keeps the dep DAG flat: paradigm crates depend on
19//! `vernier-partial`; `vernier-partial` does not depend on them.
20
21use rkyv::rancor::Error as RkyvError;
22use rkyv::util::AlignedVec;
23
24use crate::error::{PartialError, PartialFormatErrorKind};
25use crate::traits::PartialExpectation;
26
27/// Wire-format magic: ASCII `"VRPS"` (vernier partial state). Every
28/// valid partial starts with these four bytes.
29pub const MAGIC: [u8; 4] = *b"VRPS";
30
31/// Wire-format version. Bumped on any breaking change to the framing
32/// or the archived header layout. Old versions are refused at decode
33/// with [`PartialFormatErrorKind::WrongVersion`].
34///
35/// **v2 (ADR-0032):** generalized `kernel_kind` u8 → `paradigm_kind`
36/// u8 + `discriminator` u32; added `shape_fingerprint: [u32; 4]`;
37/// split the body off the header archive (two archives separated by
38/// a length prefix) so paradigm crates own their own rkyv invocations.
39pub const FORMAT_VERSION: u8 = 2;
40
41/// Minimum bytes a partial must carry to even attempt parsing:
42/// 4 magic + 1 version + 4 header_len + 4 CRC. Header and body
43/// archive bodies can each be empty in principle (an empty rkyv
44/// archive is non-zero bytes, but we don't enforce that here — the
45/// rkyv access call surfaces the right error).
46const MIN_PARTIAL_BYTES: usize = MAGIC.len() + 1 + 4 + 4;
47
48/// rkyv alignment we copy archive bytes into before
49/// [`rkyv::access`]. Covers every primitive rkyv writes on x86_64 /
50/// aarch64. Caller-supplied transport bytes are not aligned; we copy
51/// once on decode rather than imposing alignment on the wire.
52const RKYV_ALIGN: usize = 16;
53
54/// Paradigm-agnostic envelope header. Carries everything
55/// `vernier-partial` needs to validate cross-rank compatibility:
56/// paradigm + sub-kind, parity mode, dataset / params hashes, the
57/// shape fingerprint, and the rank id.
58///
59/// **Field order is wire-format load-bearing.** rkyv's archived
60/// layout follows declaration order; reordering or inserting a field
61/// changes the byte layout. Add new fields at the end and bump
62/// [`FORMAT_VERSION`].
63#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize, Debug, Clone)]
64pub struct WireEnvelopeHeader {
65    /// [`ParadigmKind::as_u8`](crate::traits::ParadigmKind::as_u8).
66    pub paradigm_kind: u8,
67    /// Paradigm-defined sub-kind. Instance: `KernelKind` discriminant.
68    /// Semantic: `0`. Panoptic: `0`.
69    pub discriminator: u32,
70    /// `ParityMode` discriminant.
71    pub parity_mode: u8,
72    /// Optional rank identifier (none for single-rank flows).
73    pub rank_id: Option<u32>,
74    /// BLAKE3 over the canonical dataset form.
75    pub dataset_hash: [u8; 32],
76    /// BLAKE3 over the canonical params archive.
77    pub params_hash: [u8; 32],
78    /// Paradigm-specific four-slot shape fingerprint. Instance:
79    /// `(n_categories, n_area_ranges, n_images, retain_iou as u32)`.
80    /// Semantic: `(n_classes, 0, n_images, 0)`. Panoptic:
81    /// `(n_categories, 0, n_images, things_stuff_split as u32)`.
82    pub shape_fingerprint: [u32; 4],
83}
84
85/// Serialize a header + already-archived body into a framed partial
86/// blob.
87///
88/// Each paradigm archives its body via its own rkyv path (the body
89/// type lives in the paradigm crate and rkyv's generic bounds are
90/// paradigm-specific) and hands the resulting bytes here. We frame
91/// with magic, version, header archive, body archive, CRC.
92pub fn encode(header: &WireEnvelopeHeader, body_archive: &[u8]) -> Result<Vec<u8>, PartialError> {
93    let header_archive = rkyv::to_bytes::<RkyvError>(header).map_err(|e| PartialError::Format {
94        kind: PartialFormatErrorKind::RkyvDecode {
95            detail: format!("rkyv::to_bytes(header) failed: {e}"),
96        },
97    })?;
98    let header_len = u32::try_from(header_archive.len()).map_err(|_| PartialError::Format {
99        kind: PartialFormatErrorKind::RkyvDecode {
100            detail: format!("header archive too large: {} bytes", header_archive.len()),
101        },
102    })?;
103
104    let mut out = Vec::with_capacity(MIN_PARTIAL_BYTES + header_archive.len() + body_archive.len());
105    out.extend_from_slice(&MAGIC);
106    out.push(FORMAT_VERSION);
107    out.extend_from_slice(&header_len.to_le_bytes());
108    out.extend_from_slice(&header_archive);
109    out.extend_from_slice(body_archive);
110    let crc = crc32fast::hash(&out);
111    out.extend_from_slice(&crc.to_le_bytes());
112    Ok(out)
113}
114
115/// Validated archive view passed to the body callback in
116/// [`with_validated_envelope`]. Carries the typed rank id (decoded
117/// from the header) plus the unaligned body bytes the callback will
118/// rkyv-access itself.
119pub struct ValidatedView<'a> {
120    /// The validated archived header. Borrows from the same aligned
121    /// buffer that the callback's lifetime is tied to.
122    pub header: &'a ArchivedWireEnvelopeHeader,
123    /// Body archive bytes. Caller copies into its own
124    /// [`rkyv::util::AlignedVec`] before [`rkyv::access`].
125    pub body_archive: &'a [u8],
126}
127
128/// Validate a partial blob's framing + header fields and run a
129/// callback on the validated view.
130///
131/// The callback runs while the aligned buffer holding the archived
132/// header is still in scope; this is why the API is callback-based
133/// rather than returning an owned view.
134///
135/// **Validation order** (cheapest-first; same as ADR-0031):
136///
137/// 1. Length: at least the framing-overhead minimum (magic + version
138///    byte + header-length prefix + CRC trailer = 13 bytes).
139/// 2. Magic: `bytes[..4] == MAGIC`.
140/// 3. Version: `bytes[4] == FORMAT_VERSION`.
141/// 4. CRC: stored CRC matches `crc32(bytes[..len-4])`.
142/// 5. rkyv access on the header archive.
143/// 6. Header field comparison against `expected`.
144///
145/// If any step fails the callback is not invoked.
146pub fn with_validated_envelope<R>(
147    bytes: &[u8],
148    expected: &PartialExpectation,
149    body_callback: impl FnOnce(ValidatedView<'_>) -> Result<R, PartialError>,
150) -> Result<R, PartialError> {
151    let (header_bytes, body_bytes) = validate_framing(bytes)?;
152
153    let mut aligned: AlignedVec<RKYV_ALIGN> = AlignedVec::with_capacity(header_bytes.len());
154    aligned.extend_from_slice(header_bytes);
155    let archived =
156        rkyv::access::<ArchivedWireEnvelopeHeader, RkyvError>(&aligned).map_err(|e| {
157            PartialError::Format {
158                kind: PartialFormatErrorKind::RkyvDecode {
159                    detail: format!("rkyv::access(header) failed: {e}"),
160                },
161            }
162        })?;
163    validate_header_fields(archived, expected)?;
164    body_callback(ValidatedView {
165        header: archived,
166        body_archive: body_bytes,
167    })
168}
169
170/// Validate magic / version / CRC / header_len framing. On success
171/// returns `(header_archive_bytes, body_archive_bytes)` slices of the
172/// input.
173fn validate_framing(bytes: &[u8]) -> Result<(&[u8], &[u8]), PartialError> {
174    if bytes.len() < MIN_PARTIAL_BYTES {
175        return Err(PartialError::Format {
176            kind: PartialFormatErrorKind::TooShort {
177                observed: bytes.len(),
178                minimum: MIN_PARTIAL_BYTES,
179            },
180        });
181    }
182
183    let magic: [u8; 4] = bytes[..4].try_into().map_err(|_| PartialError::Format {
184        kind: PartialFormatErrorKind::TooShort {
185            observed: bytes.len(),
186            minimum: MIN_PARTIAL_BYTES,
187        },
188    })?;
189    if magic != MAGIC {
190        return Err(PartialError::Format {
191            kind: PartialFormatErrorKind::WrongMagic { found: magic },
192        });
193    }
194
195    let version = bytes[4];
196    if version != FORMAT_VERSION {
197        return Err(PartialError::Format {
198            kind: PartialFormatErrorKind::WrongVersion {
199                expected: FORMAT_VERSION,
200                found: version,
201            },
202        });
203    }
204
205    // CRC must match before we trust any further byte read. Otherwise
206    // a corrupted header_len could point past the end of the buffer
207    // and we'd surface a confusing error.
208    let crc_split = bytes.len() - 4;
209    let stored_crc =
210        u32::from_le_bytes(
211            bytes[crc_split..]
212                .try_into()
213                .map_err(|_| PartialError::Format {
214                    kind: PartialFormatErrorKind::Crc,
215                })?,
216        );
217    let actual_crc = crc32fast::hash(&bytes[..crc_split]);
218    if stored_crc != actual_crc {
219        return Err(PartialError::Format {
220            kind: PartialFormatErrorKind::Crc,
221        });
222    }
223
224    let header_len =
225        u32::from_le_bytes(bytes[5..9].try_into().map_err(|_| PartialError::Format {
226            kind: PartialFormatErrorKind::TooShort {
227                observed: bytes.len(),
228                minimum: MIN_PARTIAL_BYTES,
229            },
230        })?) as usize;
231    let header_end = 9usize.saturating_add(header_len);
232    if header_end > crc_split {
233        return Err(PartialError::Format {
234            kind: PartialFormatErrorKind::TooShort {
235                observed: bytes.len(),
236                minimum: header_end + 4,
237            },
238        });
239    }
240
241    Ok((&bytes[9..header_end], &bytes[header_end..crc_split]))
242}
243
244fn validate_header_fields(
245    archived: &ArchivedWireEnvelopeHeader,
246    expected: &PartialExpectation,
247) -> Result<(), PartialError> {
248    let paradigm = archived.paradigm_kind;
249    if paradigm != expected.paradigm.as_u8() {
250        return Err(PartialError::Format {
251            kind: PartialFormatErrorKind::ParadigmMismatch {
252                expected: expected.paradigm.as_u8(),
253                found: paradigm,
254            },
255        });
256    }
257    // Paradigm matched; sub-kind discriminator must match too.
258    let discriminator = archived.discriminator.to_native();
259    if discriminator != expected.discriminator {
260        return Err(PartialError::Format {
261            kind: PartialFormatErrorKind::KernelMismatch {
262                expected: expected.discriminator,
263                found: discriminator,
264            },
265        });
266    }
267    let parity_mode = archived.parity_mode;
268    if parity_mode != expected.parity_mode {
269        return Err(PartialError::Format {
270            kind: PartialFormatErrorKind::ParityMismatch {
271                expected: expected.parity_mode,
272                found: parity_mode,
273            },
274        });
275    }
276    let fingerprint = [
277        archived.shape_fingerprint[0].to_native(),
278        archived.shape_fingerprint[1].to_native(),
279        archived.shape_fingerprint[2].to_native(),
280        archived.shape_fingerprint[3].to_native(),
281    ];
282    if fingerprint != expected.shape_fingerprint {
283        return Err(PartialError::Format {
284            kind: PartialFormatErrorKind::GridMismatch {
285                detail: format!(
286                    "expected {:?}, got {:?}",
287                    expected.shape_fingerprint, fingerprint
288                ),
289            },
290        });
291    }
292    let dataset_hash: [u8; 32] = archived.dataset_hash;
293    if dataset_hash != expected.dataset_hash {
294        return Err(PartialError::DatasetMismatch {
295            expected: expected.dataset_hash,
296            actual: dataset_hash,
297        });
298    }
299    let params_hash: [u8; 32] = archived.params_hash;
300    if params_hash != expected.params_hash {
301        return Err(PartialError::ParamsMismatch {
302            expected: expected.params_hash,
303            actual: params_hash,
304        });
305    }
306    Ok(())
307}
308
309/// Read the typed `rank_id` field off an archived header. Centralized
310/// because [`Option<u32>`]'s archived form is `ArchivedOption<u32>`
311/// which needs the `.as_ref().map(...)` pattern at every read site.
312pub fn rank_id_from_archive(header: &ArchivedWireEnvelopeHeader) -> Option<u32> {
313    header.rank_id.as_ref().map(|v| v.to_native())
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::traits::ParadigmKind;
320
321    fn fake_expectation() -> PartialExpectation {
322        PartialExpectation {
323            paradigm: ParadigmKind::Instance,
324            discriminator: 0,
325            parity_mode: 1,
326            dataset_hash: [0xAB; 32],
327            params_hash: [0xCD; 32],
328            shape_fingerprint: [80, 4, 5000, 0],
329            strict_mode: false,
330        }
331    }
332
333    fn fake_header() -> WireEnvelopeHeader {
334        WireEnvelopeHeader {
335            paradigm_kind: ParadigmKind::Instance.as_u8(),
336            discriminator: 0,
337            parity_mode: 1,
338            rank_id: None,
339            dataset_hash: [0xAB; 32],
340            params_hash: [0xCD; 32],
341            shape_fingerprint: [80, 4, 5000, 0],
342        }
343    }
344
345    #[test]
346    fn round_trip_empty_body() {
347        let bytes = encode(&fake_header(), &[]).unwrap();
348        let exp = fake_expectation();
349        with_validated_envelope(&bytes, &exp, |view| {
350            assert!(view.body_archive.is_empty());
351            assert_eq!(view.header.paradigm_kind, ParadigmKind::Instance.as_u8());
352            Ok(())
353        })
354        .unwrap();
355    }
356
357    #[test]
358    fn rejects_too_short() {
359        let err = validate_framing(b"VRP").unwrap_err();
360        assert!(matches!(
361            err,
362            PartialError::Format {
363                kind: PartialFormatErrorKind::TooShort { .. }
364            }
365        ));
366    }
367
368    #[test]
369    fn rejects_wrong_magic() {
370        let mut bytes = vec![0u8; MIN_PARTIAL_BYTES + 8];
371        bytes[..4].copy_from_slice(b"FAKE");
372        bytes[4] = FORMAT_VERSION;
373        // crc still wrong but magic check fires first
374        let err = validate_framing(&bytes).unwrap_err();
375        assert!(matches!(
376            err,
377            PartialError::Format {
378                kind: PartialFormatErrorKind::WrongMagic { .. }
379            }
380        ));
381    }
382
383    #[test]
384    fn rejects_wrong_version() {
385        let mut bytes = vec![0u8; MIN_PARTIAL_BYTES + 8];
386        bytes[..4].copy_from_slice(&MAGIC);
387        bytes[4] = 99;
388        let err = validate_framing(&bytes).unwrap_err();
389        assert!(matches!(
390            err,
391            PartialError::Format {
392                kind: PartialFormatErrorKind::WrongVersion { .. }
393            }
394        ));
395    }
396
397    #[test]
398    fn rejects_bad_crc() {
399        let mut bytes = encode(&fake_header(), &[]).unwrap();
400        let n = bytes.len();
401        bytes[n - 1] ^= 0xFF;
402        let exp = fake_expectation();
403        let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
404        assert!(matches!(
405            err,
406            PartialError::Format {
407                kind: PartialFormatErrorKind::Crc
408            }
409        ));
410    }
411
412    #[test]
413    fn rejects_paradigm_mismatch() {
414        let bytes = encode(&fake_header(), &[]).unwrap();
415        let mut exp = fake_expectation();
416        exp.paradigm = ParadigmKind::Semantic;
417        let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
418        match err {
419            PartialError::Format {
420                kind: PartialFormatErrorKind::ParadigmMismatch { expected, found },
421            } => {
422                assert_eq!(expected, ParadigmKind::Semantic.as_u8());
423                assert_eq!(found, ParadigmKind::Instance.as_u8());
424            }
425            other => panic!("unexpected error: {other:?}"),
426        }
427    }
428
429    #[test]
430    fn rejects_discriminator_mismatch() {
431        let bytes = encode(&fake_header(), &[]).unwrap();
432        let mut exp = fake_expectation();
433        exp.discriminator = 1;
434        let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
435        assert!(matches!(
436            err,
437            PartialError::Format {
438                kind: PartialFormatErrorKind::KernelMismatch { .. }
439            }
440        ));
441    }
442
443    #[test]
444    fn rejects_dataset_hash_mismatch() {
445        let bytes = encode(&fake_header(), &[]).unwrap();
446        let mut exp = fake_expectation();
447        exp.dataset_hash = [0; 32];
448        let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
449        assert!(matches!(err, PartialError::DatasetMismatch { .. }));
450    }
451
452    #[test]
453    fn round_trip_with_body() {
454        // body archive is just opaque bytes from this crate's perspective
455        let body = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b";
456        let bytes = encode(&fake_header(), body).unwrap();
457        let exp = fake_expectation();
458        with_validated_envelope(&bytes, &exp, |view| {
459            assert_eq!(view.body_archive, body);
460            Ok(())
461        })
462        .unwrap();
463    }
464}