systemprompt_agent/services/mcp/
task_helper.rs1use crate::models::a2a::{Artifact, Message, Part, Task, TaskState, TaskStatus, TextPart};
2use crate::repository::context::ContextRepository;
3use crate::repository::task::TaskRepository;
4use crate::services::MessageService;
5use rmcp::ErrorData as McpError;
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
8use systemprompt_models::{Config, TaskMetadata};
9
10#[derive(Debug)]
11pub struct TaskResult {
12 pub task_id: TaskId,
13 pub is_owner: bool,
14}
15
16pub async fn ensure_task_exists(
17 db_pool: &DbPool,
18 request_context: &mut systemprompt_models::execution::context::RequestContext,
19 tool_name: &str,
20 mcp_server_name: &str,
21) -> Result<TaskResult, McpError> {
22 if let Some(task_id) = request_context.task_id() {
23 tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
24 return Ok(TaskResult {
25 task_id: task_id.clone(),
26 is_owner: false,
27 });
28 }
29
30 let context_id = request_context.context_id();
31 let context_repo = ContextRepository::new(db_pool.clone());
32
33 let context_id = if context_id.is_empty() {
34 match context_repo
35 .find_by_session_id(request_context.session_id())
36 .await
37 {
38 Ok(Some(existing)) => {
39 tracing::debug!(
40 context_id = %existing.context_id,
41 session_id = %request_context.session_id(),
42 "Reusing existing context for MCP session"
43 );
44 request_context.execution.context_id = existing.context_id.clone();
45 existing.context_id
46 },
47 _ => {
48 let new_context_id = context_repo
49 .create_context(
50 request_context.user_id(),
51 Some(request_context.session_id()),
52 &format!("MCP Session: {}", request_context.session_id()),
53 )
54 .await
55 .map_err(|e| {
56 tracing::error!(error = %e, "Failed to auto-create context for MCP session");
57 McpError::internal_error(format!("Failed to create context: {e}"), None)
58 })?;
59
60 request_context.execution.context_id = new_context_id.clone();
61 tracing::info!(
62 context_id = %new_context_id,
63 session_id = %request_context.session_id(),
64 "Auto-created context for MCP session"
65 );
66 new_context_id
67 },
68 }
69 } else {
70 let old_context_id = context_id.clone();
71 match context_repo
72 .validate_context_ownership(&old_context_id, request_context.user_id())
73 .await
74 {
75 Ok(()) => old_context_id,
76 Err(e) => {
77 tracing::warn!(
78 context_id = %old_context_id,
79 user_id = %request_context.user_id(),
80 error = %e,
81 "Context validation failed, auto-creating new context"
82 );
83 let new_context_id = context_repo
84 .create_context(
85 request_context.user_id(),
86 Some(request_context.session_id()),
87 &format!("MCP Session: {}", request_context.session_id()),
88 )
89 .await
90 .map_err(|e| {
91 tracing::error!(error = %e, "Failed to auto-create replacement context");
92 McpError::internal_error(format!("Failed to create context: {e}"), None)
93 })?;
94
95 request_context.execution.context_id = new_context_id.clone();
96 tracing::info!(
97 old_context_id = %old_context_id,
98 new_context_id = %new_context_id,
99 session_id = %request_context.session_id(),
100 "Auto-created replacement context for invalid context_id"
101 );
102 new_context_id
103 },
104 }
105 };
106
107 let task_repo = TaskRepository::new(db_pool.clone());
108
109 let task_id = TaskId::generate();
110
111 let agent_name = request_context.agent_name().to_string();
112
113 let metadata = TaskMetadata::new_mcp_execution(
114 agent_name.clone(),
115 tool_name.to_string(),
116 mcp_server_name.to_string(),
117 );
118
119 let task = Task {
120 id: task_id.clone(),
121 context_id: context_id.clone(),
122 status: TaskStatus {
123 state: TaskState::Submitted,
124 message: None,
125 timestamp: Some(chrono::Utc::now()),
126 },
127 history: None,
128 artifacts: None,
129 metadata: Some(metadata),
130 kind: "task".to_string(),
131 };
132
133 task_repo
134 .create_task(
135 &task,
136 request_context.user_id(),
137 request_context.session_id(),
138 request_context.trace_id(),
139 &agent_name,
140 )
141 .await
142 .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;
143
144 request_context.execution.task_id = Some(task_id.clone());
145
146 tracing::info!(
147 task_id = %task_id.as_str(),
148 tool = %tool_name,
149 agent = %agent_name,
150 "Task created"
151 );
152
153 Ok(TaskResult {
154 task_id,
155 is_owner: true,
156 })
157}
158
159pub async fn complete_task(
160 db_pool: &DbPool,
161 task_id: &TaskId,
162 jwt_token: &str,
163) -> Result<(), McpError> {
164 if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
165 tracing::error!(
166 task_id = %task_id.as_str(),
167 error = ?e,
168 "Webhook broadcast failed"
169 );
170 }
171
172 Ok(())
173}
174
175async fn trigger_task_completion_broadcast(
176 db_pool: &DbPool,
177 task_id: &TaskId,
178 jwt_token: &str,
179) -> Result<(), McpError> {
180 let task_repo = TaskRepository::new(db_pool.clone());
181
182 let task_info = task_repo
183 .get_task_context_info(task_id)
184 .await
185 .map_err(|e| {
186 McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
187 })?;
188
189 if let Some(info) = task_info {
190 let context_id = info.context_id;
191 let user_id = info.user_id;
192
193 let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
194 let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
195 let webhook_payload = serde_json::json!({
196 "event_type": "task_completed",
197 "entity_id": task_id.as_str(),
198 "context_id": context_id,
199 "user_id": user_id,
200 });
201
202 tracing::debug!(
203 task_id = %task_id.as_str(),
204 context_id = %context_id,
205 "Webhook triggering"
206 );
207
208 let client = reqwest::Client::new();
209 match client
210 .post(webhook_url)
211 .header("Authorization", format!("Bearer {jwt_token}"))
212 .json(&webhook_payload)
213 .timeout(std::time::Duration::from_secs(5))
214 .send()
215 .await
216 {
217 Ok(response) => {
218 if response.status().is_success() {
219 tracing::debug!(
220 task_id = %task_id.as_str(),
221 "Task completed, webhook success"
222 );
223 } else {
224 let status = response.status();
225 tracing::error!(
226 task_id = %task_id.as_str(),
227 status = %status,
228 "Task completed, webhook failed"
229 );
230 }
231 },
232 Err(e) => {
233 tracing::error!(
234 task_id = %task_id.as_str(),
235 error = %e,
236 "Webhook failed"
237 );
238 },
239 }
240 }
241
242 Ok(())
243}
244
245pub async fn save_messages_for_tool_execution(
246 db_pool: &DbPool,
247 task_id: &TaskId,
248 context_id: &ContextId,
249 tool_name: &str,
250 tool_result: &str,
251 artifact: Option<&Artifact>,
252 user_id: &UserId,
253 session_id: &SessionId,
254 trace_id: &TraceId,
255) -> Result<(), McpError> {
256 let message_service = MessageService::new(db_pool.clone());
257
258 let user_message = Message {
259 role: "user".to_string(),
260 parts: vec![Part::Text(TextPart {
261 text: format!("Execute tool: {tool_name}"),
262 })],
263 id: MessageId::generate(),
264 task_id: Some(task_id.clone()),
265 context_id: context_id.clone(),
266 kind: "message".to_string(),
267 metadata: None,
268 extensions: None,
269 reference_task_ids: None,
270 };
271
272 let agent_text = artifact.map_or_else(
273 || format!("Tool execution completed. Result: {tool_result}"),
274 |artifact| {
275 format!(
276 "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
277 tool_result, artifact.id, artifact.metadata.artifact_type
278 )
279 },
280 );
281
282 let agent_message = Message {
283 role: "agent".to_string(),
284 parts: vec![Part::Text(TextPart { text: agent_text })],
285 id: MessageId::generate(),
286 task_id: Some(task_id.clone()),
287 context_id: context_id.clone(),
288 kind: "message".to_string(),
289 metadata: None,
290 extensions: None,
291 reference_task_ids: None,
292 };
293
294 message_service
295 .persist_messages(
296 task_id,
297 context_id,
298 vec![user_message, agent_message],
299 Some(user_id),
300 session_id,
301 trace_id,
302 )
303 .await
304 .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;
305
306 Ok(())
307}