Skip to main content

zlayer_provisioner/
cloud_init.rs

1//! Provider-agnostic cloud-init reference provisioner.
2//!
3//! [`CloudInitProvisioner`] shells out to operator-supplied commands rather than
4//! linking any cloud SDK. The operator provides a `provision_cmd` (and matching
5//! `terminate_cmd` / optional `describe_cmd`) as shell templates; the
6//! provisioner renders a cloud-init `#cloud-config` user-data document and
7//! substitutes it — along with the requested shape — into those templates before
8//! running them via `sh -c`.
9//!
10//! The generated user-data runs `zlayer node join` on first boot so the freshly
11//! created machine registers itself with the cluster leader.
12
13use std::process::Output;
14
15use crate::{
16    CapacityType, CloudProvisioner, JoinState, NodeHandle, NodeShape, PriceHint, ProviderNodeId,
17    ProvisionerError, Result,
18};
19
20/// Capacity types supported by the cloud-init provisioner.
21///
22/// The reference provisioner makes no assumptions about spot/preemptible
23/// markets, so it advertises on-demand only.
24static SUPPORTED_CAPACITY: &[CapacityType] = &[CapacityType::OnDemand];
25
26/// Configuration for [`CloudInitProvisioner`].
27///
28/// `provision_cmd` and `terminate_cmd` are shell command templates. The
29/// following placeholders are substituted before the command is run:
30///
31/// - `{user_data}` — the rendered cloud-init document (provision only)
32/// - `{provider_id}` — the node's provider id (terminate only)
33/// - `{cpu}` — requested vCPU count (provision only)
34/// - `{memory_mb}` — requested memory in mebibytes (provision only)
35/// - `{labels}` — comma-separated `k=v` labels (provision only)
36/// - `{zone}` — requested zone, or empty (provision only)
37/// - `{capacity}` — `on-demand` or `spot` (provision only)
38#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
39pub struct CloudInitConfig {
40    /// Address of the cluster leader new nodes should join.
41    pub leader_addr: String,
42    /// Bootstrap join token embedded in the user-data.
43    pub join_token: String,
44    /// Shell template that creates a node and prints its provider id to stdout.
45    pub provision_cmd: String,
46    /// Shell template that terminates a node by `{provider_id}`.
47    pub terminate_cmd: String,
48    /// Optional shell template that lists nodes (see
49    /// [`CloudProvisioner::describe`]).
50    pub describe_cmd: Option<String>,
51    /// cloud-init user-data template (see
52    /// [`CloudInitProvisioner::render_user_data`]).
53    pub user_data_template: String,
54    /// Flat hourly price used for [`CloudProvisioner::price_hint`].
55    pub hourly_usd: f64,
56}
57
58/// A [`CloudProvisioner`] that drives operator-supplied shell commands.
59#[derive(Clone, Debug)]
60pub struct CloudInitProvisioner {
61    config: CloudInitConfig,
62}
63
64/// Render comma-separated `k=v` pairs from a label map (the map is a
65/// `BTreeMap`, so iteration is already sorted by key).
66fn format_labels(labels: &std::collections::BTreeMap<String, String>) -> String {
67    labels
68        .iter()
69        .map(|(k, v)| format!("{k}={v}"))
70        .collect::<Vec<_>>()
71        .join(",")
72}
73
74/// `capacity_type` as the kebab-case token used in templates.
75fn capacity_token(capacity: CapacityType) -> &'static str {
76    match capacity {
77        CapacityType::OnDemand => "on-demand",
78        CapacityType::Spot => "spot",
79    }
80}
81
82/// Substitute the user-data placeholder into the user-data template.
83///
84/// Factored out of [`CloudInitProvisioner::render_user_data`] so it can be unit
85/// tested without constructing a provisioner.
86fn substitute_user_data(
87    template: &str,
88    leader_addr: &str,
89    join_token: &str,
90    labels: &str,
91) -> String {
92    template
93        .replace("{leader_addr}", leader_addr)
94        .replace("{join_token}", join_token)
95        .replace("{labels}", labels)
96}
97
98/// Substitute the provision placeholders into a command template.
99///
100/// Factored out of [`CloudInitProvisioner::provision`] so the substitution is
101/// unit-testable without spawning a shell.
102fn substitute_provision_cmd(template: &str, user_data: &str, shape: &NodeShape) -> String {
103    // Memory in MiB, saturating the conversion for absurdly large requests.
104    let memory_mb = shape.memory_bytes / (1024 * 1024);
105    template
106        .replace("{user_data}", user_data)
107        .replace("{cpu}", &shape.cpu.to_string())
108        .replace("{memory_mb}", &memory_mb.to_string())
109        .replace("{labels}", &format_labels(&shape.labels))
110        .replace("{zone}", shape.zone.as_deref().unwrap_or(""))
111        .replace("{capacity}", capacity_token(shape.capacity_type))
112}
113
114/// Substitute the `{provider_id}` placeholder into a command template.
115fn substitute_provider_id(template: &str, provider_id: &str) -> String {
116    template.replace("{provider_id}", provider_id)
117}
118
119/// Run a shell command template via `sh -c` and return its captured output.
120async fn run_shell(cmd: &str) -> Result<Output> {
121    tokio::process::Command::new("sh")
122        .arg("-c")
123        .arg(cmd)
124        .output()
125        .await
126        .map_err(|e| ProvisionerError::Transport(e.to_string()))
127}
128
129/// Parse `describe_cmd` stdout (`provider_id[,address]` per line) into handles.
130fn parse_describe_output(stdout: &str) -> Vec<NodeHandle> {
131    stdout
132        .lines()
133        .map(str::trim)
134        .filter(|line| !line.is_empty())
135        .map(|line| {
136            let mut parts = line.splitn(2, ',');
137            let provider_id = parts.next().unwrap_or("").trim().to_string();
138            let address = parts
139                .next()
140                .map(str::trim)
141                .filter(|a| !a.is_empty())
142                .map(ToString::to_string);
143            NodeHandle {
144                provider_id,
145                address,
146                zone: None,
147                capacity_type: CapacityType::OnDemand,
148                join_state: JoinState::Joined,
149            }
150        })
151        .collect()
152}
153
154/// Default cloud-init `#cloud-config` user-data template.
155///
156/// The returned document contains a `runcmd:` entry that runs `zlayer node join`
157/// with the placeholders `{leader_addr}`, `{join_token}`, and `{labels}`, using
158/// the `EC2`-style instance metadata endpoint to discover the public `IPv4`
159/// address to advertise. Operators may supply their own template via
160/// [`CloudInitConfig::user_data_template`]; this is only a sensible default.
161#[must_use]
162pub fn default_user_data_template() -> String {
163    // The advertise IP is resolved on the node from the cloud metadata service.
164    // `{leader_addr}`, `{join_token}`, and `{labels}` are substituted by
165    // `render_user_data`; `$(...)` is evaluated by the shell on the node.
166    let advertise = "$(curl -s http://169.254.169.254/latest/meta-data/public-ipv4)";
167    format!(
168        "#cloud-config\n\
169         runcmd:\n  \
170         - zlayer node join {{leader_addr}} --token {{join_token}} --advertise-addr {advertise} --mode full --labels {{labels}} --no-ingress\n"
171    )
172}
173
174impl CloudInitProvisioner {
175    /// Create a new provisioner from `config`.
176    #[must_use]
177    pub fn new(config: CloudInitConfig) -> Self {
178        Self { config }
179    }
180
181    /// Render the cloud-init user-data document for `shape`.
182    ///
183    /// Substitutes `{leader_addr}`, `{join_token}`, and `{labels}` (the latter
184    /// from `shape.labels`) into the configured
185    /// [`user_data_template`](CloudInitConfig::user_data_template).
186    #[must_use]
187    pub fn render_user_data(&self, shape: &NodeShape) -> String {
188        substitute_user_data(
189            &self.config.user_data_template,
190            &self.config.leader_addr,
191            &self.config.join_token,
192            &format_labels(&shape.labels),
193        )
194    }
195}
196
197#[async_trait::async_trait]
198impl CloudProvisioner for CloudInitProvisioner {
199    async fn provision(&self, shape: &NodeShape) -> Result<NodeHandle> {
200        let user_data = self.render_user_data(shape);
201        let cmd = substitute_provision_cmd(&self.config.provision_cmd, &user_data, shape);
202
203        tracing::info!(
204            provisioner = "cloud-init",
205            cpu = shape.cpu,
206            memory_bytes = shape.memory_bytes,
207            zone = shape.zone.as_deref().unwrap_or(""),
208            "provisioning node"
209        );
210
211        let output = run_shell(&cmd).await?;
212        if !output.status.success() {
213            let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
214            return Err(ProvisionerError::Capacity(stderr));
215        }
216
217        let provider_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
218        tracing::info!(provisioner = "cloud-init", %provider_id, "provisioned node");
219
220        Ok(NodeHandle {
221            provider_id,
222            address: None,
223            zone: shape.zone.clone(),
224            capacity_type: shape.capacity_type,
225            join_state: JoinState::Provisioning,
226        })
227    }
228
229    #[allow(clippy::ptr_arg)]
230    async fn terminate(&self, id: &ProviderNodeId) -> Result<()> {
231        let cmd = substitute_provider_id(&self.config.terminate_cmd, id);
232        tracing::info!(provisioner = "cloud-init", provider_id = %id, "terminating node");
233
234        let output = run_shell(&cmd).await?;
235        if !output.status.success() {
236            let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
237            return Err(ProvisionerError::Transport(stderr));
238        }
239        Ok(())
240    }
241
242    async fn describe(&self) -> Result<Vec<NodeHandle>> {
243        let Some(cmd) = self.config.describe_cmd.as_deref() else {
244            return Ok(Vec::new());
245        };
246
247        let output = run_shell(cmd).await?;
248        if !output.status.success() {
249            let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
250            return Err(ProvisionerError::Transport(stderr));
251        }
252
253        let stdout = String::from_utf8_lossy(&output.stdout);
254        Ok(parse_describe_output(&stdout))
255    }
256
257    fn capacity_types(&self) -> &[CapacityType] {
258        SUPPORTED_CAPACITY
259    }
260
261    fn price_hint(&self, shape: &NodeShape) -> Option<PriceHint> {
262        Some(PriceHint {
263            hourly_usd: self.config.hourly_usd,
264            capacity_type: shape.capacity_type,
265        })
266    }
267
268    fn name(&self) -> &'static str {
269        "cloud-init"
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::{
276        capacity_token, default_user_data_template, format_labels, parse_describe_output,
277        substitute_provider_id, substitute_provision_cmd, substitute_user_data, CloudInitConfig,
278        CloudInitProvisioner,
279    };
280    use crate::{CapacityType, CloudProvisioner, JoinState, NodeShape};
281
282    fn sample_config() -> CloudInitConfig {
283        CloudInitConfig {
284            leader_addr: "10.0.0.1:3669".to_string(),
285            join_token: "tok-abc".to_string(),
286            provision_cmd: "echo i-123".to_string(),
287            terminate_cmd: "echo gone {provider_id}".to_string(),
288            describe_cmd: None,
289            user_data_template: default_user_data_template(),
290            hourly_usd: 0.10,
291        }
292    }
293
294    #[test]
295    fn default_template_contains_join_line() {
296        let tpl = default_user_data_template();
297        assert!(tpl.contains("zlayer node join"));
298        assert!(tpl.contains("#cloud-config"));
299        assert!(tpl.contains("--mode full"));
300        assert!(tpl.contains("--no-ingress"));
301        assert!(tpl.contains("169.254.169.254"));
302    }
303
304    #[test]
305    fn render_user_data_substitutes_placeholders() {
306        let provisioner = CloudInitProvisioner::new(sample_config());
307        let mut shape = NodeShape::new(2.0, 4 * 1024 * 1024 * 1024);
308        shape
309            .labels
310            .insert("role".to_string(), "worker".to_string());
311
312        let rendered = provisioner.render_user_data(&shape);
313        assert!(rendered.contains("10.0.0.1:3669"));
314        assert!(rendered.contains("tok-abc"));
315        assert!(rendered.contains("role=worker"));
316        assert!(!rendered.contains("{leader_addr}"));
317        assert!(!rendered.contains("{join_token}"));
318        assert!(!rendered.contains("{labels}"));
319    }
320
321    #[test]
322    fn substitute_user_data_replaces_all() {
323        let out = substitute_user_data(
324            "join {leader_addr} tok {join_token} lbl {labels}",
325            "host:1",
326            "secret",
327            "a=b",
328        );
329        assert_eq!(out, "join host:1 tok secret lbl a=b");
330    }
331
332    #[test]
333    fn format_labels_is_sorted_and_joined() {
334        let mut labels = std::collections::BTreeMap::new();
335        labels.insert("z".to_string(), "1".to_string());
336        labels.insert("a".to_string(), "2".to_string());
337        assert_eq!(format_labels(&labels), "a=2,z=1");
338        assert_eq!(format_labels(&std::collections::BTreeMap::new()), "");
339    }
340
341    #[test]
342    fn capacity_token_maps_variants() {
343        assert_eq!(capacity_token(CapacityType::OnDemand), "on-demand");
344        assert_eq!(capacity_token(CapacityType::Spot), "spot");
345    }
346
347    #[test]
348    fn substitute_provision_cmd_fills_shape_fields() {
349        let mut shape = NodeShape::new(2.0, 2048 * 1024 * 1024);
350        shape.zone = Some("us-east-1a".to_string());
351        shape.capacity_type = CapacityType::Spot;
352        shape.labels.insert("k".to_string(), "v".to_string());
353
354        let cmd = substitute_provision_cmd(
355            "run --ud '{user_data}' --cpu {cpu} --mem {memory_mb} --labels {labels} --zone {zone} --cap {capacity}",
356            "USERDATA",
357            &shape,
358        );
359        assert!(cmd.contains("--ud 'USERDATA'"));
360        assert!(cmd.contains("--cpu 2"));
361        assert!(cmd.contains("--mem 2048"));
362        assert!(cmd.contains("--labels k=v"));
363        assert!(cmd.contains("--zone us-east-1a"));
364        assert!(cmd.contains("--cap spot"));
365        assert!(!cmd.contains('{'));
366    }
367
368    #[test]
369    fn substitute_provision_cmd_empty_zone() {
370        let shape = NodeShape::new(1.0, 1024 * 1024 * 1024);
371        let cmd = substitute_provision_cmd("z=[{zone}]", "ud", &shape);
372        assert_eq!(cmd, "z=[]");
373    }
374
375    #[test]
376    fn substitute_provider_id_replaces() {
377        assert_eq!(
378            substitute_provider_id("delete {provider_id} now", "i-9"),
379            "delete i-9 now"
380        );
381    }
382
383    #[test]
384    fn parse_describe_output_handles_id_and_address() {
385        let out = parse_describe_output("i-1,10.0.0.1\n  i-2  \n\ni-3, 10.0.0.3 \n");
386        assert_eq!(out.len(), 3);
387
388        assert_eq!(out[0].provider_id, "i-1");
389        assert_eq!(out[0].address.as_deref(), Some("10.0.0.1"));
390        assert_eq!(out[0].join_state, JoinState::Joined);
391
392        assert_eq!(out[1].provider_id, "i-2");
393        assert!(out[1].address.is_none());
394
395        assert_eq!(out[2].provider_id, "i-3");
396        assert_eq!(out[2].address.as_deref(), Some("10.0.0.3"));
397    }
398
399    #[test]
400    fn parse_describe_output_empty() {
401        assert!(parse_describe_output("\n  \n").is_empty());
402    }
403
404    #[test]
405    fn metadata_methods() {
406        let provisioner = CloudInitProvisioner::new(sample_config());
407        assert_eq!(provisioner.name(), "cloud-init");
408        assert_eq!(provisioner.capacity_types(), &[CapacityType::OnDemand]);
409
410        let shape = NodeShape::new(1.0, 1024 * 1024 * 1024);
411        let hint = provisioner.price_hint(&shape).expect("price hint");
412        assert!((hint.hourly_usd - 0.10).abs() < f64::EPSILON);
413        assert_eq!(hint.capacity_type, CapacityType::OnDemand);
414    }
415}