Skip to main content

visual_rubric/
pool.rs

1use std::ffi::OsString;
2use std::fs;
3use std::path::{Path, PathBuf};
4use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
5use std::sync::{Arc, Mutex, mpsc};
6use std::thread::{self, JoinHandle};
7use std::time::Duration;
8
9use rand::RngExt as _;
10use tempfile::TempDir;
11
12mod codex_home;
13mod config;
14
15pub use config::{LogCaptureConfig, LogPathMode, PoolConfig, PoolStats};
16
17use crate::{
18    AcpClient, DEFAULT_CODEX_ACP_MODEL, DEFAULT_CODEX_ACP_REASONING_EFFORT, DEFAULT_SYSTEM_PROMPT,
19    PoolError, RateLimitEvent, RubricOptions, RubricVerdict, build_codex_acp_args, encode_png,
20    parse_verdict,
21};
22use codex_home::seed_codex_home;
23
24const RECYCLE_SPAWN_ATTEMPTS: u32 = 2;
25
26/// Reusable worker pool for evaluating screenshot rubrics through Codex ACP.
27#[derive(Debug)]
28pub struct RubricPool {
29    senders: Vec<mpsc::Sender<Job>>,
30    handles: Mutex<Vec<JoinHandle<()>>>,
31    next: AtomicUsize,
32    config: PoolConfig,
33    shared: Arc<SharedPoolState>,
34}
35
36struct Job {
37    png_path: PathBuf,
38    question: String,
39    options: RubricOptions,
40    reply: mpsc::Sender<Result<RubricVerdict, PoolError>>,
41}
42
43#[derive(Default, Debug)]
44struct SharedPoolState {
45    completed: AtomicU64,
46    failures: AtomicU64,
47    worker_recycles: AtomicU64,
48    rate_limit_events: Mutex<Vec<RateLimitEvent>>,
49    fatal_quota: AtomicBool,
50    alive_mask: AtomicU64,
51}
52
53impl RubricPool {
54    /// Starts a worker pool from the supplied configuration.
55    ///
56    /// # Errors
57    ///
58    /// Returns [`PoolError`] when configuration is invalid or worker startup
59    /// fails.
60    pub fn new(config: PoolConfig) -> Result<Self, PoolError> {
61        if config.workers == 0 {
62            return Err(PoolError::Spawn(
63                "workers must be greater than zero".to_string(),
64            ));
65        }
66        if config.workers > u64::BITS as usize {
67            return Err(PoolError::Spawn(format!(
68                "workers={} exceeds alive bitmask capacity {}",
69                config.workers,
70                u64::BITS
71            )));
72        }
73
74        let shared = Arc::new(SharedPoolState::default());
75        shared
76            .alive_mask
77            .store(alive_mask(config.workers), Ordering::Release);
78
79        let mut senders = Vec::with_capacity(config.workers);
80        let mut handles = Vec::with_capacity(config.workers);
81        for worker_id in 0..config.workers {
82            let (job_tx, job_rx) = mpsc::channel();
83            let (ready_tx, ready_rx) = mpsc::channel();
84            let worker = Worker {
85                id: worker_id,
86                config: config.clone(),
87                shared: Arc::clone(&shared),
88                jobs: job_rx,
89            };
90            let handle = thread::spawn(move || worker.run(ready_tx));
91            match ready_rx.recv() {
92                Ok(Ok(())) => {
93                    senders.push(job_tx);
94                    handles.push(handle);
95                }
96                Ok(Err(error)) => {
97                    shared
98                        .alive_mask
99                        .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
100                    drop(senders);
101                    join_handles(handles);
102                    let _ = handle.join();
103                    return Err(error);
104                }
105                Err(error) => {
106                    shared
107                        .alive_mask
108                        .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
109                    drop(senders);
110                    join_handles(handles);
111                    let _ = handle.join();
112                    return Err(PoolError::WorkerCrashed {
113                        worker_id,
114                        message: format!("worker exited before startup result: {error}"),
115                    });
116                }
117            }
118        }
119
120        Ok(Self {
121            senders,
122            handles: Mutex::new(handles),
123            next: AtomicUsize::new(0),
124            config,
125            shared,
126        })
127    }
128
129    /// Submits one PNG rubric job to a live worker.
130    ///
131    /// # Errors
132    ///
133    /// Returns [`PoolError`] for missing workers, worker crashes, timeouts, PNG
134    /// IO, Codex ACP failures, or verdict parsing failures.
135    pub fn submit(
136        &self,
137        png_path: &Path,
138        question: &str,
139        opts: RubricOptions,
140    ) -> Result<RubricVerdict, PoolError> {
141        if self.shared.fatal_quota.load(Ordering::Acquire) {
142            return Err(PoolError::QuotaExceeded);
143        }
144
145        let worker_id = self.next_live_worker()?;
146        let (reply_tx, reply_rx) = mpsc::channel();
147        let job = Job {
148            png_path: png_path.to_path_buf(),
149            question: question.to_string(),
150            options: merge_options(opts, &self.config.default_options),
151            reply: reply_tx,
152        };
153
154        self.senders[worker_id]
155            .send(job)
156            .map_err(|_| PoolError::WorkerCrashed {
157                worker_id,
158                message: "worker channel closed".to_string(),
159            })?;
160
161        match reply_rx.recv_timeout(self.config.submit_timeout) {
162            Ok(result) => result,
163            Err(mpsc::RecvTimeoutError::Timeout) => Err(PoolError::Timeout {
164                worker_id,
165                timeout: self.config.submit_timeout,
166            }),
167            Err(mpsc::RecvTimeoutError::Disconnected) => Err(PoolError::WorkerCrashed {
168                worker_id,
169                message: "worker dropped reply channel".to_string(),
170            }),
171        }
172    }
173
174    /// Stops workers and returns final pool statistics.
175    #[must_use]
176    pub fn shutdown(self) -> PoolStats {
177        let Self {
178            senders,
179            handles,
180            shared,
181            ..
182        } = self;
183        drop(senders);
184        if let Ok(handles) = handles.into_inner() {
185            join_handles(handles);
186        }
187        shared.stats()
188    }
189
190    /// Returns current pool statistics without shutting the pool down.
191    #[must_use]
192    pub fn stats(&self) -> PoolStats {
193        self.shared.stats()
194    }
195
196    fn next_live_worker(&self) -> Result<usize, PoolError> {
197        let worker_count = self.senders.len();
198        for _ in 0..worker_count {
199            let idx = self.next.fetch_add(1, Ordering::AcqRel) % worker_count;
200            let mask = self.shared.alive_mask.load(Ordering::Acquire);
201            if mask & worker_bit(idx) != 0 {
202                return Ok(idx);
203            }
204        }
205        Err(PoolError::NoLiveWorkers)
206    }
207}
208
209struct Worker {
210    id: usize,
211    config: PoolConfig,
212    shared: Arc<SharedPoolState>,
213    jobs: mpsc::Receiver<Job>,
214}
215
216struct WorkerRuntime {
217    acp: AcpClient,
218    _codex_home: TempDir,
219    prompts: u32,
220    model: String,
221    effort: String,
222}
223
224impl Worker {
225    fn run(self, ready: mpsc::Sender<Result<(), PoolError>>) {
226        let mut runtime = match self.spawn_runtime(&self.config.default_options) {
227            Ok(runtime) => {
228                let _ = ready.send(Ok(()));
229                runtime
230            }
231            Err(error) => {
232                self.mark_dead();
233                let _ = ready.send(Err(error));
234                return;
235            }
236        };
237
238        while let Ok(job) = self.jobs.recv() {
239            let result = self.handle_job(&mut runtime, &job);
240            let fatal_quota = matches!(result, Err(PoolError::QuotaExceeded));
241            let _ = job.reply.send(result);
242            if fatal_quota {
243                self.shared.fatal_quota.store(true, Ordering::Release);
244            }
245            if self.shared.alive_mask.load(Ordering::Acquire) & worker_bit(self.id) == 0 {
246                break;
247            }
248        }
249    }
250
251    fn handle_job(
252        &self,
253        runtime: &mut WorkerRuntime,
254        job: &Job,
255    ) -> Result<RubricVerdict, PoolError> {
256        let mut last_error = None;
257        for attempt in 0..=self.config.max_retries {
258            if !runtime.matches_options(&job.options) {
259                self.recycle_runtime(runtime, &job.options)?;
260            }
261
262            match self.evaluate_once(runtime, job) {
263                Ok(verdict) => {
264                    self.shared.completed.fetch_add(1, Ordering::AcqRel);
265                    runtime.prompts += 1;
266                    if runtime.prompts >= self.config.max_prompts_per_worker {
267                        self.recycle_runtime(runtime, &job.options)?;
268                    }
269                    return Ok(verdict);
270                }
271                Err(PoolError::QuotaExceeded) => {
272                    self.shared.failures.fetch_add(1, Ordering::AcqRel);
273                    return Err(PoolError::QuotaExceeded);
274                }
275                Err(PoolError::RateLimited { retry_after }) => {
276                    let delay =
277                        backoff_delay(attempt, self.config.backoff_base, self.config.backoff_cap);
278                    self.shared.push_rate_limit_event(RateLimitEvent {
279                        worker_id: self.id,
280                        attempt,
281                        delay,
282                        retry_after,
283                    });
284                    last_error = Some(PoolError::RateLimited { retry_after });
285                    if attempt < self.config.max_retries {
286                        thread::sleep(delay);
287                    }
288                }
289                Err(error) => {
290                    last_error = Some(error);
291                    self.recycle_runtime(runtime, &job.options)?;
292                }
293            }
294        }
295
296        self.shared.failures.fetch_add(1, Ordering::AcqRel);
297        Err(last_error.unwrap_or_else(|| PoolError::Rpc("retry loop exhausted".to_string())))
298    }
299
300    fn evaluate_once(
301        &self,
302        runtime: &mut WorkerRuntime,
303        job: &Job,
304    ) -> Result<RubricVerdict, PoolError> {
305        let b64 = encode_png(&job.png_path)?;
306        let system_prompt = job
307            .options
308            .system_prompt
309            .as_deref()
310            .map_or(DEFAULT_SYSTEM_PROMPT, |system_prompt| system_prompt);
311        let prompt = format!("{system_prompt}\n\nQuestion: {}", job.question);
312        let text = runtime.acp.prompt_image(&prompt, &b64)?;
313        parse_verdict(&text).map_err(|e| PoolError::ParseVerdict(format!("from {text:?}: {e}")))
314    }
315
316    fn recycle_runtime(
317        &self,
318        runtime: &mut WorkerRuntime,
319        options: &RubricOptions,
320    ) -> Result<(), PoolError> {
321        self.shared.worker_recycles.fetch_add(1, Ordering::AcqRel);
322        let mut last_error = None;
323        for _ in 0..RECYCLE_SPAWN_ATTEMPTS {
324            match self.spawn_runtime(options) {
325                Ok(new_runtime) => {
326                    *runtime = new_runtime;
327                    return Ok(());
328                }
329                Err(error) => {
330                    last_error = Some(error);
331                }
332            }
333        }
334        self.mark_dead();
335        Err(last_error.unwrap_or_else(|| PoolError::Spawn("recycle failed".to_string())))
336    }
337
338    fn spawn_runtime(&self, options: &RubricOptions) -> Result<WorkerRuntime, PoolError> {
339        let codex_home =
340            TempDir::new().map_err(|e| PoolError::Spawn(format!("create CODEX_HOME: {e}")))?;
341        seed_codex_home(codex_home.path(), self.config.source_codex_home.as_deref())?;
342        let mut env = self.config.extra_env.clone();
343        env.push((
344            OsString::from("CODEX_HOME"),
345            codex_home.path().as_os_str().to_os_string(),
346        ));
347        if let Some(log_capture) = &self.config.log_capture {
348            fs::create_dir_all(&log_capture.temp_dir).map_err(|e| {
349                PoolError::Spawn(format!(
350                    "create ACP TMPDIR {}: {e}",
351                    log_capture.temp_dir.display()
352                ))
353            })?;
354            env.push((
355                OsString::from("TMPDIR"),
356                log_capture.temp_dir.as_os_str().to_os_string(),
357            ));
358        }
359
360        let model = options
361            .model
362            .as_deref()
363            .map_or(DEFAULT_CODEX_ACP_MODEL, |model| model);
364        let effort = options
365            .effort
366            .as_deref()
367            .map_or(DEFAULT_CODEX_ACP_REASONING_EFFORT, |effort| effort);
368        let acp_args = build_codex_acp_args(model, effort);
369        let mut acp = AcpClient::spawn(&self.config.codex_acp_binary, &acp_args, &env, None)?;
370        acp.start_session(None)?;
371
372        Ok(WorkerRuntime {
373            acp,
374            _codex_home: codex_home,
375            prompts: 0,
376            model: model.to_string(),
377            effort: effort.to_string(),
378        })
379    }
380
381    fn mark_dead(&self) {
382        self.shared
383            .alive_mask
384            .fetch_and(!worker_bit(self.id), Ordering::AcqRel);
385    }
386}
387
388impl WorkerRuntime {
389    fn matches_options(&self, options: &RubricOptions) -> bool {
390        options.model.as_deref() == Some(self.model.as_str())
391            && options.effort.as_deref() == Some(self.effort.as_str())
392    }
393}
394
395impl SharedPoolState {
396    fn stats(&self) -> PoolStats {
397        PoolStats {
398            completed: self.completed.load(Ordering::Acquire),
399            failures: self.failures.load(Ordering::Acquire),
400            rate_limit_events: self
401                .rate_limit_events
402                .lock()
403                .map(|events| events.clone())
404                .unwrap_or_default(),
405            worker_recycles: self.worker_recycles.load(Ordering::Acquire),
406        }
407    }
408
409    fn push_rate_limit_event(&self, event: RateLimitEvent) {
410        if let Ok(mut events) = self.rate_limit_events.lock() {
411            events.push(event);
412        }
413    }
414}
415
416fn merge_options(mut opts: RubricOptions, defaults: &RubricOptions) -> RubricOptions {
417    if opts.model.is_none() {
418        opts.model.clone_from(&defaults.model);
419    }
420    if opts.effort.is_none() {
421        opts.effort.clone_from(&defaults.effort);
422    }
423    if opts.system_prompt.is_none() {
424        opts.system_prompt.clone_from(&defaults.system_prompt);
425    }
426    opts
427}
428
429fn backoff_delay(attempt: u32, base: Duration, cap: Duration) -> Duration {
430    let multiplier = 1u32 << attempt.min(6);
431    let capped = base.saturating_mul(multiplier).min(cap);
432    let capped_millis = u64::try_from(capped.as_millis()).map_or(u64::MAX, |millis| millis);
433    let jitter_cap = capped_millis / 4;
434    let jitter_ms = rand::rng().random_range(0..=jitter_cap);
435    capped.saturating_add(Duration::from_millis(jitter_ms))
436}
437
438fn alive_mask(workers: usize) -> u64 {
439    if workers == u64::BITS as usize {
440        u64::MAX
441    } else {
442        (1u64 << workers) - 1
443    }
444}
445
446fn worker_bit(worker_id: usize) -> u64 {
447    1u64 << worker_id
448}
449
450fn join_handles(handles: Vec<JoinHandle<()>>) {
451    for handle in handles {
452        let _ = handle.join();
453    }
454}