Skip to main content

systemprompt_agent/services/a2a_server/streaming/
event_loop.rs

1use std::sync::Arc;
2
3use axum::response::sse::Event;
4use serde_json::json;
5use systemprompt_identifiers::{ContextId, MessageId, TaskId};
6use systemprompt_models::{A2AEventBuilder, AgUiEventBuilder, RequestContext};
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8
9use crate::models::a2a::jsonrpc::NumberOrString;
10use crate::models::a2a::{Message, TaskState};
11use crate::repository::task::TaskRepository;
12use crate::services::a2a_server::processing::message::{MessageProcessor, StreamEvent};
13
14use super::handlers::text::TextStreamState;
15use super::handlers::{HandleCompleteParams, HandleErrorParams, handle_complete, handle_error};
16use super::webhook_client::WebhookContext;
17
18pub struct ProcessEventsParams {
19    pub tx: UnboundedSender<Event>,
20    pub chunk_rx: UnboundedReceiver<StreamEvent>,
21    pub task_id: TaskId,
22    pub context_id: ContextId,
23    pub message_id: MessageId,
24    pub original_message: Message,
25    pub agent_name: String,
26    pub context: RequestContext,
27    pub task_repo: TaskRepository,
28    pub processor: Arc<MessageProcessor>,
29    pub request_id: NumberOrString,
30}
31
32impl std::fmt::Debug for ProcessEventsParams {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("ProcessEventsParams")
35            .field("task_id", &self.task_id)
36            .field("context_id", &self.context_id)
37            .field("message_id", &self.message_id)
38            .field("agent_name", &self.agent_name)
39            .finish_non_exhaustive()
40    }
41}
42
43struct SendA2aStatusEventParams<'a> {
44    tx: &'a UnboundedSender<Event>,
45    task_id: &'a TaskId,
46    context_id: &'a ContextId,
47    state: &'a str,
48    is_final: bool,
49    request_id: &'a NumberOrString,
50}
51
52fn send_a2a_status_event(params: &SendA2aStatusEventParams<'_>) {
53    let SendA2aStatusEventParams {
54        tx,
55        task_id,
56        context_id,
57        state,
58        is_final,
59        request_id,
60    } = params;
61    let event = json!({
62        "jsonrpc": "2.0",
63        "id": request_id,
64        "result": {
65            "kind": "status-update",
66            "taskId": task_id.as_str(),
67            "contextId": context_id.as_str(),
68            "status": {
69                "state": state,
70                "timestamp": chrono::Utc::now().to_rfc3339()
71            },
72            "final": is_final
73        }
74    });
75    if tx.send(Event::default().data(event.to_string())).is_err() {
76        tracing::trace!("Failed to send status event, channel closed");
77    }
78}
79
80pub struct EmitRunStartedParams<'a> {
81    pub tx: &'a UnboundedSender<Event>,
82    pub webhook_context: &'a WebhookContext,
83    pub context_id: &'a ContextId,
84    pub task_id: &'a TaskId,
85    pub task_repo: &'a TaskRepository,
86    pub request_id: &'a NumberOrString,
87}
88
89pub async fn emit_run_started(params: EmitRunStartedParams<'_>) {
90    let EmitRunStartedParams {
91        tx,
92        webhook_context,
93        context_id,
94        task_id,
95        task_repo,
96        request_id,
97    } = params;
98    let working_timestamp = chrono::Utc::now();
99    if let Err(e) = task_repo
100        .update_task_state(task_id, TaskState::Working, &working_timestamp)
101        .await
102    {
103        tracing::error!(task_id = %task_id, error = %e, "Failed to update task state");
104        return;
105    }
106
107    send_a2a_status_event(&SendA2aStatusEventParams {
108        tx,
109        task_id,
110        context_id,
111        state: "working",
112        is_final: false,
113        request_id,
114    });
115
116    let a2a_event = A2AEventBuilder::task_status_update(
117        task_id.clone(),
118        context_id.clone(),
119        TaskState::Working,
120        None,
121    );
122    if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
123        tracing::error!(error = %e, "Failed to broadcast A2A working");
124    }
125
126    let event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
127    if let Err(e) = webhook_context.broadcast_agui(event).await {
128        tracing::error!(error = %e, "Failed to broadcast RUN_STARTED");
129    }
130}
131
132pub async fn process_events(params: ProcessEventsParams) {
133    let ProcessEventsParams {
134        tx,
135        mut chunk_rx,
136        task_id,
137        context_id,
138        message_id,
139        original_message,
140        agent_name,
141        context,
142        task_repo,
143        processor,
144        request_id,
145    } = params;
146
147    let webhook_context =
148        WebhookContext::new(context.user_id().clone(), context.auth_token().as_str());
149
150    emit_run_started(EmitRunStartedParams {
151        tx: &tx,
152        webhook_context: &webhook_context,
153        context_id: &context_id,
154        task_id: &task_id,
155        task_repo: &task_repo,
156        request_id: &request_id,
157    })
158    .await;
159
160    tracing::info!("Stream channel received, waiting for events...");
161
162    let mut text_state = TextStreamState::new().with_webhook_context(webhook_context.clone());
163
164    while let Some(event) = chunk_rx.recv().await {
165        match event {
166            StreamEvent::Text(text) => {
167                text_state.handle_text(text, &message_id).await;
168            },
169            StreamEvent::ToolCallStarted(tool_call) => {
170                let tool_call_id = tool_call.ai_tool_call_id.as_str();
171                let start_event = AgUiEventBuilder::tool_call_start(
172                    tool_call_id,
173                    &tool_call.name,
174                    Some(message_id.to_string()),
175                );
176                if let Err(e) = webhook_context.broadcast_agui(start_event).await {
177                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_START");
178                }
179
180                let args_json =
181                    serde_json::to_string(&tool_call.arguments).unwrap_or_else(|_| String::new());
182                let args_event = AgUiEventBuilder::tool_call_args(tool_call_id, &args_json);
183                if let Err(e) = webhook_context.broadcast_agui(args_event).await {
184                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_ARGS");
185                }
186
187                let end_event = AgUiEventBuilder::tool_call_end(tool_call_id);
188                if let Err(e) = webhook_context.broadcast_agui(end_event).await {
189                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_END");
190                }
191            },
192            StreamEvent::ToolResult { call_id, result } => {
193                let result_value =
194                    serde_json::to_value(&result).unwrap_or_else(|_| serde_json::Value::Null);
195                let result_event = AgUiEventBuilder::tool_call_result(
196                    uuid::Uuid::new_v4().to_string(),
197                    &call_id,
198                    result_value,
199                );
200                if let Err(e) = webhook_context.broadcast_agui(result_event).await {
201                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_RESULT");
202                }
203            },
204            StreamEvent::ExecutionStepUpdate { step } => {
205                let step_event = AgUiEventBuilder::execution_step(step.clone(), context_id.clone());
206                if let Err(e) = webhook_context.broadcast_agui(step_event).await {
207                    tracing::error!(error = %e, "Failed to broadcast execution_step");
208                }
209            },
210            StreamEvent::Complete {
211                full_text,
212                artifacts,
213            } => {
214                text_state.finalize(&message_id).await;
215
216                let complete_params = HandleCompleteParams {
217                    tx: &tx,
218                    webhook_context: &webhook_context,
219                    full_text,
220                    artifacts,
221                    task_id: &task_id,
222                    context_id: &context_id,
223                    id: message_id.as_str(),
224                    original_message: &original_message,
225                    agent_name: &agent_name,
226                    context: &context,
227                    auth_token: context.auth_token().as_str(),
228                    task_repo: &task_repo,
229                    processor: &processor,
230                };
231                handle_complete(complete_params).await;
232
233                send_a2a_status_event(&SendA2aStatusEventParams {
234                    tx: &tx,
235                    task_id: &task_id,
236                    context_id: &context_id,
237                    state: "completed",
238                    is_final: true,
239                    request_id: &request_id,
240                });
241
242                let a2a_event = A2AEventBuilder::task_status_update(
243                    task_id.clone(),
244                    context_id.clone(),
245                    TaskState::Completed,
246                    None,
247                );
248                if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
249                    tracing::error!(error = %e, "Failed to broadcast A2A completed");
250                }
251
252                break;
253            },
254            StreamEvent::Error(error) => {
255                text_state.finalize(&message_id).await;
256                handle_error(HandleErrorParams {
257                    tx: &tx,
258                    webhook_context: &webhook_context,
259                    error,
260                    task_id: &task_id,
261                    context_id: &context_id,
262                    task_repo: &task_repo,
263                })
264                .await;
265
266                send_a2a_status_event(&SendA2aStatusEventParams {
267                    tx: &tx,
268                    task_id: &task_id,
269                    context_id: &context_id,
270                    state: "failed",
271                    is_final: true,
272                    request_id: &request_id,
273                });
274
275                let a2a_event = A2AEventBuilder::task_status_update(
276                    task_id.clone(),
277                    context_id.clone(),
278                    TaskState::Failed,
279                    None,
280                );
281                if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
282                    tracing::error!(error = %e, "Failed to broadcast A2A failed");
283                }
284
285                break;
286            },
287        }
288    }
289
290    drop(tx);
291
292    tracing::info!("Stream event loop ended");
293}
294
295pub async fn handle_stream_creation_error(
296    webhook_context: &WebhookContext,
297    error: anyhow::Error,
298    task_id: &TaskId,
299    _context_id: &ContextId,
300    task_repo: &TaskRepository,
301) {
302    let error_msg = format!("Failed to create stream: {}", error);
303    tracing::error!(task_id = %task_id, error = %error, "Failed to create stream");
304
305    let failed_timestamp = chrono::Utc::now();
306    if let Err(e) = task_repo
307        .update_task_failed_with_error(task_id, &error_msg, &failed_timestamp)
308        .await
309    {
310        tracing::error!(task_id = %task_id, error = %e, "Failed to update task to failed state");
311    }
312
313    let error_event = AgUiEventBuilder::run_error(
314        format!("Failed to process message: {error}"),
315        Some("STREAM_CREATION_ERROR".to_string()),
316    );
317    if let Err(e) = webhook_context.broadcast_agui(error_event).await {
318        tracing::error!(error = %e, "Failed to broadcast RUN_ERROR");
319    }
320}