1use std::sync::Arc;
2
3use axum::response::sse::Event;
4use serde_json::json;
5use systemprompt_identifiers::{ContextId, MessageId, TaskId};
6use systemprompt_models::{A2AEventBuilder, AgUiEventBuilder, RequestContext};
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8
9use crate::models::a2a::jsonrpc::NumberOrString;
10use crate::models::a2a::{Message, TaskState};
11use crate::repository::task::TaskRepository;
12use crate::services::a2a_server::processing::message::{MessageProcessor, StreamEvent};
13
14use super::handlers::text::TextStreamState;
15use super::handlers::{HandleCompleteParams, HandleErrorParams, handle_complete, handle_error};
16use super::webhook_client::WebhookContext;
17
18pub struct ProcessEventsParams {
19 pub tx: UnboundedSender<Event>,
20 pub chunk_rx: UnboundedReceiver<StreamEvent>,
21 pub task_id: TaskId,
22 pub context_id: ContextId,
23 pub message_id: MessageId,
24 pub original_message: Message,
25 pub agent_name: String,
26 pub context: RequestContext,
27 pub task_repo: TaskRepository,
28 pub processor: Arc<MessageProcessor>,
29 pub request_id: NumberOrString,
30}
31
32impl std::fmt::Debug for ProcessEventsParams {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("ProcessEventsParams")
35 .field("task_id", &self.task_id)
36 .field("context_id", &self.context_id)
37 .field("message_id", &self.message_id)
38 .field("agent_name", &self.agent_name)
39 .finish_non_exhaustive()
40 }
41}
42
43struct SendA2aStatusEventParams<'a> {
44 tx: &'a UnboundedSender<Event>,
45 task_id: &'a TaskId,
46 context_id: &'a ContextId,
47 state: &'a str,
48 is_final: bool,
49 request_id: &'a NumberOrString,
50}
51
52fn send_a2a_status_event(params: &SendA2aStatusEventParams<'_>) {
53 let SendA2aStatusEventParams {
54 tx,
55 task_id,
56 context_id,
57 state,
58 is_final,
59 request_id,
60 } = params;
61 let event = json!({
62 "jsonrpc": "2.0",
63 "id": request_id,
64 "result": {
65 "kind": "status-update",
66 "taskId": task_id.as_str(),
67 "contextId": context_id.as_str(),
68 "status": {
69 "state": state,
70 "timestamp": chrono::Utc::now().to_rfc3339()
71 },
72 "final": is_final
73 }
74 });
75 if tx.send(Event::default().data(event.to_string())).is_err() {
76 tracing::trace!("Failed to send status event, channel closed");
77 }
78}
79
80pub struct EmitRunStartedParams<'a> {
81 pub tx: &'a UnboundedSender<Event>,
82 pub webhook_context: &'a WebhookContext,
83 pub context_id: &'a ContextId,
84 pub task_id: &'a TaskId,
85 pub task_repo: &'a TaskRepository,
86 pub request_id: &'a NumberOrString,
87}
88
89pub async fn emit_run_started(params: EmitRunStartedParams<'_>) {
90 let EmitRunStartedParams {
91 tx,
92 webhook_context,
93 context_id,
94 task_id,
95 task_repo,
96 request_id,
97 } = params;
98 let working_timestamp = chrono::Utc::now();
99 if let Err(e) = task_repo
100 .update_task_state(task_id, TaskState::Working, &working_timestamp)
101 .await
102 {
103 tracing::error!(task_id = %task_id, error = %e, "Failed to update task state");
104 return;
105 }
106
107 send_a2a_status_event(&SendA2aStatusEventParams {
108 tx,
109 task_id,
110 context_id,
111 state: "working",
112 is_final: false,
113 request_id,
114 });
115
116 let a2a_event = A2AEventBuilder::task_status_update(
117 task_id.clone(),
118 context_id.clone(),
119 TaskState::Working,
120 None,
121 );
122 if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
123 tracing::error!(error = %e, "Failed to broadcast A2A working");
124 }
125
126 let event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
127 if let Err(e) = webhook_context.broadcast_agui(event).await {
128 tracing::error!(error = %e, "Failed to broadcast RUN_STARTED");
129 }
130}
131
132pub async fn process_events(params: ProcessEventsParams) {
133 let ProcessEventsParams {
134 tx,
135 mut chunk_rx,
136 task_id,
137 context_id,
138 message_id,
139 original_message,
140 agent_name,
141 context,
142 task_repo,
143 processor,
144 request_id,
145 } = params;
146
147 let webhook_context =
148 WebhookContext::new(context.user_id().clone(), context.auth_token().as_str());
149
150 emit_run_started(EmitRunStartedParams {
151 tx: &tx,
152 webhook_context: &webhook_context,
153 context_id: &context_id,
154 task_id: &task_id,
155 task_repo: &task_repo,
156 request_id: &request_id,
157 })
158 .await;
159
160 tracing::info!("Stream channel received, waiting for events...");
161
162 let mut text_state = TextStreamState::new().with_webhook_context(webhook_context.clone());
163
164 while let Some(event) = chunk_rx.recv().await {
165 match event {
166 StreamEvent::Text(text) => {
167 text_state.handle_text(text, &message_id).await;
168 },
169 StreamEvent::ToolCallStarted(tool_call) => {
170 let tool_call_id = tool_call.ai_tool_call_id.as_str();
171 let start_event = AgUiEventBuilder::tool_call_start(
172 tool_call_id,
173 &tool_call.name,
174 Some(message_id.to_string()),
175 );
176 if let Err(e) = webhook_context.broadcast_agui(start_event).await {
177 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_START");
178 }
179
180 let args_json =
181 serde_json::to_string(&tool_call.arguments).unwrap_or_else(|_| String::new());
182 let args_event = AgUiEventBuilder::tool_call_args(tool_call_id, &args_json);
183 if let Err(e) = webhook_context.broadcast_agui(args_event).await {
184 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_ARGS");
185 }
186
187 let end_event = AgUiEventBuilder::tool_call_end(tool_call_id);
188 if let Err(e) = webhook_context.broadcast_agui(end_event).await {
189 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_END");
190 }
191 },
192 StreamEvent::ToolResult { call_id, result } => {
193 let result_value =
194 serde_json::to_value(&result).unwrap_or_else(|_| serde_json::Value::Null);
195 let result_event = AgUiEventBuilder::tool_call_result(
196 uuid::Uuid::new_v4().to_string(),
197 &call_id,
198 result_value,
199 );
200 if let Err(e) = webhook_context.broadcast_agui(result_event).await {
201 tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_RESULT");
202 }
203 },
204 StreamEvent::ExecutionStepUpdate { step } => {
205 let step_event = AgUiEventBuilder::execution_step(step.clone(), context_id.clone());
206 if let Err(e) = webhook_context.broadcast_agui(step_event).await {
207 tracing::error!(error = %e, "Failed to broadcast execution_step");
208 }
209 },
210 StreamEvent::Complete {
211 full_text,
212 artifacts,
213 } => {
214 text_state.finalize(&message_id).await;
215
216 let complete_params = HandleCompleteParams {
217 tx: &tx,
218 webhook_context: &webhook_context,
219 full_text,
220 artifacts,
221 task_id: &task_id,
222 context_id: &context_id,
223 id: message_id.as_str(),
224 original_message: &original_message,
225 agent_name: &agent_name,
226 context: &context,
227 auth_token: context.auth_token().as_str(),
228 task_repo: &task_repo,
229 processor: &processor,
230 };
231 handle_complete(complete_params).await;
232
233 send_a2a_status_event(&SendA2aStatusEventParams {
234 tx: &tx,
235 task_id: &task_id,
236 context_id: &context_id,
237 state: "completed",
238 is_final: true,
239 request_id: &request_id,
240 });
241
242 let a2a_event = A2AEventBuilder::task_status_update(
243 task_id.clone(),
244 context_id.clone(),
245 TaskState::Completed,
246 None,
247 );
248 if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
249 tracing::error!(error = %e, "Failed to broadcast A2A completed");
250 }
251
252 break;
253 },
254 StreamEvent::Error(error) => {
255 text_state.finalize(&message_id).await;
256 handle_error(HandleErrorParams {
257 tx: &tx,
258 webhook_context: &webhook_context,
259 error,
260 task_id: &task_id,
261 context_id: &context_id,
262 task_repo: &task_repo,
263 })
264 .await;
265
266 send_a2a_status_event(&SendA2aStatusEventParams {
267 tx: &tx,
268 task_id: &task_id,
269 context_id: &context_id,
270 state: "failed",
271 is_final: true,
272 request_id: &request_id,
273 });
274
275 let a2a_event = A2AEventBuilder::task_status_update(
276 task_id.clone(),
277 context_id.clone(),
278 TaskState::Failed,
279 None,
280 );
281 if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
282 tracing::error!(error = %e, "Failed to broadcast A2A failed");
283 }
284
285 break;
286 },
287 }
288 }
289
290 drop(tx);
291
292 tracing::info!("Stream event loop ended");
293}
294
295pub async fn handle_stream_creation_error(
296 webhook_context: &WebhookContext,
297 error: anyhow::Error,
298 task_id: &TaskId,
299 _context_id: &ContextId,
300 task_repo: &TaskRepository,
301) {
302 let error_msg = format!("Failed to create stream: {}", error);
303 tracing::error!(task_id = %task_id, error = %error, "Failed to create stream");
304
305 let failed_timestamp = chrono::Utc::now();
306 if let Err(e) = task_repo
307 .update_task_failed_with_error(task_id, &error_msg, &failed_timestamp)
308 .await
309 {
310 tracing::error!(task_id = %task_id, error = %e, "Failed to update task to failed state");
311 }
312
313 let error_event = AgUiEventBuilder::run_error(
314 format!("Failed to process message: {error}"),
315 Some("STREAM_CREATION_ERROR".to_string()),
316 );
317 if let Err(e) = webhook_context.broadcast_agui(error_event).await {
318 tracing::error!(error = %e, "Failed to broadcast RUN_ERROR");
319 }
320}