1use std::process::Output;
14
15use crate::{
16 CapacityType, CloudProvisioner, JoinState, NodeHandle, NodeShape, PriceHint, ProviderNodeId,
17 ProvisionerError, Result,
18};
19
20static SUPPORTED_CAPACITY: &[CapacityType] = &[CapacityType::OnDemand];
25
26#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct CloudInitConfig {
40 pub leader_addr: String,
42 pub join_token: String,
44 pub provision_cmd: String,
46 pub terminate_cmd: String,
48 pub describe_cmd: Option<String>,
51 pub user_data_template: String,
54 pub hourly_usd: f64,
56}
57
58#[derive(Clone, Debug)]
60pub struct CloudInitProvisioner {
61 config: CloudInitConfig,
62}
63
64fn format_labels(labels: &std::collections::BTreeMap<String, String>) -> String {
67 labels
68 .iter()
69 .map(|(k, v)| format!("{k}={v}"))
70 .collect::<Vec<_>>()
71 .join(",")
72}
73
74fn capacity_token(capacity: CapacityType) -> &'static str {
76 match capacity {
77 CapacityType::OnDemand => "on-demand",
78 CapacityType::Spot => "spot",
79 }
80}
81
82fn substitute_user_data(
87 template: &str,
88 leader_addr: &str,
89 join_token: &str,
90 labels: &str,
91) -> String {
92 template
93 .replace("{leader_addr}", leader_addr)
94 .replace("{join_token}", join_token)
95 .replace("{labels}", labels)
96}
97
98fn substitute_provision_cmd(template: &str, user_data: &str, shape: &NodeShape) -> String {
103 let memory_mb = shape.memory_bytes / (1024 * 1024);
105 template
106 .replace("{user_data}", user_data)
107 .replace("{cpu}", &shape.cpu.to_string())
108 .replace("{memory_mb}", &memory_mb.to_string())
109 .replace("{labels}", &format_labels(&shape.labels))
110 .replace("{zone}", shape.zone.as_deref().unwrap_or(""))
111 .replace("{capacity}", capacity_token(shape.capacity_type))
112}
113
114fn substitute_provider_id(template: &str, provider_id: &str) -> String {
116 template.replace("{provider_id}", provider_id)
117}
118
119async fn run_shell(cmd: &str) -> Result<Output> {
121 tokio::process::Command::new("sh")
122 .arg("-c")
123 .arg(cmd)
124 .output()
125 .await
126 .map_err(|e| ProvisionerError::Transport(e.to_string()))
127}
128
129fn parse_describe_output(stdout: &str) -> Vec<NodeHandle> {
131 stdout
132 .lines()
133 .map(str::trim)
134 .filter(|line| !line.is_empty())
135 .map(|line| {
136 let mut parts = line.splitn(2, ',');
137 let provider_id = parts.next().unwrap_or("").trim().to_string();
138 let address = parts
139 .next()
140 .map(str::trim)
141 .filter(|a| !a.is_empty())
142 .map(ToString::to_string);
143 NodeHandle {
144 provider_id,
145 address,
146 zone: None,
147 capacity_type: CapacityType::OnDemand,
148 join_state: JoinState::Joined,
149 }
150 })
151 .collect()
152}
153
154#[must_use]
162pub fn default_user_data_template() -> String {
163 let advertise = "$(curl -s http://169.254.169.254/latest/meta-data/public-ipv4)";
167 format!(
168 "#cloud-config\n\
169 runcmd:\n \
170 - zlayer node join {{leader_addr}} --token {{join_token}} --advertise-addr {advertise} --mode full --labels {{labels}} --no-ingress\n"
171 )
172}
173
174impl CloudInitProvisioner {
175 #[must_use]
177 pub fn new(config: CloudInitConfig) -> Self {
178 Self { config }
179 }
180
181 #[must_use]
187 pub fn render_user_data(&self, shape: &NodeShape) -> String {
188 substitute_user_data(
189 &self.config.user_data_template,
190 &self.config.leader_addr,
191 &self.config.join_token,
192 &format_labels(&shape.labels),
193 )
194 }
195}
196
197#[async_trait::async_trait]
198impl CloudProvisioner for CloudInitProvisioner {
199 async fn provision(&self, shape: &NodeShape) -> Result<NodeHandle> {
200 let user_data = self.render_user_data(shape);
201 let cmd = substitute_provision_cmd(&self.config.provision_cmd, &user_data, shape);
202
203 tracing::info!(
204 provisioner = "cloud-init",
205 cpu = shape.cpu,
206 memory_bytes = shape.memory_bytes,
207 zone = shape.zone.as_deref().unwrap_or(""),
208 "provisioning node"
209 );
210
211 let output = run_shell(&cmd).await?;
212 if !output.status.success() {
213 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
214 return Err(ProvisionerError::Capacity(stderr));
215 }
216
217 let provider_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
218 tracing::info!(provisioner = "cloud-init", %provider_id, "provisioned node");
219
220 Ok(NodeHandle {
221 provider_id,
222 address: None,
223 zone: shape.zone.clone(),
224 capacity_type: shape.capacity_type,
225 join_state: JoinState::Provisioning,
226 })
227 }
228
229 #[allow(clippy::ptr_arg)]
230 async fn terminate(&self, id: &ProviderNodeId) -> Result<()> {
231 let cmd = substitute_provider_id(&self.config.terminate_cmd, id);
232 tracing::info!(provisioner = "cloud-init", provider_id = %id, "terminating node");
233
234 let output = run_shell(&cmd).await?;
235 if !output.status.success() {
236 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
237 return Err(ProvisionerError::Transport(stderr));
238 }
239 Ok(())
240 }
241
242 async fn describe(&self) -> Result<Vec<NodeHandle>> {
243 let Some(cmd) = self.config.describe_cmd.as_deref() else {
244 return Ok(Vec::new());
245 };
246
247 let output = run_shell(cmd).await?;
248 if !output.status.success() {
249 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
250 return Err(ProvisionerError::Transport(stderr));
251 }
252
253 let stdout = String::from_utf8_lossy(&output.stdout);
254 Ok(parse_describe_output(&stdout))
255 }
256
257 fn capacity_types(&self) -> &[CapacityType] {
258 SUPPORTED_CAPACITY
259 }
260
261 fn price_hint(&self, shape: &NodeShape) -> Option<PriceHint> {
262 Some(PriceHint {
263 hourly_usd: self.config.hourly_usd,
264 capacity_type: shape.capacity_type,
265 })
266 }
267
268 fn name(&self) -> &'static str {
269 "cloud-init"
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::{
276 capacity_token, default_user_data_template, format_labels, parse_describe_output,
277 substitute_provider_id, substitute_provision_cmd, substitute_user_data, CloudInitConfig,
278 CloudInitProvisioner,
279 };
280 use crate::{CapacityType, CloudProvisioner, JoinState, NodeShape};
281
282 fn sample_config() -> CloudInitConfig {
283 CloudInitConfig {
284 leader_addr: "10.0.0.1:3669".to_string(),
285 join_token: "tok-abc".to_string(),
286 provision_cmd: "echo i-123".to_string(),
287 terminate_cmd: "echo gone {provider_id}".to_string(),
288 describe_cmd: None,
289 user_data_template: default_user_data_template(),
290 hourly_usd: 0.10,
291 }
292 }
293
294 #[test]
295 fn default_template_contains_join_line() {
296 let tpl = default_user_data_template();
297 assert!(tpl.contains("zlayer node join"));
298 assert!(tpl.contains("#cloud-config"));
299 assert!(tpl.contains("--mode full"));
300 assert!(tpl.contains("--no-ingress"));
301 assert!(tpl.contains("169.254.169.254"));
302 }
303
304 #[test]
305 fn render_user_data_substitutes_placeholders() {
306 let provisioner = CloudInitProvisioner::new(sample_config());
307 let mut shape = NodeShape::new(2.0, 4 * 1024 * 1024 * 1024);
308 shape
309 .labels
310 .insert("role".to_string(), "worker".to_string());
311
312 let rendered = provisioner.render_user_data(&shape);
313 assert!(rendered.contains("10.0.0.1:3669"));
314 assert!(rendered.contains("tok-abc"));
315 assert!(rendered.contains("role=worker"));
316 assert!(!rendered.contains("{leader_addr}"));
317 assert!(!rendered.contains("{join_token}"));
318 assert!(!rendered.contains("{labels}"));
319 }
320
321 #[test]
322 fn substitute_user_data_replaces_all() {
323 let out = substitute_user_data(
324 "join {leader_addr} tok {join_token} lbl {labels}",
325 "host:1",
326 "secret",
327 "a=b",
328 );
329 assert_eq!(out, "join host:1 tok secret lbl a=b");
330 }
331
332 #[test]
333 fn format_labels_is_sorted_and_joined() {
334 let mut labels = std::collections::BTreeMap::new();
335 labels.insert("z".to_string(), "1".to_string());
336 labels.insert("a".to_string(), "2".to_string());
337 assert_eq!(format_labels(&labels), "a=2,z=1");
338 assert_eq!(format_labels(&std::collections::BTreeMap::new()), "");
339 }
340
341 #[test]
342 fn capacity_token_maps_variants() {
343 assert_eq!(capacity_token(CapacityType::OnDemand), "on-demand");
344 assert_eq!(capacity_token(CapacityType::Spot), "spot");
345 }
346
347 #[test]
348 fn substitute_provision_cmd_fills_shape_fields() {
349 let mut shape = NodeShape::new(2.0, 2048 * 1024 * 1024);
350 shape.zone = Some("us-east-1a".to_string());
351 shape.capacity_type = CapacityType::Spot;
352 shape.labels.insert("k".to_string(), "v".to_string());
353
354 let cmd = substitute_provision_cmd(
355 "run --ud '{user_data}' --cpu {cpu} --mem {memory_mb} --labels {labels} --zone {zone} --cap {capacity}",
356 "USERDATA",
357 &shape,
358 );
359 assert!(cmd.contains("--ud 'USERDATA'"));
360 assert!(cmd.contains("--cpu 2"));
361 assert!(cmd.contains("--mem 2048"));
362 assert!(cmd.contains("--labels k=v"));
363 assert!(cmd.contains("--zone us-east-1a"));
364 assert!(cmd.contains("--cap spot"));
365 assert!(!cmd.contains('{'));
366 }
367
368 #[test]
369 fn substitute_provision_cmd_empty_zone() {
370 let shape = NodeShape::new(1.0, 1024 * 1024 * 1024);
371 let cmd = substitute_provision_cmd("z=[{zone}]", "ud", &shape);
372 assert_eq!(cmd, "z=[]");
373 }
374
375 #[test]
376 fn substitute_provider_id_replaces() {
377 assert_eq!(
378 substitute_provider_id("delete {provider_id} now", "i-9"),
379 "delete i-9 now"
380 );
381 }
382
383 #[test]
384 fn parse_describe_output_handles_id_and_address() {
385 let out = parse_describe_output("i-1,10.0.0.1\n i-2 \n\ni-3, 10.0.0.3 \n");
386 assert_eq!(out.len(), 3);
387
388 assert_eq!(out[0].provider_id, "i-1");
389 assert_eq!(out[0].address.as_deref(), Some("10.0.0.1"));
390 assert_eq!(out[0].join_state, JoinState::Joined);
391
392 assert_eq!(out[1].provider_id, "i-2");
393 assert!(out[1].address.is_none());
394
395 assert_eq!(out[2].provider_id, "i-3");
396 assert_eq!(out[2].address.as_deref(), Some("10.0.0.3"));
397 }
398
399 #[test]
400 fn parse_describe_output_empty() {
401 assert!(parse_describe_output("\n \n").is_empty());
402 }
403
404 #[test]
405 fn metadata_methods() {
406 let provisioner = CloudInitProvisioner::new(sample_config());
407 assert_eq!(provisioner.name(), "cloud-init");
408 assert_eq!(provisioner.capacity_types(), &[CapacityType::OnDemand]);
409
410 let shape = NodeShape::new(1.0, 1024 * 1024 * 1024);
411 let hint = provisioner.price_hint(&shape).expect("price hint");
412 assert!((hint.hourly_usd - 0.10).abs() < f64::EPSILON);
413 assert_eq!(hint.capacity_type, CapacityType::OnDemand);
414 }
415}