Skip to main content

task_graph_mcp/db/
stats.rs

1//! Aggregation queries for statistics.
2
3use super::Database;
4use crate::config::StatesConfig;
5use crate::types::Stats;
6use anyhow::Result;
7use rusqlite::params;
8use std::collections::HashMap;
9
10impl Database {
11    /// Get aggregate statistics with dynamic state counting.
12    pub fn get_stats(
13        &self,
14        agent_id: Option<&str>,
15        task_id: Option<&str>,
16        states_config: &StatesConfig,
17    ) -> Result<Stats> {
18        self.with_conn(|conn| {
19            // First, get the base aggregate stats
20            let (base_sql, params_vec): (String, Vec<String>) = match (agent_id, task_id) {
21                (Some(aid), None) => (
22                    "SELECT
23                        COUNT(*) as total_tasks,
24                        COALESCE(SUM(points), 0) as total_points,
25                        0 as completed_points,
26                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
27                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
28                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
29                        COALESCE(SUM(metric_0), 0) as total_metric_0,
30                        COALESCE(SUM(metric_1), 0) as total_metric_1,
31                        COALESCE(SUM(metric_2), 0) as total_metric_2,
32                        COALESCE(SUM(metric_3), 0) as total_metric_3,
33                        COALESCE(SUM(metric_4), 0) as total_metric_4,
34                        COALESCE(SUM(metric_5), 0) as total_metric_5,
35                        COALESCE(SUM(metric_6), 0) as total_metric_6,
36                        COALESCE(SUM(metric_7), 0) as total_metric_7
37                    FROM tasks WHERE worker_id = ?1"
38                        .to_string(),
39                    vec![aid.to_string()],
40                ),
41                (None, Some(tid)) => (
42                    "WITH RECURSIVE descendants AS (
43                        SELECT id FROM tasks WHERE id = ?1
44                        UNION ALL
45                        SELECT dep.to_task_id FROM dependencies dep
46                        INNER JOIN descendants d ON dep.from_task_id = d.id
47                        WHERE dep.dep_type = 'contains'
48                    )
49                    SELECT
50                        COUNT(*) as total_tasks,
51                        COALESCE(SUM(points), 0) as total_points,
52                        0 as completed_points,
53                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
54                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
55                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
56                        COALESCE(SUM(metric_0), 0) as total_metric_0,
57                        COALESCE(SUM(metric_1), 0) as total_metric_1,
58                        COALESCE(SUM(metric_2), 0) as total_metric_2,
59                        COALESCE(SUM(metric_3), 0) as total_metric_3,
60                        COALESCE(SUM(metric_4), 0) as total_metric_4,
61                        COALESCE(SUM(metric_5), 0) as total_metric_5,
62                        COALESCE(SUM(metric_6), 0) as total_metric_6,
63                        COALESCE(SUM(metric_7), 0) as total_metric_7
64                    FROM tasks WHERE id IN (SELECT id FROM descendants)"
65                        .to_string(),
66                    vec![tid.to_string()],
67                ),
68                (Some(aid), Some(tid)) => (
69                    "WITH RECURSIVE descendants AS (
70                        SELECT id FROM tasks WHERE id = ?2
71                        UNION ALL
72                        SELECT dep.to_task_id FROM dependencies dep
73                        INNER JOIN descendants d ON dep.from_task_id = d.id
74                        WHERE dep.dep_type = 'contains'
75                    )
76                    SELECT
77                        COUNT(*) as total_tasks,
78                        COALESCE(SUM(points), 0) as total_points,
79                        0 as completed_points,
80                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
81                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
82                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
83                        COALESCE(SUM(metric_0), 0) as total_metric_0,
84                        COALESCE(SUM(metric_1), 0) as total_metric_1,
85                        COALESCE(SUM(metric_2), 0) as total_metric_2,
86                        COALESCE(SUM(metric_3), 0) as total_metric_3,
87                        COALESCE(SUM(metric_4), 0) as total_metric_4,
88                        COALESCE(SUM(metric_5), 0) as total_metric_5,
89                        COALESCE(SUM(metric_6), 0) as total_metric_6,
90                        COALESCE(SUM(metric_7), 0) as total_metric_7
91                    FROM tasks WHERE id IN (SELECT id FROM descendants) AND worker_id = ?1"
92                        .to_string(),
93                    vec![aid.to_string(), tid.to_string()],
94                ),
95                (None, None) => (
96                    "SELECT
97                        COUNT(*) as total_tasks,
98                        COALESCE(SUM(points), 0) as total_points,
99                        0 as completed_points,
100                        COALESCE(SUM(time_estimate_ms), 0) as total_time_estimate_ms,
101                        COALESCE(SUM(time_actual_ms), 0) as total_time_actual_ms,
102                        COALESCE(SUM(cost_usd), 0.0) as total_cost_usd,
103                        COALESCE(SUM(metric_0), 0) as total_metric_0,
104                        COALESCE(SUM(metric_1), 0) as total_metric_1,
105                        COALESCE(SUM(metric_2), 0) as total_metric_2,
106                        COALESCE(SUM(metric_3), 0) as total_metric_3,
107                        COALESCE(SUM(metric_4), 0) as total_metric_4,
108                        COALESCE(SUM(metric_5), 0) as total_metric_5,
109                        COALESCE(SUM(metric_6), 0) as total_metric_6,
110                        COALESCE(SUM(metric_7), 0) as total_metric_7
111                    FROM tasks"
112                        .to_string(),
113                    vec![],
114                ),
115            };
116
117            // Query base stats - returns 14 columns now
118            #[allow(clippy::type_complexity)]
119            let (
120                total_tasks,
121                total_points,
122                _completed_points_placeholder,
123                total_time_estimate_ms,
124                total_time_actual_ms,
125                total_cost_usd,
126                m0,
127                m1,
128                m2,
129                m3,
130                m4,
131                m5,
132                m6,
133                m7,
134            ): (
135                i64,
136                i64,
137                i64,
138                i64,
139                i64,
140                f64,
141                i64,
142                i64,
143                i64,
144                i64,
145                i64,
146                i64,
147                i64,
148                i64,
149            ) = if params_vec.is_empty() {
150                conn.query_row(&base_sql, [], |row| {
151                    Ok((
152                        row.get(0)?,
153                        row.get(1)?,
154                        row.get(2)?,
155                        row.get(3)?,
156                        row.get(4)?,
157                        row.get(5)?,
158                        row.get(6)?,
159                        row.get(7)?,
160                        row.get(8)?,
161                        row.get(9)?,
162                        row.get(10)?,
163                        row.get(11)?,
164                        row.get(12)?,
165                        row.get(13)?,
166                    ))
167                })?
168            } else if params_vec.len() == 1 {
169                conn.query_row(&base_sql, params![params_vec[0]], |row| {
170                    Ok((
171                        row.get(0)?,
172                        row.get(1)?,
173                        row.get(2)?,
174                        row.get(3)?,
175                        row.get(4)?,
176                        row.get(5)?,
177                        row.get(6)?,
178                        row.get(7)?,
179                        row.get(8)?,
180                        row.get(9)?,
181                        row.get(10)?,
182                        row.get(11)?,
183                        row.get(12)?,
184                        row.get(13)?,
185                    ))
186                })?
187            } else {
188                conn.query_row(&base_sql, params![params_vec[0], params_vec[1]], |row| {
189                    Ok((
190                        row.get(0)?,
191                        row.get(1)?,
192                        row.get(2)?,
193                        row.get(3)?,
194                        row.get(4)?,
195                        row.get(5)?,
196                        row.get(6)?,
197                        row.get(7)?,
198                        row.get(8)?,
199                        row.get(9)?,
200                        row.get(10)?,
201                        row.get(11)?,
202                        row.get(12)?,
203                        row.get(13)?,
204                    ))
205                })?
206            };
207
208            // Now query task counts by state
209            let count_sql = match (agent_id, task_id) {
210                (Some(_aid), None) => {
211                    "SELECT status, COUNT(*) as cnt FROM tasks WHERE worker_id = ?1 GROUP BY status"
212                }
213                (None, Some(_tid)) => {
214                    "WITH RECURSIVE descendants AS (
215                        SELECT id FROM tasks WHERE id = ?1
216                        UNION ALL
217                        SELECT dep.to_task_id FROM dependencies dep
218                        INNER JOIN descendants d ON dep.from_task_id = d.id
219                        WHERE dep.dep_type = 'contains'
220                    )
221                    SELECT status, COUNT(*) as cnt FROM tasks
222                    WHERE id IN (SELECT id FROM descendants) GROUP BY status"
223                }
224                (Some(_aid), Some(_tid)) => {
225                    "WITH RECURSIVE descendants AS (
226                        SELECT id FROM tasks WHERE id = ?2
227                        UNION ALL
228                        SELECT dep.to_task_id FROM dependencies dep
229                        INNER JOIN descendants d ON dep.from_task_id = d.id
230                        WHERE dep.dep_type = 'contains'
231                    )
232                    SELECT status, COUNT(*) as cnt FROM tasks
233                    WHERE id IN (SELECT id FROM descendants) AND worker_id = ?1 GROUP BY status"
234                }
235                (None, None) => "SELECT status, COUNT(*) as cnt FROM tasks GROUP BY status",
236            };
237
238            let mut tasks_by_status: HashMap<String, i64> = HashMap::new();
239
240            // Initialize all defined states to 0
241            for state in states_config.state_names() {
242                tasks_by_status.insert(state.to_string(), 0);
243            }
244
245            // Query and fill in actual counts
246            let mut stmt = conn.prepare(count_sql)?;
247            let status_counts: Vec<(String, i64)> = if params_vec.is_empty() {
248                stmt.query_map([], |row| {
249                    let status: String = row.get(0)?;
250                    let count: i64 = row.get(1)?;
251                    Ok((status, count))
252                })?
253                .filter_map(|r| r.ok())
254                .collect()
255            } else if params_vec.len() == 1 {
256                stmt.query_map(params![params_vec[0].clone()], |row| {
257                    let status: String = row.get(0)?;
258                    let count: i64 = row.get(1)?;
259                    Ok((status, count))
260                })?
261                .filter_map(|r| r.ok())
262                .collect()
263            } else {
264                stmt.query_map(
265                    params![params_vec[0].clone(), params_vec[1].clone()],
266                    |row| {
267                        let status: String = row.get(0)?;
268                        let count: i64 = row.get(1)?;
269                        Ok((status, count))
270                    },
271                )?
272                .filter_map(|r| r.ok())
273                .collect()
274            };
275
276            for (status, count) in status_counts {
277                tasks_by_status.insert(status, count);
278            }
279
280            // Calculate completed_points (points for tasks in non-blocking states)
281            let completed_points_sql = match (agent_id, task_id) {
282                (Some(_aid), None) => {
283                    "SELECT COALESCE(SUM(points), 0) FROM tasks 
284                     WHERE worker_id = ?1 AND status NOT IN (SELECT value FROM json_each(?2))"
285                }
286                (None, Some(_tid)) => {
287                    "WITH RECURSIVE descendants AS (
288                        SELECT id FROM tasks WHERE id = ?1
289                        UNION ALL
290                        SELECT dep.to_task_id FROM dependencies dep
291                        INNER JOIN descendants d ON dep.from_task_id = d.id
292                        WHERE dep.dep_type = 'contains'
293                    )
294                    SELECT COALESCE(SUM(points), 0) FROM tasks
295                    WHERE id IN (SELECT id FROM descendants)
296                    AND status NOT IN (SELECT value FROM json_each(?2))"
297                }
298                (Some(_aid), Some(_tid)) => {
299                    "WITH RECURSIVE descendants AS (
300                        SELECT id FROM tasks WHERE id = ?2
301                        UNION ALL
302                        SELECT dep.to_task_id FROM dependencies dep
303                        INNER JOIN descendants d ON dep.from_task_id = d.id
304                        WHERE dep.dep_type = 'contains'
305                    )
306                    SELECT COALESCE(SUM(points), 0) FROM tasks
307                    WHERE id IN (SELECT id FROM descendants) AND worker_id = ?1
308                    AND status NOT IN (SELECT value FROM json_each(?3))"
309                }
310                (None, None) => {
311                    "SELECT COALESCE(SUM(points), 0) FROM tasks 
312                     WHERE status NOT IN (SELECT value FROM json_each(?1))"
313                }
314            };
315
316            let blocking_states_json = serde_json::to_string(&states_config.blocking_states)?;
317
318            let completed_points: i64 = match (agent_id, task_id) {
319                (Some(aid), None) => conn.query_row(
320                    completed_points_sql,
321                    params![aid, blocking_states_json],
322                    |row| row.get(0),
323                )?,
324                (None, Some(tid)) => conn.query_row(
325                    completed_points_sql,
326                    params![tid, blocking_states_json],
327                    |row| row.get(0),
328                )?,
329                (Some(aid), Some(tid)) => conn.query_row(
330                    completed_points_sql,
331                    params![aid, tid, blocking_states_json],
332                    |row| row.get(0),
333                )?,
334                (None, None) => {
335                    conn.query_row(completed_points_sql, params![blocking_states_json], |row| {
336                        row.get(0)
337                    })?
338                }
339            };
340
341            Ok(Stats {
342                total_tasks,
343                tasks_by_status,
344                total_points,
345                completed_points,
346                total_time_estimate_ms,
347                total_time_actual_ms,
348                total_cost_usd,
349                total_metrics: [m0, m1, m2, m3, m4, m5, m6, m7],
350            })
351        })
352    }
353}