1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::TaskId;
7use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};
8
9fn parse_step(
10 step_id: String,
11 task_id: String,
12 status: String,
13 content: serde_json::Value,
14 started_at: DateTime<Utc>,
15 completed_at: Option<DateTime<Utc>>,
16 duration_ms: Option<i32>,
17 error_message: Option<String>,
18) -> Result<ExecutionStep> {
19 let status = status
20 .parse::<StepStatus>()
21 .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
22 let content: StepContent =
23 serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
24 Ok(ExecutionStep {
25 step_id: step_id.into(),
26 task_id: task_id.into(),
27 status,
28 started_at,
29 completed_at,
30 duration_ms,
31 error_message,
32 content,
33 })
34}
35
36#[derive(Debug, Clone)]
37pub struct ExecutionStepRepository {
38 pool: Arc<PgPool>,
39}
40
41impl ExecutionStepRepository {
42 pub fn new(db: &DbPool) -> Result<Self> {
43 let pool = db.pool_arc()?;
44 Ok(Self { pool })
45 }
46
47 pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
48 let step_id_str = step.step_id.as_str();
49 let task_id = &step.task_id;
50 let status_str = step.status.to_string();
51 let step_type_str = step.content.step_type().to_string();
52 let title = step.content.title();
53 let content_json =
54 serde_json::to_value(&step.content).context("Failed to serialize step content")?;
55 sqlx::query!(
56 r#"INSERT INTO task_execution_steps (
57 step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
58 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
59 step_id_str,
60 task_id.as_str(),
61 step_type_str,
62 title,
63 status_str,
64 content_json,
65 step.started_at,
66 step.completed_at,
67 step.duration_ms,
68 step.error_message
69 )
70 .execute(&*self.pool)
71 .await
72 .context("Failed to create execution step")?;
73 Ok(())
74 }
75
76 pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
77 let step_id_str = step_id.as_str();
78 let row = sqlx::query!(
79 r#"SELECT step_id, task_id, status, content,
80 started_at as "started_at!", completed_at, duration_ms, error_message
81 FROM task_execution_steps WHERE step_id = $1"#,
82 step_id_str
83 )
84 .fetch_optional(&*self.pool)
85 .await
86 .context(format!("Failed to get execution step: {step_id}"))?;
87 row.map(|r| {
88 parse_step(
89 r.step_id,
90 r.task_id,
91 r.status,
92 r.content,
93 r.started_at,
94 r.completed_at,
95 r.duration_ms,
96 r.error_message,
97 )
98 })
99 .transpose()
100 }
101
102 pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
103 let rows = sqlx::query!(
104 r#"SELECT step_id, task_id, status, content,
105 started_at as "started_at!", completed_at, duration_ms, error_message
106 FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
107 task_id.as_str()
108 )
109 .fetch_all(&*self.pool)
110 .await
111 .context(format!(
112 "Failed to list execution steps for task: {}",
113 task_id
114 ))?;
115 rows.into_iter()
116 .map(|r| {
117 parse_step(
118 r.step_id,
119 r.task_id,
120 r.status,
121 r.content,
122 r.started_at,
123 r.completed_at,
124 r.duration_ms,
125 r.error_message,
126 )
127 })
128 .collect()
129 }
130
131 pub async fn complete_step(
132 &self,
133 step_id: &StepId,
134 started_at: DateTime<Utc>,
135 tool_result: Option<serde_json::Value>,
136 ) -> Result<()> {
137 let completed_at = Utc::now();
138 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
139 let step_id_str = step_id.as_str();
140 let status_str = StepStatus::Completed.to_string();
141
142 if let Some(result) = tool_result {
143 sqlx::query!(
144 r#"UPDATE task_execution_steps SET
145 status = $2,
146 completed_at = $3,
147 duration_ms = $4,
148 content = jsonb_set(content, '{tool_result}', $5)
149 WHERE step_id = $1"#,
150 step_id_str,
151 status_str,
152 completed_at,
153 duration_ms,
154 result
155 )
156 .execute(&*self.pool)
157 .await
158 .context(format!("Failed to complete execution step: {step_id}"))?;
159 } else {
160 sqlx::query!(
161 r#"UPDATE task_execution_steps SET
162 status = $2,
163 completed_at = $3,
164 duration_ms = $4
165 WHERE step_id = $1"#,
166 step_id_str,
167 status_str,
168 completed_at,
169 duration_ms
170 )
171 .execute(&*self.pool)
172 .await
173 .context(format!("Failed to complete execution step: {step_id}"))?;
174 }
175
176 Ok(())
177 }
178
179 pub async fn fail_step(
180 &self,
181 step_id: &StepId,
182 started_at: DateTime<Utc>,
183 error_message: &str,
184 ) -> Result<()> {
185 let completed_at = Utc::now();
186 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
187 let step_id_str = step_id.as_str();
188 let status_str = StepStatus::Failed.to_string();
189
190 sqlx::query!(
191 r#"UPDATE task_execution_steps SET
192 status = $2,
193 completed_at = $3,
194 duration_ms = $4,
195 error_message = $5
196 WHERE step_id = $1"#,
197 step_id_str,
198 status_str,
199 completed_at,
200 duration_ms,
201 error_message
202 )
203 .execute(&*self.pool)
204 .await
205 .context(format!("Failed to fail execution step: {step_id}"))?;
206
207 Ok(())
208 }
209
210 pub async fn fail_in_progress_steps_for_task(
211 &self,
212 task_id: &TaskId,
213 error_message: &str,
214 ) -> Result<u64> {
215 let completed_at = Utc::now();
216 let in_progress_str = StepStatus::InProgress.to_string();
217 let failed_str = StepStatus::Failed.to_string();
218 let task_id_str = task_id.as_str();
219
220 let result = sqlx::query!(
221 r#"UPDATE task_execution_steps SET
222 status = $3,
223 completed_at = $4,
224 error_message = $5
225 WHERE task_id = $1 AND status = $2"#,
226 task_id_str,
227 in_progress_str,
228 failed_str,
229 completed_at,
230 error_message
231 )
232 .execute(&*self.pool)
233 .await
234 .context(format!(
235 "Failed to fail in-progress steps for task: {}",
236 task_id
237 ))?;
238
239 Ok(result.rows_affected())
240 }
241
242 pub async fn complete_planning_step(
243 &self,
244 step_id: &StepId,
245 started_at: DateTime<Utc>,
246 reasoning: Option<String>,
247 planned_tools: Option<Vec<PlannedTool>>,
248 ) -> Result<ExecutionStep> {
249 let completed_at = Utc::now();
250 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
251 let step_id_str = step_id.as_str();
252 let status_str = StepStatus::Completed.to_string();
253
254 let content = StepContent::planning(reasoning, planned_tools);
255 let content_json =
256 serde_json::to_value(&content).context("Failed to serialize planning content")?;
257
258 let row = sqlx::query!(
259 r#"UPDATE task_execution_steps SET
260 status = $2,
261 completed_at = $3,
262 duration_ms = $4,
263 content = $5
264 WHERE step_id = $1
265 RETURNING step_id, task_id, status, content,
266 started_at as "started_at!", completed_at, duration_ms, error_message"#,
267 step_id_str,
268 status_str,
269 completed_at,
270 duration_ms,
271 content_json
272 )
273 .fetch_one(&*self.pool)
274 .await
275 .context(format!("Failed to complete planning step: {step_id}"))?;
276
277 parse_step(
278 row.step_id,
279 row.task_id,
280 row.status,
281 row.content,
282 row.started_at,
283 row.completed_at,
284 row.duration_ms,
285 row.error_message,
286 )
287 }
288
289 pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
290 let exists = sqlx::query_scalar!(
291 r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
292 mcp_execution_id
293 )
294 .fetch_one(&*self.pool)
295 .await
296 .context("Failed to check mcp_execution_id existence")?;
297
298 Ok(exists)
299 }
300}