Skip to main content

task_graph_mcp/db/
stats.rs

1//! Aggregation queries for statistics.
2
3use super::Database;
4use crate::config::StatesConfig;
5use crate::types::Stats;
6use anyhow::Result;
7use rusqlite::params;
8use std::collections::HashMap;
9
10impl Database {
11    /// Get aggregate statistics with dynamic state counting.
12    pub fn get_stats(
13        &self,
14        agent_id: Option<&str>,
15        task_id: Option<&str>,
16        states_config: &StatesConfig,
17    ) -> Result<Stats> {
18        self.with_conn(|conn| {
19            // First, get the base aggregate stats
20            let (base_sql, params_vec): (String, Vec<String>) = match (agent_id, task_id) {
21                (Some(aid), None) => (
22                    "SELECT
23                        COUNT(*) as total_tasks,
24                        COALESCE(SUM(points), 0) as total_points,
25                        0 as completed_points,
26                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
27                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
28                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
29                        COALESCE(SUM(metric_0), 0) as total_metric_0,
30                        COALESCE(SUM(metric_1), 0) as total_metric_1,
31                        COALESCE(SUM(metric_2), 0) as total_metric_2,
32                        COALESCE(SUM(metric_3), 0) as total_metric_3,
33                        COALESCE(SUM(metric_4), 0) as total_metric_4,
34                        COALESCE(SUM(metric_5), 0) as total_metric_5,
35                        COALESCE(SUM(metric_6), 0) as total_metric_6,
36                        COALESCE(SUM(metric_7), 0) as total_metric_7
37                    FROM tasks WHERE worker_id = ?1"
38                        .to_string(),
39                    vec![aid.to_string()],
40                ),
41                (None, Some(tid)) => (
42                    "WITH RECURSIVE descendants AS (
43                        SELECT id FROM tasks WHERE id = ?1
44                        UNION ALL
45                        SELECT dep.to_task_id FROM dependencies dep
46                        INNER JOIN descendants d ON dep.from_task_id = d.id
47                        WHERE dep.dep_type = 'contains'
48                    )
49                    SELECT
50                        COUNT(*) as total_tasks,
51                        COALESCE(SUM(points), 0) as total_points,
52                        0 as completed_points,
53                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
54                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
55                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
56                        COALESCE(SUM(metric_0), 0) as total_metric_0,
57                        COALESCE(SUM(metric_1), 0) as total_metric_1,
58                        COALESCE(SUM(metric_2), 0) as total_metric_2,
59                        COALESCE(SUM(metric_3), 0) as total_metric_3,
60                        COALESCE(SUM(metric_4), 0) as total_metric_4,
61                        COALESCE(SUM(metric_5), 0) as total_metric_5,
62                        COALESCE(SUM(metric_6), 0) as total_metric_6,
63                        COALESCE(SUM(metric_7), 0) as total_metric_7
64                    FROM tasks WHERE id IN (SELECT id FROM descendants)"
65                        .to_string(),
66                    vec![tid.to_string()],
67                ),
68                (Some(aid), Some(tid)) => (
69                    "WITH RECURSIVE descendants AS (
70                        SELECT id FROM tasks WHERE id = ?2
71                        UNION ALL
72                        SELECT dep.to_task_id FROM dependencies dep
73                        INNER JOIN descendants d ON dep.from_task_id = d.id
74                        WHERE dep.dep_type = 'contains'
75                    )
76                    SELECT
77                        COUNT(*) as total_tasks,
78                        COALESCE(SUM(points), 0) as total_points,
79                        0 as completed_points,
80                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
81                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
82                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
83                        COALESCE(SUM(metric_0), 0) as total_metric_0,
84                        COALESCE(SUM(metric_1), 0) as total_metric_1,
85                        COALESCE(SUM(metric_2), 0) as total_metric_2,
86                        COALESCE(SUM(metric_3), 0) as total_metric_3,
87                        COALESCE(SUM(metric_4), 0) as total_metric_4,
88                        COALESCE(SUM(metric_5), 0) as total_metric_5,
89                        COALESCE(SUM(metric_6), 0) as total_metric_6,
90                        COALESCE(SUM(metric_7), 0) as total_metric_7
91                    FROM tasks WHERE id IN (SELECT id FROM descendants) AND worker_id = ?1"
92                        .to_string(),
93                    vec![aid.to_string(), tid.to_string()],
94                ),
95                (None, None) => (
96                    "SELECT
97                        COUNT(*) as total_tasks,
98                        COALESCE(SUM(points), 0) as total_points,
99                        0 as completed_points,
100                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
101                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
102                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
103                        COALESCE(SUM(metric_0), 0) as total_metric_0,
104                        COALESCE(SUM(metric_1), 0) as total_metric_1,
105                        COALESCE(SUM(metric_2), 0) as total_metric_2,
106                        COALESCE(SUM(metric_3), 0) as total_metric_3,
107                        COALESCE(SUM(metric_4), 0) as total_metric_4,
108                        COALESCE(SUM(metric_5), 0) as total_metric_5,
109                        COALESCE(SUM(metric_6), 0) as total_metric_6,
110                        COALESCE(SUM(metric_7), 0) as total_metric_7
111                    FROM tasks"
112                        .to_string(),
113                    vec![],
114                ),
115            };
116
117            // Query base stats - returns 14 columns now
118            let (
119                total_tasks,
120                total_points,
121                _completed_points_placeholder,
122                total_time_estimate_ms,
123                total_time_actual_ms,
124                total_cost_usd,
125                m0, m1, m2, m3, m4, m5, m6, m7,
126            ): (i64, i64, i64, i64, i64, f64, i64, i64, i64, i64, i64, i64, i64, i64) = if params_vec.is_empty()
127            {
128                conn.query_row(&base_sql, [], |row| {
129                    Ok((
130                        row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?,
131                        row.get(4)?, row.get(5)?, row.get(6)?, row.get(7)?,
132                        row.get(8)?, row.get(9)?, row.get(10)?, row.get(11)?,
133                        row.get(12)?, row.get(13)?,
134                    ))
135                })?
136            } else if params_vec.len() == 1 {
137                conn.query_row(&base_sql, params![params_vec[0]], |row| {
138                    Ok((
139                        row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?,
140                        row.get(4)?, row.get(5)?, row.get(6)?, row.get(7)?,
141                        row.get(8)?, row.get(9)?, row.get(10)?, row.get(11)?,
142                        row.get(12)?, row.get(13)?,
143                    ))
144                })?
145            } else {
146                conn.query_row(&base_sql, params![params_vec[0], params_vec[1]], |row| {
147                    Ok((
148                        row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?,
149                        row.get(4)?, row.get(5)?, row.get(6)?, row.get(7)?,
150                        row.get(8)?, row.get(9)?, row.get(10)?, row.get(11)?,
151                        row.get(12)?, row.get(13)?,
152                    ))
153                })?
154            };
155
156            // Now query task counts by state
157            let count_sql = match (agent_id, task_id) {
158                (Some(_aid), None) => {
159                    "SELECT status, COUNT(*) as cnt FROM tasks WHERE worker_id = ?1 GROUP BY status"
160                }
161                (None, Some(_tid)) => {
162                    "WITH RECURSIVE descendants AS (
163                        SELECT id FROM tasks WHERE id = ?1
164                        UNION ALL
165                        SELECT dep.to_task_id FROM dependencies dep
166                        INNER JOIN descendants d ON dep.from_task_id = d.id
167                        WHERE dep.dep_type = 'contains'
168                    )
169                    SELECT status, COUNT(*) as cnt FROM tasks
170                    WHERE id IN (SELECT id FROM descendants) GROUP BY status"
171                }
172                (Some(_aid), Some(_tid)) => {
173                    "WITH RECURSIVE descendants AS (
174                        SELECT id FROM tasks WHERE id = ?2
175                        UNION ALL
176                        SELECT dep.to_task_id FROM dependencies dep
177                        INNER JOIN descendants d ON dep.from_task_id = d.id
178                        WHERE dep.dep_type = 'contains'
179                    )
180                    SELECT status, COUNT(*) as cnt FROM tasks
181                    WHERE id IN (SELECT id FROM descendants) AND worker_id = ?1 GROUP BY status"
182                }
183                (None, None) => "SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status",
184            };
185
186            let mut tasks_by_status: HashMap<String, i64> = HashMap::new();
187
188            // Initialize all defined states to 0
189            for state in states_config.state_names() {
190                tasks_by_status.insert(state.to_string(), 0);
191            }
192
193            // Query and fill in actual counts
194            let mut stmt = conn.prepare(count_sql)?;
195            let status_counts: Vec<(String, i64)> = if params_vec.is_empty() {
196                stmt.query_map([], |row| {
197                    let status: String = row.get(0)?;
198                    let count: i64 = row.get(1)?;
199                    Ok((status, count))
200                })?.filter_map(|r| r.ok()).collect()
201            } else if params_vec.len() == 1 {
202                stmt.query_map(params![params_vec[0].clone()], |row| {
203                    let status: String = row.get(0)?;
204                    let count: i64 = row.get(1)?;
205                    Ok((status, count))
206                })?.filter_map(|r| r.ok()).collect()
207            } else {
208                stmt.query_map(params![params_vec[0].clone(), params_vec[1].clone()], |row| {
209                    let status: String = row.get(0)?;
210                    let count: i64 = row.get(1)?;
211                    Ok((status, count))
212                })?.filter_map(|r| r.ok()).collect()
213            };
214
215            for (status, count) in status_counts {
216                tasks_by_status.insert(status, count);
217            }
218
219            // Calculate completed_points (points for tasks in non-blocking states)
220            let completed_points_sql = match (agent_id, task_id) {
221                (Some(_aid), None) => {
222                    "SELECT COALESCE(SUM(points), 0) FROM tasks 
223                     WHERE worker_id = ?1 AND status NOT IN (SELECT value FROM json_each(?2))"
224                }
225                (None, Some(_tid)) => {
226                    "WITH RECURSIVE descendants AS (
227                        SELECT id FROM tasks WHERE id = ?1
228                        UNION ALL
229                        SELECT dep.to_task_id FROM dependencies dep
230                        INNER JOIN descendants d ON dep.from_task_id = d.id
231                        WHERE dep.dep_type = 'contains'
232                    )
233                    SELECT COALESCE(SUM(points), 0) FROM tasks
234                    WHERE id IN (SELECT id FROM descendants)
235                    AND status NOT IN (SELECT value FROM json_each(?2))"
236                }
237                (Some(_aid), Some(_tid)) => {
238                    "WITH RECURSIVE descendants AS (
239                        SELECT id FROM tasks WHERE id = ?2
240                        UNION ALL
241                        SELECT dep.to_task_id FROM dependencies dep
242                        INNER JOIN descendants d ON dep.from_task_id = d.id
243                        WHERE dep.dep_type = 'contains'
244                    )
245                    SELECT COALESCE(SUM(points), 0) FROM tasks
246                    WHERE id IN (SELECT id FROM descendants) AND worker_id = ?1
247                    AND status NOT IN (SELECT value FROM json_each(?3))"
248                }
249                (None, None) => {
250                    "SELECT COALESCE(SUM(points), 0) FROM tasks 
251                     WHERE status NOT IN (SELECT value FROM json_each(?1))"
252                }
253            };
254
255            let blocking_states_json = serde_json::to_string(&states_config.blocking_states)?;
256
257            let completed_points: i64 = match (agent_id, task_id) {
258                (Some(aid), None) => conn.query_row(
259                    completed_points_sql,
260                    params![aid, blocking_states_json],
261                    |row| row.get(0),
262                )?,
263                (None, Some(tid)) => conn.query_row(
264                    completed_points_sql,
265                    params![tid, blocking_states_json],
266                    |row| row.get(0),
267                )?,
268                (Some(aid), Some(tid)) => conn.query_row(
269                    completed_points_sql,
270                    params![aid, tid, blocking_states_json],
271                    |row| row.get(0),
272                )?,
273                (None, None) => {
274                    conn.query_row(completed_points_sql, params![blocking_states_json], |row| {
275                        row.get(0)
276                    })?
277                }
278            };
279
280            Ok(Stats {
281                total_tasks,
282                tasks_by_status,
283                total_points,
284                completed_points,
285                total_time_estimate_ms,
286                total_time_actual_ms,
287                total_cost_usd,
288                total_metrics: [m0, m1, m2, m3, m4, m5, m6, m7],
289            })
290        })
291    }
292}