Skip to main content

systemprompt_agent/repository/execution/
mod.rs

1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::TaskId;
7use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};
8
9fn parse_step(
10    step_id: String,
11    task_id: String,
12    status: String,
13    content: serde_json::Value,
14    started_at: DateTime<Utc>,
15    completed_at: Option<DateTime<Utc>>,
16    duration_ms: Option<i32>,
17    error_message: Option<String>,
18) -> Result<ExecutionStep> {
19    let status = status
20        .parse::<StepStatus>()
21        .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
22    let content: StepContent =
23        serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
24    Ok(ExecutionStep {
25        step_id: step_id.into(),
26        task_id: task_id.into(),
27        status,
28        started_at,
29        completed_at,
30        duration_ms,
31        error_message,
32        content,
33    })
34}
35
36#[derive(Debug, Clone)]
37pub struct ExecutionStepRepository {
38    pool: Arc<PgPool>,
39}
40
41impl ExecutionStepRepository {
42    pub fn new(db: &DbPool) -> Result<Self> {
43        let pool = db.pool_arc()?;
44        Ok(Self { pool })
45    }
46
47    pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
48        let step_id_str = step.step_id.as_str();
49        let task_id = &step.task_id;
50        let status_str = step.status.to_string();
51        let step_type_str = step.content.step_type().to_string();
52        let title = step.content.title();
53        let content_json =
54            serde_json::to_value(&step.content).context("Failed to serialize step content")?;
55        sqlx::query!(
56            r#"INSERT INTO task_execution_steps (
57                step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
58            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
59            step_id_str,
60            task_id.as_str(),
61            step_type_str,
62            title,
63            status_str,
64            content_json,
65            step.started_at,
66            step.completed_at,
67            step.duration_ms,
68            step.error_message
69        )
70        .execute(&*self.pool)
71        .await
72        .context("Failed to create execution step")?;
73        Ok(())
74    }
75
76    pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
77        let step_id_str = step_id.as_str();
78        let row = sqlx::query!(
79            r#"SELECT step_id, task_id, status, content,
80                    started_at as "started_at!", completed_at, duration_ms, error_message
81                FROM task_execution_steps WHERE step_id = $1"#,
82            step_id_str
83        )
84        .fetch_optional(&*self.pool)
85        .await
86        .context(format!("Failed to get execution step: {step_id}"))?;
87        row.map(|r| {
88            parse_step(
89                r.step_id,
90                r.task_id,
91                r.status,
92                r.content,
93                r.started_at,
94                r.completed_at,
95                r.duration_ms,
96                r.error_message,
97            )
98        })
99        .transpose()
100    }
101
102    pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
103        let rows = sqlx::query!(
104            r#"SELECT step_id, task_id, status, content,
105                    started_at as "started_at!", completed_at, duration_ms, error_message
106                FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
107            task_id.as_str()
108        )
109        .fetch_all(&*self.pool)
110        .await
111        .context(format!(
112            "Failed to list execution steps for task: {}",
113            task_id
114        ))?;
115        rows.into_iter()
116            .map(|r| {
117                parse_step(
118                    r.step_id,
119                    r.task_id,
120                    r.status,
121                    r.content,
122                    r.started_at,
123                    r.completed_at,
124                    r.duration_ms,
125                    r.error_message,
126                )
127            })
128            .collect()
129    }
130
131    pub async fn complete_step(
132        &self,
133        step_id: &StepId,
134        started_at: DateTime<Utc>,
135        tool_result: Option<serde_json::Value>,
136    ) -> Result<()> {
137        let completed_at = Utc::now();
138        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
139        let step_id_str = step_id.as_str();
140        let status_str = StepStatus::Completed.to_string();
141
142        if let Some(result) = tool_result {
143            sqlx::query!(
144                r#"UPDATE task_execution_steps SET
145                    status = $2,
146                    completed_at = $3,
147                    duration_ms = $4,
148                    content = jsonb_set(content, '{tool_result}', $5)
149                WHERE step_id = $1"#,
150                step_id_str,
151                status_str,
152                completed_at,
153                duration_ms,
154                result
155            )
156            .execute(&*self.pool)
157            .await
158            .context(format!("Failed to complete execution step: {step_id}"))?;
159        } else {
160            sqlx::query!(
161                r#"UPDATE task_execution_steps SET
162                    status = $2,
163                    completed_at = $3,
164                    duration_ms = $4
165                WHERE step_id = $1"#,
166                step_id_str,
167                status_str,
168                completed_at,
169                duration_ms
170            )
171            .execute(&*self.pool)
172            .await
173            .context(format!("Failed to complete execution step: {step_id}"))?;
174        }
175
176        Ok(())
177    }
178
179    pub async fn fail_step(
180        &self,
181        step_id: &StepId,
182        started_at: DateTime<Utc>,
183        error_message: &str,
184    ) -> Result<()> {
185        let completed_at = Utc::now();
186        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
187        let step_id_str = step_id.as_str();
188        let status_str = StepStatus::Failed.to_string();
189
190        sqlx::query!(
191            r#"UPDATE task_execution_steps SET
192                status = $2,
193                completed_at = $3,
194                duration_ms = $4,
195                error_message = $5
196            WHERE step_id = $1"#,
197            step_id_str,
198            status_str,
199            completed_at,
200            duration_ms,
201            error_message
202        )
203        .execute(&*self.pool)
204        .await
205        .context(format!("Failed to fail execution step: {step_id}"))?;
206
207        Ok(())
208    }
209
210    pub async fn fail_in_progress_steps_for_task(
211        &self,
212        task_id: &TaskId,
213        error_message: &str,
214    ) -> Result<u64> {
215        let completed_at = Utc::now();
216        let in_progress_str = StepStatus::InProgress.to_string();
217        let failed_str = StepStatus::Failed.to_string();
218        let task_id_str = task_id.as_str();
219
220        let result = sqlx::query!(
221            r#"UPDATE task_execution_steps SET
222                status = $3,
223                completed_at = $4,
224                error_message = $5
225            WHERE task_id = $1 AND status = $2"#,
226            task_id_str,
227            in_progress_str,
228            failed_str,
229            completed_at,
230            error_message
231        )
232        .execute(&*self.pool)
233        .await
234        .context(format!(
235            "Failed to fail in-progress steps for task: {}",
236            task_id
237        ))?;
238
239        Ok(result.rows_affected())
240    }
241
242    pub async fn complete_planning_step(
243        &self,
244        step_id: &StepId,
245        started_at: DateTime<Utc>,
246        reasoning: Option<String>,
247        planned_tools: Option<Vec<PlannedTool>>,
248    ) -> Result<ExecutionStep> {
249        let completed_at = Utc::now();
250        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
251        let step_id_str = step_id.as_str();
252        let status_str = StepStatus::Completed.to_string();
253
254        let content = StepContent::planning(reasoning, planned_tools);
255        let content_json =
256            serde_json::to_value(&content).context("Failed to serialize planning content")?;
257
258        let row = sqlx::query!(
259            r#"UPDATE task_execution_steps SET
260                status = $2,
261                completed_at = $3,
262                duration_ms = $4,
263                content = $5
264            WHERE step_id = $1
265            RETURNING step_id, task_id, status, content,
266                    started_at as "started_at!", completed_at, duration_ms, error_message"#,
267            step_id_str,
268            status_str,
269            completed_at,
270            duration_ms,
271            content_json
272        )
273        .fetch_one(&*self.pool)
274        .await
275        .context(format!("Failed to complete planning step: {step_id}"))?;
276
277        parse_step(
278            row.step_id,
279            row.task_id,
280            row.status,
281            row.content,
282            row.started_at,
283            row.completed_at,
284            row.duration_ms,
285            row.error_message,
286        )
287    }
288
289    pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
290        let exists = sqlx::query_scalar!(
291            r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
292            mcp_execution_id
293        )
294        .fetch_one(&*self.pool)
295        .await
296        .context("Failed to check mcp_execution_id existence")?;
297
298        Ok(exists)
299    }
300}