Skip to main content

systemprompt_agent/repository/execution/
mod.rs

1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::TaskId;
7use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};
8
9#[allow(missing_debug_implementations)]
10struct ParseStepParams {
11    step_id: String,
12    task_id: String,
13    status: String,
14    content: serde_json::Value,
15    started_at: DateTime<Utc>,
16    completed_at: Option<DateTime<Utc>>,
17    duration_ms: Option<i32>,
18    error_message: Option<String>,
19}
20
21fn parse_step(params: ParseStepParams) -> Result<ExecutionStep> {
22    let ParseStepParams {
23        step_id,
24        task_id,
25        status,
26        content,
27        started_at,
28        completed_at,
29        duration_ms,
30        error_message,
31    } = params;
32    let status = status
33        .parse::<StepStatus>()
34        .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
35    let content: StepContent =
36        serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
37    Ok(ExecutionStep {
38        step_id: step_id.into(),
39        task_id: task_id.into(),
40        status,
41        started_at,
42        completed_at,
43        duration_ms,
44        error_message,
45        content,
46    })
47}
48
49#[derive(Debug, Clone)]
50pub struct ExecutionStepRepository {
51    pool: Arc<PgPool>,
52    write_pool: Arc<PgPool>,
53}
54
55impl ExecutionStepRepository {
56    pub fn new(db: &DbPool) -> Result<Self> {
57        let pool = db.pool_arc()?;
58        let write_pool = db.write_pool_arc()?;
59        Ok(Self { pool, write_pool })
60    }
61
62    pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
63        let step_id_str = step.step_id.as_str();
64        let task_id = &step.task_id;
65        let status_str = step.status.to_string();
66        let step_type_str = step.content.step_type().to_string();
67        let title = step.content.title();
68        let content_json =
69            serde_json::to_value(&step.content).context("Failed to serialize step content")?;
70        sqlx::query!(
71            r#"INSERT INTO task_execution_steps (
72                step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
73            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
74            step_id_str,
75            task_id.as_str(),
76            step_type_str,
77            title,
78            status_str,
79            content_json,
80            step.started_at,
81            step.completed_at,
82            step.duration_ms,
83            step.error_message
84        )
85        .execute(&*self.write_pool)
86        .await
87        .context("Failed to create execution step")?;
88        Ok(())
89    }
90
91    pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
92        let step_id_str = step_id.as_str();
93        let row = sqlx::query!(
94            r#"SELECT step_id, task_id, status, content,
95                    started_at as "started_at!", completed_at, duration_ms, error_message
96                FROM task_execution_steps WHERE step_id = $1"#,
97            step_id_str
98        )
99        .fetch_optional(&*self.pool)
100        .await
101        .context(format!("Failed to get execution step: {step_id}"))?;
102        row.map(|r| {
103            parse_step(ParseStepParams {
104                step_id: r.step_id,
105                task_id: r.task_id,
106                status: r.status,
107                content: r.content,
108                started_at: r.started_at,
109                completed_at: r.completed_at,
110                duration_ms: r.duration_ms,
111                error_message: r.error_message,
112            })
113        })
114        .transpose()
115    }
116
117    pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
118        let rows = sqlx::query!(
119            r#"SELECT step_id, task_id, status, content,
120                    started_at as "started_at!", completed_at, duration_ms, error_message
121                FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
122            task_id.as_str()
123        )
124        .fetch_all(&*self.pool)
125        .await
126        .context(format!(
127            "Failed to list execution steps for task: {}",
128            task_id
129        ))?;
130        rows.into_iter()
131            .map(|r| {
132                parse_step(ParseStepParams {
133                    step_id: r.step_id,
134                    task_id: r.task_id,
135                    status: r.status,
136                    content: r.content,
137                    started_at: r.started_at,
138                    completed_at: r.completed_at,
139                    duration_ms: r.duration_ms,
140                    error_message: r.error_message,
141                })
142            })
143            .collect()
144    }
145
146    pub async fn complete_step(
147        &self,
148        step_id: &StepId,
149        started_at: DateTime<Utc>,
150        tool_result: Option<serde_json::Value>,
151    ) -> Result<()> {
152        let completed_at = Utc::now();
153        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
154        let step_id_str = step_id.as_str();
155        let status_str = StepStatus::Completed.to_string();
156
157        if let Some(result) = tool_result {
158            sqlx::query!(
159                r#"UPDATE task_execution_steps SET
160                    status = $2,
161                    completed_at = $3,
162                    duration_ms = $4,
163                    content = jsonb_set(content, '{tool_result}', $5)
164                WHERE step_id = $1"#,
165                step_id_str,
166                status_str,
167                completed_at,
168                duration_ms,
169                result
170            )
171            .execute(&*self.write_pool)
172            .await
173            .context(format!("Failed to complete execution step: {step_id}"))?;
174        } else {
175            sqlx::query!(
176                r#"UPDATE task_execution_steps SET
177                    status = $2,
178                    completed_at = $3,
179                    duration_ms = $4
180                WHERE step_id = $1"#,
181                step_id_str,
182                status_str,
183                completed_at,
184                duration_ms
185            )
186            .execute(&*self.write_pool)
187            .await
188            .context(format!("Failed to complete execution step: {step_id}"))?;
189        }
190
191        Ok(())
192    }
193
194    pub async fn fail_step(
195        &self,
196        step_id: &StepId,
197        started_at: DateTime<Utc>,
198        error_message: &str,
199    ) -> Result<()> {
200        let completed_at = Utc::now();
201        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
202        let step_id_str = step_id.as_str();
203        let status_str = StepStatus::Failed.to_string();
204
205        sqlx::query!(
206            r#"UPDATE task_execution_steps SET
207                status = $2,
208                completed_at = $3,
209                duration_ms = $4,
210                error_message = $5
211            WHERE step_id = $1"#,
212            step_id_str,
213            status_str,
214            completed_at,
215            duration_ms,
216            error_message
217        )
218        .execute(&*self.write_pool)
219        .await
220        .context(format!("Failed to fail execution step: {step_id}"))?;
221
222        Ok(())
223    }
224
225    pub async fn fail_in_progress_steps_for_task(
226        &self,
227        task_id: &TaskId,
228        error_message: &str,
229    ) -> Result<u64> {
230        let completed_at = Utc::now();
231        let in_progress_str = StepStatus::InProgress.to_string();
232        let failed_str = StepStatus::Failed.to_string();
233        let task_id_str = task_id.as_str();
234
235        let result = sqlx::query!(
236            r#"UPDATE task_execution_steps SET
237                status = $3,
238                completed_at = $4,
239                error_message = $5
240            WHERE task_id = $1 AND status = $2"#,
241            task_id_str,
242            in_progress_str,
243            failed_str,
244            completed_at,
245            error_message
246        )
247        .execute(&*self.write_pool)
248        .await
249        .context(format!(
250            "Failed to fail in-progress steps for task: {}",
251            task_id
252        ))?;
253
254        Ok(result.rows_affected())
255    }
256
257    pub async fn complete_planning_step(
258        &self,
259        step_id: &StepId,
260        started_at: DateTime<Utc>,
261        reasoning: Option<String>,
262        planned_tools: Option<Vec<PlannedTool>>,
263    ) -> Result<ExecutionStep> {
264        let completed_at = Utc::now();
265        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
266        let step_id_str = step_id.as_str();
267        let status_str = StepStatus::Completed.to_string();
268
269        let content = StepContent::planning(reasoning, planned_tools);
270        let content_json =
271            serde_json::to_value(&content).context("Failed to serialize planning content")?;
272
273        let row = sqlx::query!(
274            r#"UPDATE task_execution_steps SET
275                status = $2,
276                completed_at = $3,
277                duration_ms = $4,
278                content = $5
279            WHERE step_id = $1
280            RETURNING step_id, task_id, status, content,
281                    started_at as "started_at!", completed_at, duration_ms, error_message"#,
282            step_id_str,
283            status_str,
284            completed_at,
285            duration_ms,
286            content_json
287        )
288        .fetch_one(&*self.write_pool)
289        .await
290        .context(format!("Failed to complete planning step: {step_id}"))?;
291
292        parse_step(ParseStepParams {
293            step_id: row.step_id,
294            task_id: row.task_id,
295            status: row.status,
296            content: row.content,
297            started_at: row.started_at,
298            completed_at: row.completed_at,
299            duration_ms: row.duration_ms,
300            error_message: row.error_message,
301        })
302    }
303
304    pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
305        let exists = sqlx::query_scalar!(
306            r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
307            mcp_execution_id
308        )
309        .fetch_one(&*self.pool)
310        .await
311        .context("Failed to check mcp_execution_id existence")?;
312
313        Ok(exists)
314    }
315}