systemprompt_agent/services/a2a_server/processing/message/
message_handler.rs1use anyhow::{anyhow, Result};
2use uuid::Uuid;
3
4use super::persistence::{broadcast_completion, persist_completed_task};
5use super::stream_processor::StreamProcessor;
6use super::{MessageProcessor, StreamEvent};
7use crate::models::a2a::{Message, Part, Task, TaskState, TaskStatus, TextPart};
8use crate::services::a2a_server::processing::task_builder::build_completed_task;
9use crate::services::a2a_server::streaming::broadcast::broadcast_task_created;
10use crate::services::a2a_server::streaming::webhook_client::broadcast_agui_event;
11use systemprompt_identifiers::{MessageId, SessionId, TaskId, TraceId, UserId};
12use systemprompt_models::{AgUiEventBuilder, AgUiMessageRole, RequestContext, TaskMetadata};
13
14impl MessageProcessor {
15 pub async fn handle_message(
16 &self,
17 message: Message,
18 agent_name: &str,
19 context: &RequestContext,
20 ) -> Result<Task> {
21 tracing::info!(agent_name = %agent_name, "Handling non-streaming message");
22
23 let agent_runtime = self.load_agent_runtime(agent_name).await?;
24
25 self.context_repo
26 .get_context(&message.context_id, context.user_id())
27 .await
28 .map_err(|e| {
29 anyhow!(
30 "Context validation failed - context_id: {}, user_id: {}, error: {}",
31 message.context_id,
32 context.user_id(),
33 e
34 )
35 })?;
36
37 tracing::info!(
38 context_id = %message.context_id,
39 user_id = %context.user_id(),
40 "Context validated"
41 );
42
43 let task_id = match message.task_id.clone() {
44 Some(existing_task_id) => {
45 tracing::info!(task_id = %existing_task_id, "Continuing existing task");
46 existing_task_id
47 },
48 None => {
49 let new_task_id = TaskId::new(Uuid::new_v4().to_string());
50 tracing::info!(task_id = %new_task_id, "Starting NEW task with generated ID");
51 new_task_id
52 },
53 };
54
55 let metadata = TaskMetadata::new_agent_message(agent_name.to_string());
56
57 let task = Task {
58 id: task_id.clone(),
59 context_id: message.context_id.clone(),
60 status: TaskStatus {
61 state: TaskState::Submitted,
62 message: None,
63 timestamp: Some(chrono::Utc::now()),
64 },
65 history: None,
66 artifacts: None,
67 metadata: Some(metadata),
68 kind: "task".to_string(),
69 };
70
71 if let Err(e) = self
72 .task_repo
73 .create_task(
74 &task,
75 &UserId::new(context.user_id().as_str()),
76 &SessionId::new(context.session_id().as_str()),
77 &TraceId::new(context.trace_id().as_str()),
78 agent_name,
79 )
80 .await
81 {
82 return Err(anyhow!("Failed to persist task at start: {}", e));
83 }
84
85 tracing::info!(task_id = %task_id, "Task persisted to database");
86
87 broadcast_task_created(
88 &task_id,
89 &message.context_id,
90 context.user_id().as_str(),
91 &message,
92 agent_name,
93 context.auth_token().as_str(),
94 )
95 .await;
96
97 let working_timestamp = chrono::Utc::now();
98 if let Err(e) = self
99 .task_repo
100 .update_task_state(&task_id, TaskState::Working, &working_timestamp)
101 .await
102 {
103 tracing::error!(task_id = %task_id, error = %e, "Failed to mark task as working");
104 }
105
106 let stream_processor = StreamProcessor {
107 ai_service: self.ai_service.clone(),
108 context_service: self.context_service.clone(),
109 skill_service: self.skill_service.clone(),
110 execution_step_repo: self.execution_step_repo.clone(),
111 };
112
113 let mut chunk_rx = stream_processor
114 .process_message_stream(
115 &message,
116 &agent_runtime,
117 agent_name,
118 context,
119 task_id.clone(),
120 )
121 .await?;
122
123 let mut response_text = String::new();
124 let mut tool_artifacts = Vec::new();
125
126 while let Some(event) = chunk_rx.recv().await {
127 match event {
128 StreamEvent::Text(text) => {
129 response_text.push_str(&text);
130 },
131 StreamEvent::Complete {
132 full_text,
133 artifacts,
134 } => {
135 response_text = full_text;
136 tool_artifacts = artifacts;
137 },
138 StreamEvent::Error(error) => {
139 let error_event = AgUiEventBuilder::run_error(
140 error.clone(),
141 Some("EXECUTION_ERROR".to_string()),
142 );
143 if let Err(e) = broadcast_agui_event(
144 context.user_id().as_str(),
145 error_event,
146 context.auth_token().as_str(),
147 )
148 .await
149 {
150 tracing::debug!(error = %e, "Failed to broadcast error event");
151 }
152 return Err(anyhow!(error));
153 },
154 _ => {},
155 }
156 }
157
158 let task = build_completed_task(
159 task_id,
160 message.context_id.clone(),
161 response_text.clone(),
162 message.clone(),
163 tool_artifacts,
164 );
165
166 let agent_message = task.status.message.clone().unwrap_or_else(|| {
167 let client_message_id = message
168 .metadata
169 .as_ref()
170 .and_then(|m| m.get("clientMessageId"))
171 .cloned();
172
173 let metadata = client_message_id.map(|id| serde_json::json!({"clientMessageId": id}));
174
175 Message {
176 role: "agent".to_string(),
177 parts: vec![Part::Text(TextPart {
178 text: response_text.clone(),
179 })],
180 id: MessageId::generate(),
181 task_id: Some(task.id.clone()),
182 context_id: task.context_id.clone(),
183 kind: "message".to_string(),
184 metadata,
185 extensions: None,
186 reference_task_ids: None,
187 }
188 });
189
190 if context.user_type() == systemprompt_models::auth::UserType::Anon {
191 tracing::warn!(
192 context_id = %message.context_id,
193 session_id = %context.session_id(),
194 "Saving messages for anonymous user"
195 );
196 }
197
198 if let Err(e) = persist_completed_task(
199 &task,
200 &message,
201 &agent_message,
202 context,
203 &self.task_repo,
204 &self.db_pool,
205 false,
206 )
207 .await
208 {
209 let error_msg = format!("Failed to persist completed task: {}", e);
210 tracing::error!(task_id = %task.id, error = %e, "Failed to persist completed task");
211
212 let failed_timestamp = chrono::Utc::now();
213 if let Err(update_err) = self
214 .task_repo
215 .update_task_failed_with_error(&task.id, &error_msg, &failed_timestamp)
216 .await
217 {
218 tracing::error!(task_id = %task.id, error = %update_err, "Failed to update task to failed state");
219 }
220
221 return Err(e);
222 }
223
224 broadcast_completion(&task, context).await;
225
226 let user_id = context.user_id().as_str();
227 let auth_token = context.auth_token().as_str();
228 let context_id = task.context_id.clone();
229 let task_id = task.id.clone();
230 let message_id = agent_message.id.clone();
231
232 let start_event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
233 if let Err(e) = broadcast_agui_event(user_id, start_event, auth_token).await {
234 tracing::debug!(error = %e, "Failed to broadcast run_started event");
235 }
236
237 let msg_start = AgUiEventBuilder::text_message_start(
238 message_id.to_string(),
239 AgUiMessageRole::Assistant,
240 );
241 if let Err(e) = broadcast_agui_event(user_id, msg_start, auth_token).await {
242 tracing::debug!(error = %e, "Failed to broadcast text_message_start event");
243 }
244
245 let msg_content =
246 AgUiEventBuilder::text_message_content(message_id.to_string(), &response_text);
247 if let Err(e) = broadcast_agui_event(user_id, msg_content, auth_token).await {
248 tracing::debug!(error = %e, "Failed to broadcast text_message_content event");
249 }
250
251 let msg_end = AgUiEventBuilder::text_message_end(message_id.to_string());
252 if let Err(e) = broadcast_agui_event(user_id, msg_end, auth_token).await {
253 tracing::debug!(error = %e, "Failed to broadcast text_message_end event");
254 }
255
256 let result = serde_json::json!({
257 "text": response_text,
258 "artifacts": task.artifacts,
259 });
260 let finish_event = AgUiEventBuilder::run_finished(context_id, task_id, Some(result));
261 if let Err(e) = broadcast_agui_event(user_id, finish_event, auth_token).await {
262 tracing::debug!(error = %e, "Failed to broadcast run_finished event");
263 }
264
265 Ok(task)
266 }
267}