Skip to main content

task_graph_mcp/db/
state_transitions.rs

1//! State transition tracking for automatic time accumulation.
2
3use crate::config::StatesConfig;
4use crate::db::{Database, now_ms};
5use crate::types::TaskStateEvent;
6use anyhow::Result;
7use rusqlite::{Connection, params};
8
9/// Record a state transition and accumulate time if transitioning from a timed state.
10///
11/// Returns the elapsed time added to time_actual_ms (0 if previous state was not timed).
12pub(crate) fn record_state_transition(
13    conn: &Connection,
14    task_id: &str,
15    to_status: &str,
16    worker_id: Option<&str>,
17    reason: Option<&str>,
18    states_config: &StatesConfig,
19) -> Result<i64> {
20    let now = now_ms();
21    let mut elapsed_added = 0i64;
22
23    // Find and close any open transition for this task
24    let open_transition: Option<(i64, String, i64)> = conn
25        .query_row(
26            "SELECT id, event, timestamp FROM task_state_sequence
27             WHERE task_id = ?1 AND end_timestamp IS NULL
28             ORDER BY id DESC LIMIT 1",
29            params![task_id],
30            |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
31        )
32        .ok();
33
34    if let Some((open_id, prev_event_str, start_timestamp)) = open_transition {
35        // Close the previous transition
36        conn.execute(
37            "UPDATE task_state_sequence SET end_timestamp = ?1 WHERE id = ?2",
38            params![now, open_id],
39        )?;
40
41        // If previous state was a timed state, accumulate elapsed time
42        if states_config.is_timed_state(&prev_event_str) {
43            elapsed_added = now - start_timestamp;
44
45            // Add elapsed time to task's time_actual_ms
46            conn.execute(
47                "UPDATE tasks SET time_actual_ms = COALESCE(time_actual_ms, 0) + ?1, updated_at = ?2
48                 WHERE id = ?3",
49                params![elapsed_added, now, task_id],
50            )?;
51        }
52    }
53
54    // Insert the new transition
55    conn.execute(
56        "INSERT INTO task_state_sequence (task_id, worker_id, event, reason, timestamp)
57         VALUES (?1, ?2, ?3, ?4, ?5)",
58        params![task_id, worker_id, to_status, reason, now],
59    )?;
60
61    Ok(elapsed_added)
62}
63
64/// Statistics for project-wide state transitions.
65#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
66pub struct ProjectStateStats {
67    pub total_transitions: i64,
68    pub total_time_ms: i64,
69    pub tasks_affected: i64,
70    pub transitions_by_status: std::collections::HashMap<String, i64>,
71    pub time_by_status_ms: std::collections::HashMap<String, i64>,
72    pub transitions_by_agent: std::collections::HashMap<String, i64>,
73    pub time_by_agent_ms: std::collections::HashMap<String, i64>,
74}
75
76impl Database {
77    /// Get the state transition history for a task.
78    pub fn get_task_state_history(&self, task_id: &str) -> Result<Vec<TaskStateEvent>> {
79        self.with_conn(|conn| {
80            let mut stmt = conn.prepare(
81                "SELECT id, task_id, worker_id, event, reason, timestamp, end_timestamp
82                 FROM task_state_sequence
83                 WHERE task_id = ?1
84                 ORDER BY id ASC",
85            )?;
86
87            let events = stmt
88                .query_map(params![task_id], |row| {
89                    Ok(TaskStateEvent {
90                        id: row.get(0)?,
91                        task_id: row.get(1)?,
92                        worker_id: row.get(2)?,
93                        event: row.get(3)?,
94                        reason: row.get(4)?,
95                        timestamp: row.get(5)?,
96                        end_timestamp: row.get(6)?,
97                    })
98                })?
99                .collect::<Result<Vec<_>, _>>()?;
100
101            Ok(events)
102        })
103    }
104
105    /// Get the current duration in the current state (for active time tracking).
106    /// Only returns a duration if the current state is a timed state.
107    pub fn get_current_state_duration(
108        &self,
109        task_id: &str,
110        states_config: &StatesConfig,
111    ) -> Result<Option<i64>> {
112        self.with_conn(|conn| {
113            let result: Option<(String, i64)> = conn
114                .query_row(
115                    "SELECT event, timestamp FROM task_state_sequence
116                     WHERE task_id = ?1 AND end_timestamp IS NULL
117                     ORDER BY id DESC LIMIT 1",
118                    params![task_id],
119                    |row| Ok((row.get(0)?, row.get(1)?)),
120                )
121                .ok();
122
123            match result {
124                Some((event_str, start_timestamp)) => {
125                    if states_config.is_timed_state(&event_str) {
126                        return Ok(Some(now_ms() - start_timestamp));
127                    }
128                    Ok(None)
129                }
130                None => Ok(None),
131            }
132        })
133    }
134
135    /// Get project-wide state transition history with optional time range filter.
136    /// Returns all state transitions across all tasks within the specified time range.
137    pub fn get_project_state_history(
138        &self,
139        from_timestamp: Option<i64>,
140        to_timestamp: Option<i64>,
141        state_filter: Option<&[String]>,
142        limit: Option<i64>,
143    ) -> Result<Vec<TaskStateEvent>> {
144        self.with_conn(|conn| {
145            // Build query dynamically based on filters
146            let mut sql = String::from(
147                "SELECT id, task_id, worker_id, event, reason, timestamp, end_timestamp
148                 FROM task_state_sequence WHERE 1=1",
149            );
150            let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
151
152            if let Some(from_ts) = from_timestamp {
153                sql.push_str(&format!(" AND timestamp >= ?{}", param_values.len() + 1));
154                param_values.push(Box::new(from_ts));
155            }
156
157            if let Some(to_ts) = to_timestamp {
158                sql.push_str(&format!(" AND timestamp <= ?{}", param_values.len() + 1));
159                param_values.push(Box::new(to_ts));
160            }
161
162            if let Some(states) = state_filter
163                && !states.is_empty() {
164                    let placeholders: Vec<String> = states
165                        .iter()
166                        .enumerate()
167                        .map(|(i, _)| format!("?{}", param_values.len() + i + 1))
168                        .collect();
169                    sql.push_str(&format!(" AND event IN ({})", placeholders.join(", ")));
170                    for state in states {
171                        param_values.push(Box::new(state.clone()));
172                    }
173                }
174
175            sql.push_str(" ORDER BY timestamp DESC, id DESC");
176
177            if let Some(lim) = limit {
178                sql.push_str(&format!(" LIMIT ?{}", param_values.len() + 1));
179                param_values.push(Box::new(lim));
180            }
181
182            let mut stmt = conn.prepare(&sql)?;
183
184            // Convert Vec<Box<dyn ToSql>> to slice of references
185            let param_refs: Vec<&dyn rusqlite::ToSql> =
186                param_values.iter().map(|b| b.as_ref()).collect();
187
188            let events = stmt
189                .query_map(param_refs.as_slice(), |row| {
190                    Ok(TaskStateEvent {
191                        id: row.get(0)?,
192                        task_id: row.get(1)?,
193                        worker_id: row.get(2)?,
194                        event: row.get(3)?,
195                        reason: row.get(4)?,
196                        timestamp: row.get(5)?,
197                        end_timestamp: row.get(6)?,
198                    })
199                })?
200                .collect::<Result<Vec<_>, _>>()?;
201
202            Ok(events)
203        })
204    }
205
206    /// Get aggregate project statistics for state transitions within a time range.
207    /// Returns counts of transitions per state and per agent.
208    pub fn get_project_state_stats(
209        &self,
210        from_timestamp: Option<i64>,
211        to_timestamp: Option<i64>,
212    ) -> Result<ProjectStateStats> {
213        self.with_conn(|conn| {
214            let mut transitions_by_status = std::collections::HashMap::new();
215            let mut time_by_status = std::collections::HashMap::new();
216            let mut transitions_by_agent = std::collections::HashMap::new();
217            let mut time_by_agent = std::collections::HashMap::new();
218            let mut tasks_touched = std::collections::HashSet::new();
219            let mut total_transitions = 0i64;
220            let mut total_time_ms = 0i64;
221
222            // Build base query
223            let mut sql = String::from(
224                "SELECT event, worker_id, task_id, timestamp, end_timestamp FROM task_state_sequence WHERE 1=1"
225            );
226            let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
227
228            if let Some(from_ts) = from_timestamp {
229                sql.push_str(&format!(" AND timestamp >= ?{}", param_values.len() + 1));
230                param_values.push(Box::new(from_ts));
231            }
232
233            if let Some(to_ts) = to_timestamp {
234                sql.push_str(&format!(" AND timestamp <= ?{}", param_values.len() + 1));
235                param_values.push(Box::new(to_ts));
236            }
237
238            let mut stmt = conn.prepare(&sql)?;
239            let param_refs: Vec<&dyn rusqlite::ToSql> = param_values.iter().map(|b| b.as_ref()).collect();
240
241            let mut rows = stmt.query(param_refs.as_slice())?;
242
243            while let Some(row) = rows.next()? {
244                let event: String = row.get(0)?;
245                let worker_id: Option<String> = row.get(1)?;
246                let task_id: String = row.get(2)?;
247                let timestamp: i64 = row.get(3)?;
248                let end_timestamp: Option<i64> = row.get(4)?;
249
250                total_transitions += 1;
251                tasks_touched.insert(task_id);
252
253                *transitions_by_status.entry(event.clone()).or_insert(0i64) += 1;
254
255                if let Some(ref agent) = worker_id {
256                    *transitions_by_agent.entry(agent.clone()).or_insert(0i64) += 1;
257                }
258
259                // Calculate duration if we have an end timestamp
260                if let Some(end_ts) = end_timestamp {
261                    let duration = end_ts - timestamp;
262                    total_time_ms += duration;
263                    *time_by_status.entry(event).or_insert(0i64) += duration;
264
265                    if let Some(agent) = worker_id {
266                        *time_by_agent.entry(agent).or_insert(0i64) += duration;
267                    }
268                }
269            }
270
271            Ok(ProjectStateStats {
272                total_transitions,
273                total_time_ms,
274                tasks_affected: tasks_touched.len() as i64,
275                transitions_by_status,
276                time_by_status_ms: time_by_status,
277                transitions_by_agent,
278                time_by_agent_ms: time_by_agent,
279            })
280        })
281    }
282}