rustvello_postgres/orchestrator/
concurrency.rs1use async_trait::async_trait;
2
3use rustvello_core::error::RustvelloResult;
4use rustvello_core::orchestrator::OrchestratorConcurrency;
5use rustvello_proto::call::SerializedArguments;
6use rustvello_proto::config::TaskConfig;
7use rustvello_proto::identifiers::{InvocationId, TaskId};
8use rustvello_proto::status::ConcurrencyControlType;
9
10use super::PostgresOrchestrator;
11use crate::db::pg_err;
12
13#[async_trait]
14impl OrchestratorConcurrency for PostgresOrchestrator {
15 async fn check_running_concurrency(
22 &self,
23 task_id: &TaskId,
24 task_config: &TaskConfig,
25 cc_args: Option<&SerializedArguments>,
26 ) -> RustvelloResult<bool> {
27 if task_config.concurrency_control == ConcurrencyControlType::Unlimited {
28 return Ok(true);
29 }
30
31 let client = self.db.conn().await?;
32 let task_key = task_id.to_string();
33
34 let count: i64 = match cc_args {
35 Some(args) => {
36 let pairs = args.cc_arg_pairs();
37 let n_pairs = pairs.len();
38 let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> =
39 Vec::new();
40 let mut idx = 1;
41 let task_p = format!("${idx}");
42 params.push(Box::new(task_key));
43 idx += 1;
44 let pair_conds: Vec<String> = pairs
45 .iter()
46 .map(|(k, v)| {
47 let kp = format!("${idx}");
48 params.push(Box::new(k.clone()));
49 idx += 1;
50 let vp = format!("${idx}");
51 params.push(Box::new(v.clone()));
52 idx += 1;
53 format!("(cp.arg_key = {kp} AND cp.arg_value = {vp})")
54 })
55 .collect();
56 let where_pairs = pair_conds.join(" OR ");
57 let sql = format!(
58 "SELECT COUNT(*) FROM (
59 SELECT cp.invocation_id FROM cc_arg_pairs cp
60 JOIN invocations i ON cp.invocation_id = i.invocation_id
61 WHERE cp.task_id = {task_p} AND ({where_pairs})
62 AND i.status IN ('PENDING', 'RUNNING')
63 GROUP BY cp.invocation_id
64 HAVING COUNT(*) = {n_pairs}
65 ) sub"
66 );
67 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
68 params.iter().map(|p| &**p as _).collect();
69 let row = client.query_one(&sql, ¶m_refs).await.map_err(pg_err)?;
70 row.get(0)
71 }
72 None => {
73 let row = client
74 .query_one(
75 "SELECT COUNT(*) FROM invocations
76 WHERE task_id = $1 AND status IN ('PENDING', 'RUNNING')",
77 &[&task_key],
78 )
79 .await
80 .map_err(pg_err)?;
81 row.get(0)
82 }
83 };
84
85 let limit = task_config.running_concurrency.unwrap_or(1) as i64;
86 Ok(count < limit)
87 }
88
89 async fn index_for_concurrency_control(
90 &self,
91 invocation_id: &InvocationId,
92 task_id: &TaskId,
93 cc_args: Option<&SerializedArguments>,
94 ) -> RustvelloResult<()> {
95 let Some(args) = cc_args else {
96 return Ok(());
97 };
98 let client = self.db.conn().await?;
99 let task_key = task_id.to_string();
100 let pairs = args.cc_arg_pairs();
101 let inv_str = invocation_id.as_str();
102
103 for (k, v) in &pairs {
104 client
105 .execute(
106 "INSERT INTO cc_arg_pairs (invocation_id, task_id, arg_key, arg_value)
107 VALUES ($1, $2, $3, $4)
108 ON CONFLICT (invocation_id, arg_key, arg_value) DO NOTHING",
109 &[&inv_str, &task_key, k, v],
110 )
111 .await
112 .map_err(pg_err)?;
113 }
114
115 Ok(())
116 }
117
118 async fn remove_from_concurrency_index(
119 &self,
120 invocation_id: &InvocationId,
121 ) -> RustvelloResult<()> {
122 let client = self.db.conn().await?;
123
124 client
125 .execute(
126 "DELETE FROM cc_arg_pairs WHERE invocation_id = $1",
127 &[&invocation_id.as_str()],
128 )
129 .await
130 .map_err(pg_err)?;
131
132 Ok(())
133 }
134
135 async fn try_acquire_concurrency_slot(
140 &self,
141 invocation_id: &InvocationId,
142 task_id: &TaskId,
143 task_config: &TaskConfig,
144 cc_args: Option<&SerializedArguments>,
145 ) -> RustvelloResult<bool> {
146 if task_config.concurrency_control == ConcurrencyControlType::Unlimited {
147 self.index_for_concurrency_control(invocation_id, task_id, cc_args)
148 .await?;
149 return Ok(true);
150 }
151
152 let Some(args) = cc_args else {
153 return self
155 .check_running_concurrency(task_id, task_config, cc_args)
156 .await;
157 };
158
159 let mut client = self.db.conn().await?;
160 let tx = client.transaction().await.map_err(pg_err)?;
161 let task_key = task_id.to_string();
162 let pairs = args.cc_arg_pairs();
163 let n_pairs = pairs.len();
164 let limit = task_config.running_concurrency.unwrap_or(1) as i64;
165
166 let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
168 let mut idx = 1;
169 let task_p = format!("${idx}");
170 params.push(Box::new(task_key.clone()));
171 idx += 1;
172 let pair_conds: Vec<String> = pairs
173 .iter()
174 .map(|(k, v)| {
175 let kp = format!("${idx}");
176 params.push(Box::new(k.clone()));
177 idx += 1;
178 let vp = format!("${idx}");
179 params.push(Box::new(v.clone()));
180 idx += 1;
181 format!("(cp.arg_key = {kp} AND cp.arg_value = {vp})")
182 })
183 .collect();
184 let where_pairs = pair_conds.join(" OR ");
185 let limit_p = format!("${idx}");
186 params.push(Box::new(limit));
187 let check_sql = format!(
188 "SELECT (SELECT COUNT(*) FROM (
189 SELECT cp.invocation_id FROM cc_arg_pairs cp
190 JOIN invocations i ON cp.invocation_id = i.invocation_id
191 WHERE cp.task_id = {task_p} AND ({where_pairs})
192 AND i.status IN ('PENDING', 'RUNNING')
193 GROUP BY cp.invocation_id
194 HAVING COUNT(*) = {n_pairs}
195 ) sub) < {limit_p}"
196 );
197 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
198 params.iter().map(|p| &**p as _).collect();
199 let row = tx
200 .query_one(&check_sql, ¶m_refs)
201 .await
202 .map_err(pg_err)?;
203 let allowed: bool = row.get(0);
204
205 if allowed {
206 let inv_str = invocation_id.as_str();
207 for (k, v) in &pairs {
208 tx.execute(
209 "INSERT INTO cc_arg_pairs (invocation_id, task_id, arg_key, arg_value)
210 VALUES ($1, $2, $3, $4)
211 ON CONFLICT (invocation_id, arg_key, arg_value) DO NOTHING",
212 &[&inv_str, &task_key, k, v],
213 )
214 .await
215 .map_err(pg_err)?;
216 }
217 }
218
219 tx.commit().await.map_err(pg_err)?;
220 Ok(allowed)
221 }
222}