1use anyhow::{Context, Result};
2use chrono::{DateTime, Utc};
3use sqlx::PgPool;
4use std::sync::Arc;
5use systemprompt_database::DbPool;
6use systemprompt_identifiers::TaskId;
7use systemprompt_models::{ExecutionStep, PlannedTool, StepContent, StepId, StepStatus};
8
9#[allow(missing_debug_implementations)]
10struct ParseStepParams {
11 step_id: String,
12 task_id: String,
13 status: String,
14 content: serde_json::Value,
15 started_at: DateTime<Utc>,
16 completed_at: Option<DateTime<Utc>>,
17 duration_ms: Option<i32>,
18 error_message: Option<String>,
19}
20
21fn parse_step(params: ParseStepParams) -> Result<ExecutionStep> {
22 let ParseStepParams {
23 step_id,
24 task_id,
25 status,
26 content,
27 started_at,
28 completed_at,
29 duration_ms,
30 error_message,
31 } = params;
32 let status = status
33 .parse::<StepStatus>()
34 .map_err(|e| anyhow::anyhow!("Invalid status: {}", e))?;
35 let content: StepContent =
36 serde_json::from_value(content).map_err(|e| anyhow::anyhow!("Invalid content: {}", e))?;
37 Ok(ExecutionStep {
38 step_id: step_id.into(),
39 task_id: task_id.into(),
40 status,
41 started_at,
42 completed_at,
43 duration_ms,
44 error_message,
45 content,
46 })
47}
48
49#[derive(Debug, Clone)]
50pub struct ExecutionStepRepository {
51 pool: Arc<PgPool>,
52 write_pool: Arc<PgPool>,
53}
54
55impl ExecutionStepRepository {
56 pub fn new(db: &DbPool) -> Result<Self> {
57 let pool = db.pool_arc()?;
58 let write_pool = db.write_pool_arc()?;
59 Ok(Self { pool, write_pool })
60 }
61
62 pub async fn create(&self, step: &ExecutionStep) -> Result<()> {
63 let step_id_str = step.step_id.as_str();
64 let task_id = &step.task_id;
65 let status_str = step.status.to_string();
66 let step_type_str = step.content.step_type().to_string();
67 let title = step.content.title();
68 let content_json =
69 serde_json::to_value(&step.content).context("Failed to serialize step content")?;
70 sqlx::query!(
71 r#"INSERT INTO task_execution_steps (
72 step_id, task_id, step_type, title, status, content, started_at, completed_at, duration_ms, error_message
73 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"#,
74 step_id_str,
75 task_id.as_str(),
76 step_type_str,
77 title,
78 status_str,
79 content_json,
80 step.started_at,
81 step.completed_at,
82 step.duration_ms,
83 step.error_message
84 )
85 .execute(&*self.write_pool)
86 .await
87 .context("Failed to create execution step")?;
88 Ok(())
89 }
90
91 pub async fn get(&self, step_id: &StepId) -> Result<Option<ExecutionStep>> {
92 let step_id_str = step_id.as_str();
93 let row = sqlx::query!(
94 r#"SELECT step_id, task_id, status, content,
95 started_at as "started_at!", completed_at, duration_ms, error_message
96 FROM task_execution_steps WHERE step_id = $1"#,
97 step_id_str
98 )
99 .fetch_optional(&*self.pool)
100 .await
101 .context(format!("Failed to get execution step: {step_id}"))?;
102 row.map(|r| {
103 parse_step(ParseStepParams {
104 step_id: r.step_id,
105 task_id: r.task_id,
106 status: r.status,
107 content: r.content,
108 started_at: r.started_at,
109 completed_at: r.completed_at,
110 duration_ms: r.duration_ms,
111 error_message: r.error_message,
112 })
113 })
114 .transpose()
115 }
116
117 pub async fn list_by_task(&self, task_id: &TaskId) -> Result<Vec<ExecutionStep>> {
118 let rows = sqlx::query!(
119 r#"SELECT step_id, task_id, status, content,
120 started_at as "started_at!", completed_at, duration_ms, error_message
121 FROM task_execution_steps WHERE task_id = $1 ORDER BY started_at ASC"#,
122 task_id.as_str()
123 )
124 .fetch_all(&*self.pool)
125 .await
126 .context(format!(
127 "Failed to list execution steps for task: {}",
128 task_id
129 ))?;
130 rows.into_iter()
131 .map(|r| {
132 parse_step(ParseStepParams {
133 step_id: r.step_id,
134 task_id: r.task_id,
135 status: r.status,
136 content: r.content,
137 started_at: r.started_at,
138 completed_at: r.completed_at,
139 duration_ms: r.duration_ms,
140 error_message: r.error_message,
141 })
142 })
143 .collect()
144 }
145
146 pub async fn complete_step(
147 &self,
148 step_id: &StepId,
149 started_at: DateTime<Utc>,
150 tool_result: Option<serde_json::Value>,
151 ) -> Result<()> {
152 let completed_at = Utc::now();
153 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
154 let step_id_str = step_id.as_str();
155 let status_str = StepStatus::Completed.to_string();
156
157 if let Some(result) = tool_result {
158 sqlx::query!(
159 r#"UPDATE task_execution_steps SET
160 status = $2,
161 completed_at = $3,
162 duration_ms = $4,
163 content = jsonb_set(content, '{tool_result}', $5)
164 WHERE step_id = $1"#,
165 step_id_str,
166 status_str,
167 completed_at,
168 duration_ms,
169 result
170 )
171 .execute(&*self.write_pool)
172 .await
173 .context(format!("Failed to complete execution step: {step_id}"))?;
174 } else {
175 sqlx::query!(
176 r#"UPDATE task_execution_steps SET
177 status = $2,
178 completed_at = $3,
179 duration_ms = $4
180 WHERE step_id = $1"#,
181 step_id_str,
182 status_str,
183 completed_at,
184 duration_ms
185 )
186 .execute(&*self.write_pool)
187 .await
188 .context(format!("Failed to complete execution step: {step_id}"))?;
189 }
190
191 Ok(())
192 }
193
194 pub async fn fail_step(
195 &self,
196 step_id: &StepId,
197 started_at: DateTime<Utc>,
198 error_message: &str,
199 ) -> Result<()> {
200 let completed_at = Utc::now();
201 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
202 let step_id_str = step_id.as_str();
203 let status_str = StepStatus::Failed.to_string();
204
205 sqlx::query!(
206 r#"UPDATE task_execution_steps SET
207 status = $2,
208 completed_at = $3,
209 duration_ms = $4,
210 error_message = $5
211 WHERE step_id = $1"#,
212 step_id_str,
213 status_str,
214 completed_at,
215 duration_ms,
216 error_message
217 )
218 .execute(&*self.write_pool)
219 .await
220 .context(format!("Failed to fail execution step: {step_id}"))?;
221
222 Ok(())
223 }
224
225 pub async fn fail_in_progress_steps_for_task(
226 &self,
227 task_id: &TaskId,
228 error_message: &str,
229 ) -> Result<u64> {
230 let completed_at = Utc::now();
231 let in_progress_str = StepStatus::InProgress.to_string();
232 let failed_str = StepStatus::Failed.to_string();
233 let task_id_str = task_id.as_str();
234
235 let result = sqlx::query!(
236 r#"UPDATE task_execution_steps SET
237 status = $3,
238 completed_at = $4,
239 error_message = $5
240 WHERE task_id = $1 AND status = $2"#,
241 task_id_str,
242 in_progress_str,
243 failed_str,
244 completed_at,
245 error_message
246 )
247 .execute(&*self.write_pool)
248 .await
249 .context(format!(
250 "Failed to fail in-progress steps for task: {}",
251 task_id
252 ))?;
253
254 Ok(result.rows_affected())
255 }
256
257 pub async fn complete_planning_step(
258 &self,
259 step_id: &StepId,
260 started_at: DateTime<Utc>,
261 reasoning: Option<String>,
262 planned_tools: Option<Vec<PlannedTool>>,
263 ) -> Result<ExecutionStep> {
264 let completed_at = Utc::now();
265 let duration_ms = (completed_at - started_at).num_milliseconds() as i32;
266 let step_id_str = step_id.as_str();
267 let status_str = StepStatus::Completed.to_string();
268
269 let content = StepContent::planning(reasoning, planned_tools);
270 let content_json =
271 serde_json::to_value(&content).context("Failed to serialize planning content")?;
272
273 let row = sqlx::query!(
274 r#"UPDATE task_execution_steps SET
275 status = $2,
276 completed_at = $3,
277 duration_ms = $4,
278 content = $5
279 WHERE step_id = $1
280 RETURNING step_id, task_id, status, content,
281 started_at as "started_at!", completed_at, duration_ms, error_message"#,
282 step_id_str,
283 status_str,
284 completed_at,
285 duration_ms,
286 content_json
287 )
288 .fetch_one(&*self.write_pool)
289 .await
290 .context(format!("Failed to complete planning step: {step_id}"))?;
291
292 parse_step(ParseStepParams {
293 step_id: row.step_id,
294 task_id: row.task_id,
295 status: row.status,
296 content: row.content,
297 started_at: row.started_at,
298 completed_at: row.completed_at,
299 duration_ms: row.duration_ms,
300 error_message: row.error_message,
301 })
302 }
303
304 pub async fn mcp_execution_id_exists(&self, mcp_execution_id: &str) -> Result<bool> {
305 let exists = sqlx::query_scalar!(
306 r#"SELECT EXISTS(SELECT 1 FROM mcp_tool_executions WHERE mcp_execution_id = $1) as "exists!""#,
307 mcp_execution_id
308 )
309 .fetch_one(&*self.pool)
310 .await
311 .context("Failed to check mcp_execution_id existence")?;
312
313 Ok(exists)
314 }
315}