1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::RwLock;
4
5use super::disk;
6use super::platform::create_platform_driver;
7use crate::config::{VmConfig, VmHandle, VmState};
8use crate::driver::{VmDriver, VmError};
9
10pub struct VmManager {
14 driver: Box<dyn VmDriver>,
15 vms: RwLock<HashMap<String, VmHandle>>,
16 base_dir: PathBuf,
17}
18
19impl VmManager {
20 pub fn new(base_dir: PathBuf) -> Result<Self, VmError> {
25 let driver = create_platform_driver()?;
26 std::fs::create_dir_all(&base_dir).map_err(VmError::Io)?;
27 Ok(Self {
28 driver,
29 vms: RwLock::new(HashMap::new()),
30 base_dir,
31 })
32 }
33
34 pub fn with_driver(driver: Box<dyn VmDriver>, base_dir: PathBuf) -> Result<Self, VmError> {
36 std::fs::create_dir_all(&base_dir).map_err(VmError::Io)?;
37 Ok(Self {
38 driver,
39 vms: RwLock::new(HashMap::new()),
40 base_dir,
41 })
42 }
43
44 pub fn base_dir(&self) -> &Path {
46 &self.base_dir
47 }
48
49 pub fn vm_dir(&self, name: &str) -> PathBuf {
51 self.base_dir.join(name)
52 }
53
54 pub fn start(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
56 config.validate()?;
57
58 {
59 let mut vms = self
60 .vms
61 .write()
62 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
63 if let Some(existing) = vms.get(&config.name) {
64 if !matches!(existing.state, VmState::Stopped | VmState::Failed { .. }) {
65 return Err(VmError::BootFailed {
66 name: config.name.clone(),
67 detail: format!("VM already exists in state: {}", existing.state),
68 });
69 }
70 }
71 vms.insert(
72 config.name.clone(),
73 VmHandle {
74 name: config.name.clone(),
75 namespace: config.namespace.clone(),
76 state: VmState::Starting,
77 process: None,
78 serial_log: config.serial_log.clone(),
79 machine_id: None,
80 },
81 );
82 }
83
84 let vm_dir = self.vm_dir(&config.name);
85 if let Err(e) = std::fs::create_dir_all(&vm_dir) {
86 let mut vms = self
87 .vms
88 .write()
89 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
90 vms.remove(&config.name);
91 return Err(VmError::Io(e));
92 }
93
94 tracing::info!(
95 vm = %config.name,
96 cpus = config.cpus,
97 memory_mb = config.memory_mb,
98 "booting VM"
99 );
100
101 let handle = match self.driver.boot(config) {
102 Ok(handle) => handle,
103 Err(err) => {
104 tracing::warn!(vm = %config.name, error = %err, "VM boot failed, cleaning up reservation");
105 let mut vms = self
106 .vms
107 .write()
108 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
109 vms.remove(&config.name);
110 return Err(err);
111 }
112 };
113
114 let mut vms = self
115 .vms
116 .write()
117 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
118 vms.insert(config.name.clone(), handle.clone());
119
120 Ok(handle)
121 }
122
123 pub fn stop(&self, name: &str) -> Result<(), VmError> {
125 tracing::info!(vm = %name, "stopping VM");
126 let handle = self.get_handle(name)?;
127 self.driver.stop(&handle)?;
128 self.update_cached_state(name, VmState::Stopped)
129 }
130
131 pub fn kill(&self, name: &str) -> Result<(), VmError> {
133 tracing::info!(vm = %name, "force-killing VM");
134 let handle = self.get_handle(name)?;
135 self.driver.kill(&handle)?;
136 self.update_cached_state(name, VmState::Stopped)
137 }
138
139 pub fn stop_by_handle(&self, handle: &VmHandle) -> Result<(), VmError> {
141 tracing::info!(vm = %handle.name, "stopping VM by handle");
142 self.driver.stop(handle)?;
143 self.update_cached_state(&handle.name, VmState::Stopped)
144 }
145
146 pub fn kill_by_handle(&self, handle: &VmHandle) -> Result<(), VmError> {
148 tracing::info!(vm = %handle.name, "force-killing VM by handle");
149 self.driver.kill(handle)?;
150 self.update_cached_state(&handle.name, VmState::Stopped)
151 }
152
153 pub fn pause(&self, name: &str) -> Result<(), VmError> {
155 let handle = self.get_handle(name)?;
156 self.driver.pause(&handle)?;
157 self.update_cached_state(name, VmState::Paused)
158 }
159
160 pub fn resume(&self, name: &str) -> Result<(), VmError> {
162 let handle = self.get_handle(name)?;
163 self.driver.resume(&handle)?;
164 let resumed_state = self.driver.state(&handle)?;
165 self.update_cached_state(name, resumed_state)
166 }
167
168 pub fn state(&self, name: &str) -> Result<VmState, VmError> {
170 let handle = self.get_handle(name)?;
171 let state = self.driver.state(&handle)?;
172 self.update_cached_state(name, state.clone())?;
173 Ok(state)
174 }
175
176 pub fn get_ip(&self, name: &str) -> Result<Option<String>, VmError> {
178 Ok(self.state(name)?.ip().map(ToOwned::to_owned))
179 }
180
181 pub fn list(&self) -> Result<Vec<VmHandle>, VmError> {
183 let vms = self
184 .vms
185 .read()
186 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
187 Ok(vms.values().cloned().collect())
188 }
189
190 pub fn wait_all_ready(&self, timeout_secs: u64) -> Result<(), VmError> {
192 let start = std::time::Instant::now();
193 let timeout = std::time::Duration::from_secs(timeout_secs);
194
195 loop {
196 if start.elapsed() > timeout {
197 let pending = self.pending_names(|state| state.is_ready())?;
198 return Err(VmError::Hypervisor(format!(
199 "timeout waiting for VMs: {}",
200 pending.join(", ")
201 )));
202 }
203
204 let mut all_ready = true;
205 let names = self.vm_names()?;
206
207 for name in &names {
208 match self.state(name)? {
209 state if state.is_ready() => {}
210 VmState::Failed { reason } => {
211 return Err(VmError::BootFailed {
212 name: name.clone(),
213 detail: reason,
214 });
215 }
216 _ => {
217 all_ready = false;
218 }
219 }
220 }
221
222 if all_ready {
223 return Ok(());
224 }
225
226 let elapsed = start.elapsed().as_secs();
227 if elapsed > 0 && elapsed.is_multiple_of(10) {
228 let pending = self.pending_names(|state| state.is_ready())?;
229 tracing::info!(
230 pending = ?pending,
231 elapsed_secs = elapsed,
232 "waiting for VMs to become ready"
233 );
234 }
235
236 std::thread::sleep(std::time::Duration::from_secs(1));
237 }
238 }
239
240 pub fn clone_disk(base: &Path, target: &Path) -> Result<(), VmError> {
242 disk::clone_disk(base, target)
243 }
244
245 fn get_handle(&self, name: &str) -> Result<VmHandle, VmError> {
246 let vms = self
247 .vms
248 .read()
249 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
250 vms.get(name).cloned().ok_or_else(|| VmError::NotFound {
251 name: name.to_string(),
252 })
253 }
254
255 fn update_cached_state(&self, name: &str, state: VmState) -> Result<(), VmError> {
256 let mut vms = self
257 .vms
258 .write()
259 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
260 if let Some(handle) = vms.get_mut(name) {
261 handle.state = state;
262 }
263 Ok(())
264 }
265
266 fn vm_names(&self) -> Result<Vec<String>, VmError> {
267 let vms = self
268 .vms
269 .read()
270 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
271 Ok(vms.keys().cloned().collect())
272 }
273
274 fn pending_names(&self, predicate: impl Fn(&VmState) -> bool) -> Result<Vec<String>, VmError> {
275 let vms = self
276 .vms
277 .read()
278 .map_err(|e| VmError::Hypervisor(format!("lock poisoned: {}", e)))?;
279 Ok(vms
280 .iter()
281 .filter(|(_, handle)| !predicate(&handle.state))
282 .map(|(name, _)| name.clone())
283 .collect())
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use std::path::{Path, PathBuf};
290 use std::sync::atomic::{AtomicUsize, Ordering};
291 use std::sync::{mpsc, Arc, Mutex};
292
293 use super::*;
294
295 struct BlockingDriver {
296 boot_calls: AtomicUsize,
297 release_rx: Mutex<Option<mpsc::Receiver<()>>>,
298 }
299
300 struct FailedStateDriver;
301 struct ReadyAfterTwoPollsDriver {
302 polls: AtomicUsize,
303 }
304
305 impl VmDriver for Arc<BlockingDriver> {
306 fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
307 self.boot_calls.fetch_add(1, Ordering::SeqCst);
308 if let Some(rx) = self.release_rx.lock().expect("release lock").take() {
309 rx.recv().expect("release boot");
310 }
311 Ok(VmHandle {
312 name: config.name.clone(),
313 namespace: config.namespace.clone(),
314 state: VmState::Starting,
315 process: None,
316 serial_log: config.serial_log.clone(),
317 machine_id: None,
318 })
319 }
320
321 fn stop(&self, _handle: &VmHandle) -> Result<(), VmError> {
322 Ok(())
323 }
324
325 fn kill(&self, _handle: &VmHandle) -> Result<(), VmError> {
326 Ok(())
327 }
328
329 fn state(&self, _handle: &VmHandle) -> Result<VmState, VmError> {
330 Ok(VmState::Stopped)
331 }
332 }
333
334 impl VmDriver for FailedStateDriver {
335 fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
336 Ok(VmHandle {
337 name: config.name.clone(),
338 namespace: config.namespace.clone(),
339 state: VmState::Starting,
340 process: None,
341 serial_log: config.serial_log.clone(),
342 machine_id: None,
343 })
344 }
345
346 fn stop(&self, _handle: &VmHandle) -> Result<(), VmError> {
347 Ok(())
348 }
349
350 fn kill(&self, _handle: &VmHandle) -> Result<(), VmError> {
351 Ok(())
352 }
353
354 fn state(&self, _handle: &VmHandle) -> Result<VmState, VmError> {
355 Ok(VmState::Failed {
356 reason: "crashed".into(),
357 })
358 }
359 }
360
361 impl VmDriver for ReadyAfterTwoPollsDriver {
362 fn boot(&self, config: &VmConfig) -> Result<VmHandle, VmError> {
363 Ok(VmHandle {
364 name: config.name.clone(),
365 namespace: config.namespace.clone(),
366 state: VmState::Starting,
367 process: None,
368 serial_log: config.serial_log.clone(),
369 machine_id: None,
370 })
371 }
372
373 fn stop(&self, _handle: &VmHandle) -> Result<(), VmError> {
374 Ok(())
375 }
376
377 fn kill(&self, _handle: &VmHandle) -> Result<(), VmError> {
378 Ok(())
379 }
380
381 fn state(&self, _handle: &VmHandle) -> Result<VmState, VmError> {
382 let poll = self.polls.fetch_add(1, Ordering::SeqCst);
383 if poll == 0 {
384 Ok(VmState::Running)
385 } else {
386 Ok(VmState::Ready {
387 ip: "10.0.0.2".into(),
388 })
389 }
390 }
391 }
392
393 fn test_config(base_dir: &Path) -> VmConfig {
394 VmConfig {
395 name: "test-vm".into(),
396 namespace: "tests".into(),
397 kernel: PathBuf::from("/tmp/kernel"),
398 initramfs: None,
399 root_disk: None,
400 data_disk: None,
401 seed_iso: None,
402 cpus: 1,
403 memory_mb: 256,
404 networks: vec![],
405 shared_dirs: vec![],
406 serial_log: base_dir.join("serial.log"),
407 cmdline: None,
408 netns: None,
409 vsock: false,
410 machine_id: None,
411 efi_variable_store: None,
412 rosetta: false,
413 }
414 }
415
416 #[test]
417 fn start_reserves_name_before_driver_boot() {
418 let tmp = tempfile::tempdir().expect("tempdir");
419 let (release_tx, release_rx) = mpsc::channel();
420 let driver = Arc::new(BlockingDriver {
421 boot_calls: AtomicUsize::new(0),
422 release_rx: Mutex::new(Some(release_rx)),
423 });
424 let manager = Arc::new(
425 VmManager::with_driver(Box::new(driver.clone()), tmp.path().to_path_buf())
426 .expect("manager"),
427 );
428 let config = test_config(tmp.path());
429
430 let manager_clone = Arc::clone(&manager);
431 let config_clone = config.clone();
432 let boot_thread = std::thread::spawn(move || manager_clone.start(&config_clone));
433
434 while driver.boot_calls.load(Ordering::SeqCst) == 0 {
435 std::thread::sleep(std::time::Duration::from_millis(10));
436 }
437
438 let err = manager
439 .start(&config)
440 .expect_err("second concurrent start should fail");
441 assert!(err.to_string().contains("already exists"));
442
443 release_tx.send(()).expect("release first boot");
444 boot_thread
445 .join()
446 .expect("join")
447 .expect("first boot should succeed");
448
449 assert_eq!(driver.boot_calls.load(Ordering::SeqCst), 1);
450 }
451
452 #[test]
453 fn restart_is_allowed_after_failed_state_is_observed() {
454 let tmp = tempfile::tempdir().expect("tempdir");
455 let manager = VmManager::with_driver(Box::new(FailedStateDriver), tmp.path().to_path_buf())
456 .expect("manager");
457 let config = test_config(tmp.path());
458
459 manager.start(&config).expect("first boot");
460 let state = manager.state(&config.name).expect("state query");
461 assert!(matches!(state, VmState::Failed { .. }));
462
463 manager
464 .start(&config)
465 .expect("restart after failed state should be allowed");
466 }
467
468 #[test]
469 fn wait_all_ready_waits_for_ready_not_just_running() {
470 let tmp = tempfile::tempdir().expect("tempdir");
471 let manager = VmManager::with_driver(
472 Box::new(ReadyAfterTwoPollsDriver {
473 polls: AtomicUsize::new(0),
474 }),
475 tmp.path().to_path_buf(),
476 )
477 .expect("manager");
478 let config = test_config(tmp.path());
479
480 manager.start(&config).expect("boot");
481 manager
482 .wait_all_ready(2)
483 .expect("wait_all_ready should wait until ready");
484 assert_eq!(
485 manager.get_ip(&config.name).expect("ip query"),
486 Some("10.0.0.2".into())
487 );
488 }
489}