Skip to main content

task_graph_mcp/db/
state_transitions.rs

1//! State and phase transition tracking for automatic time accumulation.
2
3use crate::config::StatesConfig;
4use crate::db::{Database, now_ms};
5use crate::types::TaskSequenceEvent;
6use anyhow::Result;
7use rusqlite::{Connection, params};
8
9/// Record a state transition and accumulate time if transitioning from a timed state.
10///
11/// Uses snapshot pattern: only records the new status value. Previous status
12/// can be determined by querying the previous row for the same task.
13///
14/// Returns the elapsed time added to time_actual_ms (0 if previous state was not timed).
15pub(crate) fn record_state_transition(
16    conn: &Connection,
17    task_id: &str,
18    status: &str,
19    worker_id: Option<&str>,
20    reason: Option<&str>,
21    states_config: &StatesConfig,
22) -> Result<i64> {
23    let now = now_ms();
24    let mut elapsed_added = 0i64;
25
26    // Find and close any open transition for this task (status-based)
27    let open_transition: Option<(i64, String, i64)> = conn
28        .query_row(
29            "SELECT id, status, timestamp FROM task_sequence
30             WHERE task_id = ?1 AND end_timestamp IS NULL AND status IS NOT NULL
31             ORDER BY id DESC LIMIT 1",
32            params![task_id],
33            |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
34        )
35        .ok();
36
37    if let Some((open_id, prev_status, start_timestamp)) = open_transition {
38        // Count how many other timed tasks this worker had open (for time normalization)
39        let concurrency: i32 = if let Some(wid) = worker_id {
40            conn.query_row(
41                "SELECT COUNT(*) FROM task_sequence
42                 WHERE worker_id = ?1 AND end_timestamp IS NULL AND status IS NOT NULL AND task_id != ?2",
43                params![wid, task_id],
44                |row| row.get(0),
45            )
46            .unwrap_or(0)
47                + 1 // +1 for the current task
48        } else {
49            1
50        };
51
52        // Close the previous transition with concurrency factor
53        conn.execute(
54            "UPDATE task_sequence SET end_timestamp = ?1, concurrency = ?2 WHERE id = ?3",
55            params![now, concurrency, open_id],
56        )?;
57
58        // If previous state was a timed state, accumulate elapsed time
59        if states_config.is_timed_state(&prev_status) {
60            elapsed_added = now - start_timestamp;
61
62            // Add elapsed time to task's time_actual_ms
63            conn.execute(
64                "UPDATE tasks SET time_actual_ms = COALESCE(time_actual_ms, 0) + ?1, updated_at = ?2
65                 WHERE id = ?3",
66                params![elapsed_added, now, task_id],
67            )?;
68        }
69    }
70
71    // Count concurrent timed tasks for the new transition record
72    let new_concurrency: i32 = if let Some(wid) = worker_id {
73        conn.query_row(
74            "SELECT COUNT(*) FROM task_sequence
75             WHERE worker_id = ?1 AND end_timestamp IS NULL AND status IS NOT NULL AND task_id != ?2",
76            params![wid, task_id],
77            |row| row.get(0),
78        )
79        .unwrap_or(0)
80            + 1 // +1 for the current task being inserted
81    } else {
82        1
83    };
84
85    // Insert the new transition (snapshot pattern - only new status)
86    conn.execute(
87        "INSERT INTO task_sequence (task_id, worker_id, status, reason, timestamp, concurrency)
88         VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
89        params![task_id, worker_id, status, reason, now, new_concurrency],
90    )?;
91
92    Ok(elapsed_added)
93}
94
95/// Record a phase transition for audit purposes.
96///
97/// Uses snapshot pattern: only records the new phase value.
98pub(crate) fn record_phase_transition(
99    conn: &Connection,
100    task_id: &str,
101    phase: &str,
102    worker_id: Option<&str>,
103    reason: Option<&str>,
104) -> Result<()> {
105    let now = now_ms();
106
107    conn.execute(
108        "INSERT INTO task_sequence (task_id, worker_id, phase, reason, timestamp)
109         VALUES (?1, ?2, ?3, ?4, ?5)",
110        params![task_id, worker_id, phase, reason, now],
111    )?;
112
113    Ok(())
114}
115
116/// Statistics for project-wide state transitions.
117#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
118pub struct ProjectStateStats {
119    pub total_transitions: i64,
120    pub total_time_ms: i64,
121    pub tasks_affected: i64,
122    pub transitions_by_status: std::collections::HashMap<String, i64>,
123    pub time_by_status_ms: std::collections::HashMap<String, i64>,
124    pub transitions_by_agent: std::collections::HashMap<String, i64>,
125    pub time_by_agent_ms: std::collections::HashMap<String, i64>,
126}
127
128impl Database {
129    /// Get the unified sequence history for a task (both status and phase changes).
130    pub fn get_task_sequence_history(&self, task_id: &str) -> Result<Vec<TaskSequenceEvent>> {
131        self.with_conn(|conn| {
132            let mut stmt = conn.prepare(
133                "SELECT id, task_id, worker_id, status, phase, reason, timestamp, end_timestamp, concurrency
134                 FROM task_sequence
135                 WHERE task_id = ?1
136                 ORDER BY id ASC",
137            )?;
138
139            let events = stmt
140                .query_map(params![task_id], |row| {
141                    Ok(TaskSequenceEvent {
142                        id: row.get(0)?,
143                        task_id: row.get(1)?,
144                        worker_id: row.get(2)?,
145                        status: row.get(3)?,
146                        phase: row.get(4)?,
147                        reason: row.get(5)?,
148                        timestamp: row.get(6)?,
149                        end_timestamp: row.get(7)?,
150                        concurrency: row.get(8)?,
151                    })
152                })?
153                .collect::<Result<Vec<_>, _>>()?;
154
155            Ok(events)
156        })
157    }
158
159    /// Get the state transition history for a task (status changes only, for backward compat).
160    pub fn get_task_state_history(&self, task_id: &str) -> Result<Vec<TaskSequenceEvent>> {
161        self.with_conn(|conn| {
162            let mut stmt = conn.prepare(
163                "SELECT id, task_id, worker_id, status, phase, reason, timestamp, end_timestamp, concurrency
164                 FROM task_sequence
165                 WHERE task_id = ?1 AND status IS NOT NULL
166                 ORDER BY id ASC",
167            )?;
168
169            let events = stmt
170                .query_map(params![task_id], |row| {
171                    Ok(TaskSequenceEvent {
172                        id: row.get(0)?,
173                        task_id: row.get(1)?,
174                        worker_id: row.get(2)?,
175                        status: row.get(3)?,
176                        phase: row.get(4)?,
177                        reason: row.get(5)?,
178                        timestamp: row.get(6)?,
179                        end_timestamp: row.get(7)?,
180                        concurrency: row.get(8)?,
181                    })
182                })?
183                .collect::<Result<Vec<_>, _>>()?;
184
185            Ok(events)
186        })
187    }
188
189    /// Get the current duration in the current state (for active time tracking).
190    /// Only returns a duration if the current state is a timed state.
191    pub fn get_current_state_duration(
192        &self,
193        task_id: &str,
194        states_config: &StatesConfig,
195    ) -> Result<Option<i64>> {
196        self.with_conn(|conn| {
197            let result: Option<(String, i64)> = conn
198                .query_row(
199                    "SELECT status, timestamp FROM task_sequence
200                     WHERE task_id = ?1 AND end_timestamp IS NULL AND status IS NOT NULL
201                     ORDER BY id DESC LIMIT 1",
202                    params![task_id],
203                    |row| Ok((row.get(0)?, row.get(1)?)),
204                )
205                .ok();
206
207            match result {
208                Some((status, start_timestamp)) => {
209                    if states_config.is_timed_state(&status) {
210                        return Ok(Some(now_ms() - start_timestamp));
211                    }
212                    Ok(None)
213                }
214                None => Ok(None),
215            }
216        })
217    }
218
219    /// Get project-wide state transition history with optional time range filter.
220    /// Returns all state transitions across all tasks within the specified time range.
221    pub fn get_project_state_history(
222        &self,
223        from_timestamp: Option<i64>,
224        to_timestamp: Option<i64>,
225        state_filter: Option<&[String]>,
226        limit: Option<i64>,
227    ) -> Result<Vec<TaskSequenceEvent>> {
228        self.with_conn(|conn| {
229            // Build query dynamically based on filters
230            let mut sql = String::from(
231                "SELECT id, task_id, worker_id, status, phase, reason, timestamp, end_timestamp, concurrency
232                 FROM task_sequence WHERE status IS NOT NULL",
233            );
234            let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
235
236            if let Some(from_ts) = from_timestamp {
237                sql.push_str(&format!(" AND timestamp >= ?{}", param_values.len() + 1));
238                param_values.push(Box::new(from_ts));
239            }
240
241            if let Some(to_ts) = to_timestamp {
242                sql.push_str(&format!(" AND timestamp <= ?{}", param_values.len() + 1));
243                param_values.push(Box::new(to_ts));
244            }
245
246            if let Some(states) = state_filter
247                && !states.is_empty()
248            {
249                let placeholders: Vec<String> = states
250                    .iter()
251                    .enumerate()
252                    .map(|(i, _)| format!("?{}", param_values.len() + i + 1))
253                    .collect();
254                sql.push_str(&format!(" AND status IN ({})", placeholders.join(", ")));
255                for state in states {
256                    param_values.push(Box::new(state.clone()));
257                }
258            }
259
260            sql.push_str(" ORDER BY timestamp DESC, id DESC");
261
262            if let Some(lim) = limit {
263                sql.push_str(&format!(" LIMIT ?{}", param_values.len() + 1));
264                param_values.push(Box::new(lim));
265            }
266
267            let mut stmt = conn.prepare(&sql)?;
268
269            // Convert Vec<Box<dyn ToSql>> to slice of references
270            let param_refs: Vec<&dyn rusqlite::ToSql> =
271                param_values.iter().map(|b| b.as_ref()).collect();
272
273            let events = stmt
274                .query_map(param_refs.as_slice(), |row| {
275                    Ok(TaskSequenceEvent {
276                        id: row.get(0)?,
277                        task_id: row.get(1)?,
278                        worker_id: row.get(2)?,
279                        status: row.get(3)?,
280                        phase: row.get(4)?,
281                        reason: row.get(5)?,
282                        timestamp: row.get(6)?,
283                        end_timestamp: row.get(7)?,
284                        concurrency: row.get(8)?,
285                    })
286                })?
287                .collect::<Result<Vec<_>, _>>()?;
288
289            Ok(events)
290        })
291    }
292
293    /// Get project-wide sequence history (both status and phase changes).
294    pub fn get_project_sequence_history(
295        &self,
296        from_timestamp: Option<i64>,
297        to_timestamp: Option<i64>,
298        limit: Option<i64>,
299    ) -> Result<Vec<TaskSequenceEvent>> {
300        self.with_conn(|conn| {
301            let mut sql = String::from(
302                "SELECT id, task_id, worker_id, status, phase, reason, timestamp, end_timestamp, concurrency
303                 FROM task_sequence WHERE 1=1",
304            );
305            let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
306
307            if let Some(from_ts) = from_timestamp {
308                sql.push_str(&format!(" AND timestamp >= ?{}", param_values.len() + 1));
309                param_values.push(Box::new(from_ts));
310            }
311
312            if let Some(to_ts) = to_timestamp {
313                sql.push_str(&format!(" AND timestamp <= ?{}", param_values.len() + 1));
314                param_values.push(Box::new(to_ts));
315            }
316
317            sql.push_str(" ORDER BY timestamp DESC, id DESC");
318
319            if let Some(lim) = limit {
320                sql.push_str(&format!(" LIMIT ?{}", param_values.len() + 1));
321                param_values.push(Box::new(lim));
322            }
323
324            let mut stmt = conn.prepare(&sql)?;
325            let param_refs: Vec<&dyn rusqlite::ToSql> =
326                param_values.iter().map(|b| b.as_ref()).collect();
327
328            let events = stmt
329                .query_map(param_refs.as_slice(), |row| {
330                    Ok(TaskSequenceEvent {
331                        id: row.get(0)?,
332                        task_id: row.get(1)?,
333                        worker_id: row.get(2)?,
334                        status: row.get(3)?,
335                        phase: row.get(4)?,
336                        reason: row.get(5)?,
337                        timestamp: row.get(6)?,
338                        end_timestamp: row.get(7)?,
339                        concurrency: row.get(8)?,
340                    })
341                })?
342                .collect::<Result<Vec<_>, _>>()?;
343
344            Ok(events)
345        })
346    }
347
348    /// Get aggregate project statistics for state transitions within a time range.
349    /// Returns counts of transitions per state and per agent.
350    pub fn get_project_state_stats(
351        &self,
352        from_timestamp: Option<i64>,
353        to_timestamp: Option<i64>,
354    ) -> Result<ProjectStateStats> {
355        self.with_conn(|conn| {
356            let mut transitions_by_status = std::collections::HashMap::new();
357            let mut time_by_status = std::collections::HashMap::new();
358            let mut transitions_by_agent = std::collections::HashMap::new();
359            let mut time_by_agent = std::collections::HashMap::new();
360            let mut tasks_touched = std::collections::HashSet::new();
361            let mut total_transitions = 0i64;
362            let mut total_time_ms = 0i64;
363
364            // Build base query - only count status transitions for stats
365            let mut sql = String::from(
366                "SELECT status, worker_id, task_id, timestamp, end_timestamp
367                 FROM task_sequence WHERE status IS NOT NULL",
368            );
369            let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
370
371            if let Some(from_ts) = from_timestamp {
372                sql.push_str(&format!(" AND timestamp >= ?{}", param_values.len() + 1));
373                param_values.push(Box::new(from_ts));
374            }
375
376            if let Some(to_ts) = to_timestamp {
377                sql.push_str(&format!(" AND timestamp <= ?{}", param_values.len() + 1));
378                param_values.push(Box::new(to_ts));
379            }
380
381            let mut stmt = conn.prepare(&sql)?;
382            let param_refs: Vec<&dyn rusqlite::ToSql> =
383                param_values.iter().map(|b| b.as_ref()).collect();
384
385            let mut rows = stmt.query(param_refs.as_slice())?;
386
387            while let Some(row) = rows.next()? {
388                let status: String = row.get(0)?;
389                let worker_id: Option<String> = row.get(1)?;
390                let task_id: String = row.get(2)?;
391                let timestamp: i64 = row.get(3)?;
392                let end_timestamp: Option<i64> = row.get(4)?;
393
394                total_transitions += 1;
395                tasks_touched.insert(task_id);
396
397                *transitions_by_status.entry(status.clone()).or_insert(0i64) += 1;
398
399                if let Some(ref agent) = worker_id {
400                    *transitions_by_agent.entry(agent.clone()).or_insert(0i64) += 1;
401                }
402
403                // Calculate duration if we have an end timestamp
404                if let Some(end_ts) = end_timestamp {
405                    let duration = end_ts - timestamp;
406                    total_time_ms += duration;
407                    *time_by_status.entry(status).or_insert(0i64) += duration;
408
409                    if let Some(agent) = worker_id {
410                        *time_by_agent.entry(agent).or_insert(0i64) += duration;
411                    }
412                }
413            }
414
415            Ok(ProjectStateStats {
416                total_transitions,
417                total_time_ms,
418                tasks_affected: tasks_touched.len() as i64,
419                transitions_by_status,
420                time_by_status_ms: time_by_status,
421                transitions_by_agent,
422                time_by_agent_ms: time_by_agent,
423            })
424        })
425    }
426}