systemprompt_agent/services/a2a_server/streaming/
event_loop.rs1use std::sync::Arc;
2
3use axum::response::sse::Event;
4use serde_json::json;
5use systemprompt_identifiers::{ContextId, MessageId, TaskId};
6use systemprompt_models::{A2AEventBuilder, AgUiEventBuilder, RequestContext};
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8
9use crate::models::a2a::jsonrpc::NumberOrString;
10use crate::models::a2a::{Message, TaskState};
11use crate::repository::task::TaskRepository;
12use crate::services::a2a_server::processing::message::{MessageProcessor, StreamEvent};
13
14use super::handlers::text::TextStreamState;
15use super::handlers::{handle_complete, handle_error, HandleCompleteParams};
16use super::webhook_client::WebhookContext;
17
18pub struct ProcessEventsParams {
19 pub tx: UnboundedSender<Event>,
20 pub chunk_rx: UnboundedReceiver<StreamEvent>,
21 pub task_id: TaskId,
22 pub context_id: ContextId,
23 pub message_id: MessageId,
24 pub original_message: Message,
25 pub agent_name: String,
26 pub context: RequestContext,
27 pub task_repo: TaskRepository,
28 pub processor: Arc<MessageProcessor>,
29 pub request_id: NumberOrString,
30}
31
32impl std::fmt::Debug for ProcessEventsParams {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("ProcessEventsParams")
35 .field("task_id", &self.task_id)
36 .field("context_id", &self.context_id)
37 .field("message_id", &self.message_id)
38 .field("agent_name", &self.agent_name)
39 .finish_non_exhaustive()
40 }
41}
42
43fn send_a2a_status_event(
44 tx: &UnboundedSender<Event>,
45 task_id: &TaskId,
46 context_id: &ContextId,
47 state: &str,
48 is_final: bool,
49 request_id: &NumberOrString,
50) {
51 let event = json!({
52 "jsonrpc": "2.0",
53 "id": request_id,
54 "result": {
55 "kind": "status-update",
56 "taskId": task_id.as_str(),
57 "contextId": context_id.as_str(),
58 "status": {
59 "state": state,
60 "timestamp": chrono::Utc::now().to_rfc3339()
61 },
62 "final": is_final
63 }
64 });
65 if tx.send(Event::default().data(event.to_string())).is_err() {
66 tracing::trace!("Failed to send status event, channel closed");
67 }
68}
69
70pub async fn emit_run_started(
71 tx: &UnboundedSender<Event>,
72 webhook_context: &WebhookContext,
73 context_id: &ContextId,
74 task_id: &TaskId,
75 task_repo: &TaskRepository,
76 request_id: &NumberOrString,
77) {
78 let working_timestamp = chrono::Utc::now();
79 if let Err(e) = task_repo
80 .update_task_state(&task_id, TaskState::Working, &working_timestamp)
81 .await
82 {
83 tracing::error!(task_id = %task_id, error = %e, "Failed to update task state");
84 return;
85 }
86
87 send_a2a_status_event(tx, task_id, context_id, "working", false, request_id);
88
89 let a2a_event = A2AEventBuilder::task_status_update(
90 task_id.clone(),
91 context_id.clone(),
92 TaskState::Working,
93 None,
94 );
95 if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
96 tracing::error!(error = %e, "Failed to broadcast A2A working");
97 }
98
99 let event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
100 if let Err(e) = webhook_context.broadcast_agui(event).await {
101 tracing::error!(error = %e, "Failed to broadcast RUN_STARTED");
102 }
103}
104
105pub async fn process_events(params: ProcessEventsParams) {
106 let ProcessEventsParams {
107 tx,
108 mut chunk_rx,
109 task_id,
110 context_id,
111 message_id,
112 original_message,
113 agent_name,
114 context,
115 task_repo,
116 processor,
117 request_id,
118 } = params;
119
120 let webhook_context =
121 WebhookContext::new(context.user_id().as_str(), context.auth_token().as_str());
122
123 emit_run_started(
124 &tx,
125 &webhook_context,
126 &context_id,
127 &task_id,
128 &task_repo,
129 &request_id,
130 )
131 .await;
132
133 tracing::info!("Stream channel received, waiting for events...");
134
135 let mut text_state = TextStreamState::new().with_webhook_context(webhook_context.clone());
136
137 while let Some(event) = chunk_rx.recv().await {
138 match event {
139 StreamEvent::Text(text) => {
140 text_state.handle_text(text, message_id.as_str()).await;
141 },
142 StreamEvent::ToolCallStarted(tool_call) => {
143 let tool_call_id = tool_call.ai_tool_call_id.as_str();
144 let start_event = AgUiEventBuilder::tool_call_start(
145 tool_call_id,
146 &tool_call.name,
147 Some(message_id.to_string()),
148 );
149 if let Err(e) = webhook_context.broadcast_agui(start_event).await {
150 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_START");
151 }
152
153 let args_json =
154 serde_json::to_string(&tool_call.arguments).unwrap_or_else(|_| String::new());
155 let args_event = AgUiEventBuilder::tool_call_args(tool_call_id, &args_json);
156 if let Err(e) = webhook_context.broadcast_agui(args_event).await {
157 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_ARGS");
158 }
159
160 let end_event = AgUiEventBuilder::tool_call_end(tool_call_id);
161 if let Err(e) = webhook_context.broadcast_agui(end_event).await {
162 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_END");
163 }
164 },
165 StreamEvent::ToolResult { call_id, result } => {
166 let result_value =
167 serde_json::to_value(&result).unwrap_or_else(|_| serde_json::Value::Null);
168 let result_event = AgUiEventBuilder::tool_call_result(
169 &uuid::Uuid::new_v4().to_string(),
170 &call_id,
171 result_value,
172 );
173 if let Err(e) = webhook_context.broadcast_agui(result_event).await {
174 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_RESULT");
175 }
176 },
177 StreamEvent::ExecutionStepUpdate { step } => {
178 let step_event = AgUiEventBuilder::execution_step(step.clone(), context_id.clone());
179 if let Err(e) = webhook_context.broadcast_agui(step_event).await {
180 tracing::error!(error = %e, "Failed to broadcast execution_step");
181 }
182 },
183 StreamEvent::Complete {
184 full_text,
185 artifacts,
186 } => {
187 text_state.finalize(message_id.as_str()).await;
188
189 let complete_params = HandleCompleteParams {
190 tx: &tx,
191 webhook_context: &webhook_context,
192 full_text,
193 artifacts,
194 task_id: &task_id,
195 context_id: &context_id,
196 id: message_id.as_str(),
197 original_message: &original_message,
198 agent_name: &agent_name,
199 context: &context,
200 auth_token: context.auth_token().as_str(),
201 task_repo: &task_repo,
202 processor: &processor,
203 };
204 handle_complete(complete_params).await;
205
206 send_a2a_status_event(&tx, &task_id, &context_id, "completed", true, &request_id);
207
208 let a2a_event = A2AEventBuilder::task_status_update(
209 task_id.clone(),
210 context_id.clone(),
211 TaskState::Completed,
212 None,
213 );
214 if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
215 tracing::error!(error = %e, "Failed to broadcast A2A completed");
216 }
217
218 break;
219 },
220 StreamEvent::Error(error) => {
221 text_state.finalize(message_id.as_str()).await;
222 handle_error(
223 &tx,
224 &webhook_context,
225 error,
226 &task_id,
227 &context_id,
228 &task_repo,
229 )
230 .await;
231
232 send_a2a_status_event(&tx, &task_id, &context_id, "failed", true, &request_id);
233
234 let a2a_event = A2AEventBuilder::task_status_update(
235 task_id.clone(),
236 context_id.clone(),
237 TaskState::Failed,
238 None,
239 );
240 if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
241 tracing::error!(error = %e, "Failed to broadcast A2A failed");
242 }
243
244 break;
245 },
246 }
247 }
248
249 drop(tx);
250
251 tracing::info!("Stream event loop ended");
252}
253
254pub async fn handle_stream_creation_error(
255 webhook_context: &WebhookContext,
256 error: anyhow::Error,
257 task_id: &TaskId,
258 _context_id: &ContextId,
259 task_repo: &TaskRepository,
260) {
261 let error_msg = format!("Failed to create stream: {}", error);
262 tracing::error!(task_id = %task_id, error = %error, "Failed to create stream");
263
264 let failed_timestamp = chrono::Utc::now();
265 if let Err(e) = task_repo
266 .update_task_failed_with_error(task_id, &error_msg, &failed_timestamp)
267 .await
268 {
269 tracing::error!(task_id = %task_id, error = %e, "Failed to update task to failed state");
270 }
271
272 let error_event = AgUiEventBuilder::run_error(
273 format!("Failed to process message: {error}"),
274 Some("STREAM_CREATION_ERROR".to_string()),
275 );
276 if let Err(e) = webhook_context.broadcast_agui(error_event).await {
277 tracing::error!(error = %e, "Failed to broadcast RUN_ERROR");
278 }
279}