1use std::collections::HashMap;
59use std::io::{self, Read, Write};
60use std::sync::atomic::{AtomicU64, Ordering};
61use std::sync::Arc;
62
63use parking_lot::Mutex;
64
65pub type RequestId = u64;
67
68pub type StreamId = u64;
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73#[repr(u8)]
74pub enum MessageType {
75 Request = 0,
77 Response = 1,
79 StreamStart = 2,
81 StreamData = 3,
83 StreamEnd = 4,
85 Error = 5,
87 FlowPause = 6,
89 FlowResume = 7,
91 Ping = 8,
93 Pong = 9,
95 Cancel = 10,
97}
98
99impl TryFrom<u8> for MessageType {
100 type Error = IpcError;
101
102 fn try_from(value: u8) -> Result<Self, <Self as TryFrom<u8>>::Error> {
103 match value {
104 0 => Ok(MessageType::Request),
105 1 => Ok(MessageType::Response),
106 2 => Ok(MessageType::StreamStart),
107 3 => Ok(MessageType::StreamData),
108 4 => Ok(MessageType::StreamEnd),
109 5 => Ok(MessageType::Error),
110 6 => Ok(MessageType::FlowPause),
111 7 => Ok(MessageType::FlowResume),
112 8 => Ok(MessageType::Ping),
113 9 => Ok(MessageType::Pong),
114 10 => Ok(MessageType::Cancel),
115 _ => Err(IpcError::InvalidMessageType(value)),
116 }
117 }
118}
119
120#[derive(Debug, Clone, Copy)]
122pub struct FrameHeader {
123 pub length: u32,
125 pub id: u64,
127 pub msg_type: MessageType,
129 pub flags: u8,
131}
132
133impl FrameHeader {
134 pub const SIZE: usize = 14; pub const MAX_PAYLOAD: u32 = 16 * 1024 * 1024;
139
140 pub fn new(id: u64, msg_type: MessageType, payload_len: usize) -> Self {
142 Self {
143 length: payload_len as u32,
144 id,
145 msg_type,
146 flags: 0,
147 }
148 }
149
150 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
152 let mut buf = [0u8; Self::SIZE];
153 buf[0..4].copy_from_slice(&self.length.to_le_bytes());
154 buf[4..12].copy_from_slice(&self.id.to_le_bytes());
155 buf[12] = self.msg_type as u8;
156 buf[13] = self.flags;
157 buf
158 }
159
160 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Result<Self, IpcError> {
162 let length = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]);
163 let id = u64::from_le_bytes([buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11]]);
164 let msg_type = MessageType::try_from(buf[12])?;
165 let flags = buf[13];
166
167 if length > Self::MAX_PAYLOAD {
168 return Err(IpcError::PayloadTooLarge(length as usize));
169 }
170
171 Ok(Self {
172 length,
173 id,
174 msg_type,
175 flags,
176 })
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct Frame {
183 pub header: FrameHeader,
184 pub payload: Vec<u8>,
185}
186
187impl Frame {
188 pub fn request(id: RequestId, payload: Vec<u8>) -> Self {
190 Self {
191 header: FrameHeader::new(id, MessageType::Request, payload.len()),
192 payload,
193 }
194 }
195
196 pub fn response(id: RequestId, payload: Vec<u8>) -> Self {
198 Self {
199 header: FrameHeader::new(id, MessageType::Response, payload.len()),
200 payload,
201 }
202 }
203
204 pub fn stream_start(id: StreamId, payload: Vec<u8>) -> Self {
206 Self {
207 header: FrameHeader::new(id, MessageType::StreamStart, payload.len()),
208 payload,
209 }
210 }
211
212 pub fn stream_data(id: StreamId, payload: Vec<u8>) -> Self {
214 Self {
215 header: FrameHeader::new(id, MessageType::StreamData, payload.len()),
216 payload,
217 }
218 }
219
220 pub fn stream_end(id: StreamId) -> Self {
222 Self {
223 header: FrameHeader::new(id, MessageType::StreamEnd, 0),
224 payload: Vec::new(),
225 }
226 }
227
228 pub fn error(id: RequestId, error_code: u32, message: &str) -> Self {
230 let mut payload = Vec::with_capacity(4 + message.len());
231 payload.extend_from_slice(&error_code.to_le_bytes());
232 payload.extend_from_slice(message.as_bytes());
233 Self {
234 header: FrameHeader::new(id, MessageType::Error, payload.len()),
235 payload,
236 }
237 }
238
239 pub fn ping(id: RequestId) -> Self {
241 Self {
242 header: FrameHeader::new(id, MessageType::Ping, 0),
243 payload: Vec::new(),
244 }
245 }
246
247 pub fn pong(id: RequestId) -> Self {
249 Self {
250 header: FrameHeader::new(id, MessageType::Pong, 0),
251 payload: Vec::new(),
252 }
253 }
254
255 pub fn cancel(id: RequestId) -> Self {
257 Self {
258 header: FrameHeader::new(id, MessageType::Cancel, 0),
259 payload: Vec::new(),
260 }
261 }
262
263 pub fn to_bytes(&self) -> Vec<u8> {
265 let mut buf = Vec::with_capacity(FrameHeader::SIZE + self.payload.len());
266 buf.extend_from_slice(&self.header.to_bytes());
267 buf.extend_from_slice(&self.payload);
268 buf
269 }
270}
271
272#[derive(Debug)]
274pub enum IpcError {
275 Io(io::Error),
276 InvalidMessageType(u8),
277 PayloadTooLarge(usize),
278 UnexpectedEof,
279 RequestCancelled(RequestId),
280 StreamClosed(StreamId),
281 Timeout,
282}
283
284impl From<io::Error> for IpcError {
285 fn from(e: io::Error) -> Self {
286 IpcError::Io(e)
287 }
288}
289
290impl std::fmt::Display for IpcError {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match self {
293 IpcError::Io(e) => write!(f, "IO error: {}", e),
294 IpcError::InvalidMessageType(t) => write!(f, "Invalid message type: {}", t),
295 IpcError::PayloadTooLarge(size) => write!(f, "Payload too large: {} bytes", size),
296 IpcError::UnexpectedEof => write!(f, "Unexpected end of stream"),
297 IpcError::RequestCancelled(id) => write!(f, "Request {} cancelled", id),
298 IpcError::StreamClosed(id) => write!(f, "Stream {} closed", id),
299 IpcError::Timeout => write!(f, "Operation timed out"),
300 }
301 }
302}
303
304impl std::error::Error for IpcError {}
305
306pub struct FrameReader<R: Read> {
308 reader: R,
309 header_buf: [u8; FrameHeader::SIZE],
310}
311
312impl<R: Read> FrameReader<R> {
313 pub fn new(reader: R) -> Self {
314 Self {
315 reader,
316 header_buf: [0u8; FrameHeader::SIZE],
317 }
318 }
319
320 pub fn read_frame(&mut self) -> Result<Frame, IpcError> {
322 self.reader.read_exact(&mut self.header_buf)?;
324 let header = FrameHeader::from_bytes(&self.header_buf)?;
325
326 let mut payload = vec![0u8; header.length as usize];
328 self.reader.read_exact(&mut payload)?;
329
330 Ok(Frame { header, payload })
331 }
332
333 pub fn into_inner(self) -> R {
335 self.reader
336 }
337}
338
339pub struct FrameWriter<W: Write> {
341 writer: W,
342 buffer: Vec<u8>,
344 max_buffer: usize,
346}
347
348impl<W: Write> FrameWriter<W> {
349 const DEFAULT_BUFFER: usize = 64 * 1024;
351
352 pub fn new(writer: W) -> Self {
353 Self {
354 writer,
355 buffer: Vec::with_capacity(Self::DEFAULT_BUFFER),
356 max_buffer: Self::DEFAULT_BUFFER,
357 }
358 }
359
360 pub fn write_frame(&mut self, frame: &Frame) -> Result<(), IpcError> {
362 let bytes = frame.to_bytes();
363
364 if bytes.len() > self.max_buffer {
366 self.flush()?;
367 self.writer.write_all(&bytes)?;
368 return Ok(());
369 }
370
371 if self.buffer.len() + bytes.len() > self.max_buffer {
373 self.flush()?;
374 }
375
376 self.buffer.extend_from_slice(&bytes);
377 Ok(())
378 }
379
380 pub fn flush(&mut self) -> Result<(), IpcError> {
382 if !self.buffer.is_empty() {
383 self.writer.write_all(&self.buffer)?;
384 self.buffer.clear();
385 }
386 self.writer.flush()?;
387 Ok(())
388 }
389
390 pub fn into_inner(self) -> W {
392 self.writer
393 }
394}
395
396struct PendingRequest {
398 callback: Box<dyn FnOnce(Result<Frame, IpcError>) + Send>,
399}
400
401pub struct RequestMultiplexer {
403 next_id: AtomicU64,
405 pending: Mutex<HashMap<RequestId, PendingRequest>>,
407 streams: Mutex<HashMap<StreamId, StreamState>>,
409}
410
411struct StreamState {
413 on_data: Box<dyn Fn(Vec<u8>) + Send>,
415 on_end: Box<dyn FnOnce() + Send>,
417 #[allow(dead_code)]
419 paused: bool,
420}
421
422impl Default for RequestMultiplexer {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428impl RequestMultiplexer {
429 pub fn new() -> Self {
430 Self {
431 next_id: AtomicU64::new(1),
432 pending: Mutex::new(HashMap::new()),
433 streams: Mutex::new(HashMap::new()),
434 }
435 }
436
437 pub fn next_id(&self) -> RequestId {
439 self.next_id.fetch_add(1, Ordering::SeqCst)
440 }
441
442 pub fn register_request<F>(&self, id: RequestId, callback: F)
444 where
445 F: FnOnce(Result<Frame, IpcError>) + Send + 'static,
446 {
447 self.pending.lock().insert(
448 id,
449 PendingRequest {
450 callback: Box::new(callback),
451 },
452 );
453 }
454
455 pub fn register_stream<D, E>(&self, id: StreamId, on_data: D, on_end: E)
457 where
458 D: Fn(Vec<u8>) + Send + 'static,
459 E: FnOnce() + Send + 'static,
460 {
461 self.streams.lock().insert(
462 id,
463 StreamState {
464 on_data: Box::new(on_data),
465 on_end: Box::new(on_end),
466 paused: false,
467 },
468 );
469 }
470
471 pub fn handle_frame(&self, frame: Frame) {
473 match frame.header.msg_type {
474 MessageType::Response | MessageType::Error => {
475 if let Some(pending) = self.pending.lock().remove(&frame.header.id) {
476 (pending.callback)(Ok(frame));
477 }
478 }
479 MessageType::StreamData => {
480 if let Some(state) = self.streams.lock().get(&frame.header.id) {
481 (state.on_data)(frame.payload);
482 }
483 }
484 MessageType::StreamEnd => {
485 if let Some(state) = self.streams.lock().remove(&frame.header.id) {
486 (state.on_end)();
487 }
488 }
489 MessageType::Pong => {
490 }
492 _ => {
493 }
495 }
496 }
497
498 pub fn cancel(&self, id: RequestId) {
500 if let Some(pending) = self.pending.lock().remove(&id) {
501 (pending.callback)(Err(IpcError::RequestCancelled(id)));
502 }
503 if let Some(state) = self.streams.lock().remove(&id) {
504 (state.on_end)();
505 }
506 }
507
508 pub fn pending_count(&self) -> usize {
510 self.pending.lock().len()
511 }
512}
513
514pub struct BatchRequest {
516 requests: Vec<(RequestId, Vec<u8>)>,
517}
518
519impl Default for BatchRequest {
520 fn default() -> Self {
521 Self::new()
522 }
523}
524
525impl BatchRequest {
526 pub fn new() -> Self {
527 Self {
528 requests: Vec::new(),
529 }
530 }
531
532 pub fn add(&mut self, id: RequestId, payload: Vec<u8>) -> &mut Self {
534 self.requests.push((id, payload));
535 self
536 }
537
538 pub fn build(self) -> Vec<Frame> {
540 self.requests
541 .into_iter()
542 .map(|(id, payload)| Frame::request(id, payload))
543 .collect()
544 }
545
546 pub fn len(&self) -> usize {
548 self.requests.len()
549 }
550
551 pub fn is_empty(&self) -> bool {
553 self.requests.is_empty()
554 }
555}
556
557#[derive(Debug, Clone)]
559pub struct FlowControl {
560 pub window_size: usize,
562 pub outstanding: usize,
564 pub paused: bool,
566}
567
568impl Default for FlowControl {
569 fn default() -> Self {
570 Self {
571 window_size: 64 * 1024, outstanding: 0,
573 paused: false,
574 }
575 }
576}
577
578impl FlowControl {
579 pub fn new(window_size: usize) -> Self {
580 Self {
581 window_size,
582 outstanding: 0,
583 paused: false,
584 }
585 }
586
587 pub fn can_send(&self) -> bool {
589 !self.paused && self.outstanding < self.window_size
590 }
591
592 pub fn record_sent(&mut self, bytes: usize) {
594 self.outstanding += bytes;
595 if self.outstanding >= self.window_size {
596 self.paused = true;
597 }
598 }
599
600 pub fn record_acked(&mut self, bytes: usize) {
602 self.outstanding = self.outstanding.saturating_sub(bytes);
603 if self.outstanding < self.window_size / 2 {
604 self.paused = false;
605 }
606 }
607
608 pub fn pause(&mut self) {
610 self.paused = true;
611 }
612
613 pub fn resume(&mut self) {
615 self.paused = false;
616 }
617}
618
619pub struct StreamWriter<W: Write> {
621 writer: Arc<Mutex<FrameWriter<W>>>,
622 stream_id: StreamId,
623 flow_control: FlowControl,
624}
625
626impl<W: Write> StreamWriter<W> {
627 pub fn new(writer: Arc<Mutex<FrameWriter<W>>>, stream_id: StreamId) -> Self {
628 Self {
629 writer,
630 stream_id,
631 flow_control: FlowControl::default(),
632 }
633 }
634
635 pub fn write_chunk(&mut self, data: Vec<u8>) -> Result<(), IpcError> {
637 while !self.flow_control.can_send() {
639 std::thread::yield_now();
640 }
641
642 let frame = Frame::stream_data(self.stream_id, data);
643 let size = frame.payload.len();
644
645 self.writer.lock().write_frame(&frame)?;
646 self.flow_control.record_sent(size);
647
648 Ok(())
649 }
650
651 pub fn finish(self) -> Result<(), IpcError> {
653 let frame = Frame::stream_end(self.stream_id);
654 let mut writer = self.writer.lock();
655 writer.write_frame(&frame)?;
656 writer.flush()
657 }
658}
659
660pub trait RequestHandler: Send + Sync {
662 fn handle_request(&self, request_id: RequestId, payload: &[u8]) -> Result<Vec<u8>, IpcError>;
664
665 fn handle_stream<W: Write>(
667 &self,
668 stream_id: StreamId,
669 payload: &[u8],
670 writer: StreamWriter<W>,
671 ) -> Result<(), IpcError>;
672}
673
674pub struct IpcServer<H: RequestHandler> {
676 handler: Arc<H>,
677}
678
679impl<H: RequestHandler> IpcServer<H> {
680 pub fn new(handler: H) -> Self {
681 Self {
682 handler: Arc::new(handler),
683 }
684 }
685
686 pub fn process<R: Read, W: Write>(
688 &self,
689 reader: &mut FrameReader<R>,
690 writer: Arc<Mutex<FrameWriter<W>>>,
691 ) -> Result<(), IpcError> {
692 loop {
693 let frame = match reader.read_frame() {
694 Ok(f) => f,
695 Err(IpcError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => {
696 return Ok(()); }
698 Err(e) => return Err(e),
699 };
700
701 match frame.header.msg_type {
702 MessageType::Request => {
703 let response = match self.handler.handle_request(frame.header.id, &frame.payload)
704 {
705 Ok(data) => Frame::response(frame.header.id, data),
706 Err(e) => Frame::error(frame.header.id, 1, &e.to_string()),
707 };
708 writer.lock().write_frame(&response)?;
709 }
710 MessageType::StreamStart => {
711 let stream_writer = StreamWriter::new(Arc::clone(&writer), frame.header.id);
712 if let Err(e) =
713 self.handler
714 .handle_stream(frame.header.id, &frame.payload, stream_writer)
715 {
716 let err = Frame::error(frame.header.id, 2, &e.to_string());
717 writer.lock().write_frame(&err)?;
718 }
719 }
720 MessageType::Ping => {
721 let pong = Frame::pong(frame.header.id);
722 writer.lock().write_frame(&pong)?;
723 }
724 MessageType::Cancel => {
725 }
727 _ => {
728 }
730 }
731
732 writer.lock().flush()?;
734 }
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use std::io::Cursor;
742
743 #[test]
744 fn test_frame_header_roundtrip() {
745 let header = FrameHeader::new(12345, MessageType::Request, 100);
746 let bytes = header.to_bytes();
747 let parsed = FrameHeader::from_bytes(&bytes).unwrap();
748
749 assert_eq!(parsed.id, 12345);
750 assert_eq!(parsed.msg_type, MessageType::Request);
751 assert_eq!(parsed.length, 100);
752 }
753
754 #[test]
755 fn test_frame_roundtrip() {
756 let original = Frame::request(1, b"hello world".to_vec());
757 let bytes = original.to_bytes();
758
759 let mut reader = FrameReader::new(Cursor::new(bytes));
760 let parsed = reader.read_frame().unwrap();
761
762 assert_eq!(parsed.header.id, 1);
763 assert_eq!(parsed.header.msg_type, MessageType::Request);
764 assert_eq!(parsed.payload, b"hello world");
765 }
766
767 #[test]
768 fn test_batch_request() {
769 let mut batch = BatchRequest::new();
770 batch.add(1, b"request1".to_vec());
771 batch.add(2, b"request2".to_vec());
772 batch.add(3, b"request3".to_vec());
773
774 let frames = batch.build();
775 assert_eq!(frames.len(), 3);
776 assert_eq!(frames[0].header.id, 1);
777 assert_eq!(frames[1].header.id, 2);
778 assert_eq!(frames[2].header.id, 3);
779 }
780
781 #[test]
782 fn test_multiplexer() {
783 let mux = RequestMultiplexer::new();
784
785 let id1 = mux.next_id();
786 let id2 = mux.next_id();
787
788 assert_ne!(id1, id2);
789
790 use std::sync::atomic::AtomicBool;
791
792 let received1 = Arc::new(AtomicBool::new(false));
793 let received2 = Arc::new(AtomicBool::new(false));
794
795 {
796 let r1 = Arc::clone(&received1);
797 mux.register_request(id1, move |_| {
798 r1.store(true, Ordering::SeqCst);
799 });
800 }
801
802 {
803 let r2 = Arc::clone(&received2);
804 mux.register_request(id2, move |_| {
805 r2.store(true, Ordering::SeqCst);
806 });
807 }
808
809 mux.handle_frame(Frame::response(id2, b"resp2".to_vec()));
811 assert!(!received1.load(Ordering::SeqCst));
812 assert!(received2.load(Ordering::SeqCst));
813
814 mux.handle_frame(Frame::response(id1, b"resp1".to_vec()));
816 assert!(received1.load(Ordering::SeqCst));
817 }
818
819 #[test]
820 fn test_flow_control() {
821 let mut fc = FlowControl::new(100);
822
823 assert!(fc.can_send());
824
825 fc.record_sent(50);
826 assert!(fc.can_send());
827 assert_eq!(fc.outstanding, 50);
828
829 fc.record_sent(60);
830 assert!(!fc.can_send()); assert!(fc.paused);
832
833 fc.record_acked(80);
834 assert!(fc.can_send()); assert!(!fc.paused);
836 }
837
838 #[test]
839 fn test_error_frame() {
840 let frame = Frame::error(42, 500, "Internal error");
841
842 assert_eq!(frame.header.id, 42);
843 assert_eq!(frame.header.msg_type, MessageType::Error);
844
845 let error_code = u32::from_le_bytes([
847 frame.payload[0],
848 frame.payload[1],
849 frame.payload[2],
850 frame.payload[3],
851 ]);
852 let message = std::str::from_utf8(&frame.payload[4..]).unwrap();
853
854 assert_eq!(error_code, 500);
855 assert_eq!(message, "Internal error");
856 }
857}