1use rkyv::rancor::Error as RkyvError;
22use rkyv::util::AlignedVec;
23
24use crate::error::{PartialError, PartialFormatErrorKind};
25use crate::traits::PartialExpectation;
26
27pub const MAGIC: [u8; 4] = *b"VRPS";
30
31pub const FORMAT_VERSION: u8 = 2;
40
41const MIN_PARTIAL_BYTES: usize = MAGIC.len() + 1 + 4 + 4;
47
48const RKYV_ALIGN: usize = 16;
53
54#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize, Debug, Clone)]
64pub struct WireEnvelopeHeader {
65 pub paradigm_kind: u8,
67 pub discriminator: u32,
70 pub parity_mode: u8,
72 pub rank_id: Option<u32>,
74 pub dataset_hash: [u8; 32],
76 pub params_hash: [u8; 32],
78 pub shape_fingerprint: [u32; 4],
83}
84
85pub fn encode(header: &WireEnvelopeHeader, body_archive: &[u8]) -> Result<Vec<u8>, PartialError> {
93 let header_archive = rkyv::to_bytes::<RkyvError>(header).map_err(|e| PartialError::Format {
94 kind: PartialFormatErrorKind::RkyvDecode {
95 detail: format!("rkyv::to_bytes(header) failed: {e}"),
96 },
97 })?;
98 let header_len = u32::try_from(header_archive.len()).map_err(|_| PartialError::Format {
99 kind: PartialFormatErrorKind::RkyvDecode {
100 detail: format!("header archive too large: {} bytes", header_archive.len()),
101 },
102 })?;
103
104 let mut out = Vec::with_capacity(MIN_PARTIAL_BYTES + header_archive.len() + body_archive.len());
105 out.extend_from_slice(&MAGIC);
106 out.push(FORMAT_VERSION);
107 out.extend_from_slice(&header_len.to_le_bytes());
108 out.extend_from_slice(&header_archive);
109 out.extend_from_slice(body_archive);
110 let crc = crc32fast::hash(&out);
111 out.extend_from_slice(&crc.to_le_bytes());
112 Ok(out)
113}
114
115pub struct ValidatedView<'a> {
120 pub header: &'a ArchivedWireEnvelopeHeader,
123 pub body_archive: &'a [u8],
126}
127
128pub fn with_validated_envelope<R>(
147 bytes: &[u8],
148 expected: &PartialExpectation,
149 body_callback: impl FnOnce(ValidatedView<'_>) -> Result<R, PartialError>,
150) -> Result<R, PartialError> {
151 let (header_bytes, body_bytes) = validate_framing(bytes)?;
152
153 let mut aligned: AlignedVec<RKYV_ALIGN> = AlignedVec::with_capacity(header_bytes.len());
154 aligned.extend_from_slice(header_bytes);
155 let archived =
156 rkyv::access::<ArchivedWireEnvelopeHeader, RkyvError>(&aligned).map_err(|e| {
157 PartialError::Format {
158 kind: PartialFormatErrorKind::RkyvDecode {
159 detail: format!("rkyv::access(header) failed: {e}"),
160 },
161 }
162 })?;
163 validate_header_fields(archived, expected)?;
164 body_callback(ValidatedView {
165 header: archived,
166 body_archive: body_bytes,
167 })
168}
169
170fn validate_framing(bytes: &[u8]) -> Result<(&[u8], &[u8]), PartialError> {
174 if bytes.len() < MIN_PARTIAL_BYTES {
175 return Err(PartialError::Format {
176 kind: PartialFormatErrorKind::TooShort {
177 observed: bytes.len(),
178 minimum: MIN_PARTIAL_BYTES,
179 },
180 });
181 }
182
183 let magic: [u8; 4] = bytes[..4].try_into().map_err(|_| PartialError::Format {
184 kind: PartialFormatErrorKind::TooShort {
185 observed: bytes.len(),
186 minimum: MIN_PARTIAL_BYTES,
187 },
188 })?;
189 if magic != MAGIC {
190 return Err(PartialError::Format {
191 kind: PartialFormatErrorKind::WrongMagic { found: magic },
192 });
193 }
194
195 let version = bytes[4];
196 if version != FORMAT_VERSION {
197 return Err(PartialError::Format {
198 kind: PartialFormatErrorKind::WrongVersion {
199 expected: FORMAT_VERSION,
200 found: version,
201 },
202 });
203 }
204
205 let crc_split = bytes.len() - 4;
209 let stored_crc =
210 u32::from_le_bytes(
211 bytes[crc_split..]
212 .try_into()
213 .map_err(|_| PartialError::Format {
214 kind: PartialFormatErrorKind::Crc,
215 })?,
216 );
217 let actual_crc = crc32fast::hash(&bytes[..crc_split]);
218 if stored_crc != actual_crc {
219 return Err(PartialError::Format {
220 kind: PartialFormatErrorKind::Crc,
221 });
222 }
223
224 let header_len =
225 u32::from_le_bytes(bytes[5..9].try_into().map_err(|_| PartialError::Format {
226 kind: PartialFormatErrorKind::TooShort {
227 observed: bytes.len(),
228 minimum: MIN_PARTIAL_BYTES,
229 },
230 })?) as usize;
231 let header_end = 9usize.saturating_add(header_len);
232 if header_end > crc_split {
233 return Err(PartialError::Format {
234 kind: PartialFormatErrorKind::TooShort {
235 observed: bytes.len(),
236 minimum: header_end + 4,
237 },
238 });
239 }
240
241 Ok((&bytes[9..header_end], &bytes[header_end..crc_split]))
242}
243
244fn validate_header_fields(
245 archived: &ArchivedWireEnvelopeHeader,
246 expected: &PartialExpectation,
247) -> Result<(), PartialError> {
248 let paradigm = archived.paradigm_kind;
249 if paradigm != expected.paradigm.as_u8() {
250 return Err(PartialError::Format {
251 kind: PartialFormatErrorKind::ParadigmMismatch {
252 expected: expected.paradigm.as_u8(),
253 found: paradigm,
254 },
255 });
256 }
257 let discriminator = archived.discriminator.to_native();
259 if discriminator != expected.discriminator {
260 return Err(PartialError::Format {
261 kind: PartialFormatErrorKind::KernelMismatch {
262 expected: expected.discriminator,
263 found: discriminator,
264 },
265 });
266 }
267 let parity_mode = archived.parity_mode;
268 if parity_mode != expected.parity_mode {
269 return Err(PartialError::Format {
270 kind: PartialFormatErrorKind::ParityMismatch {
271 expected: expected.parity_mode,
272 found: parity_mode,
273 },
274 });
275 }
276 let fingerprint = [
277 archived.shape_fingerprint[0].to_native(),
278 archived.shape_fingerprint[1].to_native(),
279 archived.shape_fingerprint[2].to_native(),
280 archived.shape_fingerprint[3].to_native(),
281 ];
282 if fingerprint != expected.shape_fingerprint {
283 return Err(PartialError::Format {
284 kind: PartialFormatErrorKind::GridMismatch {
285 detail: format!(
286 "expected {:?}, got {:?}",
287 expected.shape_fingerprint, fingerprint
288 ),
289 },
290 });
291 }
292 let dataset_hash: [u8; 32] = archived.dataset_hash;
293 if dataset_hash != expected.dataset_hash {
294 return Err(PartialError::DatasetMismatch {
295 expected: expected.dataset_hash,
296 actual: dataset_hash,
297 });
298 }
299 let params_hash: [u8; 32] = archived.params_hash;
300 if params_hash != expected.params_hash {
301 return Err(PartialError::ParamsMismatch {
302 expected: expected.params_hash,
303 actual: params_hash,
304 });
305 }
306 Ok(())
307}
308
309pub fn rank_id_from_archive(header: &ArchivedWireEnvelopeHeader) -> Option<u32> {
313 header.rank_id.as_ref().map(|v| v.to_native())
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::traits::ParadigmKind;
320
321 fn fake_expectation() -> PartialExpectation {
322 PartialExpectation {
323 paradigm: ParadigmKind::Instance,
324 discriminator: 0,
325 parity_mode: 1,
326 dataset_hash: [0xAB; 32],
327 params_hash: [0xCD; 32],
328 shape_fingerprint: [80, 4, 5000, 0],
329 strict_mode: false,
330 }
331 }
332
333 fn fake_header() -> WireEnvelopeHeader {
334 WireEnvelopeHeader {
335 paradigm_kind: ParadigmKind::Instance.as_u8(),
336 discriminator: 0,
337 parity_mode: 1,
338 rank_id: None,
339 dataset_hash: [0xAB; 32],
340 params_hash: [0xCD; 32],
341 shape_fingerprint: [80, 4, 5000, 0],
342 }
343 }
344
345 #[test]
346 fn round_trip_empty_body() {
347 let bytes = encode(&fake_header(), &[]).unwrap();
348 let exp = fake_expectation();
349 with_validated_envelope(&bytes, &exp, |view| {
350 assert!(view.body_archive.is_empty());
351 assert_eq!(view.header.paradigm_kind, ParadigmKind::Instance.as_u8());
352 Ok(())
353 })
354 .unwrap();
355 }
356
357 #[test]
358 fn rejects_too_short() {
359 let err = validate_framing(b"VRP").unwrap_err();
360 assert!(matches!(
361 err,
362 PartialError::Format {
363 kind: PartialFormatErrorKind::TooShort { .. }
364 }
365 ));
366 }
367
368 #[test]
369 fn rejects_wrong_magic() {
370 let mut bytes = vec![0u8; MIN_PARTIAL_BYTES + 8];
371 bytes[..4].copy_from_slice(b"FAKE");
372 bytes[4] = FORMAT_VERSION;
373 let err = validate_framing(&bytes).unwrap_err();
375 assert!(matches!(
376 err,
377 PartialError::Format {
378 kind: PartialFormatErrorKind::WrongMagic { .. }
379 }
380 ));
381 }
382
383 #[test]
384 fn rejects_wrong_version() {
385 let mut bytes = vec![0u8; MIN_PARTIAL_BYTES + 8];
386 bytes[..4].copy_from_slice(&MAGIC);
387 bytes[4] = 99;
388 let err = validate_framing(&bytes).unwrap_err();
389 assert!(matches!(
390 err,
391 PartialError::Format {
392 kind: PartialFormatErrorKind::WrongVersion { .. }
393 }
394 ));
395 }
396
397 #[test]
398 fn rejects_bad_crc() {
399 let mut bytes = encode(&fake_header(), &[]).unwrap();
400 let n = bytes.len();
401 bytes[n - 1] ^= 0xFF;
402 let exp = fake_expectation();
403 let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
404 assert!(matches!(
405 err,
406 PartialError::Format {
407 kind: PartialFormatErrorKind::Crc
408 }
409 ));
410 }
411
412 #[test]
413 fn rejects_paradigm_mismatch() {
414 let bytes = encode(&fake_header(), &[]).unwrap();
415 let mut exp = fake_expectation();
416 exp.paradigm = ParadigmKind::Semantic;
417 let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
418 match err {
419 PartialError::Format {
420 kind: PartialFormatErrorKind::ParadigmMismatch { expected, found },
421 } => {
422 assert_eq!(expected, ParadigmKind::Semantic.as_u8());
423 assert_eq!(found, ParadigmKind::Instance.as_u8());
424 }
425 other => panic!("unexpected error: {other:?}"),
426 }
427 }
428
429 #[test]
430 fn rejects_discriminator_mismatch() {
431 let bytes = encode(&fake_header(), &[]).unwrap();
432 let mut exp = fake_expectation();
433 exp.discriminator = 1;
434 let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
435 assert!(matches!(
436 err,
437 PartialError::Format {
438 kind: PartialFormatErrorKind::KernelMismatch { .. }
439 }
440 ));
441 }
442
443 #[test]
444 fn rejects_dataset_hash_mismatch() {
445 let bytes = encode(&fake_header(), &[]).unwrap();
446 let mut exp = fake_expectation();
447 exp.dataset_hash = [0; 32];
448 let err = with_validated_envelope(&bytes, &exp, |_| Ok(())).unwrap_err();
449 assert!(matches!(err, PartialError::DatasetMismatch { .. }));
450 }
451
452 #[test]
453 fn round_trip_with_body() {
454 let body = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b";
456 let bytes = encode(&fake_header(), body).unwrap();
457 let exp = fake_expectation();
458 with_validated_envelope(&bytes, &exp, |view| {
459 assert_eq!(view.body_archive, body);
460 Ok(())
461 })
462 .unwrap();
463 }
464}