1use std::sync::Arc;
30
31use syncable_ag_ui_core::{
32 BaseEvent, Event, InterruptInfo, JsonValue, MessageId, Role, RunFinishedEvent,
33 RunFinishedOutcome, RunId, RunStartedEvent, TextMessageContentEvent, TextMessageEndEvent,
34 TextMessageStartEvent, ThreadId, ToolCallArgsEvent, ToolCallEndEvent, ToolCallId,
35 ToolCallStartEvent,
36};
37use tokio::sync::{RwLock, broadcast};
38
39#[derive(Clone)]
45pub struct EventBridge {
46 event_tx: broadcast::Sender<Event<JsonValue>>,
47 thread_id: Arc<RwLock<ThreadId>>,
48 run_id: Arc<RwLock<Option<RunId>>>,
49 current_message_id: Arc<RwLock<Option<MessageId>>>,
50 current_step_name: Arc<RwLock<Option<String>>>,
51}
52
53impl EventBridge {
54 pub fn new(
56 event_tx: broadcast::Sender<Event<JsonValue>>,
57 thread_id: Arc<RwLock<ThreadId>>,
58 run_id: Arc<RwLock<Option<RunId>>>,
59 ) -> Self {
60 Self {
61 event_tx,
62 thread_id,
63 run_id,
64 current_message_id: Arc::new(RwLock::new(None)),
65 current_step_name: Arc::new(RwLock::new(None)),
66 }
67 }
68
69 fn emit(&self, event: Event<JsonValue>) {
71 let _ = self.event_tx.send(event);
73 }
74
75 pub async fn start_run(&self) {
83 let thread_id = self.thread_id.read().await.clone();
84 let run_id = RunId::random();
85
86 *self.run_id.write().await = Some(run_id.clone());
88
89 self.emit(Event::RunStarted(RunStartedEvent {
90 base: BaseEvent::with_current_timestamp(),
91 thread_id,
92 run_id,
93 }));
94 }
95
96 pub async fn finish_run(&self) {
98 let thread_id = self.thread_id.read().await.clone();
99 let run_id = self.run_id.write().await.take();
100 let Some(run_id) = run_id else {
101 return; };
103
104 self.emit(Event::RunFinished(RunFinishedEvent {
105 base: BaseEvent::with_current_timestamp(),
106 thread_id,
107 run_id,
108 outcome: Some(RunFinishedOutcome::Success),
109 result: None,
110 interrupt: None,
111 }));
112 }
113
114 pub async fn finish_run_with_error(&self, message: &str) {
116 let _run_id = self.run_id.write().await.take();
117
118 self.emit(Event::RunError(syncable_ag_ui_core::RunErrorEvent {
119 base: BaseEvent::with_current_timestamp(),
120 message: message.to_string(),
121 code: None,
122 }));
123 }
124
125 pub async fn interrupt(&self, reason: Option<&str>, payload: Option<serde_json::Value>) {
138 let thread_id = self.thread_id.read().await.clone();
139 let run_id = self.run_id.write().await.take();
140 let Some(run_id) = run_id else {
141 return; };
143
144 let mut info = InterruptInfo::new();
145 if let Some(r) = reason {
146 info = info.with_reason(r);
147 }
148 if let Some(p) = payload {
149 info = info.with_payload(p);
150 }
151
152 self.emit(Event::RunFinished(RunFinishedEvent {
153 base: BaseEvent::with_current_timestamp(),
154 thread_id,
155 run_id,
156 outcome: Some(RunFinishedOutcome::Interrupt),
157 result: None,
158 interrupt: Some(info),
159 }));
160 }
161
162 pub async fn interrupt_with_id(
167 &self,
168 id: &str,
169 reason: Option<&str>,
170 payload: Option<serde_json::Value>,
171 ) {
172 let thread_id = self.thread_id.read().await.clone();
173 let run_id = self.run_id.write().await.take();
174 let Some(run_id) = run_id else {
175 return; };
177
178 let mut info = InterruptInfo::new().with_id(id);
179 if let Some(r) = reason {
180 info = info.with_reason(r);
181 }
182 if let Some(p) = payload {
183 info = info.with_payload(p);
184 }
185
186 self.emit(Event::RunFinished(RunFinishedEvent {
187 base: BaseEvent::with_current_timestamp(),
188 thread_id,
189 run_id,
190 outcome: Some(RunFinishedOutcome::Interrupt),
191 result: None,
192 interrupt: Some(info),
193 }));
194 }
195
196 pub async fn start_message(&self) -> MessageId {
202 let message_id = MessageId::random();
203 *self.current_message_id.write().await = Some(message_id.clone());
204
205 self.emit(Event::TextMessageStart(TextMessageStartEvent {
206 base: BaseEvent::with_current_timestamp(),
207 message_id: message_id.clone(),
208 role: Role::Assistant,
209 }));
210
211 message_id
212 }
213
214 pub async fn emit_text_chunk(&self, delta: &str) {
216 let message_id = self.current_message_id.read().await.clone();
217 if let Some(message_id) = message_id {
218 self.emit(Event::TextMessageContent(
219 TextMessageContentEvent::new_unchecked(message_id, delta),
220 ));
221 }
222 }
223
224 pub async fn end_message(&self) {
226 let message_id = self.current_message_id.write().await.take();
227 if let Some(message_id) = message_id {
228 self.emit(Event::TextMessageEnd(TextMessageEndEvent {
229 base: BaseEvent::with_current_timestamp(),
230 message_id,
231 }));
232 }
233 }
234
235 pub async fn emit_message(&self, content: &str) {
237 let _message_id = self.start_message().await;
238 self.emit_text_chunk(content).await;
239 self.end_message().await;
240 }
241
242 pub async fn start_tool_call(&self, name: &str, args: &JsonValue) -> ToolCallId {
250 let tool_call_id = ToolCallId::random();
251
252 let message_id = {
254 let mut current = self.current_message_id.write().await;
255 if current.is_none() {
256 *current = Some(MessageId::random());
257 }
258 current.clone().unwrap()
259 };
260
261 self.emit(Event::ToolCallStart(ToolCallStartEvent {
262 base: BaseEvent::with_current_timestamp(),
263 tool_call_id: tool_call_id.clone(),
264 tool_call_name: name.to_string(),
265 parent_message_id: Some(message_id),
266 }));
267
268 if !args.is_null() {
270 if let Ok(args_str) = serde_json::to_string(args) {
271 self.emit(Event::ToolCallArgs(ToolCallArgsEvent {
272 base: BaseEvent::with_current_timestamp(),
273 tool_call_id: tool_call_id.clone(),
274 delta: args_str,
275 }));
276 }
277 }
278
279 tool_call_id
280 }
281
282 pub async fn emit_tool_args_chunk(&self, tool_call_id: &ToolCallId, delta: &str) {
284 self.emit(Event::ToolCallArgs(ToolCallArgsEvent {
285 base: BaseEvent::with_current_timestamp(),
286 tool_call_id: tool_call_id.clone(),
287 delta: delta.to_string(),
288 }));
289 }
290
291 pub async fn end_tool_call(&self, tool_call_id: &ToolCallId) {
295 self.emit(Event::ToolCallEnd(ToolCallEndEvent {
296 base: BaseEvent::with_current_timestamp(),
297 tool_call_id: tool_call_id.clone(),
298 }));
299 }
300
301 pub async fn emit_tool_call(&self, name: &str, args: &JsonValue) {
303 let tool_call_id = self.start_tool_call(name, args).await;
304 self.end_tool_call(&tool_call_id).await;
305 }
306
307 pub async fn emit_state_snapshot(&self, state: JsonValue) {
313 self.emit(Event::StateSnapshot(
314 syncable_ag_ui_core::StateSnapshotEvent {
315 base: BaseEvent::with_current_timestamp(),
316 snapshot: state,
317 },
318 ));
319 }
320
321 pub async fn emit_state_delta(&self, delta: Vec<JsonValue>) {
323 self.emit(Event::StateDelta(syncable_ag_ui_core::StateDeltaEvent {
324 base: BaseEvent::with_current_timestamp(),
325 delta,
326 }));
327 }
328
329 pub async fn start_thinking(&self, title: Option<&str>) {
335 self.emit(Event::ThinkingStart(
336 syncable_ag_ui_core::ThinkingStartEvent {
337 base: BaseEvent::with_current_timestamp(),
338 title: title.map(|s| s.to_string()),
339 },
340 ));
341 }
342
343 pub async fn end_thinking(&self) {
345 self.emit(Event::ThinkingEnd(syncable_ag_ui_core::ThinkingEndEvent {
346 base: BaseEvent::with_current_timestamp(),
347 }));
348 }
349
350 pub async fn start_step(&self, name: &str) {
352 *self.current_step_name.write().await = Some(name.to_string());
353 self.emit(Event::StepStarted(syncable_ag_ui_core::StepStartedEvent {
354 base: BaseEvent::with_current_timestamp(),
355 step_name: name.to_string(),
356 }));
357 }
358
359 pub async fn end_step(&self) {
361 let step_name = self
362 .current_step_name
363 .write()
364 .await
365 .take()
366 .unwrap_or_else(|| "unknown".to_string());
367 self.emit(Event::StepFinished(
368 syncable_ag_ui_core::StepFinishedEvent {
369 base: BaseEvent::with_current_timestamp(),
370 step_name,
371 },
372 ));
373 }
374
375 pub async fn emit_custom(&self, name: &str, value: JsonValue) {
381 self.emit(Event::Custom(syncable_ag_ui_core::CustomEvent {
382 base: BaseEvent::with_current_timestamp(),
383 name: name.to_string(),
384 value,
385 }));
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 fn create_bridge() -> EventBridge {
394 let (tx, _) = broadcast::channel(100);
395 EventBridge::new(
396 tx,
397 Arc::new(RwLock::new(ThreadId::random())),
398 Arc::new(RwLock::new(None)),
399 )
400 }
401
402 #[tokio::test]
403 async fn test_start_and_finish_run() {
404 let bridge = create_bridge();
405
406 bridge.start_run().await;
407 assert!(bridge.run_id.read().await.is_some());
408
409 bridge.finish_run().await;
410 assert!(bridge.run_id.read().await.is_none());
411 }
412
413 #[tokio::test]
414 async fn test_message_lifecycle() {
415 let bridge = create_bridge();
416
417 let _msg_id = bridge.start_message().await;
418 assert!(bridge.current_message_id.read().await.is_some());
419
420 bridge.emit_text_chunk("Hello").await;
421 bridge.end_message().await;
422
423 assert!(bridge.current_message_id.read().await.is_none());
424 }
425
426 #[tokio::test]
427 async fn test_emit_complete_message() {
428 let bridge = create_bridge();
429 bridge.emit_message("Hello, world!").await;
430 }
432
433 #[tokio::test]
434 async fn test_tool_call() {
435 let bridge = create_bridge();
436
437 let tool_id = bridge
438 .start_tool_call("test", &serde_json::json!({"key": "value"}))
439 .await;
440 bridge.emit_tool_args_chunk(&tool_id, "more args").await;
441 bridge.end_tool_call(&tool_id).await;
442 }
444
445 #[tokio::test]
446 async fn test_interrupt() {
447 let bridge = create_bridge();
448
449 bridge.start_run().await;
450 assert!(bridge.run_id.read().await.is_some());
451
452 bridge.interrupt(Some("file_write"), None).await;
453 assert!(bridge.run_id.read().await.is_none());
455 }
456
457 #[tokio::test]
458 async fn test_interrupt_with_payload() {
459 let bridge = create_bridge();
460
461 bridge.start_run().await;
462 bridge
463 .interrupt(
464 Some("deployment"),
465 Some(serde_json::json!({"file": "main.rs", "action": "write"})),
466 )
467 .await;
468 assert!(bridge.run_id.read().await.is_none());
469 }
470
471 #[tokio::test]
472 async fn test_interrupt_with_id() {
473 let bridge = create_bridge();
474
475 bridge.start_run().await;
476 bridge
477 .interrupt_with_id("int-123", Some("deployment"), None)
478 .await;
479 assert!(bridge.run_id.read().await.is_none());
480 }
481
482 #[tokio::test]
483 async fn test_interrupt_without_run() {
484 let bridge = create_bridge();
485
486 bridge.interrupt(Some("test"), None).await;
488 }
489
490 #[tokio::test]
491 async fn test_events_received_by_subscriber() {
492 let (tx, mut rx) = broadcast::channel(100);
493 let bridge = EventBridge::new(
494 tx,
495 Arc::new(RwLock::new(ThreadId::random())),
496 Arc::new(RwLock::new(None)),
497 );
498
499 bridge.start_run().await;
501
502 let event = rx.recv().await.expect("Should receive event");
504 match event {
505 Event::RunStarted(_) => {}
506 _ => panic!("Expected RunStarted event"),
507 }
508
509 bridge.emit_message("Hello").await;
511
512 let event = rx.recv().await.expect("Should receive event");
514 match event {
515 Event::TextMessageStart(_) => {}
516 _ => panic!("Expected TextMessageStart"),
517 }
518
519 let event = rx.recv().await.expect("Should receive event");
520 match event {
521 Event::TextMessageContent(_) => {}
522 _ => panic!("Expected TextMessageContent"),
523 }
524
525 let event = rx.recv().await.expect("Should receive event");
526 match event {
527 Event::TextMessageEnd(_) => {}
528 _ => panic!("Expected TextMessageEnd"),
529 }
530 }
531
532 #[tokio::test]
533 async fn test_step_and_thinking_events() {
534 let (tx, mut rx) = broadcast::channel(100);
535 let bridge = EventBridge::new(
536 tx,
537 Arc::new(RwLock::new(ThreadId::random())),
538 Arc::new(RwLock::new(None)),
539 );
540
541 bridge.start_step("processing").await;
542 let event = rx.recv().await.expect("Should receive event");
543 match event {
544 Event::StepStarted(_) => {}
545 _ => panic!("Expected StepStarted"),
546 }
547
548 bridge.start_thinking(Some("Analyzing")).await;
549 let event = rx.recv().await.expect("Should receive event");
550 match event {
551 Event::ThinkingStart(_) => {}
552 _ => panic!("Expected ThinkingStart"),
553 }
554
555 bridge.end_thinking().await;
556 let event = rx.recv().await.expect("Should receive event");
557 match event {
558 Event::ThinkingEnd(_) => {}
559 _ => panic!("Expected ThinkingEnd"),
560 }
561
562 bridge.end_step().await;
563 let event = rx.recv().await.expect("Should receive event");
564 match event {
565 Event::StepFinished(_) => {}
566 _ => panic!("Expected StepFinished"),
567 }
568 }
569
570 #[tokio::test]
571 async fn test_state_snapshot_event() {
572 let (tx, mut rx) = broadcast::channel(100);
573 let bridge = EventBridge::new(
574 tx,
575 Arc::new(RwLock::new(ThreadId::random())),
576 Arc::new(RwLock::new(None)),
577 );
578
579 let state = serde_json::json!({
580 "model": "gpt-4",
581 "turn_count": 5
582 });
583
584 bridge.emit_state_snapshot(state).await;
585
586 let event = rx.recv().await.expect("Should receive event");
587 match event {
588 Event::StateSnapshot(e) => {
589 assert_eq!(e.snapshot["model"], "gpt-4");
590 assert_eq!(e.snapshot["turn_count"], 5);
591 }
592 _ => panic!("Expected StateSnapshot"),
593 }
594 }
595}