Skip to main content

systemprompt_agent/services/a2a_server/processing/message/
message_handler.rs

1use std::sync::Arc;
2
3use anyhow::{Result, anyhow};
4use uuid::Uuid;
5
6use super::persistence::{broadcast_completion, persist_completed_task};
7use super::stream_processor::StreamProcessor;
8use super::{MessageProcessor, StreamEvent};
9use crate::models::a2a::{Message, MessageRole, Part, Task, TaskState, TaskStatus, TextPart};
10use crate::services::a2a_server::processing::task_builder::build_completed_task;
11use crate::services::a2a_server::streaming::broadcast::{
12    BroadcastTaskCreatedParams, broadcast_task_created,
13};
14use crate::services::a2a_server::streaming::webhook_client::broadcast_agui_event;
15use systemprompt_identifiers::{MessageId, SessionId, TaskId, TraceId, UserId};
16use systemprompt_models::{AgUiEventBuilder, AgUiMessageRole, RequestContext, TaskMetadata};
17
18impl MessageProcessor {
19    pub async fn handle_message(
20        &self,
21        message: Message,
22        agent_name: &str,
23        context: &RequestContext,
24    ) -> Result<Task> {
25        tracing::info!(agent_name = %agent_name, "Handling non-streaming message");
26
27        let agent_runtime = self.load_agent_runtime(agent_name).await?;
28
29        self.context_repo
30            .get_context(&message.context_id, context.user_id())
31            .await
32            .map_err(|e| {
33                anyhow!(
34                    "Context validation failed - context_id: {}, user_id: {}, error: {}",
35                    message.context_id,
36                    context.user_id(),
37                    e
38                )
39            })?;
40
41        tracing::info!(
42            context_id = %message.context_id,
43            user_id = %context.user_id(),
44            "Context validated"
45        );
46
47        let task_id = message.task_id.clone().map_or_else(
48            || {
49                let new_task_id = TaskId::new(Uuid::new_v4().to_string());
50                tracing::info!(task_id = %new_task_id, "Starting NEW task with generated ID");
51                new_task_id
52            },
53            |existing_task_id| {
54                tracing::info!(task_id = %existing_task_id, "Continuing existing task");
55                existing_task_id
56            },
57        );
58
59        let metadata = TaskMetadata::new_agent_message(agent_name.to_string());
60
61        let task = Task {
62            id: task_id.clone(),
63            context_id: message.context_id.clone(),
64            status: TaskStatus {
65                state: TaskState::Submitted,
66                message: None,
67                timestamp: Some(chrono::Utc::now()),
68            },
69            history: None,
70            artifacts: None,
71            metadata: Some(metadata),
72            created_at: Some(chrono::Utc::now()),
73            last_modified: Some(chrono::Utc::now()),
74        };
75
76        if let Err(e) = self
77            .task_repo
78            .create_task(crate::repository::task::RepoCreateTaskParams {
79                task: &task,
80                user_id: &UserId::new(context.user_id().as_str()),
81                session_id: &SessionId::new(context.session_id().as_str()),
82                trace_id: &TraceId::new(context.trace_id().as_str()),
83                agent_name,
84            })
85            .await
86        {
87            return Err(anyhow!("Failed to persist task at start: {}", e));
88        }
89
90        tracing::info!(task_id = %task_id, "Task persisted to database");
91
92        broadcast_task_created(BroadcastTaskCreatedParams {
93            task_id: &task_id,
94            context_id: &message.context_id,
95            user_id: context.user_id().as_str(),
96            user_message: &message,
97            agent_name,
98            token: context.auth_token().as_str(),
99        })
100        .await;
101
102        let working_timestamp = chrono::Utc::now();
103        if let Err(e) = self
104            .task_repo
105            .update_task_state(&task_id, TaskState::Working, &working_timestamp)
106            .await
107        {
108            tracing::error!(task_id = %task_id, error = %e, "Failed to mark task as working");
109        }
110
111        let stream_processor = StreamProcessor {
112            ai_service: Arc::clone(&self.ai_service),
113            context_service: self.context_service.clone(),
114            skill_service: Arc::clone(&self.skill_service),
115            execution_step_repo: Arc::clone(&self.execution_step_repo),
116        };
117
118        let mut chunk_rx = stream_processor
119            .process_message_stream(super::ProcessMessageStreamParams {
120                a2a_message: &message,
121                agent_runtime: &agent_runtime,
122                agent_name,
123                context,
124                task_id: task_id.clone(),
125            })
126            .await?;
127
128        let mut response_text = String::new();
129        let mut tool_artifacts = Vec::new();
130
131        while let Some(event) = chunk_rx.recv().await {
132            match event {
133                StreamEvent::Text(text) => {
134                    response_text.push_str(&text);
135                },
136                StreamEvent::Complete {
137                    full_text,
138                    artifacts,
139                } => {
140                    response_text = full_text;
141                    tool_artifacts = artifacts;
142                },
143                StreamEvent::Error(error) => {
144                    let error_event = AgUiEventBuilder::run_error(
145                        error.clone(),
146                        Some("EXECUTION_ERROR".to_string()),
147                    );
148                    if let Err(e) = broadcast_agui_event(
149                        context.user_id(),
150                        error_event,
151                        context.auth_token().as_str(),
152                    )
153                    .await
154                    {
155                        tracing::debug!(error = %e, "Failed to broadcast error event");
156                    }
157                    return Err(anyhow!(error));
158                },
159                _ => {},
160            }
161        }
162
163        let task = build_completed_task(
164            task_id,
165            message.context_id.clone(),
166            response_text.clone(),
167            message.clone(),
168            tool_artifacts,
169        );
170
171        let agent_message = task.status.message.clone().unwrap_or_else(|| {
172            let client_message_id = message
173                .metadata
174                .as_ref()
175                .and_then(|m| m.get("clientMessageId"))
176                .cloned();
177
178            let metadata = client_message_id.map(|id| serde_json::json!({"clientMessageId": id}));
179
180            Message {
181                role: MessageRole::Agent,
182                parts: vec![Part::Text(TextPart {
183                    text: response_text.clone(),
184                })],
185                message_id: MessageId::generate(),
186                task_id: Some(task.id.clone()),
187                context_id: task.context_id.clone(),
188                metadata,
189                extensions: None,
190                reference_task_ids: None,
191            }
192        });
193
194        if context.user_type() == systemprompt_models::auth::UserType::Anon {
195            tracing::warn!(
196                context_id = %message.context_id,
197                session_id = %context.session_id(),
198                "Saving messages for anonymous user"
199            );
200        }
201
202        if let Err(e) = persist_completed_task(super::persistence::PersistCompletedTaskParams {
203            task: &task,
204            user_message: &message,
205            agent_message: &agent_message,
206            context,
207            task_repo: &self.task_repo,
208            db_pool: &self.db_pool,
209            artifacts_already_published: false,
210        })
211        .await
212        {
213            let error_msg = format!("Failed to persist completed task: {}", e);
214            tracing::error!(task_id = %task.id, error = %e, "Failed to persist completed task");
215
216            let failed_timestamp = chrono::Utc::now();
217            if let Err(update_err) = self
218                .task_repo
219                .update_task_failed_with_error(&task.id, &error_msg, &failed_timestamp)
220                .await
221            {
222                tracing::error!(task_id = %task.id, error = %update_err, "Failed to update task to failed state");
223            }
224
225            return Err(e);
226        }
227
228        broadcast_completion(&task, context).await;
229
230        let user_id = context.user_id();
231        let auth_token = context.auth_token().as_str();
232        let context_id = task.context_id.clone();
233        let task_id = task.id.clone();
234        let message_id = agent_message.message_id.clone();
235
236        let start_event = AgUiEventBuilder::run_started(context_id.clone(), task_id.clone(), None);
237        if let Err(e) = broadcast_agui_event(user_id, start_event, auth_token).await {
238            tracing::debug!(error = %e, "Failed to broadcast run_started event");
239        }
240
241        let msg_start = AgUiEventBuilder::text_message_start(
242            message_id.to_string(),
243            AgUiMessageRole::Assistant,
244        );
245        if let Err(e) = broadcast_agui_event(user_id, msg_start, auth_token).await {
246            tracing::debug!(error = %e, "Failed to broadcast text_message_start event");
247        }
248
249        let msg_content =
250            AgUiEventBuilder::text_message_content(message_id.to_string(), &response_text);
251        if let Err(e) = broadcast_agui_event(user_id, msg_content, auth_token).await {
252            tracing::debug!(error = %e, "Failed to broadcast text_message_content event");
253        }
254
255        let msg_end = AgUiEventBuilder::text_message_end(message_id.to_string());
256        if let Err(e) = broadcast_agui_event(user_id, msg_end, auth_token).await {
257            tracing::debug!(error = %e, "Failed to broadcast text_message_end event");
258        }
259
260        let result = serde_json::json!({
261            "text": response_text,
262            "artifacts": task.artifacts,
263        });
264        let finish_event = AgUiEventBuilder::run_finished(context_id, task_id, Some(result));
265        if let Err(e) = broadcast_agui_event(user_id, finish_event, auth_token).await {
266            tracing::debug!(error = %e, "Failed to broadcast run_finished event");
267        }
268
269        Ok(task)
270    }
271}