rustvello_sqlite/orchestrator/
concurrency.rs1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use rustvello_core::error::RustvelloResult;
6use rustvello_core::orchestrator::OrchestratorConcurrency;
7use rustvello_proto::call::SerializedArguments;
8use rustvello_proto::config::TaskConfig;
9use rustvello_proto::identifiers::{InvocationId, TaskId};
10use rustvello_proto::status::ConcurrencyControlType;
11
12use crate::db::{blocking, lock_err, sql_err};
13
14use super::SqliteOrchestrator;
15
16#[async_trait]
17impl OrchestratorConcurrency for SqliteOrchestrator {
18 async fn check_running_concurrency(
19 &self,
20 task_id: &TaskId,
21 task_config: &TaskConfig,
22 cc_args: Option<&SerializedArguments>,
23 ) -> RustvelloResult<bool> {
24 let db = Arc::clone(&self.db);
25 let task_id = task_id.clone();
26 let task_config = task_config.clone();
27 let cc_args = cc_args.cloned();
28 blocking(move || {
29 if task_config.concurrency_control == ConcurrencyControlType::Unlimited {
30 return Ok(true);
31 }
32
33 let conn = db.conn.lock().map_err(lock_err)?;
34 let task_key = task_id.to_string();
35
36 let count: i64 = match cc_args {
37 Some(args) => {
38 let pairs = args.cc_arg_pairs();
40 let n_pairs = pairs.len();
41 let pair_conds: Vec<String> = (0..pairs.len())
42 .map(|i| {
43 format!(
44 "(cp.arg_key = ?{} AND cp.arg_value = ?{})",
45 i * 2 + 2,
46 i * 2 + 3
47 )
48 })
49 .collect();
50 let where_pairs = pair_conds.join(" OR ");
51 let sql = format!(
52 "SELECT COUNT(*) FROM (
53 SELECT cp.invocation_id FROM cc_arg_pairs cp
54 JOIN invocations i ON cp.invocation_id = i.invocation_id
55 WHERE cp.task_id = ?1 AND ({where_pairs})
56 AND i.status IN ('PENDING', 'RUNNING')
57 GROUP BY cp.invocation_id
58 HAVING COUNT(*) = {n_pairs}
59 )"
60 );
61 let mut params: Vec<String> = vec![task_key];
62 for (k, v) in &pairs {
63 params.push(k.clone());
64 params.push(v.clone());
65 }
66 conn.query_row(&sql, rusqlite::params_from_iter(params.iter()), |row| {
67 row.get(0)
68 })
69 .map_err(sql_err)?
70 }
71 None => {
72 conn.query_row(
74 "SELECT COUNT(*) FROM invocations
75 WHERE task_id = ?1 AND status IN ('PENDING', 'RUNNING')",
76 rusqlite::params![&task_key],
77 |row| row.get(0),
78 )
79 .map_err(sql_err)?
80 }
81 };
82
83 let limit = task_config.running_concurrency.unwrap_or(1) as i64;
84 Ok(count < limit)
85 })
86 .await
87 }
88
89 async fn index_for_concurrency_control(
90 &self,
91 invocation_id: &InvocationId,
92 task_id: &TaskId,
93 cc_args: Option<&SerializedArguments>,
94 ) -> RustvelloResult<()> {
95 let db = Arc::clone(&self.db);
96 let invocation_id = invocation_id.clone();
97 let task_id = task_id.clone();
98 let cc_args = cc_args.cloned();
99 blocking(move || {
100
101 let Some(args) = cc_args else {
102 return Ok(());
103 };
104 let conn = db.conn.lock().map_err(lock_err)?;
105 let task_key = task_id.to_string();
106 let pairs = args.cc_arg_pairs();
107
108 for (k, v) in &pairs {
109 conn.execute(
110 "INSERT OR REPLACE INTO cc_arg_pairs (invocation_id, task_id, arg_key, arg_value)
111 VALUES (?1, ?2, ?3, ?4)",
112 rusqlite::params![invocation_id.as_str(), &task_key, k, v],
113 )
114 .map_err(sql_err)?;
115 }
116
117 Ok(())
118
119 })
120 .await
121 }
122
123 async fn remove_from_concurrency_index(
124 &self,
125 invocation_id: &InvocationId,
126 ) -> RustvelloResult<()> {
127 let db = Arc::clone(&self.db);
128 let invocation_id = invocation_id.clone();
129 blocking(move || {
130 let conn = db.conn.lock().map_err(lock_err)?;
131
132 conn.execute(
133 "DELETE FROM cc_arg_pairs WHERE invocation_id = ?1",
134 [invocation_id.as_str()],
135 )
136 .map_err(sql_err)?;
137
138 Ok(())
139 })
140 .await
141 }
142}