Skip to main content

tandem_server/app/state/automation/
scheduler.rs

1use crate::app::state::automation::rate_limit::RateLimitManager;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use tandem_types::TenantContext;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct AutomationSchedulerMetrics {
8    pub active_runs: usize,
9    pub queued_runs_by_reason: HashMap<String, usize>,
10    pub admitted_total: u64,
11    pub completed_total: u64,
12    pub avg_wait_ms: u64,
13    pub p95_wait_ms: u64,
14}
15
16// ──────────────────────────────────────────────────────────────
17// Queue metadata
18// ──────────────────────────────────────────────────────────────
19
20#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
21#[serde(rename_all = "snake_case")]
22pub enum QueueReason {
23    Capacity,
24    WorkspaceLock,
25    RateLimit,
26}
27
28impl QueueReason {
29    pub fn as_str(&self) -> &'static str {
30        match self {
31            Self::Capacity => "capacity",
32            Self::WorkspaceLock => "workspace_lock",
33            Self::RateLimit => "rate_limit",
34        }
35    }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39pub struct SchedulerMetadata {
40    #[serde(default = "default_tenant_context")]
41    pub tenant_context: TenantContext,
42    pub queue_reason: Option<QueueReason>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub resource_key: Option<String>,
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub rate_limited_provider: Option<String>,
47    #[serde(default)]
48    pub queued_at_ms: u64,
49}
50
51fn default_tenant_context() -> TenantContext {
52    TenantContext::local_implicit()
53}
54
55// ──────────────────────────────────────────────────────────────
56// Preexisting Artifact Registry (MWF-300)
57// ──────────────────────────────────────────────────────────────
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ValidatedArtifact {
61    pub path: String,
62    pub content_digest: String,
63}
64
65#[derive(Debug, Default)]
66pub struct PreexistingArtifactRegistry {
67    /// run_id -> node_id -> artifact
68    pub artifacts: HashMap<String, HashMap<String, ValidatedArtifact>>,
69}
70
71impl PreexistingArtifactRegistry {
72    pub fn new() -> Self {
73        Self {
74            artifacts: HashMap::new(),
75        }
76    }
77
78    /// Record a validated artifact for a node in a run.
79    pub fn register_validated(&mut self, run_id: &str, node_id: &str, artifact: ValidatedArtifact) {
80        self.artifacts
81            .entry(run_id.to_string())
82            .or_default()
83            .insert(node_id.to_string(), artifact);
84    }
85
86    /// Returns true if a valid artifact is registered for this (run, node) pair.
87    pub fn is_artifact_prevalidated(&self, run_id: &str, node_id: &str) -> bool {
88        self.artifacts
89            .get(run_id)
90            .and_then(|nodes| nodes.get(node_id))
91            .is_some()
92    }
93
94    /// Returns the file path of the prevalidated artifact, if any.
95    pub fn get_prevalidated_path(&self, run_id: &str, node_id: &str) -> Option<&str> {
96        self.artifacts
97            .get(run_id)
98            .and_then(|nodes| nodes.get(node_id))
99            .map(|a| a.path.as_str())
100    }
101
102    /// Remove all artifacts for a run (call on run completion/failure).
103    pub fn clear_run(&mut self, run_id: &str) {
104        self.artifacts.remove(run_id);
105    }
106}
107
108// ──────────────────────────────────────────────────────────────
109// Multi-run scheduler
110// ──────────────────────────────────────────────────────────────
111
112pub struct AutomationScheduler {
113    pub max_concurrent_runs: usize,
114    /// run_id → workspace_root (empty string if no workspace root)
115    pub active_runs: HashMap<String, String>,
116    /// workspace_root → run_id
117    pub locked_workspaces: HashMap<String, String>,
118    pub rate_limits: RateLimitManager,
119    pub preexisting_registry: PreexistingArtifactRegistry,
120    pub admitted_total: u64,
121    pub completed_total: u64,
122    /// run_id -> metadata
123    pub queue_state: HashMap<String, SchedulerMetadata>,
124    /// Wait times in ms (last 1000 runs)
125    pub wait_times: std::collections::VecDeque<u64>,
126}
127
128impl AutomationScheduler {
129    pub fn new(max_concurrent_runs: usize) -> Self {
130        Self {
131            max_concurrent_runs,
132            active_runs: HashMap::new(),
133            locked_workspaces: HashMap::new(),
134            rate_limits: RateLimitManager::new(),
135            preexisting_registry: PreexistingArtifactRegistry::new(),
136            admitted_total: 0,
137            completed_total: 0,
138            queue_state: HashMap::new(),
139            wait_times: std::collections::VecDeque::with_capacity(1000),
140        }
141    }
142
143    /// Returns Ok(()) if the run can be admitted right now.
144    /// Returns Err(SchedulerMetadata) describing why the run must wait.
145    pub fn can_admit(
146        &self,
147        run_id: &str,
148        workspace_root: Option<&str>,
149        required_providers: &[String],
150    ) -> Result<(), SchedulerMetadata> {
151        // 1. Check Rate Limits
152        for provider in required_providers {
153            if self.rate_limits.is_provider_throttled(provider) {
154                return Err(SchedulerMetadata {
155                    tenant_context: TenantContext::local_implicit(),
156                    queue_reason: Some(QueueReason::RateLimit),
157                    resource_key: None,
158                    rate_limited_provider: Some(provider.clone()),
159                    queued_at_ms: self.get_queued_at(run_id),
160                });
161            }
162        }
163
164        // 2. Check workspace lock (prevent priority inversion)
165        if let Some(root) = workspace_root {
166            if self.locked_workspaces.contains_key(root) {
167                return Err(SchedulerMetadata {
168                    tenant_context: TenantContext::local_implicit(),
169                    queue_reason: Some(QueueReason::WorkspaceLock),
170                    resource_key: Some(root.to_string()),
171                    rate_limited_provider: None,
172                    queued_at_ms: self.get_queued_at(run_id),
173                });
174            }
175        }
176
177        // 3. Check global capacity
178        if self.active_runs.len() >= self.max_concurrent_runs {
179            return Err(SchedulerMetadata {
180                tenant_context: TenantContext::local_implicit(),
181                queue_reason: Some(QueueReason::Capacity),
182                resource_key: None,
183                rate_limited_provider: None,
184                queued_at_ms: self.get_queued_at(run_id),
185            });
186        }
187
188        Ok(())
189    }
190
191    fn get_queued_at(&self, run_id: &str) -> u64 {
192        self.queue_state
193            .get(run_id)
194            .map(|m| m.queued_at_ms)
195            .unwrap_or_else(crate::util::time::now_ms)
196    }
197
198    pub fn track_queue_state(&mut self, run_id: &str, metadata: SchedulerMetadata) {
199        self.queue_state.insert(run_id.to_string(), metadata);
200    }
201
202    /// Admit a run — records the active slot and workspace lock.
203    pub fn admit_run(&mut self, run_id: &str, workspace_root: Option<&str>) {
204        let root = workspace_root.unwrap_or("").to_string();
205        if !root.is_empty() {
206            self.locked_workspaces
207                .insert(root.clone(), run_id.to_string());
208        }
209        self.active_runs.insert(run_id.to_string(), root);
210        self.admitted_total += 1;
211
212        if let Some(meta) = self.queue_state.remove(run_id) {
213            let wait_ms = crate::util::time::now_ms().saturating_sub(meta.queued_at_ms);
214            if self.wait_times.len() >= 1000 {
215                self.wait_times.pop_front();
216            }
217            self.wait_times.push_back(wait_ms);
218        }
219    }
220
221    pub fn reserve_workspace(&mut self, run_id: &str, workspace_root: Option<&str>) {
222        let root = workspace_root.unwrap_or("").to_string();
223        if root.is_empty() {
224            return;
225        }
226        self.locked_workspaces.insert(root, run_id.to_string());
227    }
228
229    pub fn release_capacity(&mut self, run_id: &str) {
230        if self.active_runs.remove(run_id).is_some() {
231            self.completed_total += 1;
232        }
233    }
234
235    pub fn release_workspace(&mut self, run_id: &str) {
236        self.locked_workspaces.retain(|_, holder| holder != run_id);
237        self.preexisting_registry.clear_run(run_id);
238        self.queue_state.remove(run_id);
239    }
240
241    /// Release a run — frees capacity and workspace lock.
242    pub fn release_run(&mut self, run_id: &str) {
243        self.release_capacity(run_id);
244        self.release_workspace(run_id);
245    }
246
247    pub fn metrics(&self) -> AutomationSchedulerMetrics {
248        let mut reasons = HashMap::new();
249        for meta in self.queue_state.values() {
250            if let Some(reason) = meta.queue_reason {
251                *reasons.entry(reason.as_str().to_string()).or_default() += 1;
252            }
253        }
254
255        let mut wait_times: Vec<u64> = self.wait_times.iter().cloned().collect();
256        wait_times.sort_unstable();
257
258        let avg_wait = if wait_times.is_empty() {
259            0
260        } else {
261            wait_times.iter().sum::<u64>() / wait_times.len() as u64
262        };
263
264        let p95_wait = if wait_times.is_empty() {
265            0
266        } else {
267            let idx = (wait_times.len() as f64 * 0.95).round() as usize;
268            wait_times
269                .get(idx.min(wait_times.len() - 1))
270                .cloned()
271                .unwrap_or(0)
272        };
273
274        AutomationSchedulerMetrics {
275            active_runs: self.active_runs.len(),
276            queued_runs_by_reason: reasons,
277            admitted_total: self.admitted_total,
278            completed_total: self.completed_total,
279            avg_wait_ms: avg_wait,
280            p95_wait_ms: p95_wait,
281        }
282    }
283
284    pub fn active_count(&self) -> usize {
285        self.active_runs.len()
286    }
287
288    pub fn is_at_capacity(&self) -> bool {
289        self.active_runs.len() >= self.max_concurrent_runs
290    }
291}