Skip to main content

systemprompt_agent/repository/execution/
mod.rs

1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::TaskId;
7use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};
8
9fn parse_step(
10    step_id: String,
11    task_id: String,
12    status: String,
13    content: serde_json::Value,
14    started_at: DateTime<Utc>,
15    completed_at: Option<DateTime<Utc>>,
16    duration_ms: Option<i32>,
17    error_message: Option<String>,
18) -> Result<ExecutionStep> {
19    let status = status
20        .parse::<StepStatus>()
21        .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
22    let content: StepContent =
23        serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
24    Ok(ExecutionStep {
25        step_id: step_id.into(),
26        task_id: task_id.into(),
27        status,
28        started_at,
29        completed_at,
30        duration_ms,
31        error_message,
32        content,
33    })
34}
35
36#[derive(Debug, Clone)]
37pub struct ExecutionStepRepository {
38    pool: Arc<PgPool>,
39    write_pool: Arc<PgPool>,
40}
41
42impl ExecutionStepRepository {
43    pub fn new(db: &DbPool) -> Result<Self> {
44        let pool = db.pool_arc()?;
45        let write_pool = db.write_pool_arc()?;
46        Ok(Self { pool, write_pool })
47    }
48
49    pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
50        let step_id_str = step.step_id.as_str();
51        let task_id = &step.task_id;
52        let status_str = step.status.to_string();
53        let step_type_str = step.content.step_type().to_string();
54        let title = step.content.title();
55        let content_json =
56            serde_json::to_value(&step.content).context("Failed to serialize step content")?;
57        sqlx::query!(
58            r#"INSERT INTO task_execution_steps (
59                step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
60            ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
61            step_id_str,
62            task_id.as_str(),
63            step_type_str,
64            title,
65            status_str,
66            content_json,
67            step.started_at,
68            step.completed_at,
69            step.duration_ms,
70            step.error_message
71        )
72        .execute(&*self.write_pool)
73        .await
74        .context("Failed to create execution step")?;
75        Ok(())
76    }
77
78    pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
79        let step_id_str = step_id.as_str();
80        let row = sqlx::query!(
81            r#"SELECT step_id, task_id, status, content,
82                    started_at as "started_at!", completed_at, duration_ms, error_message
83                FROM task_execution_steps WHERE step_id = $1"#,
84            step_id_str
85        )
86        .fetch_optional(&*self.pool)
87        .await
88        .context(format!("Failed to get execution step: {step_id}"))?;
89        row.map(|r| {
90            parse_step(
91                r.step_id,
92                r.task_id,
93                r.status,
94                r.content,
95                r.started_at,
96                r.completed_at,
97                r.duration_ms,
98                r.error_message,
99            )
100        })
101        .transpose()
102    }
103
104    pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
105        let rows = sqlx::query!(
106            r#"SELECT step_id, task_id, status, content,
107                    started_at as "started_at!", completed_at, duration_ms, error_message
108                FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
109            task_id.as_str()
110        )
111        .fetch_all(&*self.pool)
112        .await
113        .context(format!(
114            "Failed to list execution steps for task: {}",
115            task_id
116        ))?;
117        rows.into_iter()
118            .map(|r| {
119                parse_step(
120                    r.step_id,
121                    r.task_id,
122                    r.status,
123                    r.content,
124                    r.started_at,
125                    r.completed_at,
126                    r.duration_ms,
127                    r.error_message,
128                )
129            })
130            .collect()
131    }
132
133    pub async fn complete_step(
134        &self,
135        step_id: &StepId,
136        started_at: DateTime<Utc>,
137        tool_result: Option<serde_json::Value>,
138    ) -> Result<()> {
139        let completed_at = Utc::now();
140        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
141        let step_id_str = step_id.as_str();
142        let status_str = StepStatus::Completed.to_string();
143
144        if let Some(result) = tool_result {
145            sqlx::query!(
146                r#"UPDATE task_execution_steps SET
147                    status = $2,
148                    completed_at = $3,
149                    duration_ms = $4,
150                    content = jsonb_set(content, '{tool_result}', $5)
151                WHERE step_id = $1"#,
152                step_id_str,
153                status_str,
154                completed_at,
155                duration_ms,
156                result
157            )
158            .execute(&*self.write_pool)
159            .await
160            .context(format!("Failed to complete execution step: {step_id}"))?;
161        } else {
162            sqlx::query!(
163                r#"UPDATE task_execution_steps SET
164                    status = $2,
165                    completed_at = $3,
166                    duration_ms = $4
167                WHERE step_id = $1"#,
168                step_id_str,
169                status_str,
170                completed_at,
171                duration_ms
172            )
173            .execute(&*self.write_pool)
174            .await
175            .context(format!("Failed to complete execution step: {step_id}"))?;
176        }
177
178        Ok(())
179    }
180
181    pub async fn fail_step(
182        &self,
183        step_id: &StepId,
184        started_at: DateTime<Utc>,
185        error_message: &str,
186    ) -> Result<()> {
187        let completed_at = Utc::now();
188        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
189        let step_id_str = step_id.as_str();
190        let status_str = StepStatus::Failed.to_string();
191
192        sqlx::query!(
193            r#"UPDATE task_execution_steps SET
194                status = $2,
195                completed_at = $3,
196                duration_ms = $4,
197                error_message = $5
198            WHERE step_id = $1"#,
199            step_id_str,
200            status_str,
201            completed_at,
202            duration_ms,
203            error_message
204        )
205        .execute(&*self.write_pool)
206        .await
207        .context(format!("Failed to fail execution step: {step_id}"))?;
208
209        Ok(())
210    }
211
212    pub async fn fail_in_progress_steps_for_task(
213        &self,
214        task_id: &TaskId,
215        error_message: &str,
216    ) -> Result<u64> {
217        let completed_at = Utc::now();
218        let in_progress_str = StepStatus::InProgress.to_string();
219        let failed_str = StepStatus::Failed.to_string();
220        let task_id_str = task_id.as_str();
221
222        let result = sqlx::query!(
223            r#"UPDATE task_execution_steps SET
224                status = $3,
225                completed_at = $4,
226                error_message = $5
227            WHERE task_id = $1 AND status = $2"#,
228            task_id_str,
229            in_progress_str,
230            failed_str,
231            completed_at,
232            error_message
233        )
234        .execute(&*self.write_pool)
235        .await
236        .context(format!(
237            "Failed to fail in-progress steps for task: {}",
238            task_id
239        ))?;
240
241        Ok(result.rows_affected())
242    }
243
244    pub async fn complete_planning_step(
245        &self,
246        step_id: &StepId,
247        started_at: DateTime<Utc>,
248        reasoning: Option<String>,
249        planned_tools: Option<Vec<PlannedTool>>,
250    ) -> Result<ExecutionStep> {
251        let completed_at = Utc::now();
252        let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
253        let step_id_str = step_id.as_str();
254        let status_str = StepStatus::Completed.to_string();
255
256        let content = StepContent::planning(reasoning, planned_tools);
257        let content_json =
258            serde_json::to_value(&content).context("Failed to serialize planning content")?;
259
260        let row = sqlx::query!(
261            r#"UPDATE task_execution_steps SET
262                status = $2,
263                completed_at = $3,
264                duration_ms = $4,
265                content = $5
266            WHERE step_id = $1
267            RETURNING step_id, task_id, status, content,
268                    started_at as "started_at!", completed_at, duration_ms, error_message"#,
269            step_id_str,
270            status_str,
271            completed_at,
272            duration_ms,
273            content_json
274        )
275        .fetch_one(&*self.write_pool)
276        .await
277        .context(format!("Failed to complete planning step: {step_id}"))?;
278
279        parse_step(
280            row.step_id,
281            row.task_id,
282            row.status,
283            row.content,
284            row.started_at,
285            row.completed_at,
286            row.duration_ms,
287            row.error_message,
288        )
289    }
290
291    pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
292        let exists = sqlx::query_scalar!(
293            r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
294            mcp_execution_id
295        )
296        .fetch_one(&*self.pool)
297        .await
298        .context("Failed to check mcp_execution_id existence")?;
299
300        Ok(exists)
301    }
302}