1use std::ffi::OsString;
2use std::path::{Path, PathBuf};
3use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
4use std::sync::{Arc, Mutex, mpsc};
5use std::thread::{self, JoinHandle};
6use std::time::Duration;
7
8use rand::Rng as _;
9use tempfile::TempDir;
10
11mod codex_home;
12
13use crate::{
14 AcpClient, DEFAULT_CODEX_ACP_MODEL, DEFAULT_CODEX_ACP_REASONING_EFFORT, DEFAULT_SYSTEM_PROMPT,
15 PoolError, RateLimitEvent, RubricOptions, RubricVerdict, default_codex_acp_binary,
16 default_options, encode_png, parse_verdict,
17};
18use codex_home::seed_codex_home;
19
20const DEFAULT_SUBMIT_TIMEOUT: Duration = Duration::from_secs(600);
21const RECYCLE_SPAWN_ATTEMPTS: u32 = 2;
22
23pub struct RubricPool {
25 senders: Vec<mpsc::Sender<Job>>,
26 handles: Mutex<Vec<JoinHandle<()>>>,
27 next: AtomicUsize,
28 config: PoolConfig,
29 shared: Arc<SharedPoolState>,
30}
31
32#[derive(Clone, Debug)]
34pub struct PoolConfig {
35 pub workers: usize,
37 pub max_prompts_per_worker: u32,
39 pub max_retries: u32,
41 pub backoff_base: Duration,
43 pub backoff_cap: Duration,
45 pub default_options: RubricOptions,
47 pub codex_acp_binary: PathBuf,
49 pub extra_env: Vec<(OsString, OsString)>,
51 pub submit_timeout: Duration,
53 pub source_codex_home: Option<PathBuf>,
55}
56
57impl Default for PoolConfig {
58 fn default() -> Self {
59 Self {
60 workers: 4,
61 max_prompts_per_worker: 50,
62 max_retries: 4,
63 backoff_base: Duration::from_secs(30),
64 backoff_cap: Duration::from_secs(300),
65 default_options: default_options(),
66 codex_acp_binary: default_codex_acp_binary(),
67 extra_env: Vec::new(),
68 submit_timeout: DEFAULT_SUBMIT_TIMEOUT,
69 source_codex_home: None,
70 }
71 }
72}
73
74struct Job {
75 png_path: PathBuf,
76 question: String,
77 options: RubricOptions,
78 reply: mpsc::Sender<Result<RubricVerdict, PoolError>>,
79}
80
81#[derive(Clone, Debug, Default)]
83pub struct PoolStats {
84 pub completed: u64,
86 pub failures: u64,
88 pub rate_limit_events: Vec<RateLimitEvent>,
90 pub worker_recycles: u64,
92}
93
94#[derive(Default)]
95struct SharedPoolState {
96 completed: AtomicU64,
97 failures: AtomicU64,
98 worker_recycles: AtomicU64,
99 rate_limit_events: Mutex<Vec<RateLimitEvent>>,
100 fatal_quota: AtomicBool,
101 alive_mask: AtomicU64,
102}
103
104impl RubricPool {
105 pub fn new(config: PoolConfig) -> Result<Self, PoolError> {
112 if config.workers == 0 {
113 return Err(PoolError::Spawn(
114 "workers must be greater than zero".to_string(),
115 ));
116 }
117 if config.workers > u64::BITS as usize {
118 return Err(PoolError::Spawn(format!(
119 "workers={} exceeds alive bitmask capacity {}",
120 config.workers,
121 u64::BITS
122 )));
123 }
124
125 let shared = Arc::new(SharedPoolState::default());
126 shared
127 .alive_mask
128 .store(alive_mask(config.workers), Ordering::Release);
129
130 let mut senders = Vec::with_capacity(config.workers);
131 let mut handles = Vec::with_capacity(config.workers);
132 for worker_id in 0..config.workers {
133 let (job_tx, job_rx) = mpsc::channel();
134 let (ready_tx, ready_rx) = mpsc::channel();
135 let worker = Worker {
136 id: worker_id,
137 config: config.clone(),
138 shared: Arc::clone(&shared),
139 jobs: job_rx,
140 };
141 let handle = thread::spawn(move || worker.run(ready_tx));
142 match ready_rx.recv() {
143 Ok(Ok(())) => {
144 senders.push(job_tx);
145 handles.push(handle);
146 }
147 Ok(Err(error)) => {
148 shared
149 .alive_mask
150 .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
151 drop(senders);
152 join_handles(handles);
153 let _ = handle.join();
154 return Err(error);
155 }
156 Err(error) => {
157 shared
158 .alive_mask
159 .fetch_and(!worker_bit(worker_id), Ordering::AcqRel);
160 drop(senders);
161 join_handles(handles);
162 let _ = handle.join();
163 return Err(PoolError::WorkerCrashed {
164 worker_id,
165 message: format!("worker exited before startup result: {error}"),
166 });
167 }
168 }
169 }
170
171 Ok(Self {
172 senders,
173 handles: Mutex::new(handles),
174 next: AtomicUsize::new(0),
175 config,
176 shared,
177 })
178 }
179
180 pub fn submit(
187 &self,
188 png_path: &Path,
189 question: &str,
190 opts: RubricOptions,
191 ) -> Result<RubricVerdict, PoolError> {
192 if self.shared.fatal_quota.load(Ordering::Acquire) {
193 return Err(PoolError::QuotaExceeded);
194 }
195
196 let worker_id = self.next_live_worker()?;
197 let (reply_tx, reply_rx) = mpsc::channel();
198 let job = Job {
199 png_path: png_path.to_path_buf(),
200 question: question.to_string(),
201 options: merge_options(opts, &self.config.default_options),
202 reply: reply_tx,
203 };
204
205 self.senders[worker_id]
206 .send(job)
207 .map_err(|_| PoolError::WorkerCrashed {
208 worker_id,
209 message: "worker channel closed".to_string(),
210 })?;
211
212 match reply_rx.recv_timeout(self.config.submit_timeout) {
213 Ok(result) => result,
214 Err(mpsc::RecvTimeoutError::Timeout) => Err(PoolError::Timeout {
215 worker_id,
216 timeout: self.config.submit_timeout,
217 }),
218 Err(mpsc::RecvTimeoutError::Disconnected) => Err(PoolError::WorkerCrashed {
219 worker_id,
220 message: "worker dropped reply channel".to_string(),
221 }),
222 }
223 }
224
225 #[must_use]
227 pub fn shutdown(self) -> PoolStats {
228 let Self {
229 senders,
230 handles,
231 shared,
232 ..
233 } = self;
234 drop(senders);
235 if let Ok(handles) = handles.into_inner() {
236 join_handles(handles);
237 }
238 shared.stats()
239 }
240
241 #[must_use]
243 pub fn stats(&self) -> PoolStats {
244 self.shared.stats()
245 }
246
247 fn next_live_worker(&self) -> Result<usize, PoolError> {
248 let worker_count = self.senders.len();
249 for _ in 0..worker_count {
250 let idx = self.next.fetch_add(1, Ordering::AcqRel) % worker_count;
251 let mask = self.shared.alive_mask.load(Ordering::Acquire);
252 if mask & worker_bit(idx) != 0 {
253 return Ok(idx);
254 }
255 }
256 Err(PoolError::NoLiveWorkers)
257 }
258}
259
260struct Worker {
261 id: usize,
262 config: PoolConfig,
263 shared: Arc<SharedPoolState>,
264 jobs: mpsc::Receiver<Job>,
265}
266
267struct WorkerRuntime {
268 acp: AcpClient,
269 _codex_home: TempDir,
270 prompts: u32,
271 model: String,
272 effort: String,
273}
274
275impl Worker {
276 fn run(self, ready: mpsc::Sender<Result<(), PoolError>>) {
277 let mut runtime = match self.spawn_runtime(&self.config.default_options) {
278 Ok(runtime) => {
279 let _ = ready.send(Ok(()));
280 runtime
281 }
282 Err(error) => {
283 self.mark_dead();
284 let _ = ready.send(Err(error));
285 return;
286 }
287 };
288
289 while let Ok(job) = self.jobs.recv() {
290 let result = self.handle_job(&mut runtime, &job);
291 let fatal_quota = matches!(result, Err(PoolError::QuotaExceeded));
292 let _ = job.reply.send(result);
293 if fatal_quota {
294 self.shared.fatal_quota.store(true, Ordering::Release);
295 }
296 if self.shared.alive_mask.load(Ordering::Acquire) & worker_bit(self.id) == 0 {
297 break;
298 }
299 }
300 }
301
302 fn handle_job(
303 &self,
304 runtime: &mut WorkerRuntime,
305 job: &Job,
306 ) -> Result<RubricVerdict, PoolError> {
307 let mut last_error = None;
308 for attempt in 0..=self.config.max_retries {
309 if !runtime.matches_options(&job.options) {
310 self.recycle_runtime(runtime, &job.options)?;
311 }
312
313 match self.evaluate_once(runtime, job) {
314 Ok(verdict) => {
315 self.shared.completed.fetch_add(1, Ordering::AcqRel);
316 runtime.prompts += 1;
317 if runtime.prompts >= self.config.max_prompts_per_worker {
318 self.recycle_runtime(runtime, &job.options)?;
319 }
320 return Ok(verdict);
321 }
322 Err(PoolError::QuotaExceeded) => {
323 self.shared.failures.fetch_add(1, Ordering::AcqRel);
324 return Err(PoolError::QuotaExceeded);
325 }
326 Err(PoolError::RateLimited { retry_after }) => {
327 let delay =
328 backoff_delay(attempt, self.config.backoff_base, self.config.backoff_cap);
329 self.shared.push_rate_limit_event(RateLimitEvent {
330 worker_id: self.id,
331 attempt,
332 delay,
333 retry_after,
334 });
335 last_error = Some(PoolError::RateLimited { retry_after });
336 if attempt < self.config.max_retries {
337 thread::sleep(delay);
338 }
339 }
340 Err(error) => {
341 last_error = Some(error);
342 self.recycle_runtime(runtime, &job.options)?;
343 }
344 }
345 }
346
347 self.shared.failures.fetch_add(1, Ordering::AcqRel);
348 Err(last_error.unwrap_or_else(|| PoolError::Rpc("retry loop exhausted".to_string())))
349 }
350
351 fn evaluate_once(
352 &self,
353 runtime: &mut WorkerRuntime,
354 job: &Job,
355 ) -> Result<RubricVerdict, PoolError> {
356 let b64 = encode_png(&job.png_path)?;
357 let system_prompt = job
358 .options
359 .system_prompt
360 .as_deref()
361 .unwrap_or(DEFAULT_SYSTEM_PROMPT);
362 let prompt = format!("{system_prompt}\n\nQuestion: {}", job.question);
363 let text = runtime.acp.prompt_image(&prompt, &b64)?;
364 parse_verdict(&text).map_err(|e| PoolError::ParseVerdict(format!("from {text:?}: {e}")))
365 }
366
367 fn recycle_runtime(
368 &self,
369 runtime: &mut WorkerRuntime,
370 options: &RubricOptions,
371 ) -> Result<(), PoolError> {
372 self.shared.worker_recycles.fetch_add(1, Ordering::AcqRel);
373 let mut last_error = None;
374 for _ in 0..RECYCLE_SPAWN_ATTEMPTS {
375 match self.spawn_runtime(options) {
376 Ok(new_runtime) => {
377 *runtime = new_runtime;
378 return Ok(());
379 }
380 Err(error) => {
381 last_error = Some(error);
382 }
383 }
384 }
385 self.mark_dead();
386 Err(last_error.unwrap_or_else(|| PoolError::Spawn("recycle failed".to_string())))
387 }
388
389 fn spawn_runtime(&self, options: &RubricOptions) -> Result<WorkerRuntime, PoolError> {
390 let codex_home =
391 TempDir::new().map_err(|e| PoolError::Spawn(format!("create CODEX_HOME: {e}")))?;
392 seed_codex_home(codex_home.path(), self.config.source_codex_home.as_deref())?;
393 let mut env = self.config.extra_env.clone();
394 env.push((
395 OsString::from("CODEX_HOME"),
396 codex_home.path().as_os_str().to_os_string(),
397 ));
398
399 let model = options.model.as_deref().unwrap_or(DEFAULT_CODEX_ACP_MODEL);
400 let effort = options
401 .effort
402 .as_deref()
403 .unwrap_or(DEFAULT_CODEX_ACP_REASONING_EFFORT);
404 let mut acp = AcpClient::spawn(&self.config.codex_acp_binary, model, effort, &env, None)?;
405 acp.start_session(None)?;
406
407 Ok(WorkerRuntime {
408 acp,
409 _codex_home: codex_home,
410 prompts: 0,
411 model: model.to_string(),
412 effort: effort.to_string(),
413 })
414 }
415
416 fn mark_dead(&self) {
417 self.shared
418 .alive_mask
419 .fetch_and(!worker_bit(self.id), Ordering::AcqRel);
420 }
421}
422
423impl WorkerRuntime {
424 fn matches_options(&self, options: &RubricOptions) -> bool {
425 options.model.as_deref() == Some(self.model.as_str())
426 && options.effort.as_deref() == Some(self.effort.as_str())
427 }
428}
429
430impl SharedPoolState {
431 fn stats(&self) -> PoolStats {
432 PoolStats {
433 completed: self.completed.load(Ordering::Acquire),
434 failures: self.failures.load(Ordering::Acquire),
435 rate_limit_events: self
436 .rate_limit_events
437 .lock()
438 .map(|events| events.clone())
439 .unwrap_or_default(),
440 worker_recycles: self.worker_recycles.load(Ordering::Acquire),
441 }
442 }
443
444 fn push_rate_limit_event(&self, event: RateLimitEvent) {
445 if let Ok(mut events) = self.rate_limit_events.lock() {
446 events.push(event);
447 }
448 }
449}
450
451fn merge_options(mut opts: RubricOptions, defaults: &RubricOptions) -> RubricOptions {
452 if opts.model.is_none() {
453 opts.model.clone_from(&defaults.model);
454 }
455 if opts.effort.is_none() {
456 opts.effort.clone_from(&defaults.effort);
457 }
458 if opts.system_prompt.is_none() {
459 opts.system_prompt.clone_from(&defaults.system_prompt);
460 }
461 opts
462}
463
464fn backoff_delay(attempt: u32, base: Duration, cap: Duration) -> Duration {
465 let multiplier = 1u32 << attempt.min(6);
466 let capped = base.saturating_mul(multiplier).min(cap);
467 let jitter_cap = capped.as_millis() as u64 / 4;
468 let jitter_ms = rand::rng().random_range(0..=jitter_cap);
469 capped + Duration::from_millis(jitter_ms)
470}
471
472fn alive_mask(workers: usize) -> u64 {
473 if workers == u64::BITS as usize {
474 u64::MAX
475 } else {
476 (1u64 << workers) - 1
477 }
478}
479
480fn worker_bit(worker_id: usize) -> u64 {
481 1u64 << worker_id
482}
483
484fn join_handles(handles: Vec<JoinHandle<()>>) {
485 for handle in handles {
486 let _ = handle.join();
487 }
488}