Skip to main content

systemprompt_agent/services/mcp/
task_helper.rs

1use crate::models::a2a::{Artifact, Message, Part, Task, TaskState, TaskStatus, TextPart};
2use crate::repository::context::ContextRepository;
3use crate::repository::task::TaskRepository;
4use crate::services::MessageService;
5use rmcp::ErrorData as McpError;
6use systemprompt_database::DbPool;
7use systemprompt_identifiers::{ContextId, MessageId, SessionId, TaskId, TraceId, UserId};
8use systemprompt_models::{Config, TaskMetadata};
9
10#[derive(Debug)]
11pub struct TaskResult {
12    pub task_id: TaskId,
13    pub is_owner: bool,
14}
15
16pub async fn ensure_task_exists(
17    db_pool: &DbPool,
18    request_context: &mut systemprompt_models::execution::context::RequestContext,
19    tool_name: &str,
20    mcp_server_name: &str,
21) -> Result<TaskResult, McpError> {
22    if let Some(task_id) = request_context.task_id() {
23        tracing::info!(task_id = %task_id.as_str(), "Task reused from parent");
24        return Ok(TaskResult {
25            task_id: task_id.clone(),
26            is_owner: false,
27        });
28    }
29
30    let context_id = request_context.context_id();
31    let context_repo = ContextRepository::new(db_pool.clone());
32
33    let context_id = if context_id.is_empty() {
34        match context_repo
35            .find_by_session_id(request_context.session_id())
36            .await
37        {
38            Ok(Some(existing)) => {
39                tracing::debug!(
40                    context_id = %existing.context_id,
41                    session_id = %request_context.session_id(),
42                    "Reusing existing context for MCP session"
43                );
44                request_context.execution.context_id = existing.context_id.clone();
45                existing.context_id
46            },
47            _ => {
48                let new_context_id = context_repo
49                    .create_context(
50                        request_context.user_id(),
51                        Some(request_context.session_id()),
52                        &format!("MCP Session: {}", request_context.session_id()),
53                    )
54                    .await
55                    .map_err(|e| {
56                        tracing::error!(error = %e, "Failed to auto-create context for MCP session");
57                        McpError::internal_error(format!("Failed to create context: {e}"), None)
58                    })?;
59
60                request_context.execution.context_id = new_context_id.clone();
61                tracing::info!(
62                    context_id = %new_context_id,
63                    session_id = %request_context.session_id(),
64                    "Auto-created context for MCP session"
65                );
66                new_context_id
67            },
68        }
69    } else {
70        let old_context_id = context_id.clone();
71        match context_repo
72            .validate_context_ownership(&old_context_id, request_context.user_id())
73            .await
74        {
75            Ok(()) => old_context_id,
76            Err(e) => {
77                tracing::warn!(
78                    context_id = %old_context_id,
79                    user_id = %request_context.user_id(),
80                    error = %e,
81                    "Context validation failed, auto-creating new context"
82                );
83                let new_context_id = context_repo
84                    .create_context(
85                        request_context.user_id(),
86                        Some(request_context.session_id()),
87                        &format!("MCP Session: {}", request_context.session_id()),
88                    )
89                    .await
90                    .map_err(|e| {
91                        tracing::error!(error = %e, "Failed to auto-create replacement context");
92                        McpError::internal_error(format!("Failed to create context: {e}"), None)
93                    })?;
94
95                request_context.execution.context_id = new_context_id.clone();
96                tracing::info!(
97                    old_context_id = %old_context_id,
98                    new_context_id = %new_context_id,
99                    session_id = %request_context.session_id(),
100                    "Auto-created replacement context for invalid context_id"
101                );
102                new_context_id
103            },
104        }
105    };
106
107    let task_repo = TaskRepository::new(db_pool.clone());
108
109    let task_id = TaskId::generate();
110
111    let agent_name = request_context.agent_name().to_string();
112
113    let metadata = TaskMetadata::new_mcp_execution(
114        agent_name.clone(),
115        tool_name.to_string(),
116        mcp_server_name.to_string(),
117    );
118
119    let task = Task {
120        id: task_id.clone(),
121        context_id: context_id.clone(),
122        status: TaskStatus {
123            state: TaskState::Submitted,
124            message: None,
125            timestamp: Some(chrono::Utc::now()),
126        },
127        history: None,
128        artifacts: None,
129        metadata: Some(metadata),
130        kind: "task".to_string(),
131    };
132
133    task_repo
134        .create_task(
135            &task,
136            request_context.user_id(),
137            request_context.session_id(),
138            request_context.trace_id(),
139            &agent_name,
140        )
141        .await
142        .map_err(|e| McpError::internal_error(format!("Failed to create task: {e}"), None))?;
143
144    request_context.execution.task_id = Some(task_id.clone());
145
146    tracing::info!(
147        task_id = %task_id.as_str(),
148        tool = %tool_name,
149        agent = %agent_name,
150        "Task created"
151    );
152
153    Ok(TaskResult {
154        task_id,
155        is_owner: true,
156    })
157}
158
159pub async fn complete_task(
160    db_pool: &DbPool,
161    task_id: &TaskId,
162    jwt_token: &str,
163) -> Result<(), McpError> {
164    if let Err(e) = trigger_task_completion_broadcast(db_pool, task_id, jwt_token).await {
165        tracing::error!(
166            task_id = %task_id.as_str(),
167            error = ?e,
168            "Webhook broadcast failed"
169        );
170    }
171
172    Ok(())
173}
174
175async fn trigger_task_completion_broadcast(
176    db_pool: &DbPool,
177    task_id: &TaskId,
178    jwt_token: &str,
179) -> Result<(), McpError> {
180    let task_repo = TaskRepository::new(db_pool.clone());
181
182    let task_info = task_repo
183        .get_task_context_info(task_id)
184        .await
185        .map_err(|e| {
186            McpError::internal_error(format!("Failed to load task for webhook: {e}"), None)
187        })?;
188
189    if let Some(info) = task_info {
190        let context_id = info.context_id;
191        let user_id = info.user_id;
192
193        let config = Config::get().map_err(|e| McpError::internal_error(e.to_string(), None))?;
194        let webhook_url = format!("{}/api/v1/webhook/broadcast", config.api_server_url);
195        let webhook_payload = serde_json::json!({
196            "event_type": "task_completed",
197            "entity_id": task_id.as_str(),
198            "context_id": context_id,
199            "user_id": user_id,
200        });
201
202        tracing::debug!(
203            task_id = %task_id.as_str(),
204            context_id = %context_id,
205            "Webhook triggering"
206        );
207
208        let client = reqwest::Client::new();
209        match client
210            .post(webhook_url)
211            .header("Authorization", format!("Bearer {jwt_token}"))
212            .json(&webhook_payload)
213            .timeout(std::time::Duration::from_secs(5))
214            .send()
215            .await
216        {
217            Ok(response) => {
218                if response.status().is_success() {
219                    tracing::debug!(
220                        task_id = %task_id.as_str(),
221                        "Task completed, webhook success"
222                    );
223                } else {
224                    let status = response.status();
225                    tracing::error!(
226                        task_id = %task_id.as_str(),
227                        status = %status,
228                        "Task completed, webhook failed"
229                    );
230                }
231            },
232            Err(e) => {
233                tracing::error!(
234                    task_id = %task_id.as_str(),
235                    error = %e,
236                    "Webhook failed"
237                );
238            },
239        }
240    }
241
242    Ok(())
243}
244
245pub async fn save_messages_for_tool_execution(
246    db_pool: &DbPool,
247    task_id: &TaskId,
248    context_id: &ContextId,
249    tool_name: &str,
250    tool_result: &str,
251    artifact: Option<&Artifact>,
252    user_id: &UserId,
253    session_id: &SessionId,
254    trace_id: &TraceId,
255) -> Result<(), McpError> {
256    let message_service = MessageService::new(db_pool.clone());
257
258    let user_message = Message {
259        role: "user".to_string(),
260        parts: vec![Part::Text(TextPart {
261            text: format!("Execute tool: {tool_name}"),
262        })],
263        id: MessageId::generate(),
264        task_id: Some(task_id.clone()),
265        context_id: context_id.clone(),
266        kind: "message".to_string(),
267        metadata: None,
268        extensions: None,
269        reference_task_ids: None,
270    };
271
272    let agent_text = artifact.map_or_else(
273        || format!("Tool execution completed. Result: {tool_result}"),
274        |artifact| {
275            format!(
276                "Tool execution completed. Result: {}\n\nArtifact created: {} (type: {})",
277                tool_result, artifact.id, artifact.metadata.artifact_type
278            )
279        },
280    );
281
282    let agent_message = Message {
283        role: "agent".to_string(),
284        parts: vec![Part::Text(TextPart { text: agent_text })],
285        id: MessageId::generate(),
286        task_id: Some(task_id.clone()),
287        context_id: context_id.clone(),
288        kind: "message".to_string(),
289        metadata: None,
290        extensions: None,
291        reference_task_ids: None,
292    };
293
294    message_service
295        .persist_messages(
296            task_id,
297            context_id,
298            vec![user_message, agent_message],
299            Some(user_id),
300            session_id,
301            trace_id,
302        )
303        .await
304        .map_err(|e| McpError::internal_error(format!("Failed to save messages: {e}"), None))?;
305
306    Ok(())
307}