1use std::collections::HashMap;
59use std::io::{self, Read, Write};
60use std::sync::Arc;
61use std::sync::atomic::{AtomicU64, Ordering};
62
63use parking_lot::Mutex;
64
65pub type RequestId = u64;
67
68pub type StreamId = u64;
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73#[repr(u8)]
74pub enum MessageType {
75 Request = 0,
77 Response = 1,
79 StreamStart = 2,
81 StreamData = 3,
83 StreamEnd = 4,
85 Error = 5,
87 FlowPause = 6,
89 FlowResume = 7,
91 Ping = 8,
93 Pong = 9,
95 Cancel = 10,
97}
98
99impl TryFrom<u8> for MessageType {
100 type Error = IpcError;
101
102 fn try_from(value: u8) -> Result<Self, <Self as TryFrom<u8>>::Error> {
103 match value {
104 0 => Ok(MessageType::Request),
105 1 => Ok(MessageType::Response),
106 2 => Ok(MessageType::StreamStart),
107 3 => Ok(MessageType::StreamData),
108 4 => Ok(MessageType::StreamEnd),
109 5 => Ok(MessageType::Error),
110 6 => Ok(MessageType::FlowPause),
111 7 => Ok(MessageType::FlowResume),
112 8 => Ok(MessageType::Ping),
113 9 => Ok(MessageType::Pong),
114 10 => Ok(MessageType::Cancel),
115 _ => Err(IpcError::InvalidMessageType(value)),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy)]
122pub struct FrameHeader {
123 pub length: u32,
125 pub id: u64,
127 pub msg_type: MessageType,
129 pub flags: u8,
131}
132
133impl FrameHeader {
134 pub const SIZE: usize = 14; pub const MAX_PAYLOAD: u32 = 16 * 1024 * 1024;
139
140 pub fn new(id: u64, msg_type: MessageType, payload_len: usize) -> Self {
142 Self {
143 length: payload_len as u32,
144 id,
145 msg_type,
146 flags: 0,
147 }
148 }
149
150 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
152 let mut buf = [0u8; Self::SIZE];
153 buf[0..4].copy_from_slice(&self.length.to_le_bytes());
154 buf[4..12].copy_from_slice(&self.id.to_le_bytes());
155 buf[12] = self.msg_type as u8;
156 buf[13] = self.flags;
157 buf
158 }
159
160 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self, IpcError> {
162 let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
163 let id = u64::from_le_bytes([
164 buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
165 ]);
166 let msg_type = MessageType::try_from(buf[12])?;
167 let flags = buf[13];
168
169 if length > Self::MAX_PAYLOAD {
170 return Err(IpcError::PayloadTooLarge(length as usize));
171 }
172
173 Ok(Self {
174 length,
175 id,
176 msg_type,
177 flags,
178 })
179 }
180}
181
182#[derive(Debug, Clone)]
184pub struct Frame {
185 pub header: FrameHeader,
186 pub payload: Vec<u8>,
187}
188
189impl Frame {
190 pub fn request(id: RequestId, payload: Vec<u8>) -> Self {
192 Self {
193 header: FrameHeader::new(id, MessageType::Request, payload.len()),
194 payload,
195 }
196 }
197
198 pub fn response(id: RequestId, payload: Vec<u8>) -> Self {
200 Self {
201 header: FrameHeader::new(id, MessageType::Response, payload.len()),
202 payload,
203 }
204 }
205
206 pub fn stream_start(id: StreamId, payload: Vec<u8>) -> Self {
208 Self {
209 header: FrameHeader::new(id, MessageType::StreamStart, payload.len()),
210 payload,
211 }
212 }
213
214 pub fn stream_data(id: StreamId, payload: Vec<u8>) -> Self {
216 Self {
217 header: FrameHeader::new(id, MessageType::StreamData, payload.len()),
218 payload,
219 }
220 }
221
222 pub fn stream_end(id: StreamId) -> Self {
224 Self {
225 header: FrameHeader::new(id, MessageType::StreamEnd, 0),
226 payload: Vec::new(),
227 }
228 }
229
230 pub fn error(id: RequestId, error_code: u32, message: &str) -> Self {
232 let mut payload = Vec::with_capacity(4 + message.len());
233 payload.extend_from_slice(&error_code.to_le_bytes());
234 payload.extend_from_slice(message.as_bytes());
235 Self {
236 header: FrameHeader::new(id, MessageType::Error, payload.len()),
237 payload,
238 }
239 }
240
241 pub fn ping(id: RequestId) -> Self {
243 Self {
244 header: FrameHeader::new(id, MessageType::Ping, 0),
245 payload: Vec::new(),
246 }
247 }
248
249 pub fn pong(id: RequestId) -> Self {
251 Self {
252 header: FrameHeader::new(id, MessageType::Pong, 0),
253 payload: Vec::new(),
254 }
255 }
256
257 pub fn cancel(id: RequestId) -> Self {
259 Self {
260 header: FrameHeader::new(id, MessageType::Cancel, 0),
261 payload: Vec::new(),
262 }
263 }
264
265 pub fn to_bytes(&self) -> Vec<u8> {
267 let mut buf = Vec::with_capacity(FrameHeader::SIZE + self.payload.len());
268 buf.extend_from_slice(&self.header.to_bytes());
269 buf.extend_from_slice(&self.payload);
270 buf
271 }
272}
273
274#[derive(Debug)]
276pub enum IpcError {
277 Io(io::Error),
278 InvalidMessageType(u8),
279 PayloadTooLarge(usize),
280 UnexpectedEof,
281 RequestCancelled(RequestId),
282 StreamClosed(StreamId),
283 Timeout,
284}
285
286impl From<io::Error> for IpcError {
287 fn from(e: io::Error) -> Self {
288 IpcError::Io(e)
289 }
290}
291
292impl std::fmt::Display for IpcError {
293 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
294 match self {
295 IpcError::Io(e) => write!(f, "IO error: {}", e),
296 IpcError::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
297 IpcError::PayloadTooLarge(size) => write!(f, "Payload too large: {} bytes", size),
298 IpcError::UnexpectedEof => write!(f, "Unexpected end of stream"),
299 IpcError::RequestCancelled(id) => write!(f, "Request {} cancelled", id),
300 IpcError::StreamClosed(id) => write!(f, "Stream {} closed", id),
301 IpcError::Timeout => write!(f, "Operation timed out"),
302 }
303 }
304}
305
306impl std::error::Error for IpcError {}
307
308pub struct FrameReader<R: Read> {
310 reader: R,
311 header_buf: [u8; FrameHeader::SIZE],
312}
313
314impl<R: Read> FrameReader<R> {
315 pub fn new(reader: R) -> Self {
316 Self {
317 reader,
318 header_buf: [0u8; FrameHeader::SIZE],
319 }
320 }
321
322 pub fn read_frame(&mut self) -> Result<Frame, IpcError> {
324 self.reader.read_exact(&mut self.header_buf)?;
326 let header = FrameHeader::from_bytes(&self.header_buf)?;
327
328 let mut payload = vec![0u8; header.length as usize];
330 self.reader.read_exact(&mut payload)?;
331
332 Ok(Frame { header, payload })
333 }
334
335 pub fn into_inner(self) -> R {
337 self.reader
338 }
339}
340
341pub struct FrameWriter<W: Write> {
343 writer: W,
344 buffer: Vec<u8>,
346 max_buffer: usize,
348}
349
350impl<W: Write> FrameWriter<W> {
351 const DEFAULT_BUFFER: usize = 64 * 1024;
353
354 pub fn new(writer: W) -> Self {
355 Self {
356 writer,
357 buffer: Vec::with_capacity(Self::DEFAULT_BUFFER),
358 max_buffer: Self::DEFAULT_BUFFER,
359 }
360 }
361
362 pub fn write_frame(&mut self, frame: &Frame) -> Result<(), IpcError> {
364 let bytes = frame.to_bytes();
365
366 if bytes.len() > self.max_buffer {
368 self.flush()?;
369 self.writer.write_all(&bytes)?;
370 return Ok(());
371 }
372
373 if self.buffer.len() + bytes.len() > self.max_buffer {
375 self.flush()?;
376 }
377
378 self.buffer.extend_from_slice(&bytes);
379 Ok(())
380 }
381
382 pub fn flush(&mut self) -> Result<(), IpcError> {
384 if !self.buffer.is_empty() {
385 self.writer.write_all(&self.buffer)?;
386 self.buffer.clear();
387 }
388 self.writer.flush()?;
389 Ok(())
390 }
391
392 pub fn into_inner(self) -> W {
394 self.writer
395 }
396}
397
398struct PendingRequest {
400 callback: Box<dyn FnOnce(Result<Frame, IpcError>) + Send>,
401}
402
403pub struct RequestMultiplexer {
405 next_id: AtomicU64,
407 pending: Mutex<HashMap<RequestId, PendingRequest>>,
409 streams: Mutex<HashMap<StreamId, StreamState>>,
411}
412
413struct StreamState {
415 on_data: Box<dyn Fn(Vec<u8>) + Send>,
417 on_end: Box<dyn FnOnce() + Send>,
419 #[allow(dead_code)]
421 paused: bool,
422}
423
424impl Default for RequestMultiplexer {
425 fn default() -> Self {
426 Self::new()
427 }
428}
429
430impl RequestMultiplexer {
431 pub fn new() -> Self {
432 Self {
433 next_id: AtomicU64::new(1),
434 pending: Mutex::new(HashMap::new()),
435 streams: Mutex::new(HashMap::new()),
436 }
437 }
438
439 pub fn next_id(&self) -> RequestId {
441 self.next_id.fetch_add(1, Ordering::SeqCst)
442 }
443
444 pub fn register_request<F>(&self, id: RequestId, callback: F)
446 where
447 F: FnOnce(Result<Frame, IpcError>) + Send + 'static,
448 {
449 self.pending.lock().insert(
450 id,
451 PendingRequest {
452 callback: Box::new(callback),
453 },
454 );
455 }
456
457 pub fn register_stream<D, E>(&self, id: StreamId, on_data: D, on_end: E)
459 where
460 D: Fn(Vec<u8>) + Send + 'static,
461 E: FnOnce() + Send + 'static,
462 {
463 self.streams.lock().insert(
464 id,
465 StreamState {
466 on_data: Box::new(on_data),
467 on_end: Box::new(on_end),
468 paused: false,
469 },
470 );
471 }
472
473 pub fn handle_frame(&self, frame: Frame) {
475 match frame.header.msg_type {
476 MessageType::Response | MessageType::Error => {
477 if let Some(pending) = self.pending.lock().remove(&frame.header.id) {
478 (pending.callback)(Ok(frame));
479 }
480 }
481 MessageType::StreamData => {
482 if let Some(state) = self.streams.lock().get(&frame.header.id) {
483 (state.on_data)(frame.payload);
484 }
485 }
486 MessageType::StreamEnd => {
487 if let Some(state) = self.streams.lock().remove(&frame.header.id) {
488 (state.on_end)();
489 }
490 }
491 MessageType::Pong => {
492 }
494 _ => {
495 }
497 }
498 }
499
500 pub fn cancel(&self, id: RequestId) {
502 if let Some(pending) = self.pending.lock().remove(&id) {
503 (pending.callback)(Err(IpcError::RequestCancelled(id)));
504 }
505 if let Some(state) = self.streams.lock().remove(&id) {
506 (state.on_end)();
507 }
508 }
509
510 pub fn pending_count(&self) -> usize {
512 self.pending.lock().len()
513 }
514}
515
516pub struct BatchRequest {
518 requests: Vec<(RequestId, Vec<u8>)>,
519}
520
521impl Default for BatchRequest {
522 fn default() -> Self {
523 Self::new()
524 }
525}
526
527impl BatchRequest {
528 pub fn new() -> Self {
529 Self {
530 requests: Vec::new(),
531 }
532 }
533
534 pub fn add(&mut self, id: RequestId, payload: Vec<u8>) -> &mut Self {
536 self.requests.push((id, payload));
537 self
538 }
539
540 pub fn build(self) -> Vec<Frame> {
542 self.requests
543 .into_iter()
544 .map(|(id, payload)| Frame::request(id, payload))
545 .collect()
546 }
547
548 pub fn len(&self) -> usize {
550 self.requests.len()
551 }
552
553 pub fn is_empty(&self) -> bool {
555 self.requests.is_empty()
556 }
557}
558
559#[derive(Debug, Clone)]
561pub struct FlowControl {
562 pub window_size: usize,
564 pub outstanding: usize,
566 pub paused: bool,
568}
569
570impl Default for FlowControl {
571 fn default() -> Self {
572 Self {
573 window_size: 64 * 1024, outstanding: 0,
575 paused: false,
576 }
577 }
578}
579
580impl FlowControl {
581 pub fn new(window_size: usize) -> Self {
582 Self {
583 window_size,
584 outstanding: 0,
585 paused: false,
586 }
587 }
588
589 pub fn can_send(&self) -> bool {
591 !self.paused && self.outstanding < self.window_size
592 }
593
594 pub fn record_sent(&mut self, bytes: usize) {
596 self.outstanding += bytes;
597 if self.outstanding >= self.window_size {
598 self.paused = true;
599 }
600 }
601
602 pub fn record_acked(&mut self, bytes: usize) {
604 self.outstanding = self.outstanding.saturating_sub(bytes);
605 if self.outstanding < self.window_size / 2 {
606 self.paused = false;
607 }
608 }
609
610 pub fn pause(&mut self) {
612 self.paused = true;
613 }
614
615 pub fn resume(&mut self) {
617 self.paused = false;
618 }
619}
620
621pub struct StreamWriter<W: Write> {
623 writer: Arc<Mutex<FrameWriter<W>>>,
624 stream_id: StreamId,
625 flow_control: FlowControl,
626}
627
628impl<W: Write> StreamWriter<W> {
629 pub fn new(writer: Arc<Mutex<FrameWriter<W>>>, stream_id: StreamId) -> Self {
630 Self {
631 writer,
632 stream_id,
633 flow_control: FlowControl::default(),
634 }
635 }
636
637 pub fn write_chunk(&mut self, data: Vec<u8>) -> Result<(), IpcError> {
639 while !self.flow_control.can_send() {
641 std::thread::yield_now();
642 }
643
644 let frame = Frame::stream_data(self.stream_id, data);
645 let size = frame.payload.len();
646
647 self.writer.lock().write_frame(&frame)?;
648 self.flow_control.record_sent(size);
649
650 Ok(())
651 }
652
653 pub fn finish(self) -> Result<(), IpcError> {
655 let frame = Frame::stream_end(self.stream_id);
656 let mut writer = self.writer.lock();
657 writer.write_frame(&frame)?;
658 writer.flush()
659 }
660}
661
662pub trait RequestHandler: Send + Sync {
664 fn handle_request(&self, request_id: RequestId, payload: &[u8]) -> Result<Vec<u8>, IpcError>;
666
667 fn handle_stream<W: Write>(
669 &self,
670 stream_id: StreamId,
671 payload: &[u8],
672 writer: StreamWriter<W>,
673 ) -> Result<(), IpcError>;
674}
675
676pub struct IpcServer<H: RequestHandler> {
678 handler: Arc<H>,
679}
680
681impl<H: RequestHandler> IpcServer<H> {
682 pub fn new(handler: H) -> Self {
683 Self {
684 handler: Arc::new(handler),
685 }
686 }
687
688 pub fn process<R: Read, W: Write>(
690 &self,
691 reader: &mut FrameReader<R>,
692 writer: Arc<Mutex<FrameWriter<W>>>,
693 ) -> Result<(), IpcError> {
694 loop {
695 let frame = match reader.read_frame() {
696 Ok(f) => f,
697 Err(IpcError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
698 return Ok(()); }
700 Err(e) => return Err(e),
701 };
702
703 match frame.header.msg_type {
704 MessageType::Request => {
705 let response =
706 match self.handler.handle_request(frame.header.id, &frame.payload) {
707 Ok(data) => Frame::response(frame.header.id, data),
708 Err(e) => Frame::error(frame.header.id, 1, &e.to_string()),
709 };
710 writer.lock().write_frame(&response)?;
711 }
712 MessageType::StreamStart => {
713 let stream_writer = StreamWriter::new(Arc::clone(&writer), frame.header.id);
714 if let Err(e) =
715 self.handler
716 .handle_stream(frame.header.id, &frame.payload, stream_writer)
717 {
718 let err = Frame::error(frame.header.id, 2, &e.to_string());
719 writer.lock().write_frame(&err)?;
720 }
721 }
722 MessageType::Ping => {
723 let pong = Frame::pong(frame.header.id);
724 writer.lock().write_frame(&pong)?;
725 }
726 MessageType::Cancel => {
727 }
729 _ => {
730 }
732 }
733
734 writer.lock().flush()?;
736 }
737 }
738}
739
740#[cfg(test)]
741mod tests {
742 use super::*;
743 use std::io::Cursor;
744
745 #[test]
746 fn test_frame_header_roundtrip() {
747 let header = FrameHeader::new(12345, MessageType::Request, 100);
748 let bytes = header.to_bytes();
749 let parsed = FrameHeader::from_bytes(&bytes).unwrap();
750
751 assert_eq!(parsed.id, 12345);
752 assert_eq!(parsed.msg_type, MessageType::Request);
753 assert_eq!(parsed.length, 100);
754 }
755
756 #[test]
757 fn test_frame_roundtrip() {
758 let original = Frame::request(1, b"hello world".to_vec());
759 let bytes = original.to_bytes();
760
761 let mut reader = FrameReader::new(Cursor::new(bytes));
762 let parsed = reader.read_frame().unwrap();
763
764 assert_eq!(parsed.header.id, 1);
765 assert_eq!(parsed.header.msg_type, MessageType::Request);
766 assert_eq!(parsed.payload, b"hello world");
767 }
768
769 #[test]
770 fn test_batch_request() {
771 let mut batch = BatchRequest::new();
772 batch.add(1, b"request1".to_vec());
773 batch.add(2, b"request2".to_vec());
774 batch.add(3, b"request3".to_vec());
775
776 let frames = batch.build();
777 assert_eq!(frames.len(), 3);
778 assert_eq!(frames[0].header.id, 1);
779 assert_eq!(frames[1].header.id, 2);
780 assert_eq!(frames[2].header.id, 3);
781 }
782
783 #[test]
784 fn test_multiplexer() {
785 let mux = RequestMultiplexer::new();
786
787 let id1 = mux.next_id();
788 let id2 = mux.next_id();
789
790 assert_ne!(id1, id2);
791
792 use std::sync::atomic::AtomicBool;
793
794 let received1 = Arc::new(AtomicBool::new(false));
795 let received2 = Arc::new(AtomicBool::new(false));
796
797 {
798 let r1 = Arc::clone(&received1);
799 mux.register_request(id1, move |_| {
800 r1.store(true, Ordering::SeqCst);
801 });
802 }
803
804 {
805 let r2 = Arc::clone(&received2);
806 mux.register_request(id2, move |_| {
807 r2.store(true, Ordering::SeqCst);
808 });
809 }
810
811 mux.handle_frame(Frame::response(id2, b"resp2".to_vec()));
813 assert!(!received1.load(Ordering::SeqCst));
814 assert!(received2.load(Ordering::SeqCst));
815
816 mux.handle_frame(Frame::response(id1, b"resp1".to_vec()));
818 assert!(received1.load(Ordering::SeqCst));
819 }
820
821 #[test]
822 fn test_flow_control() {
823 let mut fc = FlowControl::new(100);
824
825 assert!(fc.can_send());
826
827 fc.record_sent(50);
828 assert!(fc.can_send());
829 assert_eq!(fc.outstanding, 50);
830
831 fc.record_sent(60);
832 assert!(!fc.can_send()); assert!(fc.paused);
834
835 fc.record_acked(80);
836 assert!(fc.can_send()); assert!(!fc.paused);
838 }
839
840 #[test]
841 fn test_error_frame() {
842 let frame = Frame::error(42, 500, "Internal error");
843
844 assert_eq!(frame.header.id, 42);
845 assert_eq!(frame.header.msg_type, MessageType::Error);
846
847 let error_code = u32::from_le_bytes([
849 frame.payload[0],
850 frame.payload[1],
851 frame.payload[2],
852 frame.payload[3],
853 ]);
854 let message = std::str::from_utf8(&frame.payload[4..]).unwrap();
855
856 assert_eq!(error_code, 500);
857 assert_eq!(message, "Internal error");
858 }
859}