Skip to main content

systemprompt_agent/services/a2a_server/processing/message/
message_handler.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use uuid::Uuid;
5
6use super::persistence::{broadcast_completion, persist_completed_task};
7use super::stream_processor::StreamProcessor;
8use super::{MessageProcessor, StreamEvent};
9use crate::models::a2a::{Message, Part, Task, TaskState, TaskStatus, TextPart};
10use crate::services::a2a_server::processing::task_builder::build_completed_task;
11use crate::services::a2a_server::streaming::broadcast::{
12    BroadcastTaskCreatedParams, broadcast_task_created,
13};
14use crate::services::a2a_server::streaming::webhook_client::broadcast_agui_event;
15use systemprompt_identifiers::{MessageId, SessionId, TaskId, TraceId, UserId};
16use systemprompt_models::{AgUiEventBuilder, AgUiMessageRole, RequestContext, TaskMetadata};
17
18impl MessageProcessor {
19    pub async fn handle_message(
20        &self,
21        message: Message,
22        agent_name: &str,
23        context: &RequestContext,
24    ) -> Result<Task> {
25        tracing::info!(agent_name = %agent_name, "Handling non-streaming message");
26
27        let agent_runtime = self.load_agent_runtime(agent_name).await?;
28
29        self.context_repo
30            .get_context(&message.context_id, context.user_id())
31            .await
32            .map_err(|e| {
33                anyhow!(
34                    "Context validation failed - context_id: {}, user_id: {}, error: {}",
35                    message.context_id,
36                    context.user_id(),
37                    e
38                )
39            })?;
40
41        tracing::info!(
42            context_id = %message.context_id,
43            user_id = %context.user_id(),
44            "Context validated"
45        );
46
47        let task_id = message.task_id.clone().map_or_else(
48            || {
49                let new_task_id = TaskId::new(Uuid::new_v4().to_string());
50                tracing::info!(task_id = %new_task_id, "Starting NEW task with generated ID");
51                new_task_id
52            },
53            |existing_task_id| {
54                tracing::info!(task_id = %existing_task_id, "Continuing existing task");
55                existing_task_id
56            },
57        );
58
59        let metadata = TaskMetadata::new_agent_message(agent_name.to_string());
60
61        let task = Task {
62            id: task_id.clone(),
63            context_id: message.context_id.clone(),
64            status: TaskStatus {
65                state: TaskState::Submitted,
66                message: None,
67                timestamp: Some(chrono::Utc::now()),
68            },
69            history: None,
70            artifacts: None,
71            metadata: Some(metadata),
72            kind: "task".to_string(),
73        };
74
75        if let Err(e) = self
76            .task_repo
77            .create_task(crate::repository::task::RepoCreateTaskParams {
78                task: &task,
79                user_id: &UserId::new(context.user_id().as_str()),
80                session_id: &SessionId::new(context.session_id().as_str()),
81                trace_id: &TraceId::new(context.trace_id().as_str()),
82                agent_name,
83            })
84            .await
85        {
86            return Err(anyhow!("Failed to persist task at start: {}", e));
87        }
88
89        tracing::info!(task_id = %task_id, "Task persisted to database");
90
91        broadcast_task_created(BroadcastTaskCreatedParams {
92            task_id: &task_id,
93            context_id: &message.context_id,
94            user_id: context.user_id().as_str(),
95            user_message: &message,
96            agent_name,
97            token: context.auth_token().as_str(),
98        })
99        .await;
100
101        let working_timestamp = chrono::Utc::now();
102        if let Err(e) = self
103            .task_repo
104            .update_task_state(&task_id, TaskState::Working, &working_timestamp)
105            .await
106        {
107            tracing::error!(task_id = %task_id, error = %e, "Failed to mark task as working");
108        }
109
110        let stream_processor = StreamProcessor {
111            ai_service: Arc::clone(&self.ai_service),
112            context_service: self.context_service.clone(),
113            skill_service: Arc::clone(&self.skill_service),
114            execution_step_repo: Arc::clone(&self.execution_step_repo),
115        };
116
117        let mut chunk_rx = stream_processor
118            .process_message_stream(super::ProcessMessageStreamParams {
119                a2a_message: &message,
120                agent_runtime: &agent_runtime,
121                agent_name,
122                context,
123                task_id: task_id.clone(),
124            })
125            .await?;
126
127        let mut response_text = String::new();
128        let mut tool_artifacts = Vec::new();
129
130        while let Some(event) = chunk_rx.recv().await {
131            match event {
132                StreamEvent::Text(text) => {
133                    response_text.push_str(&text);
134                },
135                StreamEvent::Complete {
136                    full_text,
137                    artifacts,
138                } => {
139                    response_text = full_text;
140                    tool_artifacts = artifacts;
141                },
142                StreamEvent::Error(error) => {
143                    let error_event = AgUiEventBuilder::run_error(
144                        error.clone(),
145                        Some("EXECUTION_ERROR".to_string()),
146                    );
147                    if let Err(e) = broadcast_agui_event(
148                        context.user_id().as_str(),
149                        error_event,
150                        context.auth_token().as_str(),
151                    )
152                    .await
153                    {
154                        tracing::debug!(error = %e, "Failed to broadcast error event");
155                    }
156                    return Err(anyhow!(error));
157                },
158                _ => {},
159            }
160        }
161
162        let task = build_completed_task(
163            task_id,
164            message.context_id.clone(),
165            response_text.clone(),
166            message.clone(),
167            tool_artifacts,
168        );
169
170        let agent_message = task.status.message.clone().unwrap_or_else(|| {
171            let client_message_id = message
172                .metadata
173                .as_ref()
174                .and_then(|m| m.get("clientMessageId"))
175                .cloned();
176
177            let metadata = client_message_id.map(|id| serde_json::json!({"clientMessageId": id}));
178
179            Message {
180                role: "agent".to_string(),
181                parts: vec![Part::Text(TextPart {
182                    text: response_text.clone(),
183                })],
184                id: MessageId::generate(),
185                task_id: Some(task.id.clone()),
186                context_id: task.context_id.clone(),
187                kind: "message".to_string(),
188                metadata,
189                extensions: None,
190                reference_task_ids: None,
191            }
192        });
193
194        if context.user_type() == systemprompt_models::auth::UserType::Anon {
195            tracing::warn!(
196                context_id = %message.context_id,
197                session_id = %context.session_id(),
198                "Saving messages for anonymous user"
199            );
200        }
201
202        if let Err(e) = persist_completed_task(super::persistence::PersistCompletedTaskParams {
203            task: &task,
204            user_message: &message,
205            agent_message: &agent_message,
206            context,
207            task_repo: &self.task_repo,
208            db_pool: &self.db_pool,
209            artifacts_already_published: false,
210        })
211        .await
212        {
213            let error_msg = format!("Failed to persist completed task: {}", e);
214            tracing::error!(task_id = %task.id, error = %e, "Failed to persist completed task");
215
216            let failed_timestamp = chrono::Utc::now();
217            if let Err(update_err) = self
218                .task_repo
219                .update_task_failed_with_error(&task.id, &error_msg, &failed_timestamp)
220                .await
221            {
222                tracing::error!(task_id = %task.id, error = %update_err, "Failed to update task to failed state");
223            }
224
225            return Err(e);
226        }
227
228        broadcast_completion(&task, context).await;
229
230        let user_id = context.user_id().as_str();
231        let auth_token = context.auth_token().as_str();
232        let context_id = task.context_id.clone();
233        let task_id = task.id.clone();
234        let message_id = agent_message.id.clone();
235
236        let start_event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
237        if let Err(e) = broadcast_agui_event(user_id, start_event, auth_token).await {
238            tracing::debug!(error = %e, "Failed to broadcast run_started event");
239        }
240
241        let msg_start = AgUiEventBuilder::text_message_start(
242            message_id.to_string(),
243            AgUiMessageRole::Assistant,
244        );
245        if let Err(e) = broadcast_agui_event(user_id, msg_start, auth_token).await {
246            tracing::debug!(error = %e, "Failed to broadcast text_message_start event");
247        }
248
249        let msg_content =
250            AgUiEventBuilder::text_message_content(message_id.to_string(), &response_text);
251        if let Err(e) = broadcast_agui_event(user_id, msg_content, auth_token).await {
252            tracing::debug!(error = %e, "Failed to broadcast text_message_content event");
253        }
254
255        let msg_end = AgUiEventBuilder::text_message_end(message_id.to_string());
256        if let Err(e) = broadcast_agui_event(user_id, msg_end, auth_token).await {
257            tracing::debug!(error = %e, "Failed to broadcast text_message_end event");
258        }
259
260        let result = serde_json::json!({
261            "text": response_text,
262            "artifacts": task.artifacts,
263        });
264        let finish_event = AgUiEventBuilder::run_finished(context_id, task_id, Some(result));
265        if let Err(e) = broadcast_agui_event(user_id, finish_event, auth_token).await {
266            tracing::debug!(error = %e, "Failed to broadcast run_finished event");
267        }
268
269        Ok(task)
270    }
271}