1use std::collections::HashMap;
56use std::io::{self, Read, Write};
57use std::sync::atomic::{AtomicU64, Ordering};
58use std::sync::Arc;
59
60use parking_lot::Mutex;
61
62pub type RequestId = u64;
64
65pub type StreamId = u64;
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70#[repr(u8)]
71pub enum MessageType {
72 Request = 0,
74 Response = 1,
76 StreamStart = 2,
78 StreamData = 3,
80 StreamEnd = 4,
82 Error = 5,
84 FlowPause = 6,
86 FlowResume = 7,
88 Ping = 8,
90 Pong = 9,
92 Cancel = 10,
94}
95
96impl TryFrom<u8> for MessageType {
97 type Error = IpcError;
98
99 fn try_from(value: u8) -> Result<Self, <Self as TryFrom<u8>>::Error> {
100 match value {
101 0 => Ok(MessageType::Request),
102 1 => Ok(MessageType::Response),
103 2 => Ok(MessageType::StreamStart),
104 3 => Ok(MessageType::StreamData),
105 4 => Ok(MessageType::StreamEnd),
106 5 => Ok(MessageType::Error),
107 6 => Ok(MessageType::FlowPause),
108 7 => Ok(MessageType::FlowResume),
109 8 => Ok(MessageType::Ping),
110 9 => Ok(MessageType::Pong),
111 10 => Ok(MessageType::Cancel),
112 _ => Err(IpcError::InvalidMessageType(value)),
113 }
114 }
115}
116
117#[derive(Debug, Clone, Copy)]
119pub struct FrameHeader {
120 pub length: u32,
122 pub id: u64,
124 pub msg_type: MessageType,
126 pub flags: u8,
128}
129
130impl FrameHeader {
131 pub const SIZE: usize = 14; pub const MAX_PAYLOAD: u32 = 16 * 1024 * 1024;
136
137 pub fn new(id: u64, msg_type: MessageType, payload_len: usize) -> Self {
139 Self {
140 length: payload_len as u32,
141 id,
142 msg_type,
143 flags: 0,
144 }
145 }
146
147 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
149 let mut buf = [0u8; Self::SIZE];
150 buf[0..4].copy_from_slice(&self.length.to_le_bytes());
151 buf[4..12].copy_from_slice(&self.id.to_le_bytes());
152 buf[12] = self.msg_type as u8;
153 buf[13] = self.flags;
154 buf
155 }
156
157 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self, IpcError> {
159 let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
160 let id = u64::from_le_bytes([buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11]]);
161 let msg_type = MessageType::try_from(buf[12])?;
162 let flags = buf[13];
163
164 if length > Self::MAX_PAYLOAD {
165 return Err(IpcError::PayloadTooLarge(length as usize));
166 }
167
168 Ok(Self {
169 length,
170 id,
171 msg_type,
172 flags,
173 })
174 }
175}
176
177#[derive(Debug, Clone)]
179pub struct Frame {
180 pub header: FrameHeader,
181 pub payload: Vec<u8>,
182}
183
184impl Frame {
185 pub fn request(id: RequestId, payload: Vec<u8>) -> Self {
187 Self {
188 header: FrameHeader::new(id, MessageType::Request, payload.len()),
189 payload,
190 }
191 }
192
193 pub fn response(id: RequestId, payload: Vec<u8>) -> Self {
195 Self {
196 header: FrameHeader::new(id, MessageType::Response, payload.len()),
197 payload,
198 }
199 }
200
201 pub fn stream_start(id: StreamId, payload: Vec<u8>) -> Self {
203 Self {
204 header: FrameHeader::new(id, MessageType::StreamStart, payload.len()),
205 payload,
206 }
207 }
208
209 pub fn stream_data(id: StreamId, payload: Vec<u8>) -> Self {
211 Self {
212 header: FrameHeader::new(id, MessageType::StreamData, payload.len()),
213 payload,
214 }
215 }
216
217 pub fn stream_end(id: StreamId) -> Self {
219 Self {
220 header: FrameHeader::new(id, MessageType::StreamEnd, 0),
221 payload: Vec::new(),
222 }
223 }
224
225 pub fn error(id: RequestId, error_code: u32, message: &str) -> Self {
227 let mut payload = Vec::with_capacity(4 + message.len());
228 payload.extend_from_slice(&error_code.to_le_bytes());
229 payload.extend_from_slice(message.as_bytes());
230 Self {
231 header: FrameHeader::new(id, MessageType::Error, payload.len()),
232 payload,
233 }
234 }
235
236 pub fn ping(id: RequestId) -> Self {
238 Self {
239 header: FrameHeader::new(id, MessageType::Ping, 0),
240 payload: Vec::new(),
241 }
242 }
243
244 pub fn pong(id: RequestId) -> Self {
246 Self {
247 header: FrameHeader::new(id, MessageType::Pong, 0),
248 payload: Vec::new(),
249 }
250 }
251
252 pub fn cancel(id: RequestId) -> Self {
254 Self {
255 header: FrameHeader::new(id, MessageType::Cancel, 0),
256 payload: Vec::new(),
257 }
258 }
259
260 pub fn to_bytes(&self) -> Vec<u8> {
262 let mut buf = Vec::with_capacity(FrameHeader::SIZE + self.payload.len());
263 buf.extend_from_slice(&self.header.to_bytes());
264 buf.extend_from_slice(&self.payload);
265 buf
266 }
267}
268
269#[derive(Debug)]
271pub enum IpcError {
272 Io(io::Error),
273 InvalidMessageType(u8),
274 PayloadTooLarge(usize),
275 UnexpectedEof,
276 RequestCancelled(RequestId),
277 StreamClosed(StreamId),
278 Timeout,
279}
280
281impl From<io::Error> for IpcError {
282 fn from(e: io::Error) -> Self {
283 IpcError::Io(e)
284 }
285}
286
287impl std::fmt::Display for IpcError {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 match self {
290 IpcError::Io(e) => write!(f, "IO error: {}", e),
291 IpcError::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
292 IpcError::PayloadTooLarge(size) => write!(f, "Payload too large: {} bytes", size),
293 IpcError::UnexpectedEof => write!(f, "Unexpected end of stream"),
294 IpcError::RequestCancelled(id) => write!(f, "Request {} cancelled", id),
295 IpcError::StreamClosed(id) => write!(f, "Stream {} closed", id),
296 IpcError::Timeout => write!(f, "Operation timed out"),
297 }
298 }
299}
300
301impl std::error::Error for IpcError {}
302
303pub struct FrameReader<R: Read> {
305 reader: R,
306 header_buf: [u8; FrameHeader::SIZE],
307}
308
309impl<R: Read> FrameReader<R> {
310 pub fn new(reader: R) -> Self {
311 Self {
312 reader,
313 header_buf: [0u8; FrameHeader::SIZE],
314 }
315 }
316
317 pub fn read_frame(&mut self) -> Result<Frame, IpcError> {
319 self.reader.read_exact(&mut self.header_buf)?;
321 let header = FrameHeader::from_bytes(&self.header_buf)?;
322
323 let mut payload = vec![0u8; header.length as usize];
325 self.reader.read_exact(&mut payload)?;
326
327 Ok(Frame { header, payload })
328 }
329
330 pub fn into_inner(self) -> R {
332 self.reader
333 }
334}
335
336pub struct FrameWriter<W: Write> {
338 writer: W,
339 buffer: Vec<u8>,
341 max_buffer: usize,
343}
344
345impl<W: Write> FrameWriter<W> {
346 const DEFAULT_BUFFER: usize = 64 * 1024;
348
349 pub fn new(writer: W) -> Self {
350 Self {
351 writer,
352 buffer: Vec::with_capacity(Self::DEFAULT_BUFFER),
353 max_buffer: Self::DEFAULT_BUFFER,
354 }
355 }
356
357 pub fn write_frame(&mut self, frame: &Frame) -> Result<(), IpcError> {
359 let bytes = frame.to_bytes();
360
361 if bytes.len() > self.max_buffer {
363 self.flush()?;
364 self.writer.write_all(&bytes)?;
365 return Ok(());
366 }
367
368 if self.buffer.len() + bytes.len() > self.max_buffer {
370 self.flush()?;
371 }
372
373 self.buffer.extend_from_slice(&bytes);
374 Ok(())
375 }
376
377 pub fn flush(&mut self) -> Result<(), IpcError> {
379 if !self.buffer.is_empty() {
380 self.writer.write_all(&self.buffer)?;
381 self.buffer.clear();
382 }
383 self.writer.flush()?;
384 Ok(())
385 }
386
387 pub fn into_inner(self) -> W {
389 self.writer
390 }
391}
392
393struct PendingRequest {
395 callback: Box<dyn FnOnce(Result<Frame, IpcError>) + Send>,
396}
397
398pub struct RequestMultiplexer {
400 next_id: AtomicU64,
402 pending: Mutex<HashMap<RequestId, PendingRequest>>,
404 streams: Mutex<HashMap<StreamId, StreamState>>,
406}
407
408struct StreamState {
410 on_data: Box<dyn Fn(Vec<u8>) + Send>,
412 on_end: Box<dyn FnOnce() + Send>,
414 #[allow(dead_code)]
416 paused: bool,
417}
418
419impl Default for RequestMultiplexer {
420 fn default() -> Self {
421 Self::new()
422 }
423}
424
425impl RequestMultiplexer {
426 pub fn new() -> Self {
427 Self {
428 next_id: AtomicU64::new(1),
429 pending: Mutex::new(HashMap::new()),
430 streams: Mutex::new(HashMap::new()),
431 }
432 }
433
434 pub fn next_id(&self) -> RequestId {
436 self.next_id.fetch_add(1, Ordering::SeqCst)
437 }
438
439 pub fn register_request<F>(&self, id: RequestId, callback: F)
441 where
442 F: FnOnce(Result<Frame, IpcError>) + Send + 'static,
443 {
444 self.pending.lock().insert(
445 id,
446 PendingRequest {
447 callback: Box::new(callback),
448 },
449 );
450 }
451
452 pub fn register_stream<D, E>(&self, id: StreamId, on_data: D, on_end: E)
454 where
455 D: Fn(Vec<u8>) + Send + 'static,
456 E: FnOnce() + Send + 'static,
457 {
458 self.streams.lock().insert(
459 id,
460 StreamState {
461 on_data: Box::new(on_data),
462 on_end: Box::new(on_end),
463 paused: false,
464 },
465 );
466 }
467
468 pub fn handle_frame(&self, frame: Frame) {
470 match frame.header.msg_type {
471 MessageType::Response | MessageType::Error => {
472 if let Some(pending) = self.pending.lock().remove(&frame.header.id) {
473 (pending.callback)(Ok(frame));
474 }
475 }
476 MessageType::StreamData => {
477 if let Some(state) = self.streams.lock().get(&frame.header.id) {
478 (state.on_data)(frame.payload);
479 }
480 }
481 MessageType::StreamEnd => {
482 if let Some(state) = self.streams.lock().remove(&frame.header.id) {
483 (state.on_end)();
484 }
485 }
486 MessageType::Pong => {
487 }
489 _ => {
490 }
492 }
493 }
494
495 pub fn cancel(&self, id: RequestId) {
497 if let Some(pending) = self.pending.lock().remove(&id) {
498 (pending.callback)(Err(IpcError::RequestCancelled(id)));
499 }
500 if let Some(state) = self.streams.lock().remove(&id) {
501 (state.on_end)();
502 }
503 }
504
505 pub fn pending_count(&self) -> usize {
507 self.pending.lock().len()
508 }
509}
510
511pub struct BatchRequest {
513 requests: Vec<(RequestId, Vec<u8>)>,
514}
515
516impl Default for BatchRequest {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522impl BatchRequest {
523 pub fn new() -> Self {
524 Self {
525 requests: Vec::new(),
526 }
527 }
528
529 pub fn add(&mut self, id: RequestId, payload: Vec<u8>) -> &mut Self {
531 self.requests.push((id, payload));
532 self
533 }
534
535 pub fn build(self) -> Vec<Frame> {
537 self.requests
538 .into_iter()
539 .map(|(id, payload)| Frame::request(id, payload))
540 .collect()
541 }
542
543 pub fn len(&self) -> usize {
545 self.requests.len()
546 }
547
548 pub fn is_empty(&self) -> bool {
550 self.requests.is_empty()
551 }
552}
553
554#[derive(Debug, Clone)]
556pub struct FlowControl {
557 pub window_size: usize,
559 pub outstanding: usize,
561 pub paused: bool,
563}
564
565impl Default for FlowControl {
566 fn default() -> Self {
567 Self {
568 window_size: 64 * 1024, outstanding: 0,
570 paused: false,
571 }
572 }
573}
574
575impl FlowControl {
576 pub fn new(window_size: usize) -> Self {
577 Self {
578 window_size,
579 outstanding: 0,
580 paused: false,
581 }
582 }
583
584 pub fn can_send(&self) -> bool {
586 !self.paused && self.outstanding < self.window_size
587 }
588
589 pub fn record_sent(&mut self, bytes: usize) {
591 self.outstanding += bytes;
592 if self.outstanding >= self.window_size {
593 self.paused = true;
594 }
595 }
596
597 pub fn record_acked(&mut self, bytes: usize) {
599 self.outstanding = self.outstanding.saturating_sub(bytes);
600 if self.outstanding < self.window_size / 2 {
601 self.paused = false;
602 }
603 }
604
605 pub fn pause(&mut self) {
607 self.paused = true;
608 }
609
610 pub fn resume(&mut self) {
612 self.paused = false;
613 }
614}
615
616pub struct StreamWriter<W: Write> {
618 writer: Arc<Mutex<FrameWriter<W>>>,
619 stream_id: StreamId,
620 flow_control: FlowControl,
621}
622
623impl<W: Write> StreamWriter<W> {
624 pub fn new(writer: Arc<Mutex<FrameWriter<W>>>, stream_id: StreamId) -> Self {
625 Self {
626 writer,
627 stream_id,
628 flow_control: FlowControl::default(),
629 }
630 }
631
632 pub fn write_chunk(&mut self, data: Vec<u8>) -> Result<(), IpcError> {
634 while !self.flow_control.can_send() {
636 std::thread::yield_now();
637 }
638
639 let frame = Frame::stream_data(self.stream_id, data);
640 let size = frame.payload.len();
641
642 self.writer.lock().write_frame(&frame)?;
643 self.flow_control.record_sent(size);
644
645 Ok(())
646 }
647
648 pub fn finish(self) -> Result<(), IpcError> {
650 let frame = Frame::stream_end(self.stream_id);
651 let mut writer = self.writer.lock();
652 writer.write_frame(&frame)?;
653 writer.flush()
654 }
655}
656
657pub trait RequestHandler: Send + Sync {
659 fn handle_request(&self, request_id: RequestId, payload: &[u8]) -> Result<Vec<u8>, IpcError>;
661
662 fn handle_stream<W: Write>(
664 &self,
665 stream_id: StreamId,
666 payload: &[u8],
667 writer: StreamWriter<W>,
668 ) -> Result<(), IpcError>;
669}
670
671pub struct IpcServer<H: RequestHandler> {
673 handler: Arc<H>,
674}
675
676impl<H: RequestHandler> IpcServer<H> {
677 pub fn new(handler: H) -> Self {
678 Self {
679 handler: Arc::new(handler),
680 }
681 }
682
683 pub fn process<R: Read, W: Write>(
685 &self,
686 reader: &mut FrameReader<R>,
687 writer: Arc<Mutex<FrameWriter<W>>>,
688 ) -> Result<(), IpcError> {
689 loop {
690 let frame = match reader.read_frame() {
691 Ok(f) => f,
692 Err(IpcError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
693 return Ok(()); }
695 Err(e) => return Err(e),
696 };
697
698 match frame.header.msg_type {
699 MessageType::Request => {
700 let response = match self.handler.handle_request(frame.header.id, &frame.payload)
701 {
702 Ok(data) => Frame::response(frame.header.id, data),
703 Err(e) => Frame::error(frame.header.id, 1, &e.to_string()),
704 };
705 writer.lock().write_frame(&response)?;
706 }
707 MessageType::StreamStart => {
708 let stream_writer = StreamWriter::new(Arc::clone(&writer), frame.header.id);
709 if let Err(e) =
710 self.handler
711 .handle_stream(frame.header.id, &frame.payload, stream_writer)
712 {
713 let err = Frame::error(frame.header.id, 2, &e.to_string());
714 writer.lock().write_frame(&err)?;
715 }
716 }
717 MessageType::Ping => {
718 let pong = Frame::pong(frame.header.id);
719 writer.lock().write_frame(&pong)?;
720 }
721 MessageType::Cancel => {
722 }
724 _ => {
725 }
727 }
728
729 writer.lock().flush()?;
731 }
732 }
733}
734
735#[cfg(test)]
736mod tests {
737 use super::*;
738 use std::io::Cursor;
739
740 #[test]
741 fn test_frame_header_roundtrip() {
742 let header = FrameHeader::new(12345, MessageType::Request, 100);
743 let bytes = header.to_bytes();
744 let parsed = FrameHeader::from_bytes(&bytes).unwrap();
745
746 assert_eq!(parsed.id, 12345);
747 assert_eq!(parsed.msg_type, MessageType::Request);
748 assert_eq!(parsed.length, 100);
749 }
750
751 #[test]
752 fn test_frame_roundtrip() {
753 let original = Frame::request(1, b"hello world".to_vec());
754 let bytes = original.to_bytes();
755
756 let mut reader = FrameReader::new(Cursor::new(bytes));
757 let parsed = reader.read_frame().unwrap();
758
759 assert_eq!(parsed.header.id, 1);
760 assert_eq!(parsed.header.msg_type, MessageType::Request);
761 assert_eq!(parsed.payload, b"hello world");
762 }
763
764 #[test]
765 fn test_batch_request() {
766 let mut batch = BatchRequest::new();
767 batch.add(1, b"request1".to_vec());
768 batch.add(2, b"request2".to_vec());
769 batch.add(3, b"request3".to_vec());
770
771 let frames = batch.build();
772 assert_eq!(frames.len(), 3);
773 assert_eq!(frames[0].header.id, 1);
774 assert_eq!(frames[1].header.id, 2);
775 assert_eq!(frames[2].header.id, 3);
776 }
777
778 #[test]
779 fn test_multiplexer() {
780 let mux = RequestMultiplexer::new();
781
782 let id1 = mux.next_id();
783 let id2 = mux.next_id();
784
785 assert_ne!(id1, id2);
786
787 use std::sync::atomic::AtomicBool;
788
789 let received1 = Arc::new(AtomicBool::new(false));
790 let received2 = Arc::new(AtomicBool::new(false));
791
792 {
793 let r1 = Arc::clone(&received1);
794 mux.register_request(id1, move |_| {
795 r1.store(true, Ordering::SeqCst);
796 });
797 }
798
799 {
800 let r2 = Arc::clone(&received2);
801 mux.register_request(id2, move |_| {
802 r2.store(true, Ordering::SeqCst);
803 });
804 }
805
806 mux.handle_frame(Frame::response(id2, b"resp2".to_vec()));
808 assert!(!received1.load(Ordering::SeqCst));
809 assert!(received2.load(Ordering::SeqCst));
810
811 mux.handle_frame(Frame::response(id1, b"resp1".to_vec()));
813 assert!(received1.load(Ordering::SeqCst));
814 }
815
816 #[test]
817 fn test_flow_control() {
818 let mut fc = FlowControl::new(100);
819
820 assert!(fc.can_send());
821
822 fc.record_sent(50);
823 assert!(fc.can_send());
824 assert_eq!(fc.outstanding, 50);
825
826 fc.record_sent(60);
827 assert!(!fc.can_send()); assert!(fc.paused);
829
830 fc.record_acked(80);
831 assert!(fc.can_send()); assert!(!fc.paused);
833 }
834
835 #[test]
836 fn test_error_frame() {
837 let frame = Frame::error(42, 500, "Internal error");
838
839 assert_eq!(frame.header.id, 42);
840 assert_eq!(frame.header.msg_type, MessageType::Error);
841
842 let error_code = u32::from_le_bytes([
844 frame.payload[0],
845 frame.payload[1],
846 frame.payload[2],
847 frame.payload[3],
848 ]);
849 let message = std::str::from_utf8(&frame.payload[4..]).unwrap();
850
851 assert_eq!(error_code, 500);
852 assert_eq!(message, "Internal error");
853 }
854}