1use alloc::string::{String, ToString};
11use alloc::vec::Vec;
12use base64::Engine;
13use core::fmt;
14
15pub(crate) const CACHED_STRING_SENTINEL: u32 = u32::MAX;
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum DecodeError {
20 MessageTooShort { expected: usize, actual: usize },
22 U8BufferEmpty,
24 U16BufferEmpty,
26 U32BufferEmpty,
28 StringBufferTooShort { expected: usize, actual: usize },
30 InvalidUtf8 { position: usize },
32 InvalidMessageType { value: u8 },
34 InvalidHeaderOffsets {
36 u16_offset: u32,
37 u8_offset: u32,
38 str_offset: u32,
39 total_len: usize,
40 },
41 Custom(String),
43}
44
45impl fmt::Display for DecodeError {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 DecodeError::MessageTooShort { expected, actual } => {
49 write!(
50 f,
51 "message too short: expected at least {expected} bytes, got {actual}"
52 )
53 }
54 DecodeError::U8BufferEmpty => write!(f, "u8 buffer empty when trying to read"),
55 DecodeError::U16BufferEmpty => write!(f, "u16 buffer empty when trying to read"),
56 DecodeError::U32BufferEmpty => write!(f, "u32 buffer empty when trying to read"),
57 DecodeError::StringBufferTooShort { expected, actual } => {
58 write!(
59 f,
60 "string buffer too short: expected {expected} bytes, got {actual}"
61 )
62 }
63 DecodeError::InvalidUtf8 { position } => {
64 write!(f, "invalid UTF-8 at position {position}")
65 }
66 DecodeError::InvalidMessageType { value } => {
67 write!(f, "invalid message type: {value}")
68 }
69 DecodeError::InvalidHeaderOffsets {
70 u16_offset,
71 u8_offset,
72 str_offset,
73 total_len,
74 } => {
75 write!(
76 f,
77 "invalid header offsets: u16={u16_offset}, u8={u8_offset}, str={str_offset}, total_len={total_len}"
78 )
79 }
80 DecodeError::Custom(msg) => write!(f, "{msg}"),
81 }
82 }
83}
84
85impl core::error::Error for DecodeError {}
86
87impl From<DecodeError> for String {
88 fn from(err: DecodeError) -> String {
89 err.to_string()
90 }
91}
92
93#[repr(u8)]
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum MessageType {
97 Evaluate = 0,
99 Respond = 1,
101}
102
103#[derive(Debug, Clone)]
105pub(crate) struct OutboundIPCMessage {
106 pub(crate) message: IPCMessage,
107 pub(crate) top_level: bool,
116}
117
118impl OutboundIPCMessage {
119 pub(crate) fn new(message: IPCMessage, top_level: bool) -> Self {
120 Self { message, top_level }
121 }
122}
123
124#[derive(Debug, Clone)]
152pub(crate) struct IPCMessage {
153 data: Vec<u8>,
154}
155
156impl IPCMessage {
157 pub fn new(data: Vec<u8>) -> Self {
159 Self { data }
160 }
161
162 pub fn ty(&self) -> Result<MessageType, DecodeError> {
164 let mut decoded = DecodedData::from_bytes(&self.data)?;
165 let message_type = decoded.take_u8()?;
166 match message_type {
167 0 => Ok(MessageType::Evaluate),
168 1 => Ok(MessageType::Respond),
169 v => Err(DecodeError::InvalidMessageType { value: v }),
170 }
171 }
172
173 pub fn decoded(&self) -> Result<DecodedVariant<'_>, DecodeError> {
175 let mut decoded = DecodedData::from_bytes(&self.data)?;
176 let message_type = decoded.take_u8()?;
177 let message_type = match message_type {
178 0 => DecodedVariant::Evaluate { data: decoded },
179 1 => DecodedVariant::Respond { data: decoded },
180 v => return Err(DecodeError::InvalidMessageType { value: v }),
181 };
182 Ok(message_type)
183 }
184
185 pub fn data(&self) -> &[u8] {
187 &self.data
188 }
189
190 pub fn into_data(self) -> Vec<u8> {
192 self.data
193 }
194}
195
196#[derive(Debug)]
198pub(crate) enum DecodedVariant<'a> {
199 Respond { data: DecodedData<'a> },
201 Evaluate { data: DecodedData<'a> },
203}
204
205#[derive(Debug)]
207pub struct DecodedData<'a> {
208 u8_buf: &'a [u8],
209 u16_buf: &'a [u16],
210 u32_buf: &'a [u32],
211 str_buf: &'a [u8],
212}
213
214impl<'a> DecodedData<'a> {
215 pub(crate) fn from_bytes(bytes: &'a [u8]) -> Result<Self, DecodeError> {
217 if bytes.len() < 12 {
218 return Err(DecodeError::MessageTooShort {
219 expected: 12,
220 actual: bytes.len(),
221 });
222 }
223
224 let header: [u32; 3] = bytemuck::cast_slice(&bytes[0..12])
225 .try_into()
226 .map_err(|_| DecodeError::Custom("failed to parse header".to_string()))?;
227 let [u16_offset, u8_offset, str_offset] = header;
228
229 let total_len = bytes.len();
231 if u16_offset as usize > total_len
232 || u8_offset as usize > total_len
233 || str_offset as usize > total_len
234 || u16_offset < 12
235 || u8_offset < u16_offset
236 || str_offset < u8_offset
237 {
238 return Err(DecodeError::InvalidHeaderOffsets {
239 u16_offset,
240 u8_offset,
241 str_offset,
242 total_len,
243 });
244 }
245
246 let u32_buf = bytemuck::cast_slice(&bytes[12..u16_offset as usize]);
247 let u16_buf = bytemuck::cast_slice(&bytes[u16_offset as usize..u8_offset as usize]);
248 let u8_buf = &bytes[u8_offset as usize..str_offset as usize];
249 let str_buf = &bytes[str_offset as usize..];
250
251 Ok(Self {
252 u8_buf,
253 u16_buf,
254 u32_buf,
255 str_buf,
256 })
257 }
258
259 pub(crate) fn take_u8(&mut self) -> Result<u8, DecodeError> {
261 let [first, rest @ ..] = &self.u8_buf else {
262 return Err(DecodeError::U8BufferEmpty);
263 };
264 self.u8_buf = rest;
265 Ok(*first)
266 }
267
268 pub(crate) fn take_u16(&mut self) -> Result<u16, DecodeError> {
270 let [first, rest @ ..] = &self.u16_buf else {
271 return Err(DecodeError::U16BufferEmpty);
272 };
273 self.u16_buf = rest;
274 Ok(*first)
275 }
276
277 pub(crate) fn take_u32(&mut self) -> Result<u32, DecodeError> {
279 let [first, rest @ ..] = &self.u32_buf else {
280 return Err(DecodeError::U32BufferEmpty);
281 };
282 self.u32_buf = rest;
283 Ok(*first)
284 }
285
286 pub(crate) fn take_u64(&mut self) -> Result<u64, DecodeError> {
288 let low = self.take_u32()? as u64;
289 let high = self.take_u32()? as u64;
290 Ok((high << 32) | low)
291 }
292
293 pub(crate) fn take_u128(&mut self) -> Result<u128, DecodeError> {
295 let low = self.take_u64()? as u128;
296 let high = self.take_u64()? as u128;
297 Ok((high << 64) | low)
298 }
299
300 pub(crate) fn take_str(&mut self) -> Result<&'a str, DecodeError> {
302 let len = self.take_u32()? as usize;
303 let actual_len = self.str_buf.len();
304 let Some((buf, rem)) = self.str_buf.split_at_checked(len) else {
305 return Err(DecodeError::StringBufferTooShort {
306 expected: len,
307 actual: actual_len,
308 });
309 };
310 let s = core::str::from_utf8(buf).map_err(|e| DecodeError::InvalidUtf8 {
311 position: e.valid_up_to(),
312 })?;
313 self.str_buf = rem;
314 Ok(s)
315 }
316
317 pub(crate) fn is_empty(&self) -> bool {
319 self.u8_buf.is_empty()
320 && self.u16_buf.is_empty()
321 && self.u32_buf.is_empty()
322 && self.str_buf.is_empty()
323 }
324}
325
326#[derive(Debug, Default)]
328pub struct EncodedData {
329 pub(crate) u8_buf: Vec<u8>,
330 pub(crate) u16_buf: Vec<u16>,
331 pub(crate) u32_buf: Vec<u32>,
332 pub(crate) str_buf: Vec<u8>,
333 pub(crate) heap_ids_to_recycle_after_flush: Vec<u64>,
334 pub(crate) pending_type_ids: Vec<u32>,
338 pub(crate) needs_flush: bool,
341}
342
343impl EncodedData {
344 pub fn new() -> Self {
346 Self {
347 u8_buf: Vec::new(),
348 u16_buf: Vec::new(),
349 u32_buf: Vec::new(),
350 str_buf: Vec::new(),
351 heap_ids_to_recycle_after_flush: Vec::new(),
352 pending_type_ids: Vec::new(),
353 needs_flush: false,
354 }
355 }
356
357 pub fn mark_needs_flush(&mut self) {
360 self.needs_flush = true;
361 }
362
363 pub(crate) fn register_pending_type_id(&mut self, type_id: u32) {
366 self.pending_type_ids.push(type_id);
367 }
368
369 pub(crate) fn take_pending_type_ids(&mut self) -> Vec<u32> {
370 core::mem::take(&mut self.pending_type_ids)
371 }
372
373 pub(crate) fn defer_heap_id_recycle_until_flush(&mut self, id: u64) {
374 self.heap_ids_to_recycle_after_flush.push(id);
375 }
376
377 pub(crate) fn take_heap_ids_to_recycle_after_flush(&mut self) -> Vec<u64> {
378 core::mem::take(&mut self.heap_ids_to_recycle_after_flush)
379 }
380
381 pub(crate) fn byte_len(&self) -> usize {
383 12 + self.u32_buf.len() * 4
384 + self.u16_buf.len() * 2
385 + self.u8_buf.len()
386 + self.str_buf.len()
387 }
388
389 pub(crate) fn push_u8(&mut self, value: u8) {
391 self.u8_buf.push(value);
392 }
393
394 pub(crate) fn push_u16(&mut self, value: u16) {
396 self.u16_buf.push(value);
397 }
398
399 pub(crate) fn push_u32(&mut self, value: u32) {
401 self.u32_buf.push(value);
402 }
403
404 pub(crate) fn insert_u32s(&mut self, index: usize, values: &[u32]) {
406 let index = index.min(self.u32_buf.len());
407 let mut u32_buf = Vec::with_capacity(values.len() + self.u32_buf.len());
408 u32_buf.extend_from_slice(&self.u32_buf[..index]);
409 u32_buf.extend_from_slice(values);
410 u32_buf.extend_from_slice(&self.u32_buf[index..]);
411 self.u32_buf = u32_buf;
412 }
413
414 pub(crate) fn push_u64(&mut self, value: u64) {
416 self.push_u32((value & 0xFFFFFFFF) as u32);
417 self.push_u32((value >> 32) as u32);
418 }
419
420 pub(crate) fn push_u128(&mut self, value: u128) {
422 self.push_u64((value & 0xFFFFFFFFFFFFFFFF) as u64);
423 self.push_u64((value >> 64) as u64);
424 }
425
426 pub(crate) fn push_str(&mut self, value: &str) {
428 let len = u32::try_from(value.len()).expect("string length exceeds u32::MAX");
429 assert_ne!(
430 len, CACHED_STRING_SENTINEL,
431 "string length conflicts with cached string sentinel"
432 );
433 self.push_u32(len);
434 self.str_buf.extend_from_slice(value.as_bytes());
435 }
436
437 pub(crate) fn to_bytes(&self) -> Vec<u8> {
439 let u16_offset = 12 + self.u32_buf.len() * 4;
440 let u8_offset = u16_offset + self.u16_buf.len() * 2;
441 let str_offset = u8_offset + self.u8_buf.len();
442
443 let total_len = str_offset + self.str_buf.len();
444 let mut bytes = Vec::with_capacity(total_len);
445
446 bytes.extend_from_slice(&(u16_offset as u32).to_le_bytes());
448 bytes.extend_from_slice(&(u8_offset as u32).to_le_bytes());
449 bytes.extend_from_slice(&(str_offset as u32).to_le_bytes());
450
451 for &u in &self.u32_buf {
453 bytes.extend_from_slice(&u.to_le_bytes());
454 }
455
456 for &u in &self.u16_buf {
458 bytes.extend_from_slice(&u.to_le_bytes());
459 }
460
461 bytes.extend_from_slice(&self.u8_buf);
463
464 bytes.extend_from_slice(&self.str_buf);
466
467 bytes
468 }
469
470 pub(crate) fn extend(&mut self, other: &EncodedData) {
472 self.u8_buf.extend_from_slice(&other.u8_buf);
473 self.u16_buf.extend_from_slice(&other.u16_buf);
474 self.u32_buf.extend_from_slice(&other.u32_buf);
475 self.str_buf.extend_from_slice(&other.str_buf);
476 self.heap_ids_to_recycle_after_flush
477 .extend_from_slice(&other.heap_ids_to_recycle_after_flush);
478 self.needs_flush |= other.needs_flush;
479 }
480}
481
482pub(crate) fn decode_data(bytes: &[u8]) -> Option<IPCMessage> {
484 let engine = base64::engine::general_purpose::STANDARD;
485 let data = engine.decode(bytes).ok()?;
486 Some(IPCMessage { data })
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn message_header_only_carries_message_type() {
495 let mut encoder = EncodedData::new();
496 encoder.push_u8(MessageType::Evaluate as u8);
497 encoder.push_u32(99);
498
499 let msg = IPCMessage::new(encoder.to_bytes());
500 assert_eq!(msg.ty().unwrap(), MessageType::Evaluate);
501
502 let DecodedVariant::Evaluate { mut data, .. } = msg.decoded().unwrap() else {
503 panic!("expected Evaluate message");
504 };
505 assert_eq!(data.take_u32().unwrap(), 99);
506 }
507
508 #[test]
509 fn deferred_recycle_ids_are_encoder_local() {
510 let mut queued = EncodedData::new();
511 queued.defer_heap_id_recycle_until_flush(10);
512
513 let mut unrelated = EncodedData::new();
514 unrelated.defer_heap_id_recycle_until_flush(20);
515
516 assert_eq!(unrelated.take_heap_ids_to_recycle_after_flush(), vec![20]);
517 assert_eq!(queued.take_heap_ids_to_recycle_after_flush(), vec![10]);
518 }
519
520 #[test]
521 fn deferred_recycle_ids_extend_with_encoder_data() {
522 let mut outer = EncodedData::new();
523 outer.defer_heap_id_recycle_until_flush(10);
524
525 let mut encoded_during_op = EncodedData::new();
526 encoded_during_op.defer_heap_id_recycle_until_flush(20);
527
528 outer.extend(&encoded_during_op);
529
530 assert_eq!(outer.take_heap_ids_to_recycle_after_flush(), vec![10, 20]);
531 }
532}