systemprompt_agent/services/mcp/
task_helper.rs1use crate::models::a2a::{
2 Artifact, Message, MessageRole, Part, Task, TaskState, TaskStatus, TextPart,
3};
4use crate::repository::context::ContextRepository;
5use crate::repository::task::TaskRepository;
6use crate::services::MessageService;
7use rmcp::ErrorData as McpError;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
10use systemprompt_models::{Config, TaskMetadata};
11
12#[derive(Debug)]
13pub struct TaskResult {
14 pub task_id: TaskId,
15 pub is_owner: bool,
16}
17
18pub async fn ensure_task_exists(
19 db_pool: &DbPool,
20 request_context: &mut systemprompt_models::execution::context::RequestContext,
21 tool_name: &str,
22 mcp_server_name: &str,
23) -> Result<TaskResult, McpError> {
24 if let Some(task_id) = request_context.task_id() {
25 tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
26 return Ok(TaskResult {
27 task_id: task_id.clone(),
28 is_owner: false,
29 });
30 }
31
32 let context_id = request_context.context_id();
33 let context_repo = ContextRepository::new(db_pool).map_err(|e| {
34 McpError::internal_error(format!("Failed to create context repository: {e}"), None)
35 })?;
36
37 let context_id = if context_id.is_empty() {
38 if let Ok(Some(existing)) = context_repo
39 .find_by_session_id(request_context.session_id())
40 .await
41 {
42 tracing::debug!(
43 context_id = %existing.context_id,
44 session_id = %request_context.session_id(),
45 "Reusing existing context for MCP session"
46 );
47 request_context.execution.context_id = existing.context_id.clone();
48 existing.context_id
49 } else {
50 let new_context_id = context_repo
51 .create_context(
52 request_context.user_id(),
53 Some(request_context.session_id()),
54 &format!("MCP Session: {}", request_context.session_id()),
55 )
56 .await
57 .map_err(|e| {
58 tracing::error!(error = %e, "Failed to auto-create context for MCP session");
59 McpError::internal_error(format!("Failed to create context: {e}"), None)
60 })?;
61
62 request_context.execution.context_id = new_context_id.clone();
63 tracing::info!(
64 context_id = %new_context_id,
65 session_id = %request_context.session_id(),
66 "Auto-created context for MCP session"
67 );
68 new_context_id
69 }
70 } else {
71 let old_context_id = context_id.clone();
72 match context_repo
73 .validate_context_ownership(&old_context_id, request_context.user_id())
74 .await
75 {
76 Ok(()) => old_context_id,
77 Err(e) => {
78 tracing::warn!(
79 context_id = %old_context_id,
80 user_id = %request_context.user_id(),
81 error = %e,
82 "Context validation failed, auto-creating new context"
83 );
84 let new_context_id = context_repo
85 .create_context(
86 request_context.user_id(),
87 Some(request_context.session_id()),
88 &format!("MCP Session: {}", request_context.session_id()),
89 )
90 .await
91 .map_err(|e| {
92 tracing::error!(error = %e, "Failed to auto-create replacement context");
93 McpError::internal_error(format!("Failed to create context: {e}"), None)
94 })?;
95
96 request_context.execution.context_id = new_context_id.clone();
97 tracing::info!(
98 old_context_id = %old_context_id,
99 new_context_id = %new_context_id,
100 session_id = %request_context.session_id(),
101 "Auto-created replacement context for invalid context_id"
102 );
103 new_context_id
104 },
105 }
106 };
107
108 let task_repo = TaskRepository::new(db_pool).map_err(|e| {
109 McpError::internal_error(format!("Failed to create task repository: {e}"), None)
110 })?;
111
112 let task_id = TaskId::generate();
113
114 let agent_name = request_context.agent_name().to_string();
115
116 let metadata = TaskMetadata::new_mcp_execution(
117 agent_name.clone(),
118 tool_name.to_string(),
119 mcp_server_name.to_string(),
120 );
121
122 let task = Task {
123 id: task_id.clone(),
124 context_id: context_id.clone(),
125 status: TaskStatus {
126 state: TaskState::Submitted,
127 message: None,
128 timestamp: Some(chrono::Utc::now()),
129 },
130 history: None,
131 artifacts: None,
132 metadata: Some(metadata),
133 created_at: Some(chrono::Utc::now()),
134 last_modified: Some(chrono::Utc::now()),
135 };
136
137 task_repo
138 .create_task(crate::repository::task::RepoCreateTaskParams {
139 task: &task,
140 user_id: request_context.user_id(),
141 session_id: request_context.session_id(),
142 trace_id: request_context.trace_id(),
143 agent_name: &agent_name,
144 })
145 .await
146 .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;
147
148 request_context.execution.task_id = Some(task_id.clone());
149
150 tracing::info!(
151 task_id = %task_id.as_str(),
152 tool = %tool_name,
153 agent = %agent_name,
154 "Task created"
155 );
156
157 Ok(TaskResult {
158 task_id,
159 is_owner: true,
160 })
161}
162
163pub async fn complete_task(
164 db_pool: &DbPool,
165 task_id: &TaskId,
166 jwt_token: &str,
167) -> Result<(), McpError> {
168 if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
169 tracing::error!(
170 task_id = %task_id.as_str(),
171 error = ?e,
172 "Webhook broadcast failed"
173 );
174 }
175
176 Ok(())
177}
178
179async fn trigger_task_completion_broadcast(
180 db_pool: &DbPool,
181 task_id: &TaskId,
182 jwt_token: &str,
183) -> Result<(), McpError> {
184 let task_repo = TaskRepository::new(db_pool).map_err(|e| {
185 McpError::internal_error(format!("Failed to create task repository: {e}"), None)
186 })?;
187
188 let task_info = task_repo
189 .get_task_context_info(task_id)
190 .await
191 .map_err(|e| {
192 McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
193 })?;
194
195 if let Some(info) = task_info {
196 let context_id = info.context_id;
197 let user_id = info.user_id;
198
199 let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
200 let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
201 let webhook_payload = serde_json::json!({
202 "event_type": "task_completed",
203 "entity_id": task_id.as_str(),
204 "context_id": context_id,
205 "user_id": user_id,
206 });
207
208 tracing::debug!(
209 task_id = %task_id.as_str(),
210 context_id = %context_id,
211 "Webhook triggering"
212 );
213
214 let client = reqwest::Client::new();
215 match client
216 .post(webhook_url)
217 .header("Authorization", format!("Bearer {jwt_token}"))
218 .json(&webhook_payload)
219 .timeout(std::time::Duration::from_secs(5))
220 .send()
221 .await
222 {
223 Ok(response) => {
224 if response.status().is_success() {
225 tracing::debug!(
226 task_id = %task_id.as_str(),
227 "Task completed, webhook success"
228 );
229 } else {
230 let status = response.status();
231 tracing::error!(
232 task_id = %task_id.as_str(),
233 status = %status,
234 "Task completed, webhook failed"
235 );
236 }
237 },
238 Err(e) => {
239 tracing::error!(
240 task_id = %task_id.as_str(),
241 error = %e,
242 "Webhook failed"
243 );
244 },
245 }
246 }
247
248 Ok(())
249}
250
251#[derive(Debug)]
252pub struct SaveMessagesForToolExecutionParams<'a> {
253 pub db_pool: &'a DbPool,
254 pub task_id: &'a TaskId,
255 pub context_id: &'a ContextId,
256 pub tool_name: &'a str,
257 pub tool_result: &'a str,
258 pub artifact: Option<&'a Artifact>,
259 pub user_id: &'a UserId,
260 pub session_id: &'a SessionId,
261 pub trace_id: &'a TraceId,
262}
263
264pub async fn save_messages_for_tool_execution(
265 params: SaveMessagesForToolExecutionParams<'_>,
266) -> Result<(), McpError> {
267 let SaveMessagesForToolExecutionParams {
268 db_pool,
269 task_id,
270 context_id,
271 tool_name,
272 tool_result,
273 artifact,
274 user_id,
275 session_id,
276 trace_id,
277 } = params;
278 let message_service = MessageService::new(db_pool).map_err(|e| {
279 McpError::internal_error(format!("Failed to create message service: {e}"), None)
280 })?;
281
282 let user_message = Message {
283 role: MessageRole::User,
284 parts: vec![Part::Text(TextPart {
285 text: format!("Execute tool: {tool_name}"),
286 })],
287 message_id: MessageId::generate(),
288 task_id: Some(task_id.clone()),
289 context_id: context_id.clone(),
290 metadata: None,
291 extensions: None,
292 reference_task_ids: None,
293 };
294
295 let agent_text = artifact.map_or_else(
296 || format!("Tool execution completed. Result: {tool_result}"),
297 |artifact| {
298 format!(
299 "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
300 tool_result, artifact.id, artifact.metadata.artifact_type
301 )
302 },
303 );
304
305 let agent_message = Message {
306 role: MessageRole::Agent,
307 parts: vec![Part::Text(TextPart { text: agent_text })],
308 message_id: MessageId::generate(),
309 task_id: Some(task_id.clone()),
310 context_id: context_id.clone(),
311 metadata: None,
312 extensions: None,
313 reference_task_ids: None,
314 };
315
316 message_service
317 .persist_messages(crate::services::PersistMessagesParams {
318 task_id,
319 context_id,
320 messages: vec![user_message, agent_message],
321 user_id: Some(user_id),
322 session_id,
323 trace_id,
324 })
325 .await
326 .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;
327
328 Ok(())
329}