Skip to main content

systemprompt_agent/services/a2a_server/streaming/
event_loop.rs

1use std::sync::Arc;
2
3use axum::response::sse::Event;
4use serde_json::json;
5use systemprompt_identifiers::{ContextId, MessageId, TaskId};
6use systemprompt_models::{A2AEventBuilder, AgUiEventBuilder, RequestContext};
7use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
8
9use crate::models::a2a::jsonrpc::NumberOrString;
10use crate::models::a2a::{Message, TaskState};
11use crate::repository::task::TaskRepository;
12use crate::services::a2a_server::processing::message::{MessageProcessor, StreamEvent};
13
14use super::handlers::text::TextStreamState;
15use super::handlers::{handle_complete, handle_error, HandleCompleteParams};
16use super::webhook_client::WebhookContext;
17
18pub struct ProcessEventsParams {
19    pub tx: UnboundedSender<Event>,
20    pub chunk_rx: UnboundedReceiver<StreamEvent>,
21    pub task_id: TaskId,
22    pub context_id: ContextId,
23    pub message_id: MessageId,
24    pub original_message: Message,
25    pub agent_name: String,
26    pub context: RequestContext,
27    pub task_repo: TaskRepository,
28    pub processor: Arc<MessageProcessor>,
29    pub request_id: NumberOrString,
30}
31
32impl std::fmt::Debug for ProcessEventsParams {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("ProcessEventsParams")
35            .field("task_id", &self.task_id)
36            .field("context_id", &self.context_id)
37            .field("message_id", &self.message_id)
38            .field("agent_name", &self.agent_name)
39            .finish_non_exhaustive()
40    }
41}
42
43fn send_a2a_status_event(
44    tx: &UnboundedSender<Event>,
45    task_id: &TaskId,
46    context_id: &ContextId,
47    state: &str,
48    is_final: bool,
49    request_id: &NumberOrString,
50) {
51    let event = json!({
52        "jsonrpc": "2.0",
53        "id": request_id,
54        "result": {
55            "kind": "status-update",
56            "taskId": task_id.as_str(),
57            "contextId": context_id.as_str(),
58            "status": {
59                "state": state,
60                "timestamp": chrono::Utc::now().to_rfc3339()
61            },
62            "final": is_final
63        }
64    });
65    if tx.send(Event::default().data(event.to_string())).is_err() {
66        tracing::trace!("Failed to send status event, channel closed");
67    }
68}
69
70pub async fn emit_run_started(
71    tx: &UnboundedSender<Event>,
72    webhook_context: &WebhookContext,
73    context_id: &ContextId,
74    task_id: &TaskId,
75    task_repo: &TaskRepository,
76    request_id: &NumberOrString,
77) {
78    let working_timestamp = chrono::Utc::now();
79    if let Err(e) = task_repo
80        .update_task_state(&task_id, TaskState::Working, &working_timestamp)
81        .await
82    {
83        tracing::error!(task_id = %task_id, error = %e, "Failed to update task state");
84        return;
85    }
86
87    send_a2a_status_event(tx, task_id, context_id, "working", false, request_id);
88
89    let a2a_event = A2AEventBuilder::task_status_update(
90        task_id.clone(),
91        context_id.clone(),
92        TaskState::Working,
93        None,
94    );
95    if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
96        tracing::error!(error = %e, "Failed to broadcast A2A working");
97    }
98
99    let event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
100    if let Err(e) = webhook_context.broadcast_agui(event).await {
101        tracing::error!(error = %e, "Failed to broadcast RUN_STARTED");
102    }
103}
104
105pub async fn process_events(params: ProcessEventsParams) {
106    let ProcessEventsParams {
107        tx,
108        mut chunk_rx,
109        task_id,
110        context_id,
111        message_id,
112        original_message,
113        agent_name,
114        context,
115        task_repo,
116        processor,
117        request_id,
118    } = params;
119
120    let webhook_context =
121        WebhookContext::new(context.user_id().as_str(), context.auth_token().as_str());
122
123    emit_run_started(
124        &tx,
125        &webhook_context,
126        &context_id,
127        &task_id,
128        &task_repo,
129        &request_id,
130    )
131    .await;
132
133    tracing::info!("Stream channel received, waiting for events...");
134
135    let mut text_state = TextStreamState::new().with_webhook_context(webhook_context.clone());
136
137    while let Some(event) = chunk_rx.recv().await {
138        match event {
139            StreamEvent::Text(text) => {
140                text_state.handle_text(text, message_id.as_str()).await;
141            },
142            StreamEvent::ToolCallStarted(tool_call) => {
143                let tool_call_id = tool_call.ai_tool_call_id.as_str();
144                let start_event = AgUiEventBuilder::tool_call_start(
145                    tool_call_id,
146                    &tool_call.name,
147                    Some(message_id.to_string()),
148                );
149                if let Err(e) = webhook_context.broadcast_agui(start_event).await {
150                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_START");
151                }
152
153                let args_json =
154                    serde_json::to_string(&tool_call.arguments).unwrap_or_else(|_| String::new());
155                let args_event = AgUiEventBuilder::tool_call_args(tool_call_id, &args_json);
156                if let Err(e) = webhook_context.broadcast_agui(args_event).await {
157                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_ARGS");
158                }
159
160                let end_event = AgUiEventBuilder::tool_call_end(tool_call_id);
161                if let Err(e) = webhook_context.broadcast_agui(end_event).await {
162                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_END");
163                }
164            },
165            StreamEvent::ToolResult { call_id, result } => {
166                let result_value =
167                    serde_json::to_value(&result).unwrap_or_else(|_| serde_json::Value::Null);
168                let result_event = AgUiEventBuilder::tool_call_result(
169                    &uuid::Uuid::new_v4().to_string(),
170                    &call_id,
171                    result_value,
172                );
173                if let Err(e) = webhook_context.broadcast_agui(result_event).await {
174                    tracing::error!(error = %e, "Failed to broadcast TOOL_CALL_RESULT");
175                }
176            },
177            StreamEvent::ExecutionStepUpdate { step } => {
178                let step_event = AgUiEventBuilder::execution_step(step.clone(), context_id.clone());
179                if let Err(e) = webhook_context.broadcast_agui(step_event).await {
180                    tracing::error!(error = %e, "Failed to broadcast execution_step");
181                }
182            },
183            StreamEvent::Complete {
184                full_text,
185                artifacts,
186            } => {
187                text_state.finalize(message_id.as_str()).await;
188
189                let complete_params = HandleCompleteParams {
190                    tx: &tx,
191                    webhook_context: &webhook_context,
192                    full_text,
193                    artifacts,
194                    task_id: &task_id,
195                    context_id: &context_id,
196                    id: message_id.as_str(),
197                    original_message: &original_message,
198                    agent_name: &agent_name,
199                    context: &context,
200                    auth_token: context.auth_token().as_str(),
201                    task_repo: &task_repo,
202                    processor: &processor,
203                };
204                handle_complete(complete_params).await;
205
206                send_a2a_status_event(&tx, &task_id, &context_id, "completed", true, &request_id);
207
208                let a2a_event = A2AEventBuilder::task_status_update(
209                    task_id.clone(),
210                    context_id.clone(),
211                    TaskState::Completed,
212                    None,
213                );
214                if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
215                    tracing::error!(error = %e, "Failed to broadcast A2A completed");
216                }
217
218                break;
219            },
220            StreamEvent::Error(error) => {
221                text_state.finalize(message_id.as_str()).await;
222                handle_error(
223                    &tx,
224                    &webhook_context,
225                    error,
226                    &task_id,
227                    &context_id,
228                    &task_repo,
229                )
230                .await;
231
232                send_a2a_status_event(&tx, &task_id, &context_id, "failed", true, &request_id);
233
234                let a2a_event = A2AEventBuilder::task_status_update(
235                    task_id.clone(),
236                    context_id.clone(),
237                    TaskState::Failed,
238                    None,
239                );
240                if let Err(e) = webhook_context.broadcast_a2a(a2a_event).await {
241                    tracing::error!(error = %e, "Failed to broadcast A2A failed");
242                }
243
244                break;
245            },
246        }
247    }
248
249    drop(tx);
250
251    tracing::info!("Stream event loop ended");
252}
253
254pub async fn handle_stream_creation_error(
255    webhook_context: &WebhookContext,
256    error: anyhow::Error,
257    task_id: &TaskId,
258    _context_id: &ContextId,
259    task_repo: &TaskRepository,
260) {
261    let error_msg = format!("Failed to create stream: {}", error);
262    tracing::error!(task_id = %task_id, error = %error, "Failed to create stream");
263
264    let failed_timestamp = chrono::Utc::now();
265    if let Err(e) = task_repo
266        .update_task_failed_with_error(task_id, &error_msg, &failed_timestamp)
267        .await
268    {
269        tracing::error!(task_id = %task_id, error = %e, "Failed to update task to failed state");
270    }
271
272    let error_event = AgUiEventBuilder::run_error(
273        format!("Failed to process message: {error}"),
274        Some("STREAM_CREATION_ERROR".to_string()),
275    );
276    if let Err(e) = webhook_context.broadcast_agui(error_event).await {
277        tracing::error!(error = %e, "Failed to broadcast RUN_ERROR");
278    }
279}