Skip to main content

systemprompt_agent/services/mcp/
task_helper.rs

1use crate::models::a2a::{
2    Artifact, Message, MessageRole, Part, Task, TaskState, TaskStatus, TextPart,
3};
4use crate::repository::context::ContextRepository;
5use crate::repository::task::TaskRepository;
6use crate::services::MessageService;
7use rmcp::ErrorData as McpError;
8use systemprompt_database::DbPool;
9use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
10use systemprompt_models::{Config, TaskMetadata};
11
12#[derive(Debug)]
13pub struct TaskResult {
14    pub task_id: TaskId,
15    pub is_owner: bool,
16}
17
18pub async fn ensure_task_exists(
19    db_pool: &DbPool,
20    request_context: &mut systemprompt_models::execution::context::RequestContext,
21    tool_name: &str,
22    mcp_server_name: &str,
23) -> Result<TaskResult, McpError> {
24    if let Some(task_id) = request_context.task_id() {
25        tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
26        return Ok(TaskResult {
27            task_id: task_id.clone(),
28            is_owner: false,
29        });
30    }
31
32    let context_id = request_context.context_id();
33    let context_repo = ContextRepository::new(db_pool).map_err(|e| {
34        McpError::internal_error(format!("Failed to create context repository: {e}"), None)
35    })?;
36
37    let context_id = if context_id.is_empty() {
38        if let Ok(Some(existing)) = context_repo
39            .find_by_session_id(request_context.session_id())
40            .await
41        {
42            tracing::debug!(
43                context_id = %existing.context_id,
44                session_id = %request_context.session_id(),
45                "Reusing existing context for MCP session"
46            );
47            request_context.execution.context_id = existing.context_id.clone();
48            existing.context_id
49        } else {
50            let new_context_id = context_repo
51                .create_context(
52                    request_context.user_id(),
53                    Some(request_context.session_id()),
54                    &format!("MCP Session: {}", request_context.session_id()),
55                )
56                .await
57                .map_err(|e| {
58                    tracing::error!(error = %e, "Failed to auto-create context for MCP session");
59                    McpError::internal_error(format!("Failed to create context: {e}"), None)
60                })?;
61
62            request_context.execution.context_id = new_context_id.clone();
63            tracing::info!(
64                context_id = %new_context_id,
65                session_id = %request_context.session_id(),
66                "Auto-created context for MCP session"
67            );
68            new_context_id
69        }
70    } else {
71        let old_context_id = context_id.clone();
72        match context_repo
73            .validate_context_ownership(&old_context_id, request_context.user_id())
74            .await
75        {
76            Ok(()) => old_context_id,
77            Err(e) => {
78                tracing::warn!(
79                    context_id = %old_context_id,
80                    user_id = %request_context.user_id(),
81                    error = %e,
82                    "Context validation failed, auto-creating new context"
83                );
84                let new_context_id = context_repo
85                    .create_context(
86                        request_context.user_id(),
87                        Some(request_context.session_id()),
88                        &format!("MCP Session: {}", request_context.session_id()),
89                    )
90                    .await
91                    .map_err(|e| {
92                        tracing::error!(error = %e, "Failed to auto-create replacement context");
93                        McpError::internal_error(format!("Failed to create context: {e}"), None)
94                    })?;
95
96                request_context.execution.context_id = new_context_id.clone();
97                tracing::info!(
98                    old_context_id = %old_context_id,
99                    new_context_id = %new_context_id,
100                    session_id = %request_context.session_id(),
101                    "Auto-created replacement context for invalid context_id"
102                );
103                new_context_id
104            },
105        }
106    };
107
108    let task_repo = TaskRepository::new(db_pool).map_err(|e| {
109        McpError::internal_error(format!("Failed to create task repository: {e}"), None)
110    })?;
111
112    let task_id = TaskId::generate();
113
114    let agent_name = request_context.agent_name().to_string();
115
116    let metadata = TaskMetadata::new_mcp_execution(
117        agent_name.clone(),
118        tool_name.to_string(),
119        mcp_server_name.to_string(),
120    );
121
122    let task = Task {
123        id: task_id.clone(),
124        context_id: context_id.clone(),
125        status: TaskStatus {
126            state: TaskState::Submitted,
127            message: None,
128            timestamp: Some(chrono::Utc::now()),
129        },
130        history: None,
131        artifacts: None,
132        metadata: Some(metadata),
133        created_at: Some(chrono::Utc::now()),
134        last_modified: Some(chrono::Utc::now()),
135    };
136
137    task_repo
138        .create_task(crate::repository::task::RepoCreateTaskParams {
139            task: &task,
140            user_id: request_context.user_id(),
141            session_id: request_context.session_id(),
142            trace_id: request_context.trace_id(),
143            agent_name: &agent_name,
144        })
145        .await
146        .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;
147
148    request_context.execution.task_id = Some(task_id.clone());
149
150    tracing::info!(
151        task_id = %task_id.as_str(),
152        tool = %tool_name,
153        agent = %agent_name,
154        "Task created"
155    );
156
157    Ok(TaskResult {
158        task_id,
159        is_owner: true,
160    })
161}
162
163pub async fn complete_task(
164    db_pool: &DbPool,
165    task_id: &TaskId,
166    jwt_token: &str,
167) -> Result<(), McpError> {
168    if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
169        tracing::error!(
170            task_id = %task_id.as_str(),
171            error = ?e,
172            "Webhook broadcast failed"
173        );
174    }
175
176    Ok(())
177}
178
179async fn trigger_task_completion_broadcast(
180    db_pool: &DbPool,
181    task_id: &TaskId,
182    jwt_token: &str,
183) -> Result<(), McpError> {
184    let task_repo = TaskRepository::new(db_pool).map_err(|e| {
185        McpError::internal_error(format!("Failed to create task repository: {e}"), None)
186    })?;
187
188    let task_info = task_repo
189        .get_task_context_info(task_id)
190        .await
191        .map_err(|e| {
192            McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
193        })?;
194
195    if let Some(info) = task_info {
196        let context_id = info.context_id;
197        let user_id = info.user_id;
198
199        let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
200        let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
201        let webhook_payload = serde_json::json!({
202            "event_type": "task_completed",
203            "entity_id": task_id.as_str(),
204            "context_id": context_id,
205            "user_id": user_id,
206        });
207
208        tracing::debug!(
209            task_id = %task_id.as_str(),
210            context_id = %context_id,
211            "Webhook triggering"
212        );
213
214        let client = reqwest::Client::new();
215        match client
216            .post(webhook_url)
217            .header("Authorization", format!("Bearer {jwt_token}"))
218            .json(&webhook_payload)
219            .timeout(std::time::Duration::from_secs(5))
220            .send()
221            .await
222        {
223            Ok(response) => {
224                if response.status().is_success() {
225                    tracing::debug!(
226                        task_id = %task_id.as_str(),
227                        "Task completed, webhook success"
228                    );
229                } else {
230                    let status = response.status();
231                    tracing::error!(
232                        task_id = %task_id.as_str(),
233                        status = %status,
234                        "Task completed, webhook failed"
235                    );
236                }
237            },
238            Err(e) => {
239                tracing::error!(
240                    task_id = %task_id.as_str(),
241                    error = %e,
242                    "Webhook failed"
243                );
244            },
245        }
246    }
247
248    Ok(())
249}
250
251#[derive(Debug)]
252pub struct SaveMessagesForToolExecutionParams<'a> {
253    pub db_pool: &'a DbPool,
254    pub task_id: &'a TaskId,
255    pub context_id: &'a ContextId,
256    pub tool_name: &'a str,
257    pub tool_result: &'a str,
258    pub artifact: Option<&'a Artifact>,
259    pub user_id: &'a UserId,
260    pub session_id: &'a SessionId,
261    pub trace_id: &'a TraceId,
262}
263
264pub async fn save_messages_for_tool_execution(
265    params: SaveMessagesForToolExecutionParams<'_>,
266) -> Result<(), McpError> {
267    let SaveMessagesForToolExecutionParams {
268        db_pool,
269        task_id,
270        context_id,
271        tool_name,
272        tool_result,
273        artifact,
274        user_id,
275        session_id,
276        trace_id,
277    } = params;
278    let message_service = MessageService::new(db_pool).map_err(|e| {
279        McpError::internal_error(format!("Failed to create message service: {e}"), None)
280    })?;
281
282    let user_message = Message {
283        role: MessageRole::User,
284        parts: vec![Part::Text(TextPart {
285            text: format!("Execute tool: {tool_name}"),
286        })],
287        message_id: MessageId::generate(),
288        task_id: Some(task_id.clone()),
289        context_id: context_id.clone(),
290        metadata: None,
291        extensions: None,
292        reference_task_ids: None,
293    };
294
295    let agent_text = artifact.map_or_else(
296        || format!("Tool execution completed. Result: {tool_result}"),
297        |artifact| {
298            format!(
299                "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
300                tool_result, artifact.id, artifact.metadata.artifact_type
301            )
302        },
303    );
304
305    let agent_message = Message {
306        role: MessageRole::Agent,
307        parts: vec![Part::Text(TextPart { text: agent_text })],
308        message_id: MessageId::generate(),
309        task_id: Some(task_id.clone()),
310        context_id: context_id.clone(),
311        metadata: None,
312        extensions: None,
313        reference_task_ids: None,
314    };
315
316    message_service
317        .persist_messages(crate::services::PersistMessagesParams {
318            task_id,
319            context_id,
320            messages: vec![user_message, agent_message],
321            user_id: Some(user_id),
322            session_id,
323            trace_id,
324        })
325        .await
326        .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;
327
328    Ok(())
329}