1#[derive(Debug, thiserror::Error)]
6pub enum ChannelError {
7 #[error("I/O error: {0}")]
9 Io(#[from] std::io::Error),
10
11 #[error("channel closed")]
13 ChannelClosed,
14
15 #[error("confirmation cancelled")]
17 ConfirmCancelled,
18
19 #[error("{0}")]
21 Other(String),
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum AttachmentKind {
27 Audio,
28 Image,
29 Video,
30 File,
31}
32
33#[derive(Debug, Clone)]
35pub struct Attachment {
36 pub kind: AttachmentKind,
37 pub data: Vec<u8>,
38 pub filename: Option<String>,
39}
40
41#[derive(Debug, Clone)]
43pub struct ChannelMessage {
44 pub text: String,
45 pub attachments: Vec<Attachment>,
46}
47
48pub trait Channel: Send {
50 fn recv(&mut self)
56 -> impl Future<Output = Result<Option<ChannelMessage>, ChannelError>> + Send;
57
58 fn try_recv(&mut self) -> Option<ChannelMessage> {
60 None
61 }
62
63 fn supports_exit(&self) -> bool {
68 true
69 }
70
71 fn send(&mut self, text: &str) -> impl Future<Output = Result<(), ChannelError>> + Send;
77
78 fn send_chunk(&mut self, chunk: &str) -> impl Future<Output = Result<(), ChannelError>> + Send;
84
85 fn flush_chunks(&mut self) -> impl Future<Output = Result<(), ChannelError>> + Send;
91
92 fn send_typing(&mut self) -> impl Future<Output = Result<(), ChannelError>> + Send {
98 async { Ok(()) }
99 }
100
101 fn send_status(
107 &mut self,
108 _text: &str,
109 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
110 async { Ok(()) }
111 }
112
113 fn send_thinking_chunk(
119 &mut self,
120 _chunk: &str,
121 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
122 async { Ok(()) }
123 }
124
125 fn send_queue_count(
131 &mut self,
132 _count: usize,
133 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
134 async { Ok(()) }
135 }
136
137 fn send_usage(
143 &mut self,
144 _input_tokens: u64,
145 _output_tokens: u64,
146 _context_window: u64,
147 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
148 async { Ok(()) }
149 }
150
151 fn send_diff(
157 &mut self,
158 _diff: crate::DiffData,
159 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
160 async { Ok(()) }
161 }
162
163 fn send_tool_start(
173 &mut self,
174 _tool_name: &str,
175 _tool_call_id: &str,
176 _params: Option<serde_json::Value>,
177 _parent_tool_use_id: Option<String>,
178 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
179 async { Ok(()) }
180 }
181
182 #[allow(clippy::too_many_arguments)]
193 fn send_tool_output(
194 &mut self,
195 tool_name: &str,
196 body: &str,
197 _diff: Option<crate::DiffData>,
198 _filter_stats: Option<String>,
199 _kept_lines: Option<Vec<usize>>,
200 _locations: Option<Vec<String>>,
201 _tool_call_id: &str,
202 _is_error: bool,
203 _parent_tool_use_id: Option<String>,
204 _raw_response: Option<serde_json::Value>,
205 _started_at: Option<std::time::Instant>,
206 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
207 let formatted = crate::agent::format_tool_output(tool_name, body);
208 async move { self.send(&formatted).await }
209 }
210
211 fn confirm(
218 &mut self,
219 _prompt: &str,
220 ) -> impl Future<Output = Result<bool, ChannelError>> + Send {
221 async { Ok(true) }
222 }
223
224 fn send_stop_hint(
233 &mut self,
234 _hint: StopHint,
235 ) -> impl Future<Output = Result<(), ChannelError>> + Send {
236 async { Ok(()) }
237 }
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq)]
246pub enum StopHint {
247 MaxTokens,
249 MaxTurnRequests,
251}
252
253#[derive(Debug, Clone)]
255pub enum LoopbackEvent {
256 Chunk(String),
257 Flush,
258 FullMessage(String),
259 Status(String),
260 ToolStart {
262 tool_name: String,
263 tool_call_id: String,
264 params: Option<serde_json::Value>,
266 parent_tool_use_id: Option<String>,
268 started_at: std::time::Instant,
270 },
271 ToolOutput {
272 tool_name: String,
273 display: String,
274 diff: Option<crate::DiffData>,
275 filter_stats: Option<String>,
276 kept_lines: Option<Vec<usize>>,
277 locations: Option<Vec<String>>,
278 tool_call_id: String,
279 is_error: bool,
280 terminal_id: Option<String>,
282 parent_tool_use_id: Option<String>,
284 raw_response: Option<serde_json::Value>,
286 started_at: Option<std::time::Instant>,
288 },
289 Usage {
291 input_tokens: u64,
292 output_tokens: u64,
293 context_window: u64,
294 },
295 SessionTitle(String),
297 Plan(Vec<(String, PlanItemStatus)>),
299 ThinkingChunk(String),
301 Stop(StopHint),
305}
306
307#[derive(Debug, Clone)]
309pub enum PlanItemStatus {
310 Pending,
311 InProgress,
312 Completed,
313}
314
315pub struct LoopbackHandle {
317 pub input_tx: tokio::sync::mpsc::Sender<ChannelMessage>,
318 pub output_rx: tokio::sync::mpsc::Receiver<LoopbackEvent>,
319 pub cancel_signal: std::sync::Arc<tokio::sync::Notify>,
321}
322
323pub struct LoopbackChannel {
325 input_rx: tokio::sync::mpsc::Receiver<ChannelMessage>,
326 output_tx: tokio::sync::mpsc::Sender<LoopbackEvent>,
327}
328
329impl LoopbackChannel {
330 #[must_use]
332 pub fn pair(buffer: usize) -> (Self, LoopbackHandle) {
333 let (input_tx, input_rx) = tokio::sync::mpsc::channel(buffer);
334 let (output_tx, output_rx) = tokio::sync::mpsc::channel(buffer);
335 let cancel_signal = std::sync::Arc::new(tokio::sync::Notify::new());
336 (
337 Self {
338 input_rx,
339 output_tx,
340 },
341 LoopbackHandle {
342 input_tx,
343 output_rx,
344 cancel_signal,
345 },
346 )
347 }
348}
349
350impl Channel for LoopbackChannel {
351 fn supports_exit(&self) -> bool {
352 false
353 }
354
355 async fn recv(&mut self) -> Result<Option<ChannelMessage>, ChannelError> {
356 Ok(self.input_rx.recv().await)
357 }
358
359 async fn send(&mut self, text: &str) -> Result<(), ChannelError> {
360 self.output_tx
361 .send(LoopbackEvent::FullMessage(text.to_owned()))
362 .await
363 .map_err(|_| ChannelError::ChannelClosed)
364 }
365
366 async fn send_chunk(&mut self, chunk: &str) -> Result<(), ChannelError> {
367 self.output_tx
368 .send(LoopbackEvent::Chunk(chunk.to_owned()))
369 .await
370 .map_err(|_| ChannelError::ChannelClosed)
371 }
372
373 async fn flush_chunks(&mut self) -> Result<(), ChannelError> {
374 self.output_tx
375 .send(LoopbackEvent::Flush)
376 .await
377 .map_err(|_| ChannelError::ChannelClosed)
378 }
379
380 async fn send_status(&mut self, text: &str) -> Result<(), ChannelError> {
381 self.output_tx
382 .send(LoopbackEvent::Status(text.to_owned()))
383 .await
384 .map_err(|_| ChannelError::ChannelClosed)
385 }
386
387 async fn send_thinking_chunk(&mut self, chunk: &str) -> Result<(), ChannelError> {
388 self.output_tx
389 .send(LoopbackEvent::ThinkingChunk(chunk.to_owned()))
390 .await
391 .map_err(|_| ChannelError::ChannelClosed)
392 }
393
394 async fn send_tool_start(
395 &mut self,
396 tool_name: &str,
397 tool_call_id: &str,
398 params: Option<serde_json::Value>,
399 parent_tool_use_id: Option<String>,
400 ) -> Result<(), ChannelError> {
401 self.output_tx
402 .send(LoopbackEvent::ToolStart {
403 tool_name: tool_name.to_owned(),
404 tool_call_id: tool_call_id.to_owned(),
405 params,
406 parent_tool_use_id,
407 started_at: std::time::Instant::now(),
408 })
409 .await
410 .map_err(|_| ChannelError::ChannelClosed)
411 }
412
413 #[allow(clippy::too_many_arguments)]
414 async fn send_tool_output(
415 &mut self,
416 tool_name: &str,
417 body: &str,
418 diff: Option<crate::DiffData>,
419 filter_stats: Option<String>,
420 kept_lines: Option<Vec<usize>>,
421 locations: Option<Vec<String>>,
422 tool_call_id: &str,
423 is_error: bool,
424 parent_tool_use_id: Option<String>,
425 raw_response: Option<serde_json::Value>,
426 started_at: Option<std::time::Instant>,
427 ) -> Result<(), ChannelError> {
428 self.output_tx
429 .send(LoopbackEvent::ToolOutput {
430 tool_name: tool_name.to_owned(),
431 display: body.to_owned(),
432 diff,
433 filter_stats,
434 kept_lines,
435 locations,
436 tool_call_id: tool_call_id.to_owned(),
437 is_error,
438 terminal_id: None,
439 parent_tool_use_id,
440 raw_response,
441 started_at,
442 })
443 .await
444 .map_err(|_| ChannelError::ChannelClosed)
445 }
446
447 async fn confirm(&mut self, _prompt: &str) -> Result<bool, ChannelError> {
448 Ok(true)
449 }
450
451 async fn send_stop_hint(&mut self, hint: StopHint) -> Result<(), ChannelError> {
452 self.output_tx
453 .send(LoopbackEvent::Stop(hint))
454 .await
455 .map_err(|_| ChannelError::ChannelClosed)
456 }
457
458 async fn send_usage(
459 &mut self,
460 input_tokens: u64,
461 output_tokens: u64,
462 context_window: u64,
463 ) -> Result<(), ChannelError> {
464 self.output_tx
465 .send(LoopbackEvent::Usage {
466 input_tokens,
467 output_tokens,
468 context_window,
469 })
470 .await
471 .map_err(|_| ChannelError::ChannelClosed)
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn channel_message_creation() {
481 let msg = ChannelMessage {
482 text: "hello".to_string(),
483 attachments: vec![],
484 };
485 assert_eq!(msg.text, "hello");
486 assert!(msg.attachments.is_empty());
487 }
488
489 struct StubChannel;
490
491 impl Channel for StubChannel {
492 async fn recv(&mut self) -> Result<Option<ChannelMessage>, ChannelError> {
493 Ok(None)
494 }
495
496 async fn send(&mut self, _text: &str) -> Result<(), ChannelError> {
497 Ok(())
498 }
499
500 async fn send_chunk(&mut self, _chunk: &str) -> Result<(), ChannelError> {
501 Ok(())
502 }
503
504 async fn flush_chunks(&mut self) -> Result<(), ChannelError> {
505 Ok(())
506 }
507 }
508
509 #[tokio::test]
510 async fn send_chunk_default_is_noop() {
511 let mut ch = StubChannel;
512 ch.send_chunk("partial").await.unwrap();
513 }
514
515 #[tokio::test]
516 async fn flush_chunks_default_is_noop() {
517 let mut ch = StubChannel;
518 ch.flush_chunks().await.unwrap();
519 }
520
521 #[tokio::test]
522 async fn stub_channel_confirm_auto_approves() {
523 let mut ch = StubChannel;
524 let result = ch.confirm("Delete everything?").await.unwrap();
525 assert!(result);
526 }
527
528 #[tokio::test]
529 async fn stub_channel_send_typing_default() {
530 let mut ch = StubChannel;
531 ch.send_typing().await.unwrap();
532 }
533
534 #[tokio::test]
535 async fn stub_channel_recv_returns_none() {
536 let mut ch = StubChannel;
537 let msg = ch.recv().await.unwrap();
538 assert!(msg.is_none());
539 }
540
541 #[tokio::test]
542 async fn stub_channel_send_ok() {
543 let mut ch = StubChannel;
544 ch.send("hello").await.unwrap();
545 }
546
547 #[test]
548 fn channel_message_clone() {
549 let msg = ChannelMessage {
550 text: "test".to_string(),
551 attachments: vec![],
552 };
553 let cloned = msg.clone();
554 assert_eq!(cloned.text, "test");
555 }
556
557 #[test]
558 fn channel_message_debug() {
559 let msg = ChannelMessage {
560 text: "debug".to_string(),
561 attachments: vec![],
562 };
563 let debug = format!("{msg:?}");
564 assert!(debug.contains("debug"));
565 }
566
567 #[test]
568 fn attachment_kind_equality() {
569 assert_eq!(AttachmentKind::Audio, AttachmentKind::Audio);
570 assert_ne!(AttachmentKind::Audio, AttachmentKind::Image);
571 }
572
573 #[test]
574 fn attachment_construction() {
575 let a = Attachment {
576 kind: AttachmentKind::Audio,
577 data: vec![0, 1, 2],
578 filename: Some("test.wav".into()),
579 };
580 assert_eq!(a.kind, AttachmentKind::Audio);
581 assert_eq!(a.data.len(), 3);
582 assert_eq!(a.filename.as_deref(), Some("test.wav"));
583 }
584
585 #[test]
586 fn channel_message_with_attachments() {
587 let msg = ChannelMessage {
588 text: String::new(),
589 attachments: vec![Attachment {
590 kind: AttachmentKind::Audio,
591 data: vec![42],
592 filename: None,
593 }],
594 };
595 assert_eq!(msg.attachments.len(), 1);
596 assert_eq!(msg.attachments[0].kind, AttachmentKind::Audio);
597 }
598
599 #[test]
600 fn stub_channel_try_recv_returns_none() {
601 let mut ch = StubChannel;
602 assert!(ch.try_recv().is_none());
603 }
604
605 #[tokio::test]
606 async fn stub_channel_send_queue_count_noop() {
607 let mut ch = StubChannel;
608 ch.send_queue_count(5).await.unwrap();
609 }
610
611 #[test]
614 fn loopback_pair_returns_linked_handles() {
615 let (channel, handle) = LoopbackChannel::pair(8);
616 drop(channel);
618 drop(handle);
619 }
620
621 #[tokio::test]
622 async fn loopback_cancel_signal_can_be_notified_and_awaited() {
623 let (_channel, handle) = LoopbackChannel::pair(8);
624 let signal = std::sync::Arc::clone(&handle.cancel_signal);
625 let notified = signal.notified();
627 handle.cancel_signal.notify_one();
628 notified.await; }
630
631 #[tokio::test]
632 async fn loopback_cancel_signal_shared_across_clones() {
633 let (_channel, handle) = LoopbackChannel::pair(8);
634 let signal_a = std::sync::Arc::clone(&handle.cancel_signal);
635 let signal_b = std::sync::Arc::clone(&handle.cancel_signal);
636 let notified = signal_b.notified();
637 signal_a.notify_one();
638 notified.await;
639 }
640
641 #[tokio::test]
642 async fn loopback_send_recv_round_trip() {
643 let (mut channel, handle) = LoopbackChannel::pair(8);
644 handle
645 .input_tx
646 .send(ChannelMessage {
647 text: "hello".to_owned(),
648 attachments: vec![],
649 })
650 .await
651 .unwrap();
652 let msg = channel.recv().await.unwrap().unwrap();
653 assert_eq!(msg.text, "hello");
654 }
655
656 #[tokio::test]
657 async fn loopback_recv_returns_none_when_handle_dropped() {
658 let (mut channel, handle) = LoopbackChannel::pair(8);
659 drop(handle);
660 let result = channel.recv().await.unwrap();
661 assert!(result.is_none());
662 }
663
664 #[tokio::test]
665 async fn loopback_send_produces_full_message_event() {
666 let (mut channel, mut handle) = LoopbackChannel::pair(8);
667 channel.send("world").await.unwrap();
668 let event = handle.output_rx.recv().await.unwrap();
669 assert!(matches!(event, LoopbackEvent::FullMessage(t) if t == "world"));
670 }
671
672 #[tokio::test]
673 async fn loopback_send_chunk_then_flush() {
674 let (mut channel, mut handle) = LoopbackChannel::pair(8);
675 channel.send_chunk("part1").await.unwrap();
676 channel.flush_chunks().await.unwrap();
677 let ev1 = handle.output_rx.recv().await.unwrap();
678 let ev2 = handle.output_rx.recv().await.unwrap();
679 assert!(matches!(ev1, LoopbackEvent::Chunk(t) if t == "part1"));
680 assert!(matches!(ev2, LoopbackEvent::Flush));
681 }
682
683 #[tokio::test]
684 async fn loopback_send_tool_output() {
685 let (mut channel, mut handle) = LoopbackChannel::pair(8);
686 channel
687 .send_tool_output(
688 "bash", "exit 0", None, None, None, None, "", false, None, None, None,
689 )
690 .await
691 .unwrap();
692 let event = handle.output_rx.recv().await.unwrap();
693 match event {
694 LoopbackEvent::ToolOutput {
695 tool_name,
696 display,
697 diff,
698 filter_stats,
699 kept_lines,
700 locations,
701 tool_call_id,
702 is_error,
703 terminal_id,
704 parent_tool_use_id,
705 raw_response,
706 ..
707 } => {
708 assert_eq!(tool_name, "bash");
709 assert_eq!(display, "exit 0");
710 assert!(diff.is_none());
711 assert!(filter_stats.is_none());
712 assert!(kept_lines.is_none());
713 assert!(locations.is_none());
714 assert_eq!(tool_call_id, "");
715 assert!(!is_error);
716 assert!(terminal_id.is_none());
717 assert!(parent_tool_use_id.is_none());
718 assert!(raw_response.is_none());
719 }
720 _ => panic!("expected ToolOutput event"),
721 }
722 }
723
724 #[tokio::test]
725 async fn loopback_confirm_auto_approves() {
726 let (mut channel, _handle) = LoopbackChannel::pair(8);
727 let result = channel.confirm("are you sure?").await.unwrap();
728 assert!(result);
729 }
730
731 #[tokio::test]
732 async fn loopback_send_error_when_output_closed() {
733 let (mut channel, handle) = LoopbackChannel::pair(8);
734 drop(handle);
736 let result = channel.send("too late").await;
737 assert!(matches!(result, Err(ChannelError::ChannelClosed)));
738 }
739
740 #[tokio::test]
741 async fn loopback_send_chunk_error_when_output_closed() {
742 let (mut channel, handle) = LoopbackChannel::pair(8);
743 drop(handle);
744 let result = channel.send_chunk("chunk").await;
745 assert!(matches!(result, Err(ChannelError::ChannelClosed)));
746 }
747
748 #[tokio::test]
749 async fn loopback_flush_error_when_output_closed() {
750 let (mut channel, handle) = LoopbackChannel::pair(8);
751 drop(handle);
752 let result = channel.flush_chunks().await;
753 assert!(matches!(result, Err(ChannelError::ChannelClosed)));
754 }
755
756 #[tokio::test]
757 async fn loopback_send_status_event() {
758 let (mut channel, mut handle) = LoopbackChannel::pair(8);
759 channel.send_status("working...").await.unwrap();
760 let event = handle.output_rx.recv().await.unwrap();
761 assert!(matches!(event, LoopbackEvent::Status(s) if s == "working..."));
762 }
763
764 #[tokio::test]
765 async fn loopback_send_usage_produces_usage_event() {
766 let (mut channel, mut handle) = LoopbackChannel::pair(8);
767 channel.send_usage(100, 50, 200_000).await.unwrap();
768 let event = handle.output_rx.recv().await.unwrap();
769 match event {
770 LoopbackEvent::Usage {
771 input_tokens,
772 output_tokens,
773 context_window,
774 } => {
775 assert_eq!(input_tokens, 100);
776 assert_eq!(output_tokens, 50);
777 assert_eq!(context_window, 200_000);
778 }
779 _ => panic!("expected Usage event"),
780 }
781 }
782
783 #[tokio::test]
784 async fn loopback_send_usage_error_when_closed() {
785 let (mut channel, handle) = LoopbackChannel::pair(8);
786 drop(handle);
787 let result = channel.send_usage(1, 2, 3).await;
788 assert!(matches!(result, Err(ChannelError::ChannelClosed)));
789 }
790
791 #[test]
792 fn plan_item_status_variants_are_distinct() {
793 assert!(!matches!(
794 PlanItemStatus::Pending,
795 PlanItemStatus::InProgress
796 ));
797 assert!(!matches!(
798 PlanItemStatus::InProgress,
799 PlanItemStatus::Completed
800 ));
801 assert!(!matches!(
802 PlanItemStatus::Completed,
803 PlanItemStatus::Pending
804 ));
805 }
806
807 #[test]
808 fn loopback_event_session_title_carries_string() {
809 let event = LoopbackEvent::SessionTitle("hello".to_owned());
810 assert!(matches!(event, LoopbackEvent::SessionTitle(s) if s == "hello"));
811 }
812
813 #[test]
814 fn loopback_event_plan_carries_entries() {
815 let entries = vec![
816 ("step 1".to_owned(), PlanItemStatus::Pending),
817 ("step 2".to_owned(), PlanItemStatus::InProgress),
818 ];
819 let event = LoopbackEvent::Plan(entries);
820 match event {
821 LoopbackEvent::Plan(e) => {
822 assert_eq!(e.len(), 2);
823 assert!(matches!(e[0].1, PlanItemStatus::Pending));
824 assert!(matches!(e[1].1, PlanItemStatus::InProgress));
825 }
826 _ => panic!("expected Plan event"),
827 }
828 }
829}