Skip to main content

rustvello_sqlite/orchestrator/
concurrency.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::orchestrator::OrchestratorConcurrency;
7use rustvello_proto::call::SerializedArguments;
8use rustvello_proto::config::TaskConfig;
9use rustvello_proto::identifiers::{InvocationId, TaskId};
10use rustvello_proto::status::ConcurrencyControlType;
11
12use crate::db::{blocking, lock_err, sql_err};
13
14use super::SqliteOrchestrator;
15
16#[async_trait]
17impl OrchestratorConcurrency for SqliteOrchestrator {
18    async fn check_running_concurrency(
19        &self,
20        task_id: &TaskId,
21        task_config: &TaskConfig,
22        cc_args: Option<&SerializedArguments>,
23    ) -> RustvelloResult<bool> {
24        let db = Arc::clone(&self.db);
25        let task_id = task_id.clone();
26        let task_config = task_config.clone();
27        let cc_args = cc_args.cloned();
28        blocking(move || {
29            if task_config.concurrency_control == ConcurrencyControlType::Unlimited {
30                return Ok(true);
31            }
32
33            let conn = db.conn.lock().map_err(lock_err)?;
34            let task_key = task_id.to_string();
35
36            let count: i64 = match cc_args {
37                Some(args) => {
38                    // Arg-level CC: per-pair intersection via GROUP BY/HAVING
39                    let pairs = args.cc_arg_pairs();
40                    let n_pairs = pairs.len();
41                    let pair_conds: Vec<String> = (0..pairs.len())
42                        .map(|i| {
43                            format!(
44                                "(cp.arg_key = ?{} AND cp.arg_value = ?{})",
45                                i * 2 + 2,
46                                i * 2 + 3
47                            )
48                        })
49                        .collect();
50                    let where_pairs = pair_conds.join(" OR ");
51                    let sql = format!(
52                        "SELECT COUNT(*) FROM (
53                             SELECT cp.invocation_id FROM cc_arg_pairs cp
54                             JOIN invocations i ON cp.invocation_id = i.invocation_id
55                             WHERE cp.task_id = ?1 AND ({where_pairs})
56                               AND i.status IN ('PENDING', 'RUNNING')
57                             GROUP BY cp.invocation_id
58                             HAVING COUNT(*) = {n_pairs}
59                         )"
60                    );
61                    let mut params: Vec<String> = vec![task_key];
62                    for (k, v) in &pairs {
63                        params.push(k.clone());
64                        params.push(v.clone());
65                    }
66                    conn.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
67                        row.get(0)
68                    })
69                    .map_err(sql_err)?
70                }
71                None => {
72                    // Task-level CC: count from invocations table directly
73                    conn.query_row(
74                        "SELECT COUNT(*) FROM invocations
75                         WHERE task_id = ?1 AND status IN ('PENDING', 'RUNNING')",
76                        rusqlite::params![&task_key],
77                        |row| row.get(0),
78                    )
79                    .map_err(sql_err)?
80                }
81            };
82
83            let limit = task_config.running_concurrency.unwrap_or(1) as i64;
84            Ok(count < limit)
85        })
86        .await
87    }
88
89    async fn index_for_concurrency_control(
90        &self,
91        invocation_id: &InvocationId,
92        task_id: &TaskId,
93        cc_args: Option<&SerializedArguments>,
94    ) -> RustvelloResult<()> {
95        let db = Arc::clone(&self.db);
96        let invocation_id = invocation_id.clone();
97        let task_id = task_id.clone();
98        let cc_args = cc_args.cloned();
99        blocking(move || {
100
101            let Some(args) = cc_args else {
102                return Ok(());
103            };
104            let conn = db.conn.lock().map_err(lock_err)?;
105            let task_key = task_id.to_string();
106            let pairs = args.cc_arg_pairs();
107
108            for (k, v) in &pairs {
109                conn.execute(
110                    "INSERT OR REPLACE INTO cc_arg_pairs (invocation_id, task_id, arg_key, arg_value)
111                     VALUES (?1, ?2, ?3, ?4)",
112                    rusqlite::params![invocation_id.as_str(), &task_key, k, v],
113                )
114                .map_err(sql_err)?;
115            }
116
117            Ok(())
118
119        })
120        .await
121    }
122
123    async fn remove_from_concurrency_index(
124        &self,
125        invocation_id: &InvocationId,
126    ) -> RustvelloResult<()> {
127        let db = Arc::clone(&self.db);
128        let invocation_id = invocation_id.clone();
129        blocking(move || {
130            let conn = db.conn.lock().map_err(lock_err)?;
131
132            conn.execute(
133                "DELETE FROM cc_arg_pairs WHERE invocation_id = ?1",
134                [invocation_id.as_str()],
135            )
136            .map_err(sql_err)?;
137
138            Ok(())
139        })
140        .await
141    }
142}