1use std::io::{BufReader, BufWriter};
7use std::path::PathBuf;
8use std::process::{Child, Command, Stdio};
9use std::sync::Arc;
10use std::time::Duration;
11
12use crate::error::{Error, Result};
13
14use super::protocol::{WorkerCommand, WorkerResponse, read_message, write_message};
15
16pub struct WorkerHandle {
20 child: Child,
22 stdin: BufWriter<std::process::ChildStdin>,
24 stdout: BufReader<std::process::ChildStdout>,
26 killed: bool,
28}
29
30impl WorkerHandle {
31 pub fn spawn() -> Result<Self> {
38 let worker_path = Self::find_worker_binary()?;
39
40 let mut child = Command::new(&worker_path)
41 .stdin(Stdio::piped())
42 .stdout(Stdio::piped())
43 .stderr(Stdio::inherit()) .spawn()
45 .map_err(|e| {
46 Error::Ipc(format!(
47 "Failed to spawn worker process '{}': {}",
48 worker_path.display(),
49 e
50 ))
51 })?;
52
53 let stdin = child.stdin.take().ok_or_else(|| {
54 Error::Ipc("Failed to get worker stdin".to_string())
55 })?;
56 let stdout = child.stdout.take().ok_or_else(|| {
57 Error::Ipc("Failed to get worker stdout".to_string())
58 })?;
59
60 let mut handle = Self {
61 child,
62 stdin: BufWriter::new(stdin),
63 stdout: BufReader::new(stdout),
64 killed: false,
65 };
66
67 handle.send_command(&WorkerCommand::Ping)?;
69 match handle.recv_response()? {
70 WorkerResponse::Pong => Ok(handle),
71 other => Err(Error::Ipc(format!(
72 "Unexpected response from worker: {:?}",
73 other
74 ))),
75 }
76 }
77
78 fn find_worker_binary() -> Result<PathBuf> {
80 if let Ok(path) = std::env::var("VENUS_WORKER_PATH") {
82 let path = PathBuf::from(path);
83 if path.exists() {
84 return Ok(path);
85 }
86 }
87
88 if let Ok(exe_path) = std::env::current_exe()
90 && let Some(exe_dir) = exe_path.parent() {
91 let worker_name = if cfg!(windows) {
92 "venus-worker.exe"
93 } else {
94 "venus-worker"
95 };
96 let worker_path = exe_dir.join(worker_name);
97 if worker_path.exists() {
98 return Ok(worker_path);
99 }
100 }
101
102 let worker_name = if cfg!(windows) {
104 "venus-worker.exe"
105 } else {
106 "venus-worker"
107 };
108 if let Ok(path) = which::which(worker_name) {
109 return Ok(path);
110 }
111
112 if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
114 for profile in &["debug", "release"] {
115 let worker_name = if cfg!(windows) {
116 "venus-worker.exe"
117 } else {
118 "venus-worker"
119 };
120 let path = PathBuf::from(&manifest_dir)
121 .join("..")
122 .join("..")
123 .join("target")
124 .join(profile)
125 .join(worker_name);
126 if path.exists() {
127 return Ok(path.canonicalize().unwrap_or(path));
128 }
129 }
130 }
131
132 Err(Error::Ipc(
133 "Could not find venus-worker binary. Set VENUS_WORKER_PATH or ensure it's in PATH."
134 .to_string(),
135 ))
136 }
137
138 pub fn send_command(&mut self, cmd: &WorkerCommand) -> Result<()> {
140 if self.killed {
141 return Err(Error::Ipc("Worker has been killed".to_string()));
142 }
143 write_message(&mut self.stdin, cmd)
144 }
145
146 pub fn recv_response(&mut self) -> Result<WorkerResponse> {
148 if self.killed {
149 return Err(Error::Ipc("Worker has been killed".to_string()));
150 }
151 read_message(&mut self.stdout)
152 }
153
154 pub fn load_cell(
156 &mut self,
157 dylib_path: PathBuf,
158 dep_count: usize,
159 entry_symbol: String,
160 name: String,
161 ) -> Result<()> {
162 self.send_command(&WorkerCommand::LoadCell {
163 dylib_path: dylib_path.to_string_lossy().to_string(),
164 dep_count,
165 entry_symbol,
166 name,
167 })?;
168
169 match self.recv_response()? {
170 WorkerResponse::Loaded => Ok(()),
171 WorkerResponse::Error { message } => {
172 Err(Error::Execution(format!("Failed to load cell: {}", message)))
173 }
174 other => Err(Error::Ipc(format!(
175 "Unexpected response when loading cell: {:?}",
176 other
177 ))),
178 }
179 }
180
181 pub fn execute(&mut self, inputs: Vec<Vec<u8>>) -> Result<Vec<u8>> {
185 self.execute_with_widgets(inputs, Vec::new()).map(|(bytes, _)| bytes)
186 }
187
188 pub fn execute_with_widgets(
192 &mut self,
193 inputs: Vec<Vec<u8>>,
194 widget_values_json: Vec<u8>,
195 ) -> Result<(Vec<u8>, Vec<u8>)> {
196 self.send_command(&WorkerCommand::Execute { inputs, widget_values_json })?;
197
198 match self.recv_response()? {
199 WorkerResponse::Output { bytes, widgets_json } => Ok((bytes, widgets_json)),
200 WorkerResponse::Error { message } => {
201 Err(Error::Execution(message))
202 }
203 WorkerResponse::Panic { message } => {
204 Err(Error::Execution(format!(
205 "Cell panicked: {}. Check for unwrap() on None/Err, out-of-bounds access, or other panic sources.",
206 message
207 )))
208 }
209 other => Err(Error::Ipc(format!(
210 "Unexpected response when executing: {:?}",
211 other
212 ))),
213 }
214 }
215
216 pub fn kill(&mut self) -> Result<()> {
221 if self.killed {
222 return Ok(());
223 }
224
225 self.killed = true;
226
227 let _ = self.send_command(&WorkerCommand::Shutdown);
230
231 std::thread::sleep(Duration::from_millis(10));
233
234 if let Err(e) = self.child.kill() {
236 let is_already_dead = e.raw_os_error().map_or(false, |code| {
239 cfg!(unix) && code == 3 || cfg!(windows) && code == 87
240 });
241
242 if !is_already_dead {
243 tracing::warn!("Failed to kill worker: {}", e);
244 }
245 }
246
247 let _ = self.child.wait();
249
250 Ok(())
251 }
252
253 pub fn is_alive(&mut self) -> bool {
255 if self.killed {
256 return false;
257 }
258 matches!(self.child.try_wait(), Ok(None))
259 }
260
261 pub fn pid(&self) -> u32 {
263 self.child.id()
264 }
265
266 pub fn shutdown(mut self) -> Result<()> {
268 if self.killed {
269 return Ok(());
270 }
271
272 let _ = self.send_command(&WorkerCommand::Shutdown);
273
274 match self.child.wait() {
278 Ok(status) => {
279 if status.success() {
280 Ok(())
281 } else {
282 Err(Error::Ipc(format!(
283 "Worker exited with status: {}",
284 status
285 )))
286 }
287 }
288 Err(e) => Err(Error::Ipc(format!("Failed to wait for worker: {}", e))),
289 }
290 }
291}
292
293impl Drop for WorkerHandle {
294 fn drop(&mut self) {
295 let _ = self.kill();
297 }
298}
299
300pub struct WorkerPool {
305 available: Vec<WorkerHandle>,
307 max_size: usize,
309}
310
311impl WorkerPool {
312 pub fn new(max_size: usize) -> Self {
314 Self {
315 available: Vec::with_capacity(max_size),
316 max_size,
317 }
318 }
319
320 pub fn with_warm_workers(max_size: usize, warm_count: usize) -> Result<Self> {
322 let mut pool = Self::new(max_size);
323 for _ in 0..warm_count.min(max_size) {
324 let worker = WorkerHandle::spawn()?;
325 pool.available.push(worker);
326 }
327 Ok(pool)
328 }
329
330 pub fn get(&mut self) -> Result<WorkerHandle> {
332 while let Some(mut worker) = self.available.pop() {
334 if worker.is_alive() {
335 return Ok(worker);
336 }
337 }
339
340 WorkerHandle::spawn()
342 }
343
344 pub fn put(&mut self, mut worker: WorkerHandle) {
348 if !worker.is_alive() {
349 return;
350 }
351
352 if self.available.len() < self.max_size {
353 self.available.push(worker);
354 }
355 }
357
358 pub fn shutdown(&mut self) {
360 for mut worker in self.available.drain(..) {
361 let _ = worker.kill();
362 }
363 }
364
365 pub fn available_count(&self) -> usize {
367 self.available.len()
368 }
369}
370
371impl Drop for WorkerPool {
372 fn drop(&mut self) {
373 self.shutdown();
374 }
375}
376
377#[derive(Clone)]
381pub struct WorkerKillHandle {
382 pid: u32,
384 killed: Arc<std::sync::atomic::AtomicBool>,
386}
387
388impl WorkerKillHandle {
389 pub fn new(worker: &WorkerHandle) -> Self {
391 Self {
392 pid: worker.pid(),
393 killed: Arc::new(std::sync::atomic::AtomicBool::new(false)),
394 }
395 }
396
397 pub fn kill(&self) {
402 if self.killed.swap(true, std::sync::atomic::Ordering::SeqCst) {
403 tracing::debug!("Worker {} already killed, skipping", self.pid);
404 return; }
406
407 tracing::info!("Sending SIGKILL to worker process {}", self.pid);
408
409 #[cfg(unix)]
410 {
411 unsafe {
413 let result = libc::kill(self.pid as i32, libc::SIGKILL);
414 if result != 0 {
415 tracing::warn!("Failed to kill worker {}: errno={}", self.pid, *libc::__errno_location());
416 } else {
417 tracing::info!("SIGKILL sent successfully to worker {}", self.pid);
418 }
419 }
420 }
421
422 #[cfg(windows)]
423 {
424 use windows::Win32::Foundation::CloseHandle;
425 use windows::Win32::System::Threading::{OpenProcess, TerminateProcess, PROCESS_TERMINATE};
426
427 unsafe {
428 if let Ok(handle) = OpenProcess(PROCESS_TERMINATE, false, self.pid) {
429 let _ = TerminateProcess(handle, 1);
430 let _ = CloseHandle(handle);
431 }
432 }
433 }
434 }
435
436 pub fn is_killed(&self) -> bool {
438 self.killed.load(std::sync::atomic::Ordering::SeqCst)
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
450 #[ignore = "Requires venus-worker binary"]
451 fn test_worker_spawn_and_ping() {
452 let worker = WorkerHandle::spawn().unwrap();
453 assert!(worker.pid() > 0);
454 }
455
456 #[test]
457 #[ignore = "Requires venus-worker binary"]
458 fn test_worker_pool() {
459 let mut pool = WorkerPool::new(4);
460 let worker1 = pool.get().unwrap();
461 let pid1 = worker1.pid();
462 pool.put(worker1);
463
464 let worker2 = pool.get().unwrap();
465 assert_eq!(worker2.pid(), pid1); }
467}