1use std::marker::PhantomData;
39
40use syncable_ag_ui_core::{
41 AgentState, Event, InterruptInfo, JsonValue, MessageId, RunErrorEvent, RunFinishedEvent,
42 RunId, RunStartedEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent,
43 ThinkingEndEvent, ThinkingStartEvent, ThinkingTextMessageContentEvent,
44 ThinkingTextMessageEndEvent, ThinkingTextMessageStartEvent, ThreadId, ToolCallArgsEvent,
45 ToolCallEndEvent, ToolCallId, ToolCallStartEvent,
46};
47use async_trait::async_trait;
48
49use crate::error::ServerError;
50use crate::transport::SseSender;
51
52#[async_trait]
68pub trait EventProducer<StateT: AgentState = JsonValue>: Send + Sync {
69 async fn emit(&self, event: Event<StateT>) -> Result<(), ServerError>;
73
74 async fn emit_many(&self, events: Vec<Event<StateT>>) -> Result<(), ServerError> {
78 for event in events {
79 self.emit(event).await?;
80 }
81 Ok(())
82 }
83
84 fn is_connected(&self) -> bool;
88}
89
90#[async_trait]
92impl<StateT: AgentState> EventProducer<StateT> for SseSender<StateT> {
93 async fn emit(&self, event: Event<StateT>) -> Result<(), ServerError> {
94 self.send(event)
95 .await
96 .map_err(|_| ServerError::Channel("SSE channel closed".into()))
97 }
98
99 fn is_connected(&self) -> bool {
100 !self.is_closed()
101 }
102}
103
104pub struct MessageStream<'a, P: EventProducer<StateT>, StateT: AgentState = JsonValue> {
118 producer: &'a P,
119 message_id: MessageId,
120 _state: PhantomData<StateT>,
121}
122
123impl<'a, P: EventProducer<StateT>, StateT: AgentState> MessageStream<'a, P, StateT> {
124 pub async fn start(producer: &'a P) -> Result<Self, ServerError> {
128 let message_id = MessageId::random();
129 producer
130 .emit(Event::TextMessageStart(TextMessageStartEvent::new(
131 message_id.clone(),
132 )))
133 .await?;
134 Ok(Self {
135 producer,
136 message_id,
137 _state: PhantomData,
138 })
139 }
140
141 pub async fn start_with_id(
143 producer: &'a P,
144 message_id: MessageId,
145 ) -> Result<Self, ServerError> {
146 producer
147 .emit(Event::TextMessageStart(TextMessageStartEvent::new(
148 message_id.clone(),
149 )))
150 .await?;
151 Ok(Self {
152 producer,
153 message_id,
154 _state: PhantomData,
155 })
156 }
157
158 pub async fn content(&self, delta: impl Into<String>) -> Result<(), ServerError> {
163 let delta = delta.into();
164 if delta.is_empty() {
165 return Ok(());
166 }
167 self.producer
168 .emit(Event::TextMessageContent(
169 TextMessageContentEvent::new_unchecked(self.message_id.clone(), delta),
170 ))
171 .await
172 }
173
174 pub async fn end(self) -> Result<MessageId, ServerError> {
179 self.producer
180 .emit(Event::TextMessageEnd(TextMessageEndEvent::new(
181 self.message_id.clone(),
182 )))
183 .await?;
184 Ok(self.message_id)
185 }
186
187 pub fn message_id(&self) -> &MessageId {
189 &self.message_id
190 }
191}
192
193pub struct ToolCallStream<'a, P: EventProducer<StateT>, StateT: AgentState = JsonValue> {
207 producer: &'a P,
208 tool_call_id: ToolCallId,
209 _state: PhantomData<StateT>,
210}
211
212impl<'a, P: EventProducer<StateT>, StateT: AgentState> ToolCallStream<'a, P, StateT> {
213 pub async fn start(producer: &'a P, name: impl Into<String>) -> Result<Self, ServerError> {
218 let tool_call_id = ToolCallId::random();
219 producer
220 .emit(Event::ToolCallStart(ToolCallStartEvent::new(
221 tool_call_id.clone(),
222 name,
223 )))
224 .await?;
225 Ok(Self {
226 producer,
227 tool_call_id,
228 _state: PhantomData,
229 })
230 }
231
232 pub async fn start_with_id(
234 producer: &'a P,
235 tool_call_id: ToolCallId,
236 name: impl Into<String>,
237 ) -> Result<Self, ServerError> {
238 producer
239 .emit(Event::ToolCallStart(ToolCallStartEvent::new(
240 tool_call_id.clone(),
241 name,
242 )))
243 .await?;
244 Ok(Self {
245 producer,
246 tool_call_id,
247 _state: PhantomData,
248 })
249 }
250
251 pub async fn args(&self, delta: impl Into<String>) -> Result<(), ServerError> {
255 self.producer
256 .emit(Event::ToolCallArgs(ToolCallArgsEvent::new(
257 self.tool_call_id.clone(),
258 delta,
259 )))
260 .await
261 }
262
263 pub async fn end(self) -> Result<ToolCallId, ServerError> {
268 self.producer
269 .emit(Event::ToolCallEnd(ToolCallEndEvent::new(
270 self.tool_call_id.clone(),
271 )))
272 .await?;
273 Ok(self.tool_call_id)
274 }
275
276 pub fn tool_call_id(&self) -> &ToolCallId {
278 &self.tool_call_id
279 }
280}
281
282pub struct ThinkingMessageStream<'a, P: EventProducer<StateT>, StateT: AgentState = JsonValue> {
296 producer: &'a P,
297 _state: PhantomData<StateT>,
298}
299
300impl<'a, P: EventProducer<StateT>, StateT: AgentState> ThinkingMessageStream<'a, P, StateT> {
301 pub async fn start(producer: &'a P) -> Result<Self, ServerError> {
305 producer
306 .emit(Event::ThinkingTextMessageStart(
307 ThinkingTextMessageStartEvent::new(),
308 ))
309 .await?;
310 Ok(Self {
311 producer,
312 _state: PhantomData,
313 })
314 }
315
316 pub async fn content(&self, delta: impl Into<String>) -> Result<(), ServerError> {
321 self.producer
322 .emit(Event::ThinkingTextMessageContent(
323 ThinkingTextMessageContentEvent::new(delta),
324 ))
325 .await
326 }
327
328 pub async fn end(self) -> Result<(), ServerError> {
333 self.producer
334 .emit(Event::ThinkingTextMessageEnd(
335 ThinkingTextMessageEndEvent::new(),
336 ))
337 .await
338 }
339}
340
341pub struct ThinkingStep<'a, P: EventProducer<StateT>, StateT: AgentState = JsonValue> {
361 producer: &'a P,
362 _state: PhantomData<StateT>,
363}
364
365impl<'a, P: EventProducer<StateT>, StateT: AgentState> ThinkingStep<'a, P, StateT> {
366 pub async fn start(
370 producer: &'a P,
371 title: Option<impl Into<String>>,
372 ) -> Result<Self, ServerError> {
373 let event = if let Some(t) = title {
374 ThinkingStartEvent::new().with_title(t)
375 } else {
376 ThinkingStartEvent::new()
377 };
378 producer.emit(Event::ThinkingStart(event)).await?;
379 Ok(Self {
380 producer,
381 _state: PhantomData,
382 })
383 }
384
385 pub async fn end(self) -> Result<(), ServerError> {
390 self.producer
391 .emit(Event::ThinkingEnd(ThinkingEndEvent::new()))
392 .await
393 }
394
395 pub fn producer(&self) -> &'a P {
399 self.producer
400 }
401}
402
403pub struct AgentSession<P: EventProducer<StateT>, StateT: AgentState = JsonValue> {
422 producer: P,
423 thread_id: ThreadId,
424 current_run: Option<RunId>,
425 _state: PhantomData<StateT>,
426}
427
428impl<P: EventProducer<StateT>, StateT: AgentState> AgentSession<P, StateT> {
429 pub fn new(producer: P) -> Self {
433 Self {
434 producer,
435 thread_id: ThreadId::random(),
436 current_run: None,
437 _state: PhantomData,
438 }
439 }
440
441 pub fn with_thread_id(producer: P, thread_id: ThreadId) -> Self {
443 Self {
444 producer,
445 thread_id,
446 current_run: None,
447 _state: PhantomData,
448 }
449 }
450
451 pub async fn start_run(&mut self) -> Result<RunId, ServerError> {
456 if self.current_run.is_some() {
457 return Err(ServerError::Channel("Run already in progress".into()));
458 }
459 let run_id = RunId::random();
460 self.producer
461 .emit(Event::RunStarted(RunStartedEvent::new(
462 self.thread_id.clone(),
463 run_id.clone(),
464 )))
465 .await?;
466 self.current_run = Some(run_id.clone());
467 Ok(run_id)
468 }
469
470 pub async fn finish_run(&mut self, result: Option<JsonValue>) -> Result<(), ServerError> {
475 if let Some(run_id) = self.current_run.take() {
476 let mut event = RunFinishedEvent::new(self.thread_id.clone(), run_id);
477 if let Some(r) = result {
478 event = event.with_result(r);
479 }
480 self.producer.emit(Event::RunFinished(event)).await?;
481 }
482 Ok(())
483 }
484
485 pub async fn run_error(&mut self, message: impl Into<String>) -> Result<(), ServerError> {
489 self.current_run = None;
490 self.producer
491 .emit(Event::RunError(RunErrorEvent::new(message)))
492 .await
493 }
494
495 pub async fn run_error_with_code(
497 &mut self,
498 message: impl Into<String>,
499 code: impl Into<String>,
500 ) -> Result<(), ServerError> {
501 self.current_run = None;
502 self.producer
503 .emit(Event::RunError(
504 RunErrorEvent::new(message).with_code(code),
505 ))
506 .await
507 }
508
509 pub fn producer(&self) -> &P {
511 &self.producer
512 }
513
514 pub fn thread_id(&self) -> &ThreadId {
516 &self.thread_id
517 }
518
519 pub fn run_id(&self) -> Option<&RunId> {
521 self.current_run.as_ref()
522 }
523
524 pub fn is_running(&self) -> bool {
526 self.current_run.is_some()
527 }
528
529 pub fn is_connected(&self) -> bool {
531 self.producer.is_connected()
532 }
533
534 pub async fn start_thinking(
546 &self,
547 title: Option<impl Into<String>>,
548 ) -> Result<ThinkingStep<'_, P, StateT>, ServerError> {
549 ThinkingStep::start(&self.producer, title).await
550 }
551
552 pub async fn interrupt(
570 &mut self,
571 reason: Option<impl Into<String>>,
572 payload: Option<JsonValue>,
573 ) -> Result<(), ServerError> {
574 let run_id = self.current_run.take();
575 if let Some(run_id) = run_id {
576 let mut info = InterruptInfo::new();
577 if let Some(r) = reason {
578 info = info.with_reason(r);
579 }
580 if let Some(p) = payload {
581 info = info.with_payload(p);
582 }
583
584 let event = RunFinishedEvent::new(self.thread_id.clone(), run_id).with_interrupt(info);
585 self.producer.emit(Event::RunFinished(event)).await?;
586 }
587 Ok(())
588 }
589
590 pub async fn interrupt_with_id(
608 &mut self,
609 id: impl Into<String>,
610 reason: Option<impl Into<String>>,
611 payload: Option<JsonValue>,
612 ) -> Result<(), ServerError> {
613 let run_id = self.current_run.take();
614 if let Some(run_id) = run_id {
615 let mut info = InterruptInfo::new().with_id(id);
616 if let Some(r) = reason {
617 info = info.with_reason(r);
618 }
619 if let Some(p) = payload {
620 info = info.with_payload(p);
621 }
622
623 let event = RunFinishedEvent::new(self.thread_id.clone(), run_id).with_interrupt(info);
624 self.producer.emit(Event::RunFinished(event)).await?;
625 }
626 Ok(())
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633 use std::sync::{Arc, Mutex};
634
635 struct MockProducer {
637 events: Arc<Mutex<Vec<Event>>>,
638 connected: bool,
639 }
640
641 impl MockProducer {
642 fn new() -> Self {
643 Self {
644 events: Arc::new(Mutex::new(Vec::new())),
645 connected: true,
646 }
647 }
648
649 fn events(&self) -> Vec<Event> {
650 self.events.lock().unwrap().clone()
651 }
652 }
653
654 #[async_trait]
655 impl EventProducer for MockProducer {
656 async fn emit(&self, event: Event) -> Result<(), ServerError> {
657 if !self.connected {
658 return Err(ServerError::Channel("disconnected".into()));
659 }
660 self.events.lock().unwrap().push(event);
661 Ok(())
662 }
663
664 fn is_connected(&self) -> bool {
665 self.connected
666 }
667 }
668
669 #[tokio::test]
670 async fn test_event_producer_emit() {
671 let producer = MockProducer::new();
672
673 producer
674 .emit(Event::RunError(RunErrorEvent::new("test")))
675 .await
676 .unwrap();
677
678 let events = producer.events();
679 assert_eq!(events.len(), 1);
680 assert!(matches!(events[0], Event::RunError(_)));
681 }
682
683 #[tokio::test]
684 async fn test_event_producer_emit_many() {
685 let producer = MockProducer::new();
686
687 producer
688 .emit_many(vec![
689 Event::RunError(RunErrorEvent::new("error1")),
690 Event::RunError(RunErrorEvent::new("error2")),
691 ])
692 .await
693 .unwrap();
694
695 let events = producer.events();
696 assert_eq!(events.len(), 2);
697 }
698
699 #[tokio::test]
700 async fn test_message_stream() {
701 let producer = MockProducer::new();
702
703 let msg = MessageStream::start(&producer).await.unwrap();
704 msg.content("Hello, ").await.unwrap();
705 msg.content("world!").await.unwrap();
706 let _message_id = msg.end().await.unwrap();
707
708 let events = producer.events();
709 assert_eq!(events.len(), 4); assert!(matches!(events[0], Event::TextMessageStart(_)));
712 assert!(matches!(events[1], Event::TextMessageContent(_)));
713 assert!(matches!(events[2], Event::TextMessageContent(_)));
714 assert!(matches!(events[3], Event::TextMessageEnd(_)));
715 }
716
717 #[tokio::test]
718 async fn test_message_stream_empty_content_ignored() {
719 let producer = MockProducer::new();
720
721 let msg = MessageStream::start(&producer).await.unwrap();
722 msg.content("").await.unwrap(); msg.content("Hello").await.unwrap();
724 msg.end().await.unwrap();
725
726 let events = producer.events();
727 assert_eq!(events.len(), 3); }
729
730 #[tokio::test]
731 async fn test_tool_call_stream() {
732 let producer = MockProducer::new();
733
734 let call = ToolCallStream::start(&producer, "get_weather").await.unwrap();
735 call.args(r#"{"location": "#).await.unwrap();
736 call.args(r#""NYC"}"#).await.unwrap();
737 let _tool_call_id = call.end().await.unwrap();
738
739 let events = producer.events();
740 assert_eq!(events.len(), 4); assert!(matches!(events[0], Event::ToolCallStart(_)));
743 assert!(matches!(events[1], Event::ToolCallArgs(_)));
744 assert!(matches!(events[2], Event::ToolCallArgs(_)));
745 assert!(matches!(events[3], Event::ToolCallEnd(_)));
746 }
747
748 #[tokio::test]
749 async fn test_agent_session_run_lifecycle() {
750 let producer = MockProducer::new();
751 let mut session = AgentSession::new(producer);
752
753 assert!(!session.is_running());
754
755 let run_id = session.start_run().await.unwrap();
757 assert!(session.is_running());
758 assert_eq!(session.run_id(), Some(&run_id));
759
760 session.finish_run(None).await.unwrap();
762 assert!(!session.is_running());
763 assert_eq!(session.run_id(), None);
764
765 let events = session.producer().events();
766 assert_eq!(events.len(), 2);
767 assert!(matches!(events[0], Event::RunStarted(_)));
768 assert!(matches!(events[1], Event::RunFinished(_)));
769 }
770
771 #[tokio::test]
772 async fn test_agent_session_run_error() {
773 let producer = MockProducer::new();
774 let mut session = AgentSession::new(producer);
775
776 session.start_run().await.unwrap();
777 session.run_error("Something went wrong").await.unwrap();
778
779 assert!(!session.is_running());
780
781 let events = session.producer().events();
782 assert_eq!(events.len(), 2);
783 assert!(matches!(events[0], Event::RunStarted(_)));
784 assert!(matches!(events[1], Event::RunError(_)));
785 }
786
787 #[tokio::test]
788 async fn test_agent_session_double_start_error() {
789 let producer = MockProducer::new();
790 let mut session = AgentSession::new(producer);
791
792 session.start_run().await.unwrap();
793 let result = session.start_run().await;
794
795 assert!(result.is_err());
796 }
797
798 #[tokio::test]
799 async fn test_agent_session_finish_without_run() {
800 let producer = MockProducer::new();
801 let mut session = AgentSession::new(producer);
802
803 session.finish_run(None).await.unwrap();
805
806 let events = session.producer().events();
807 assert!(events.is_empty());
808 }
809
810 #[tokio::test]
815 async fn test_thinking_message_stream() {
816 let producer = MockProducer::new();
817
818 let thinking = ThinkingMessageStream::start(&producer).await.unwrap();
819 thinking.content("Let me analyze...").await.unwrap();
820 thinking.content("The answer is...").await.unwrap();
821 thinking.end().await.unwrap();
822
823 let events = producer.events();
824 assert_eq!(events.len(), 4); assert!(matches!(events[0], Event::ThinkingTextMessageStart(_)));
827 assert!(matches!(events[1], Event::ThinkingTextMessageContent(_)));
828 assert!(matches!(events[2], Event::ThinkingTextMessageContent(_)));
829 assert!(matches!(events[3], Event::ThinkingTextMessageEnd(_)));
830 }
831
832 #[tokio::test]
833 async fn test_thinking_message_stream_empty_content_allowed() {
834 let producer = MockProducer::new();
835
836 let thinking = ThinkingMessageStream::start(&producer).await.unwrap();
837 thinking.content("").await.unwrap(); thinking.content("Thinking...").await.unwrap();
839 thinking.end().await.unwrap();
840
841 let events = producer.events();
842 assert_eq!(events.len(), 4); }
845
846 #[tokio::test]
851 async fn test_thinking_step() {
852 let producer = MockProducer::new();
853
854 let step = ThinkingStep::start(&producer, None::<String>).await.unwrap();
855 step.end().await.unwrap();
856
857 let events = producer.events();
858 assert_eq!(events.len(), 2); assert!(matches!(events[0], Event::ThinkingStart(_)));
861 assert!(matches!(events[1], Event::ThinkingEnd(_)));
862 }
863
864 #[tokio::test]
865 async fn test_thinking_step_with_title() {
866 let producer = MockProducer::new();
867
868 let step = ThinkingStep::start(&producer, Some("Analyzing query"))
869 .await
870 .unwrap();
871 step.end().await.unwrap();
872
873 let events = producer.events();
874 assert_eq!(events.len(), 2);
875
876 if let Event::ThinkingStart(start) = &events[0] {
877 assert_eq!(start.title, Some("Analyzing query".to_string()));
878 } else {
879 panic!("Expected ThinkingStart event");
880 }
881 }
882
883 #[tokio::test]
884 async fn test_thinking_step_with_content() {
885 let producer = MockProducer::new();
886
887 let step = ThinkingStep::start(&producer, Some("Planning"))
888 .await
889 .unwrap();
890
891 let thinking = ThinkingMessageStream::start(step.producer()).await.unwrap();
893 thinking.content("First, consider...").await.unwrap();
894 thinking.end().await.unwrap();
895
896 step.end().await.unwrap();
897
898 let events = producer.events();
899 assert_eq!(events.len(), 5); assert!(matches!(events[0], Event::ThinkingStart(_)));
902 assert!(matches!(events[1], Event::ThinkingTextMessageStart(_)));
903 assert!(matches!(events[2], Event::ThinkingTextMessageContent(_)));
904 assert!(matches!(events[3], Event::ThinkingTextMessageEnd(_)));
905 assert!(matches!(events[4], Event::ThinkingEnd(_)));
906 }
907
908 #[tokio::test]
913 async fn test_agent_session_start_thinking() {
914 let producer = MockProducer::new();
915 let session = AgentSession::new(producer);
916
917 let step = session.start_thinking(Some("Reasoning")).await.unwrap();
918 step.end().await.unwrap();
919
920 let events = session.producer().events();
921 assert_eq!(events.len(), 2);
922 assert!(matches!(events[0], Event::ThinkingStart(_)));
923 assert!(matches!(events[1], Event::ThinkingEnd(_)));
924 }
925
926 #[tokio::test]
927 async fn test_agent_session_start_thinking_no_title() {
928 let producer = MockProducer::new();
929 let session = AgentSession::new(producer);
930
931 let step = session.start_thinking(None::<String>).await.unwrap();
932 step.end().await.unwrap();
933
934 let events = session.producer().events();
935 assert_eq!(events.len(), 2);
936
937 if let Event::ThinkingStart(start) = &events[0] {
938 assert!(start.title.is_none());
939 } else {
940 panic!("Expected ThinkingStart event");
941 }
942 }
943
944 #[tokio::test]
949 async fn test_agent_session_interrupt() {
950 use syncable_ag_ui_core::RunFinishedOutcome;
951
952 let producer = MockProducer::new();
953 let mut session = AgentSession::new(producer);
954
955 session.start_run().await.unwrap();
956 session
957 .interrupt(
958 Some("human_approval"),
959 Some(serde_json::json!({"action": "send_email"})),
960 )
961 .await
962 .unwrap();
963
964 assert!(!session.is_running());
966
967 let events = session.producer().events();
968 assert_eq!(events.len(), 2); assert!(matches!(events[0], Event::RunStarted(_)));
971
972 if let Event::RunFinished(finished) = &events[1] {
973 assert_eq!(finished.outcome, Some(RunFinishedOutcome::Interrupt));
974 assert!(finished.interrupt.is_some());
975 let info = finished.interrupt.as_ref().unwrap();
976 assert_eq!(info.reason, Some("human_approval".to_string()));
977 assert!(info.payload.is_some());
978 } else {
979 panic!("Expected RunFinished event");
980 }
981 }
982
983 #[tokio::test]
984 async fn test_agent_session_interrupt_with_id() {
985 use syncable_ag_ui_core::RunFinishedOutcome;
986
987 let producer = MockProducer::new();
988 let mut session = AgentSession::new(producer);
989
990 session.start_run().await.unwrap();
991 session
992 .interrupt_with_id(
993 "approval-001",
994 Some("database_modification"),
995 Some(serde_json::json!({"query": "DELETE FROM users"})),
996 )
997 .await
998 .unwrap();
999
1000 assert!(!session.is_running());
1001
1002 let events = session.producer().events();
1003 assert_eq!(events.len(), 2);
1004
1005 if let Event::RunFinished(finished) = &events[1] {
1006 assert_eq!(finished.outcome, Some(RunFinishedOutcome::Interrupt));
1007 let info = finished.interrupt.as_ref().unwrap();
1008 assert_eq!(info.id, Some("approval-001".to_string()));
1009 assert_eq!(info.reason, Some("database_modification".to_string()));
1010 } else {
1011 panic!("Expected RunFinished event");
1012 }
1013 }
1014
1015 #[tokio::test]
1016 async fn test_agent_session_interrupt_without_run() {
1017 let producer = MockProducer::new();
1018 let mut session = AgentSession::new(producer);
1019
1020 session
1022 .interrupt(Some("test"), None)
1023 .await
1024 .unwrap();
1025
1026 let events = session.producer().events();
1027 assert!(events.is_empty());
1028 }
1029
1030 #[tokio::test]
1031 async fn test_agent_session_interrupt_minimal() {
1032 let producer = MockProducer::new();
1033 let mut session = AgentSession::new(producer);
1034
1035 session.start_run().await.unwrap();
1036
1037 session
1039 .interrupt(None::<String>, None)
1040 .await
1041 .unwrap();
1042
1043 let events = session.producer().events();
1044 assert_eq!(events.len(), 2);
1045
1046 if let Event::RunFinished(finished) = &events[1] {
1047 let info = finished.interrupt.as_ref().unwrap();
1048 assert!(info.id.is_none());
1049 assert!(info.reason.is_none());
1050 assert!(info.payload.is_none());
1051 } else {
1052 panic!("Expected RunFinished event");
1053 }
1054 }
1055}