Skip to main content

systemprompt_agent/services/a2a_server/processing/message/
message_handler.rs

1use anyhow::{anyhow, Result};
2use uuid::Uuid;
3
4use super::persistence::{broadcast_completion, persist_completed_task};
5use super::stream_processor::StreamProcessor;
6use super::{MessageProcessor, StreamEvent};
7use crate::models::a2a::{Message, Part, Task, TaskState, TaskStatus, TextPart};
8use crate::services::a2a_server::processing::task_builder::build_completed_task;
9use crate::services::a2a_server::streaming::broadcast::broadcast_task_created;
10use crate::services::a2a_server::streaming::webhook_client::broadcast_agui_event;
11use systemprompt_identifiers::{MessageId, SessionId, TaskId, TraceId, UserId};
12use systemprompt_models::{AgUiEventBuilder, AgUiMessageRole, RequestContext, TaskMetadata};
13
14impl MessageProcessor {
15    pub async fn handle_message(
16        &self,
17        message: Message,
18        agent_name: &str,
19        context: &RequestContext,
20    ) -> Result<Task> {
21        tracing::info!(agent_name = %agent_name, "Handling non-streaming message");
22
23        let agent_runtime = self.load_agent_runtime(agent_name).await?;
24
25        self.context_repo
26            .get_context(&message.context_id, context.user_id())
27            .await
28            .map_err(|e| {
29                anyhow!(
30                    "Context validation failed - context_id: {}, user_id: {}, error: {}",
31                    message.context_id,
32                    context.user_id(),
33                    e
34                )
35            })?;
36
37        tracing::info!(
38            context_id = %message.context_id,
39            user_id = %context.user_id(),
40            "Context validated"
41        );
42
43        let task_id = match message.task_id.clone() {
44            Some(existing_task_id) => {
45                tracing::info!(task_id = %existing_task_id, "Continuing existing task");
46                existing_task_id
47            },
48            None => {
49                let new_task_id = TaskId::new(Uuid::new_v4().to_string());
50                tracing::info!(task_id = %new_task_id, "Starting NEW task with generated ID");
51                new_task_id
52            },
53        };
54
55        let metadata = TaskMetadata::new_agent_message(agent_name.to_string());
56
57        let task = Task {
58            id: task_id.clone(),
59            context_id: message.context_id.clone(),
60            status: TaskStatus {
61                state: TaskState::Submitted,
62                message: None,
63                timestamp: Some(chrono::Utc::now()),
64            },
65            history: None,
66            artifacts: None,
67            metadata: Some(metadata),
68            kind: "task".to_string(),
69        };
70
71        if let Err(e) = self
72            .task_repo
73            .create_task(
74                &task,
75                &UserId::new(context.user_id().as_str()),
76                &SessionId::new(context.session_id().as_str()),
77                &TraceId::new(context.trace_id().as_str()),
78                agent_name,
79            )
80            .await
81        {
82            return Err(anyhow!("Failed to persist task at start: {}", e));
83        }
84
85        tracing::info!(task_id = %task_id, "Task persisted to database");
86
87        broadcast_task_created(
88            &task_id,
89            &message.context_id,
90            context.user_id().as_str(),
91            &message,
92            agent_name,
93            context.auth_token().as_str(),
94        )
95        .await;
96
97        let working_timestamp = chrono::Utc::now();
98        if let Err(e) = self
99            .task_repo
100            .update_task_state(&task_id, TaskState::Working, &working_timestamp)
101            .await
102        {
103            tracing::error!(task_id = %task_id, error = %e, "Failed to mark task as working");
104        }
105
106        let stream_processor = StreamProcessor {
107            ai_service: self.ai_service.clone(),
108            context_service: self.context_service.clone(),
109            skill_service: self.skill_service.clone(),
110            execution_step_repo: self.execution_step_repo.clone(),
111        };
112
113        let mut chunk_rx = stream_processor
114            .process_message_stream(
115                &message,
116                &agent_runtime,
117                agent_name,
118                context,
119                task_id.clone(),
120            )
121            .await?;
122
123        let mut response_text = String::new();
124        let mut tool_artifacts = Vec::new();
125
126        while let Some(event) = chunk_rx.recv().await {
127            match event {
128                StreamEvent::Text(text) => {
129                    response_text.push_str(&text);
130                },
131                StreamEvent::Complete {
132                    full_text,
133                    artifacts,
134                } => {
135                    response_text = full_text;
136                    tool_artifacts = artifacts;
137                },
138                StreamEvent::Error(error) => {
139                    let error_event = AgUiEventBuilder::run_error(
140                        error.clone(),
141                        Some("EXECUTION_ERROR".to_string()),
142                    );
143                    if let Err(e) = broadcast_agui_event(
144                        context.user_id().as_str(),
145                        error_event,
146                        context.auth_token().as_str(),
147                    )
148                    .await
149                    {
150                        tracing::debug!(error = %e, "Failed to broadcast error event");
151                    }
152                    return Err(anyhow!(error));
153                },
154                _ => {},
155            }
156        }
157
158        let task = build_completed_task(
159            task_id,
160            message.context_id.clone(),
161            response_text.clone(),
162            message.clone(),
163            tool_artifacts,
164        );
165
166        let agent_message = task.status.message.clone().unwrap_or_else(|| {
167            let client_message_id = message
168                .metadata
169                .as_ref()
170                .and_then(|m| m.get("clientMessageId"))
171                .cloned();
172
173            let metadata = client_message_id.map(|id| serde_json::json!({"clientMessageId": id}));
174
175            Message {
176                role: "agent".to_string(),
177                parts: vec![Part::Text(TextPart {
178                    text: response_text.clone(),
179                })],
180                id: MessageId::generate(),
181                task_id: Some(task.id.clone()),
182                context_id: task.context_id.clone(),
183                kind: "message".to_string(),
184                metadata,
185                extensions: None,
186                reference_task_ids: None,
187            }
188        });
189
190        if context.user_type() == systemprompt_models::auth::UserType::Anon {
191            tracing::warn!(
192                context_id = %message.context_id,
193                session_id = %context.session_id(),
194                "Saving messages for anonymous user"
195            );
196        }
197
198        if let Err(e) = persist_completed_task(
199            &task,
200            &message,
201            &agent_message,
202            context,
203            &self.task_repo,
204            &self.db_pool,
205            false,
206        )
207        .await
208        {
209            let error_msg = format!("Failed to persist completed task: {}", e);
210            tracing::error!(task_id = %task.id, error = %e, "Failed to persist completed task");
211
212            let failed_timestamp = chrono::Utc::now();
213            if let Err(update_err) = self
214                .task_repo
215                .update_task_failed_with_error(&task.id, &error_msg, &failed_timestamp)
216                .await
217            {
218                tracing::error!(task_id = %task.id, error = %update_err, "Failed to update task to failed state");
219            }
220
221            return Err(e);
222        }
223
224        broadcast_completion(&task, context).await;
225
226        let user_id = context.user_id().as_str();
227        let auth_token = context.auth_token().as_str();
228        let context_id = task.context_id.clone();
229        let task_id = task.id.clone();
230        let message_id = agent_message.id.clone();
231
232        let start_event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
233        if let Err(e) = broadcast_agui_event(user_id, start_event, auth_token).await {
234            tracing::debug!(error = %e, "Failed to broadcast run_started event");
235        }
236
237        let msg_start = AgUiEventBuilder::text_message_start(
238            message_id.to_string(),
239            AgUiMessageRole::Assistant,
240        );
241        if let Err(e) = broadcast_agui_event(user_id, msg_start, auth_token).await {
242            tracing::debug!(error = %e, "Failed to broadcast text_message_start event");
243        }
244
245        let msg_content =
246            AgUiEventBuilder::text_message_content(message_id.to_string(), &response_text);
247        if let Err(e) = broadcast_agui_event(user_id, msg_content, auth_token).await {
248            tracing::debug!(error = %e, "Failed to broadcast text_message_content event");
249        }
250
251        let msg_end = AgUiEventBuilder::text_message_end(message_id.to_string());
252        if let Err(e) = broadcast_agui_event(user_id, msg_end, auth_token).await {
253            tracing::debug!(error = %e, "Failed to broadcast text_message_end event");
254        }
255
256        let result = serde_json::json!({
257            "text": response_text,
258            "artifacts": task.artifacts,
259        });
260        let finish_event = AgUiEventBuilder::run_finished(context_id, task_id, Some(result));
261        if let Err(e) = broadcast_agui_event(user_id, finish_event, auth_token).await {
262            tracing::debug!(error = %e, "Failed to broadcast run_finished event");
263        }
264
265        Ok(task)
266    }
267}