1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::TaskId;
7use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};
8
9fn parse_step(
10 step_id: String,
11 task_id: String,
12 status: String,
13 content: serde_json::Value,
14 started_at: DateTime<Utc>,
15 completed_at: Option<DateTime<Utc>>,
16 duration_ms: Option<i32>,
17 error_message: Option<String>,
18) -> Result<ExecutionStep> {
19 let status = status
20 .parse::<StepStatus>()
21 .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
22 let content: StepContent =
23 serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
24 Ok(ExecutionStep {
25 step_id: step_id.into(),
26 task_id: task_id.into(),
27 status,
28 started_at,
29 completed_at,
30 duration_ms,
31 error_message,
32 content,
33 })
34}
35
36#[derive(Debug, Clone)]
37pub struct ExecutionStepRepository {
38 pool: Arc<PgPool>,
39 write_pool: Arc<PgPool>,
40}
41
42impl ExecutionStepRepository {
43 pub fn new(db: &DbPool) -> Result<Self> {
44 let pool = db.pool_arc()?;
45 let write_pool = db.write_pool_arc()?;
46 Ok(Self { pool, write_pool })
47 }
48
49 pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
50 let step_id_str = step.step_id.as_str();
51 let task_id = &step.task_id;
52 let status_str = step.status.to_string();
53 let step_type_str = step.content.step_type().to_string();
54 let title = step.content.title();
55 let content_json =
56 serde_json::to_value(&step.content).context("Failed to serialize step content")?;
57 sqlx::query!(
58 r#"INSERT INTO task_execution_steps (
59 step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
60 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
61 step_id_str,
62 task_id.as_str(),
63 step_type_str,
64 title,
65 status_str,
66 content_json,
67 step.started_at,
68 step.completed_at,
69 step.duration_ms,
70 step.error_message
71 )
72 .execute(&*self.write_pool)
73 .await
74 .context("Failed to create execution step")?;
75 Ok(())
76 }
77
78 pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
79 let step_id_str = step_id.as_str();
80 let row = sqlx::query!(
81 r#"SELECT step_id, task_id, status, content,
82 started_at as "started_at!", completed_at, duration_ms, error_message
83 FROM task_execution_steps WHERE step_id = $1"#,
84 step_id_str
85 )
86 .fetch_optional(&*self.pool)
87 .await
88 .context(format!("Failed to get execution step: {step_id}"))?;
89 row.map(|r| {
90 parse_step(
91 r.step_id,
92 r.task_id,
93 r.status,
94 r.content,
95 r.started_at,
96 r.completed_at,
97 r.duration_ms,
98 r.error_message,
99 )
100 })
101 .transpose()
102 }
103
104 pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
105 let rows = sqlx::query!(
106 r#"SELECT step_id, task_id, status, content,
107 started_at as "started_at!", completed_at, duration_ms, error_message
108 FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
109 task_id.as_str()
110 )
111 .fetch_all(&*self.pool)
112 .await
113 .context(format!(
114 "Failed to list execution steps for task: {}",
115 task_id
116 ))?;
117 rows.into_iter()
118 .map(|r| {
119 parse_step(
120 r.step_id,
121 r.task_id,
122 r.status,
123 r.content,
124 r.started_at,
125 r.completed_at,
126 r.duration_ms,
127 r.error_message,
128 )
129 })
130 .collect()
131 }
132
133 pub async fn complete_step(
134 &self,
135 step_id: &StepId,
136 started_at: DateTime<Utc>,
137 tool_result: Option<serde_json::Value>,
138 ) -> Result<()> {
139 let completed_at = Utc::now();
140 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
141 let step_id_str = step_id.as_str();
142 let status_str = StepStatus::Completed.to_string();
143
144 if let Some(result) = tool_result {
145 sqlx::query!(
146 r#"UPDATE task_execution_steps SET
147 status = $2,
148 completed_at = $3,
149 duration_ms = $4,
150 content = jsonb_set(content, '{tool_result}', $5)
151 WHERE step_id = $1"#,
152 step_id_str,
153 status_str,
154 completed_at,
155 duration_ms,
156 result
157 )
158 .execute(&*self.write_pool)
159 .await
160 .context(format!("Failed to complete execution step: {step_id}"))?;
161 } else {
162 sqlx::query!(
163 r#"UPDATE task_execution_steps SET
164 status = $2,
165 completed_at = $3,
166 duration_ms = $4
167 WHERE step_id = $1"#,
168 step_id_str,
169 status_str,
170 completed_at,
171 duration_ms
172 )
173 .execute(&*self.write_pool)
174 .await
175 .context(format!("Failed to complete execution step: {step_id}"))?;
176 }
177
178 Ok(())
179 }
180
181 pub async fn fail_step(
182 &self,
183 step_id: &StepId,
184 started_at: DateTime<Utc>,
185 error_message: &str,
186 ) -> Result<()> {
187 let completed_at = Utc::now();
188 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
189 let step_id_str = step_id.as_str();
190 let status_str = StepStatus::Failed.to_string();
191
192 sqlx::query!(
193 r#"UPDATE task_execution_steps SET
194 status = $2,
195 completed_at = $3,
196 duration_ms = $4,
197 error_message = $5
198 WHERE step_id = $1"#,
199 step_id_str,
200 status_str,
201 completed_at,
202 duration_ms,
203 error_message
204 )
205 .execute(&*self.write_pool)
206 .await
207 .context(format!("Failed to fail execution step: {step_id}"))?;
208
209 Ok(())
210 }
211
212 pub async fn fail_in_progress_steps_for_task(
213 &self,
214 task_id: &TaskId,
215 error_message: &str,
216 ) -> Result<u64> {
217 let completed_at = Utc::now();
218 let in_progress_str = StepStatus::InProgress.to_string();
219 let failed_str = StepStatus::Failed.to_string();
220 let task_id_str = task_id.as_str();
221
222 let result = sqlx::query!(
223 r#"UPDATE task_execution_steps SET
224 status = $3,
225 completed_at = $4,
226 error_message = $5
227 WHERE task_id = $1 AND status = $2"#,
228 task_id_str,
229 in_progress_str,
230 failed_str,
231 completed_at,
232 error_message
233 )
234 .execute(&*self.write_pool)
235 .await
236 .context(format!(
237 "Failed to fail in-progress steps for task: {}",
238 task_id
239 ))?;
240
241 Ok(result.rows_affected())
242 }
243
244 pub async fn complete_planning_step(
245 &self,
246 step_id: &StepId,
247 started_at: DateTime<Utc>,
248 reasoning: Option<String>,
249 planned_tools: Option<Vec<PlannedTool>>,
250 ) -> Result<ExecutionStep> {
251 let completed_at = Utc::now();
252 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
253 let step_id_str = step_id.as_str();
254 let status_str = StepStatus::Completed.to_string();
255
256 let content = StepContent::planning(reasoning, planned_tools);
257 let content_json =
258 serde_json::to_value(&content).context("Failed to serialize planning content")?;
259
260 let row = sqlx::query!(
261 r#"UPDATE task_execution_steps SET
262 status = $2,
263 completed_at = $3,
264 duration_ms = $4,
265 content = $5
266 WHERE step_id = $1
267 RETURNING step_id, task_id, status, content,
268 started_at as "started_at!", completed_at, duration_ms, error_message"#,
269 step_id_str,
270 status_str,
271 completed_at,
272 duration_ms,
273 content_json
274 )
275 .fetch_one(&*self.write_pool)
276 .await
277 .context(format!("Failed to complete planning step: {step_id}"))?;
278
279 parse_step(
280 row.step_id,
281 row.task_id,
282 row.status,
283 row.content,
284 row.started_at,
285 row.completed_at,
286 row.duration_ms,
287 row.error_message,
288 )
289 }
290
291 pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
292 let exists = sqlx::query_scalar!(
293 r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
294 mcp_execution_id
295 )
296 .fetch_one(&*self.pool)
297 .await
298 .context("Failed to check mcp_execution_id existence")?;
299
300 Ok(exists)
301 }
302}