Skip to main content

rustvello_postgres/orchestrator/
concurrency.rs

1use async_trait::async_trait;
2
3use rustvello_core::error::RustvelloResult;
4use rustvello_core::orchestrator::OrchestratorConcurrency;
5use rustvello_proto::call::SerializedArguments;
6use rustvello_proto::config::TaskConfig;
7use rustvello_proto::identifiers::{InvocationId, TaskId};
8use rustvello_proto::status::ConcurrencyControlType;
9
10use super::PostgresOrchestrator;
11use crate::db::pg_err;
12
13#[async_trait]
14impl OrchestratorConcurrency for PostgresOrchestrator {
15    /// **Note:** This check-and-decide pattern is inherently subject to
16    /// TOCTOU races in multi-node PostgreSQL deployments. Two concurrent
17    /// callers may both read the same count and both admit a new invocation,
18    /// briefly exceeding the concurrency limit. An advisory lock or
19    /// `INSERT … WHERE (SELECT COUNT …) < limit` would be needed for
20    /// strict enforcement, which is a trait-level design change.
21    async fn check_running_concurrency(
22        &self,
23        task_id: &TaskId,
24        task_config: &TaskConfig,
25        cc_args: Option<&SerializedArguments>,
26    ) -> RustvelloResult<bool> {
27        if task_config.concurrency_control == ConcurrencyControlType::Unlimited {
28            return Ok(true);
29        }
30
31        let client = self.db.conn().await?;
32        let task_key = task_id.to_string();
33
34        let count: i64 = match cc_args {
35            Some(args) => {
36                let pairs = args.cc_arg_pairs();
37                let n_pairs = pairs.len();
38                let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> =
39                    Vec::new();
40                let mut idx = 1;
41                let task_p = format!("${idx}");
42                params.push(Box::new(task_key));
43                idx += 1;
44                let pair_conds: Vec<String> = pairs
45                    .iter()
46                    .map(|(k, v)| {
47                        let kp = format!("${idx}");
48                        params.push(Box::new(k.clone()));
49                        idx += 1;
50                        let vp = format!("${idx}");
51                        params.push(Box::new(v.clone()));
52                        idx += 1;
53                        format!("(cp.arg_key = {kp} AND cp.arg_value = {vp})")
54                    })
55                    .collect();
56                let where_pairs = pair_conds.join(" OR ");
57                let sql = format!(
58                    "SELECT COUNT(*) FROM (
59                         SELECT cp.invocation_id FROM cc_arg_pairs cp
60                         JOIN invocations i ON cp.invocation_id = i.invocation_id
61                         WHERE cp.task_id = {task_p} AND ({where_pairs})
62                           AND i.status IN ('PENDING', 'RUNNING')
63                         GROUP BY cp.invocation_id
64                         HAVING COUNT(*) = {n_pairs}
65                     ) sub"
66                );
67                let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
68                    params.iter().map(|p| &**p as _).collect();
69                let row = client.query_one(&sql, &param_refs).await.map_err(pg_err)?;
70                row.get(0)
71            }
72            None => {
73                let row = client
74                    .query_one(
75                        "SELECT COUNT(*) FROM invocations
76                         WHERE task_id = $1 AND status IN ('PENDING', 'RUNNING')",
77                        &[&task_key],
78                    )
79                    .await
80                    .map_err(pg_err)?;
81                row.get(0)
82            }
83        };
84
85        let limit = task_config.running_concurrency.unwrap_or(1) as i64;
86        Ok(count < limit)
87    }
88
89    async fn index_for_concurrency_control(
90        &self,
91        invocation_id: &InvocationId,
92        task_id: &TaskId,
93        cc_args: Option<&SerializedArguments>,
94    ) -> RustvelloResult<()> {
95        let Some(args) = cc_args else {
96            return Ok(());
97        };
98        let client = self.db.conn().await?;
99        let task_key = task_id.to_string();
100        let pairs = args.cc_arg_pairs();
101        let inv_str = invocation_id.as_str();
102
103        for (k, v) in &pairs {
104            client
105                .execute(
106                    "INSERT INTO cc_arg_pairs (invocation_id, task_id, arg_key, arg_value)
107                     VALUES ($1, $2, $3, $4)
108                     ON CONFLICT (invocation_id, arg_key, arg_value) DO NOTHING",
109                    &[&inv_str, &task_key, k, v],
110                )
111                .await
112                .map_err(pg_err)?;
113        }
114
115        Ok(())
116    }
117
118    async fn remove_from_concurrency_index(
119        &self,
120        invocation_id: &InvocationId,
121    ) -> RustvelloResult<()> {
122        let client = self.db.conn().await?;
123
124        client
125            .execute(
126                "DELETE FROM cc_arg_pairs WHERE invocation_id = $1",
127                &[&invocation_id.as_str()],
128            )
129            .await
130            .map_err(pg_err)?;
131
132        Ok(())
133    }
134
135    /// Atomic check-and-index via a single INSERT … SELECT … WHERE count < limit.
136    ///
137    /// For per-pair CC, indexes each arg pair atomically. Checks all pairs
138    /// collectively (GROUP BY/HAVING intersection) before allowing the insert.
139    async fn try_acquire_concurrency_slot(
140        &self,
141        invocation_id: &InvocationId,
142        task_id: &TaskId,
143        task_config: &TaskConfig,
144        cc_args: Option<&SerializedArguments>,
145    ) -> RustvelloResult<bool> {
146        if task_config.concurrency_control == ConcurrencyControlType::Unlimited {
147            self.index_for_concurrency_control(invocation_id, task_id, cc_args)
148                .await?;
149            return Ok(true);
150        }
151
152        let Some(args) = cc_args else {
153            // Task-level CC: no per-pair index, just check invocations directly
154            return self
155                .check_running_concurrency(task_id, task_config, cc_args)
156                .await;
157        };
158
159        let mut client = self.db.conn().await?;
160        let tx = client.transaction().await.map_err(pg_err)?;
161        let task_key = task_id.to_string();
162        let pairs = args.cc_arg_pairs();
163        let n_pairs = pairs.len();
164        let limit = task_config.running_concurrency.unwrap_or(1) as i64;
165
166        // Build the per-pair count check
167        let mut params: Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>> = Vec::new();
168        let mut idx = 1;
169        let task_p = format!("${idx}");
170        params.push(Box::new(task_key.clone()));
171        idx += 1;
172        let pair_conds: Vec<String> = pairs
173            .iter()
174            .map(|(k, v)| {
175                let kp = format!("${idx}");
176                params.push(Box::new(k.clone()));
177                idx += 1;
178                let vp = format!("${idx}");
179                params.push(Box::new(v.clone()));
180                idx += 1;
181                format!("(cp.arg_key = {kp} AND cp.arg_value = {vp})")
182            })
183            .collect();
184        let where_pairs = pair_conds.join(" OR ");
185        let limit_p = format!("${idx}");
186        params.push(Box::new(limit));
187        let check_sql = format!(
188            "SELECT (SELECT COUNT(*) FROM (
189                 SELECT cp.invocation_id FROM cc_arg_pairs cp
190                 JOIN invocations i ON cp.invocation_id = i.invocation_id
191                 WHERE cp.task_id = {task_p} AND ({where_pairs})
192                   AND i.status IN ('PENDING', 'RUNNING')
193                 GROUP BY cp.invocation_id
194                 HAVING COUNT(*) = {n_pairs}
195             ) sub) < {limit_p}"
196        );
197        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
198            params.iter().map(|p| &**p as _).collect();
199        let row = tx
200            .query_one(&check_sql, &param_refs)
201            .await
202            .map_err(pg_err)?;
203        let allowed: bool = row.get(0);
204
205        if allowed {
206            let inv_str = invocation_id.as_str();
207            for (k, v) in &pairs {
208                tx.execute(
209                    "INSERT INTO cc_arg_pairs (invocation_id, task_id, arg_key, arg_value)
210                     VALUES ($1, $2, $3, $4)
211                     ON CONFLICT (invocation_id, arg_key, arg_value) DO NOTHING",
212                    &[&inv_str, &task_key, k, v],
213                )
214                .await
215                .map_err(pg_err)?;
216            }
217        }
218
219        tx.commit().await.map_err(pg_err)?;
220        Ok(allowed)
221    }
222}