1use std::{path::Path, sync::Arc};
15
16use thiserror::Error;
17
18use crate::{
19 action::{ApiAction, ApiResponse},
20 controller::RuntimeApiController,
21 error::ApiError,
22 schemas::{
23 BalloonConfig, BootSourceConfig, ConfigFile, CpuConfig, DriveConfig, EntropyConfig,
24 HotplugMemoryConfig, InstanceAction, LoggerConfig, MachineConfig, MetricsConfig,
25 MmdsConfig, MmdsContents, NetworkInterfaceConfig, PmemConfig, SerialConfig, VsockConfig,
26 common::{MAX_DRIVES, MAX_NICS, MAX_PMEM},
27 },
28};
29
30#[derive(Debug, Error)]
32pub enum ReplayError {
33 #[error("failed to read --config-file {path}: {source}")]
35 Io {
36 path: String,
38 #[source]
40 source: std::io::Error,
41 },
42 #[error("failed to parse --config-file {path}: {source}")]
44 Parse {
45 path: String,
47 #[source]
49 source: serde_json::Error,
50 },
51 #[error("--config-file is missing the required `boot-source` member")]
53 MissingBootSource,
54 #[error("--config-file collection {field} exceeds {max} entries")]
56 CollectionCap {
57 field: &'static str,
59 max: usize,
61 },
62 #[error("--config-file rejected: {0}")]
64 Validation(String),
65 #[error("--config-file action `{label}` failed: {source}")]
67 Dispatch {
68 label: &'static str,
70 #[source]
72 source: ApiError,
73 },
74 #[error("--config-file action `{label}` rejected by VMM (status {status}): {fault_message}")]
76 Fault {
77 label: &'static str,
79 status: u16,
81 fault_message: String,
83 },
84}
85
86pub async fn parse_config_file(path: impl AsRef<Path>) -> Result<ConfigFile, ReplayError> {
89 let path = path.as_ref();
90 let bytes = tokio::fs::read(path).await.map_err(|e| ReplayError::Io {
91 path: path.display().to_string(),
92 source: e,
93 })?;
94 let cfg: ConfigFile = serde_json::from_slice(&bytes).map_err(|e| ReplayError::Parse {
95 path: path.display().to_string(),
96 source: e,
97 })?;
98 Ok(cfg)
99}
100
101pub async fn replay_config(
105 controller: &Arc<RuntimeApiController>,
106 cfg: ConfigFile,
107 start_microvm: bool,
108) -> Result<(), ReplayError> {
109 if cfg.boot_source.is_none() {
110 return Err(ReplayError::MissingBootSource);
111 }
112
113 if cfg.drives.len() > MAX_DRIVES {
114 return Err(ReplayError::CollectionCap {
115 field: "drives",
116 max: MAX_DRIVES,
117 });
118 }
119 if cfg.network_interfaces.len() > MAX_NICS {
120 return Err(ReplayError::CollectionCap {
121 field: "network-interfaces",
122 max: MAX_NICS,
123 });
124 }
125 if cfg.pmem.len() > MAX_PMEM {
126 return Err(ReplayError::CollectionCap {
127 field: "pmem",
128 max: MAX_PMEM,
129 });
130 }
131
132 if let Some(raw) = cfg.machine_config {
134 let validated = MachineConfig::try_from(raw).map_err(ReplayError::Validation)?;
135 dispatch(controller, ApiAction::PutMachineConfig(validated)).await?;
136 }
137 if let Some(raw) = cfg.cpu_config {
139 let validated = CpuConfig::try_from(raw).map_err(ReplayError::Validation)?;
140 dispatch(controller, ApiAction::PutCpuConfig(validated)).await?;
141 }
142 if let Some(raw) = cfg.boot_source {
144 let validated = BootSourceConfig::try_from(raw).map_err(ReplayError::Validation)?;
145 dispatch(controller, ApiAction::PutBootSource(validated)).await?;
146 }
147 for raw in cfg.drives {
149 let validated = DriveConfig::try_from(raw).map_err(ReplayError::Validation)?;
150 dispatch(controller, ApiAction::PutDrive(validated)).await?;
151 }
152 for raw in cfg.network_interfaces {
154 let validated = NetworkInterfaceConfig::try_from(raw).map_err(ReplayError::Validation)?;
155 dispatch(controller, ApiAction::PutNetwork(validated)).await?;
156 }
157 if let Some(raw) = cfg.vsock {
159 let validated = VsockConfig::try_from(raw).map_err(ReplayError::Validation)?;
160 dispatch(controller, ApiAction::PutVsock(validated)).await?;
161 }
162 if let Some(raw) = cfg.mmds_config {
164 let validated = MmdsConfig::try_from(raw).map_err(ReplayError::Validation)?;
165 dispatch(controller, ApiAction::PutMmdsConfig(validated)).await?;
166 }
167 if let Some(value) = cfg.mmds {
169 dispatch(controller, ApiAction::PutMmds(MmdsContents::new(value))).await?;
170 }
171 if let Some(raw) = cfg.balloon {
173 let validated = BalloonConfig::try_from(raw).map_err(ReplayError::Validation)?;
174 dispatch(controller, ApiAction::PutBalloon(validated)).await?;
175 }
176 if let Some(raw) = cfg.entropy {
178 let validated = EntropyConfig::try_from(raw).map_err(ReplayError::Validation)?;
179 dispatch(controller, ApiAction::PutEntropy(validated)).await?;
180 }
181 if let Some(raw) = cfg.serial {
183 let validated = SerialConfig::try_from(raw).map_err(ReplayError::Validation)?;
184 dispatch(controller, ApiAction::PutSerial(validated)).await?;
185 }
186 for raw in cfg.pmem {
188 let validated = PmemConfig::try_from(raw).map_err(ReplayError::Validation)?;
189 dispatch(controller, ApiAction::PutPmem(validated)).await?;
190 }
191 if let Some(raw) = cfg.hotplug_memory {
193 let validated = HotplugMemoryConfig::try_from(raw).map_err(ReplayError::Validation)?;
194 dispatch(controller, ApiAction::PutHotplugMemory(validated)).await?;
195 }
196 if let Some(raw) = cfg.logger {
198 let validated = LoggerConfig::try_from(raw).map_err(ReplayError::Validation)?;
199 dispatch(controller, ApiAction::PutLogger(validated)).await?;
200 }
201 if let Some(raw) = cfg.metrics {
203 let validated = MetricsConfig::try_from(raw).map_err(ReplayError::Validation)?;
204 dispatch(controller, ApiAction::PutMetrics(validated)).await?;
205 }
206 if start_microvm {
208 dispatch(controller, ApiAction::Action(InstanceAction::InstanceStart)).await?;
209 }
210 Ok(())
211}
212
213async fn dispatch(
214 controller: &Arc<RuntimeApiController>,
215 action: ApiAction,
216) -> Result<(), ReplayError> {
217 let label = action.label();
218 let resp = controller
219 .dispatch(action)
220 .await
221 .map_err(|e| ReplayError::Dispatch { label, source: e })?;
222 match resp {
223 ApiResponse::NoContent | ApiResponse::Json(_) => Ok(()),
224 ApiResponse::Fault {
225 status,
226 fault_message,
227 } => Err(ReplayError::Fault {
228 label,
229 status,
230 fault_message,
231 }),
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use std::time::Duration;
238
239 use super::*;
240 use crate::controller::{ControllerSnapshot, TimeoutTable};
241
242 fn build_controller() -> (Arc<RuntimeApiController>, crate::controller::ActionReceiver) {
243 let snap = ControllerSnapshot::new("anonymous", "1.16.0", "1.16.0 (squib test)");
244 let (c, rx) = RuntimeApiController::new(snap, TimeoutTable::from_spec(), 64);
245 (Arc::new(c), rx)
246 }
247
248 fn drain_acker(
249 mut rx: crate::controller::ActionReceiver,
250 ) -> tokio::task::JoinHandle<Vec<&'static str>> {
251 tokio::spawn(async move {
252 let mut labels = Vec::new();
253 while let Some((action, ack)) = rx.recv().await {
254 labels.push(action.label());
255 let _ = ack.send(ApiResponse::NoContent);
256 }
257 labels
258 })
259 }
260
261 #[tokio::test]
262 async fn test_should_reject_config_without_boot_source() {
263 let (c, _rx) = build_controller();
264 let cfg = ConfigFile::default();
265 let res = replay_config(&c, cfg, false).await;
266 assert!(matches!(res, Err(ReplayError::MissingBootSource)));
267 }
268
269 #[tokio::test]
270 async fn test_should_replay_minimal_config_in_order() {
271 let (c, rx) = build_controller();
272 let drain = drain_acker(rx);
273
274 let cfg: ConfigFile = serde_json::from_str(
275 r#"{
276 "boot-source": {"kernel_image_path":"/tmp/k"},
277 "machine-config": {"vcpu_count":1,"mem_size_mib":256}
278 }"#,
279 )
280 .unwrap();
281 replay_config(&c, cfg, false).await.unwrap();
282 drop(c);
283 let labels = tokio::time::timeout(Duration::from_millis(500), drain)
284 .await
285 .unwrap()
286 .unwrap();
287 assert_eq!(labels[0], "PUT /machine-config");
289 assert_eq!(labels[1], "PUT /boot-source");
290 }
291
292 #[tokio::test]
293 async fn test_should_dispatch_instance_start_when_requested() {
294 let (c, rx) = build_controller();
295 let drain = drain_acker(rx);
296
297 let cfg: ConfigFile = serde_json::from_str(
298 r#"{"boot-source":{"kernel_image_path":"/tmp/k"},
299 "machine-config":{"vcpu_count":1,"mem_size_mib":256}}"#,
300 )
301 .unwrap();
302 replay_config(&c, cfg, true).await.unwrap();
303 drop(c);
304 let labels = tokio::time::timeout(Duration::from_millis(500), drain)
305 .await
306 .unwrap()
307 .unwrap();
308 assert_eq!(*labels.last().unwrap(), "PUT /actions");
309 }
310
311 #[tokio::test]
312 async fn test_should_reject_drives_over_cap() {
313 let (c, _rx) = build_controller();
314 let mut cfg = ConfigFile {
315 boot_source: Some(crate::schemas::boot_source::RawBootSourceConfig {
316 kernel_image_path: "/tmp/k".into(),
317 initrd_path: None,
318 boot_args: None,
319 }),
320 ..ConfigFile::default()
321 };
322 for i in 0..=MAX_DRIVES {
323 cfg.drives.push(crate::schemas::drive::RawDriveConfig {
324 drive_id: format!("d_{i}"),
325 path_on_host: "/tmp/x".into(),
326 is_root_device: false,
327 is_read_only: false,
328 cache_type: crate::schemas::drive::CacheType::Unsafe,
329 io_engine: crate::schemas::drive::IoEngine::Sync,
330 partuuid: None,
331 rate_limiter: None,
332 socket: None,
333 });
334 }
335 let res = replay_config(&c, cfg, false).await;
336 assert!(matches!(
337 res,
338 Err(ReplayError::CollectionCap {
339 field: "drives",
340 ..
341 })
342 ));
343 }
344}