systemprompt_agent/services/mcp/
task_helper.rs1use crate::models::a2a::{Artifact, Message, Part, Task, TaskState, TaskStatus, TextPart};
2use crate::repository::context::ContextRepository;
3use crate::repository::task::TaskRepository;
4use crate::services::MessageService;
5use rmcp::ErrorData as McpError;
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
8use systemprompt_models::{Config, TaskMetadata};
9
10#[derive(Debug)]
11pub struct TaskResult {
12 pub task_id: TaskId,
13 pub is_owner: bool,
14}
15
16pub async fn ensure_task_exists(
17 db_pool: &DbPool,
18 request_context: &mut systemprompt_models::execution::context::RequestContext,
19 tool_name: &str,
20 mcp_server_name: &str,
21) -> Result<TaskResult, McpError> {
22 if let Some(task_id) = request_context.task_id() {
23 tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
24 return Ok(TaskResult {
25 task_id: task_id.clone(),
26 is_owner: false,
27 });
28 }
29
30 let context_id = request_context.context_id();
31 let context_repo = ContextRepository::new(db_pool).map_err(|e| {
32 McpError::internal_error(format!("Failed to create context repository: {e}"), None)
33 })?;
34
35 let context_id = if context_id.is_empty() {
36 if let Ok(Some(existing)) = context_repo
37 .find_by_session_id(request_context.session_id())
38 .await
39 {
40 tracing::debug!(
41 context_id = %existing.context_id,
42 session_id = %request_context.session_id(),
43 "Reusing existing context for MCP session"
44 );
45 request_context.execution.context_id = existing.context_id.clone();
46 existing.context_id
47 } else {
48 let new_context_id = context_repo
49 .create_context(
50 request_context.user_id(),
51 Some(request_context.session_id()),
52 &format!("MCP Session: {}", request_context.session_id()),
53 )
54 .await
55 .map_err(|e| {
56 tracing::error!(error = %e, "Failed to auto-create context for MCP session");
57 McpError::internal_error(format!("Failed to create context: {e}"), None)
58 })?;
59
60 request_context.execution.context_id = new_context_id.clone();
61 tracing::info!(
62 context_id = %new_context_id,
63 session_id = %request_context.session_id(),
64 "Auto-created context for MCP session"
65 );
66 new_context_id
67 }
68 } else {
69 let old_context_id = context_id.clone();
70 match context_repo
71 .validate_context_ownership(&old_context_id, request_context.user_id())
72 .await
73 {
74 Ok(()) => old_context_id,
75 Err(e) => {
76 tracing::warn!(
77 context_id = %old_context_id,
78 user_id = %request_context.user_id(),
79 error = %e,
80 "Context validation failed, auto-creating new context"
81 );
82 let new_context_id = context_repo
83 .create_context(
84 request_context.user_id(),
85 Some(request_context.session_id()),
86 &format!("MCP Session: {}", request_context.session_id()),
87 )
88 .await
89 .map_err(|e| {
90 tracing::error!(error = %e, "Failed to auto-create replacement context");
91 McpError::internal_error(format!("Failed to create context: {e}"), None)
92 })?;
93
94 request_context.execution.context_id = new_context_id.clone();
95 tracing::info!(
96 old_context_id = %old_context_id,
97 new_context_id = %new_context_id,
98 session_id = %request_context.session_id(),
99 "Auto-created replacement context for invalid context_id"
100 );
101 new_context_id
102 },
103 }
104 };
105
106 let task_repo = TaskRepository::new(db_pool).map_err(|e| {
107 McpError::internal_error(format!("Failed to create task repository: {e}"), None)
108 })?;
109
110 let task_id = TaskId::generate();
111
112 let agent_name = request_context.agent_name().to_string();
113
114 let metadata = TaskMetadata::new_mcp_execution(
115 agent_name.clone(),
116 tool_name.to_string(),
117 mcp_server_name.to_string(),
118 );
119
120 let task = Task {
121 id: task_id.clone(),
122 context_id: context_id.clone(),
123 status: TaskStatus {
124 state: TaskState::Submitted,
125 message: None,
126 timestamp: Some(chrono::Utc::now()),
127 },
128 history: None,
129 artifacts: None,
130 metadata: Some(metadata),
131 kind: "task".to_string(),
132 };
133
134 task_repo
135 .create_task(crate::repository::task::RepoCreateTaskParams {
136 task: &task,
137 user_id: request_context.user_id(),
138 session_id: request_context.session_id(),
139 trace_id: request_context.trace_id(),
140 agent_name: &agent_name,
141 })
142 .await
143 .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;
144
145 request_context.execution.task_id = Some(task_id.clone());
146
147 tracing::info!(
148 task_id = %task_id.as_str(),
149 tool = %tool_name,
150 agent = %agent_name,
151 "Task created"
152 );
153
154 Ok(TaskResult {
155 task_id,
156 is_owner: true,
157 })
158}
159
160pub async fn complete_task(
161 db_pool: &DbPool,
162 task_id: &TaskId,
163 jwt_token: &str,
164) -> Result<(), McpError> {
165 if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
166 tracing::error!(
167 task_id = %task_id.as_str(),
168 error = ?e,
169 "Webhook broadcast failed"
170 );
171 }
172
173 Ok(())
174}
175
176async fn trigger_task_completion_broadcast(
177 db_pool: &DbPool,
178 task_id: &TaskId,
179 jwt_token: &str,
180) -> Result<(), McpError> {
181 let task_repo = TaskRepository::new(db_pool).map_err(|e| {
182 McpError::internal_error(format!("Failed to create task repository: {e}"), None)
183 })?;
184
185 let task_info = task_repo
186 .get_task_context_info(task_id)
187 .await
188 .map_err(|e| {
189 McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
190 })?;
191
192 if let Some(info) = task_info {
193 let context_id = info.context_id;
194 let user_id = info.user_id;
195
196 let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
197 let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
198 let webhook_payload = serde_json::json!({
199 "event_type": "task_completed",
200 "entity_id": task_id.as_str(),
201 "context_id": context_id,
202 "user_id": user_id,
203 });
204
205 tracing::debug!(
206 task_id = %task_id.as_str(),
207 context_id = %context_id,
208 "Webhook triggering"
209 );
210
211 let client = reqwest::Client::new();
212 match client
213 .post(webhook_url)
214 .header("Authorization", format!("Bearer {jwt_token}"))
215 .json(&webhook_payload)
216 .timeout(std::time::Duration::from_secs(5))
217 .send()
218 .await
219 {
220 Ok(response) => {
221 if response.status().is_success() {
222 tracing::debug!(
223 task_id = %task_id.as_str(),
224 "Task completed, webhook success"
225 );
226 } else {
227 let status = response.status();
228 tracing::error!(
229 task_id = %task_id.as_str(),
230 status = %status,
231 "Task completed, webhook failed"
232 );
233 }
234 },
235 Err(e) => {
236 tracing::error!(
237 task_id = %task_id.as_str(),
238 error = %e,
239 "Webhook failed"
240 );
241 },
242 }
243 }
244
245 Ok(())
246}
247
248#[derive(Debug)]
249pub struct SaveMessagesForToolExecutionParams<'a> {
250 pub db_pool: &'a DbPool,
251 pub task_id: &'a TaskId,
252 pub context_id: &'a ContextId,
253 pub tool_name: &'a str,
254 pub tool_result: &'a str,
255 pub artifact: Option<&'a Artifact>,
256 pub user_id: &'a UserId,
257 pub session_id: &'a SessionId,
258 pub trace_id: &'a TraceId,
259}
260
261pub async fn save_messages_for_tool_execution(
262 params: SaveMessagesForToolExecutionParams<'_>,
263) -> Result<(), McpError> {
264 let SaveMessagesForToolExecutionParams {
265 db_pool,
266 task_id,
267 context_id,
268 tool_name,
269 tool_result,
270 artifact,
271 user_id,
272 session_id,
273 trace_id,
274 } = params;
275 let message_service = MessageService::new(db_pool).map_err(|e| {
276 McpError::internal_error(format!("Failed to create message service: {e}"), None)
277 })?;
278
279 let user_message = Message {
280 role: "user".to_string(),
281 parts: vec![Part::Text(TextPart {
282 text: format!("Execute tool: {tool_name}"),
283 })],
284 id: MessageId::generate(),
285 task_id: Some(task_id.clone()),
286 context_id: context_id.clone(),
287 kind: "message".to_string(),
288 metadata: None,
289 extensions: None,
290 reference_task_ids: None,
291 };
292
293 let agent_text = artifact.map_or_else(
294 || format!("Tool execution completed. Result: {tool_result}"),
295 |artifact| {
296 format!(
297 "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
298 tool_result, artifact.id, artifact.metadata.artifact_type
299 )
300 },
301 );
302
303 let agent_message = Message {
304 role: "agent".to_string(),
305 parts: vec![Part::Text(TextPart { text: agent_text })],
306 id: MessageId::generate(),
307 task_id: Some(task_id.clone()),
308 context_id: context_id.clone(),
309 kind: "message".to_string(),
310 metadata: None,
311 extensions: None,
312 reference_task_ids: None,
313 };
314
315 message_service
316 .persist_messages(crate::services::PersistMessagesParams {
317 task_id,
318 context_id,
319 messages: vec![user_message, agent_message],
320 user_id: Some(user_id),
321 session_id,
322 trace_id,
323 })
324 .await
325 .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;
326
327 Ok(())
328}