Skip to main content

systemprompt_agent/services/mcp/
task_helper.rs

1use crate::models::a2a::{Artifact, Message, Part, Task, TaskState, TaskStatus, TextPart};
2use crate::repository::context::ContextRepository;
3use crate::repository::task::TaskRepository;
4use crate::services::MessageService;
5use rmcp::ErrorData as McpError;
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
8use systemprompt_models::{Config, TaskMetadata};
9
10#[derive(Debug)]
11pub struct TaskResult {
12    pub task_id: TaskId,
13    pub is_owner: bool,
14}
15
16pub async fn ensure_task_exists(
17    db_pool: &DbPool,
18    request_context: &mut systemprompt_models::execution::context::RequestContext,
19    tool_name: &str,
20    mcp_server_name: &str,
21) -> Result<TaskResult, McpError> {
22    if let Some(task_id) = request_context.task_id() {
23        tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
24        return Ok(TaskResult {
25            task_id: task_id.clone(),
26            is_owner: false,
27        });
28    }
29
30    let context_id = request_context.context_id();
31    let context_repo = ContextRepository::new(db_pool).map_err(|e| {
32        McpError::internal_error(format!("Failed to create context repository: {e}"), None)
33    })?;
34
35    let context_id = if context_id.is_empty() {
36        if let Ok(Some(existing)) = context_repo
37            .find_by_session_id(request_context.session_id())
38            .await
39        {
40            tracing::debug!(
41                context_id = %existing.context_id,
42                session_id = %request_context.session_id(),
43                "Reusing existing context for MCP session"
44            );
45            request_context.execution.context_id = existing.context_id.clone();
46            existing.context_id
47        } else {
48            let new_context_id = context_repo
49                .create_context(
50                    request_context.user_id(),
51                    Some(request_context.session_id()),
52                    &format!("MCP Session: {}", request_context.session_id()),
53                )
54                .await
55                .map_err(|e| {
56                    tracing::error!(error = %e, "Failed to auto-create context for MCP session");
57                    McpError::internal_error(format!("Failed to create context: {e}"), None)
58                })?;
59
60            request_context.execution.context_id = new_context_id.clone();
61            tracing::info!(
62                context_id = %new_context_id,
63                session_id = %request_context.session_id(),
64                "Auto-created context for MCP session"
65            );
66            new_context_id
67        }
68    } else {
69        let old_context_id = context_id.clone();
70        match context_repo
71            .validate_context_ownership(&old_context_id, request_context.user_id())
72            .await
73        {
74            Ok(()) => old_context_id,
75            Err(e) => {
76                tracing::warn!(
77                    context_id = %old_context_id,
78                    user_id = %request_context.user_id(),
79                    error = %e,
80                    "Context validation failed, auto-creating new context"
81                );
82                let new_context_id = context_repo
83                    .create_context(
84                        request_context.user_id(),
85                        Some(request_context.session_id()),
86                        &format!("MCP Session: {}", request_context.session_id()),
87                    )
88                    .await
89                    .map_err(|e| {
90                        tracing::error!(error = %e, "Failed to auto-create replacement context");
91                        McpError::internal_error(format!("Failed to create context: {e}"), None)
92                    })?;
93
94                request_context.execution.context_id = new_context_id.clone();
95                tracing::info!(
96                    old_context_id = %old_context_id,
97                    new_context_id = %new_context_id,
98                    session_id = %request_context.session_id(),
99                    "Auto-created replacement context for invalid context_id"
100                );
101                new_context_id
102            },
103        }
104    };
105
106    let task_repo = TaskRepository::new(db_pool).map_err(|e| {
107        McpError::internal_error(format!("Failed to create task repository: {e}"), None)
108    })?;
109
110    let task_id = TaskId::generate();
111
112    let agent_name = request_context.agent_name().to_string();
113
114    let metadata = TaskMetadata::new_mcp_execution(
115        agent_name.clone(),
116        tool_name.to_string(),
117        mcp_server_name.to_string(),
118    );
119
120    let task = Task {
121        id: task_id.clone(),
122        context_id: context_id.clone(),
123        status: TaskStatus {
124            state: TaskState::Submitted,
125            message: None,
126            timestamp: Some(chrono::Utc::now()),
127        },
128        history: None,
129        artifacts: None,
130        metadata: Some(metadata),
131        kind: "task".to_string(),
132    };
133
134    task_repo
135        .create_task(crate::repository::task::RepoCreateTaskParams {
136            task: &task,
137            user_id: request_context.user_id(),
138            session_id: request_context.session_id(),
139            trace_id: request_context.trace_id(),
140            agent_name: &agent_name,
141        })
142        .await
143        .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;
144
145    request_context.execution.task_id = Some(task_id.clone());
146
147    tracing::info!(
148        task_id = %task_id.as_str(),
149        tool = %tool_name,
150        agent = %agent_name,
151        "Task created"
152    );
153
154    Ok(TaskResult {
155        task_id,
156        is_owner: true,
157    })
158}
159
160pub async fn complete_task(
161    db_pool: &DbPool,
162    task_id: &TaskId,
163    jwt_token: &str,
164) -> Result<(), McpError> {
165    if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
166        tracing::error!(
167            task_id = %task_id.as_str(),
168            error = ?e,
169            "Webhook broadcast failed"
170        );
171    }
172
173    Ok(())
174}
175
176async fn trigger_task_completion_broadcast(
177    db_pool: &DbPool,
178    task_id: &TaskId,
179    jwt_token: &str,
180) -> Result<(), McpError> {
181    let task_repo = TaskRepository::new(db_pool).map_err(|e| {
182        McpError::internal_error(format!("Failed to create task repository: {e}"), None)
183    })?;
184
185    let task_info = task_repo
186        .get_task_context_info(task_id)
187        .await
188        .map_err(|e| {
189            McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
190        })?;
191
192    if let Some(info) = task_info {
193        let context_id = info.context_id;
194        let user_id = info.user_id;
195
196        let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
197        let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
198        let webhook_payload = serde_json::json!({
199            "event_type": "task_completed",
200            "entity_id": task_id.as_str(),
201            "context_id": context_id,
202            "user_id": user_id,
203        });
204
205        tracing::debug!(
206            task_id = %task_id.as_str(),
207            context_id = %context_id,
208            "Webhook triggering"
209        );
210
211        let client = reqwest::Client::new();
212        match client
213            .post(webhook_url)
214            .header("Authorization", format!("Bearer {jwt_token}"))
215            .json(&webhook_payload)
216            .timeout(std::time::Duration::from_secs(5))
217            .send()
218            .await
219        {
220            Ok(response) => {
221                if response.status().is_success() {
222                    tracing::debug!(
223                        task_id = %task_id.as_str(),
224                        "Task completed, webhook success"
225                    );
226                } else {
227                    let status = response.status();
228                    tracing::error!(
229                        task_id = %task_id.as_str(),
230                        status = %status,
231                        "Task completed, webhook failed"
232                    );
233                }
234            },
235            Err(e) => {
236                tracing::error!(
237                    task_id = %task_id.as_str(),
238                    error = %e,
239                    "Webhook failed"
240                );
241            },
242        }
243    }
244
245    Ok(())
246}
247
248#[derive(Debug)]
249pub struct SaveMessagesForToolExecutionParams<'a> {
250    pub db_pool: &'a DbPool,
251    pub task_id: &'a TaskId,
252    pub context_id: &'a ContextId,
253    pub tool_name: &'a str,
254    pub tool_result: &'a str,
255    pub artifact: Option<&'a Artifact>,
256    pub user_id: &'a UserId,
257    pub session_id: &'a SessionId,
258    pub trace_id: &'a TraceId,
259}
260
261pub async fn save_messages_for_tool_execution(
262    params: SaveMessagesForToolExecutionParams<'_>,
263) -> Result<(), McpError> {
264    let SaveMessagesForToolExecutionParams {
265        db_pool,
266        task_id,
267        context_id,
268        tool_name,
269        tool_result,
270        artifact,
271        user_id,
272        session_id,
273        trace_id,
274    } = params;
275    let message_service = MessageService::new(db_pool).map_err(|e| {
276        McpError::internal_error(format!("Failed to create message service: {e}"), None)
277    })?;
278
279    let user_message = Message {
280        role: "user".to_string(),
281        parts: vec![Part::Text(TextPart {
282            text: format!("Execute tool: {tool_name}"),
283        })],
284        id: MessageId::generate(),
285        task_id: Some(task_id.clone()),
286        context_id: context_id.clone(),
287        kind: "message".to_string(),
288        metadata: None,
289        extensions: None,
290        reference_task_ids: None,
291    };
292
293    let agent_text = artifact.map_or_else(
294        || format!("Tool execution completed. Result: {tool_result}"),
295        |artifact| {
296            format!(
297                "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
298                tool_result, artifact.id, artifact.metadata.artifact_type
299            )
300        },
301    );
302
303    let agent_message = Message {
304        role: "agent".to_string(),
305        parts: vec![Part::Text(TextPart { text: agent_text })],
306        id: MessageId::generate(),
307        task_id: Some(task_id.clone()),
308        context_id: context_id.clone(),
309        kind: "message".to_string(),
310        metadata: None,
311        extensions: None,
312        reference_task_ids: None,
313    };
314
315    message_service
316        .persist_messages(crate::services::PersistMessagesParams {
317            task_id,
318            context_id,
319            messages: vec![user_message, agent_message],
320            user_id: Some(user_id),
321            session_id,
322            trace_id,
323        })
324        .await
325        .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;
326
327    Ok(())
328}