Skip to main content

zlayer_init_actions/
actions.rs

1//! Built-in init actions
2
3use crate::error::{InitError, Result};
4use std::collections::HashMap;
5use std::time::Duration;
6use tokio::process::Command;
7use tokio::time::{sleep, timeout};
8
9/// Wait for a TCP port to be open
10pub struct WaitTcp {
11    pub host: String,
12    pub port: u16,
13    pub timeout: Duration,
14    pub interval: Duration,
15}
16
17impl WaitTcp {
18    /// # Errors
19    /// Returns `InitError::TcpFailed` if the connection times out.
20    pub async fn execute(&self) -> Result<()> {
21        let start = std::time::Instant::now();
22
23        loop {
24            if tokio::net::TcpStream::connect(&format!("{}:{}", self.host, self.port))
25                .await
26                .is_ok()
27            {
28                return Ok(());
29            }
30
31            if start.elapsed() >= self.timeout {
32                return Err(InitError::TcpFailed {
33                    host: self.host.clone(),
34                    port: self.port,
35                    reason: format!("timeout after {:?}", self.timeout),
36                });
37            }
38
39            sleep(self.interval).await;
40        }
41    }
42}
43
44/// Wait for an HTTP endpoint to respond
45pub struct WaitHttp {
46    pub url: String,
47    pub expect_status: Option<u16>,
48    pub timeout: Duration,
49    pub interval: Duration,
50}
51
52impl WaitHttp {
53    /// # Errors
54    /// Returns `InitError::HttpFailed` if the request times out or the expected status is not received.
55    pub async fn execute(&self) -> Result<()> {
56        let start = std::time::Instant::now();
57        let client = reqwest::Client::builder()
58            .timeout(Duration::from_secs(5))
59            .build()
60            .map_err(|e| InitError::HttpFailed {
61                url: self.url.clone(),
62                reason: format!("failed to create client: {e}"),
63            })?;
64
65        loop {
66            let response = client.get(&self.url).send().await;
67
68            if let Ok(resp) = response {
69                let status = resp.status().as_u16();
70
71                if let Some(expected) = self.expect_status {
72                    if status == expected {
73                        return Ok(());
74                    }
75                } else if (200..300).contains(&status) {
76                    return Ok(());
77                }
78            }
79
80            if start.elapsed() >= self.timeout {
81                return Err(InitError::HttpFailed {
82                    url: self.url.clone(),
83                    reason: format!("timeout after {:?}", self.timeout),
84                });
85            }
86
87            sleep(self.interval).await;
88        }
89    }
90}
91
92/// Run a shell command
93pub struct RunCommand {
94    pub command: String,
95    pub timeout: Duration,
96}
97
98#[cfg(unix)]
99fn build_shell_command(cmd: &str) -> Command {
100    let mut c = Command::new("sh");
101    c.arg("-c").arg(cmd);
102    c
103}
104
105#[cfg(windows)]
106fn build_shell_command(cmd: &str) -> Command {
107    let mut c = Command::new("cmd");
108    c.arg("/C").arg(cmd);
109    c
110}
111
112impl RunCommand {
113    /// # Errors
114    /// Returns an error if the command fails, exits non-zero, or times out.
115    pub async fn execute(&self) -> Result<()> {
116        match timeout(self.timeout, build_shell_command(&self.command).output()).await {
117            Ok(Ok(output)) => {
118                if output.status.success() {
119                    Ok(())
120                } else {
121                    Err(InitError::CommandFailed {
122                        command: self.command.clone(),
123                        code: output.status.code().unwrap_or(-1),
124                        stdout: String::from_utf8_lossy(&output.stdout).to_string(),
125                        stderr: String::from_utf8_lossy(&output.stderr).to_string(),
126                    })
127                }
128            }
129            Ok(Err(_)) => Err(InitError::CommandFailed {
130                command: self.command.clone(),
131                code: -1,
132                stdout: String::new(),
133                stderr: "timeout".to_string(),
134            }),
135            Err(_) => Err(InitError::Timeout {
136                timeout: self.timeout,
137            }),
138        }
139    }
140}
141
142/// Push files to S3 from a local path
143#[cfg(feature = "s3")]
144pub struct S3Push {
145    /// Local source path (file or directory)
146    pub source: String,
147    /// S3 bucket name
148    pub bucket: String,
149    /// S3 key prefix
150    pub key: String,
151    /// Custom S3 endpoint (for S3-compatible services)
152    pub endpoint: Option<String>,
153    /// Region
154    pub region: Option<String>,
155    /// Upload timeout
156    pub timeout: Duration,
157}
158
159#[cfg(feature = "s3")]
160impl S3Push {
161    /// Execute the S3 push action, uploading files to the configured bucket.
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if the AWS SDK configuration fails, the S3 client
166    /// cannot be created, or any file upload fails.
167    pub async fn execute(&self) -> Result<()> {
168        use aws_sdk_s3::Client;
169
170        // Build AWS config
171        let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
172        if let Some(ref region) = self.region {
173            config_loader = config_loader.region(aws_config::Region::new(region.clone()));
174        }
175        let sdk_config = config_loader.load().await;
176
177        // Build S3 client
178        let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
179        if let Some(ref endpoint) = self.endpoint {
180            s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
181        }
182        let client = Client::from_conf(s3_config.build());
183
184        let source_path = std::path::Path::new(&self.source);
185
186        if source_path.is_file() {
187            // Upload single file
188            self.upload_file(&client, source_path, &self.key).await?;
189        } else if source_path.is_dir() {
190            // Upload directory recursively
191            self.upload_directory(&client, source_path, &self.key)
192                .await?;
193        } else {
194            return Err(InitError::S3Failed {
195                bucket: self.bucket.clone(),
196                key: self.key.clone(),
197                reason: format!("source path '{}' does not exist", self.source),
198            });
199        }
200
201        Ok(())
202    }
203
204    #[cfg(feature = "s3")]
205    async fn upload_file(
206        &self,
207        client: &aws_sdk_s3::Client,
208        path: &std::path::Path,
209        key: &str,
210    ) -> Result<()> {
211        use aws_sdk_s3::primitives::ByteStream;
212
213        tracing::info!(
214            bucket = %self.bucket,
215            key = %key,
216            source = %path.display(),
217            "pushing file to S3"
218        );
219
220        let data = tokio::fs::read(path)
221            .await
222            .map_err(|e| InitError::S3Failed {
223                bucket: self.bucket.clone(),
224                key: key.to_string(),
225                reason: format!("failed to read file: {e}"),
226            })?;
227
228        tokio::time::timeout(
229            self.timeout,
230            client
231                .put_object()
232                .bucket(&self.bucket)
233                .key(key)
234                .body(ByteStream::from(data))
235                .content_type("application/octet-stream")
236                .send(),
237        )
238        .await
239        .map_err(|_| InitError::Timeout {
240            timeout: self.timeout,
241        })?
242        .map_err(|e| InitError::S3Failed {
243            bucket: self.bucket.clone(),
244            key: key.to_string(),
245            reason: format!("put_object failed: {e}"),
246        })?;
247
248        tracing::info!(bucket = %self.bucket, key = %key, "S3 push complete");
249        Ok(())
250    }
251
252    #[cfg(feature = "s3")]
253    async fn upload_directory(
254        &self,
255        client: &aws_sdk_s3::Client,
256        dir: &std::path::Path,
257        prefix: &str,
258    ) -> Result<()> {
259        let mut entries = tokio::fs::read_dir(dir)
260            .await
261            .map_err(|e| InitError::S3Failed {
262                bucket: self.bucket.clone(),
263                key: prefix.to_string(),
264                reason: format!("failed to read directory: {e}"),
265            })?;
266
267        while let Some(entry) = entries
268            .next_entry()
269            .await
270            .map_err(|e| InitError::S3Failed {
271                bucket: self.bucket.clone(),
272                key: prefix.to_string(),
273                reason: format!("failed to read directory entry: {e}"),
274            })?
275        {
276            let path = entry.path();
277            let file_name = entry.file_name();
278            let key = format!(
279                "{}/{}",
280                prefix.trim_end_matches('/'),
281                file_name.to_string_lossy()
282            );
283
284            if path.is_file() {
285                self.upload_file(client, &path, &key).await?;
286            } else if path.is_dir() {
287                // Use Box::pin for recursive async
288                Box::pin(self.upload_directory(client, &path, &key)).await?;
289            }
290        }
291
292        Ok(())
293    }
294}
295
296/// Pull files from S3 to a local path
297#[cfg(feature = "s3")]
298pub struct S3Pull {
299    /// S3 bucket name
300    pub bucket: String,
301    /// S3 key or prefix to download
302    pub key: String,
303    /// Local destination path
304    pub destination: String,
305    /// Custom S3 endpoint (for S3-compatible services)
306    pub endpoint: Option<String>,
307    /// Region
308    pub region: Option<String>,
309    /// Download timeout
310    pub timeout: Duration,
311}
312
313#[cfg(feature = "s3")]
314impl S3Pull {
315    /// Execute the S3 pull action, downloading files from the configured bucket.
316    ///
317    /// # Errors
318    ///
319    /// Returns an error if the AWS SDK configuration fails, the S3 client
320    /// cannot be created, or any file download fails.
321    pub async fn execute(&self) -> Result<()> {
322        use aws_sdk_s3::Client;
323        use tokio::io::AsyncWriteExt;
324
325        // Build AWS config
326        let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
327        if let Some(ref region) = self.region {
328            config_loader = config_loader.region(aws_config::Region::new(region.clone()));
329        }
330        let sdk_config = config_loader.load().await;
331
332        // Build S3 client
333        let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
334        if let Some(ref endpoint) = self.endpoint {
335            s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
336        }
337        let client = Client::from_conf(s3_config.build());
338
339        tracing::info!(
340            bucket = %self.bucket,
341            key = %self.key,
342            destination = %self.destination,
343            "pulling from S3"
344        );
345
346        // Get object from S3
347        let result = tokio::time::timeout(
348            self.timeout,
349            client
350                .get_object()
351                .bucket(&self.bucket)
352                .key(&self.key)
353                .send(),
354        )
355        .await
356        .map_err(|_| InitError::Timeout {
357            timeout: self.timeout,
358        })?
359        .map_err(|e| InitError::S3Failed {
360            bucket: self.bucket.clone(),
361            key: self.key.clone(),
362            reason: format!("get_object failed: {e}"),
363        })?;
364
365        // Read body
366        let data = result
367            .body
368            .collect()
369            .await
370            .map_err(|e| InitError::S3Failed {
371                bucket: self.bucket.clone(),
372                key: self.key.clone(),
373                reason: format!("failed to read body: {e}"),
374            })?
375            .into_bytes();
376
377        // Write to destination
378        let dest_path = std::path::Path::new(&self.destination);
379        if let Some(parent) = dest_path.parent() {
380            tokio::fs::create_dir_all(parent)
381                .await
382                .map_err(|e| InitError::S3Failed {
383                    bucket: self.bucket.clone(),
384                    key: self.key.clone(),
385                    reason: format!("failed to create destination directory: {e}"),
386                })?;
387        }
388
389        let mut file = tokio::fs::File::create(&self.destination)
390            .await
391            .map_err(|e| InitError::S3Failed {
392                bucket: self.bucket.clone(),
393                key: self.key.clone(),
394                reason: format!("failed to create file: {e}"),
395            })?;
396
397        file.write_all(&data)
398            .await
399            .map_err(|e| InitError::S3Failed {
400                bucket: self.bucket.clone(),
401                key: self.key.clone(),
402                reason: format!("failed to write file: {e}"),
403            })?;
404
405        tracing::info!(
406            bucket = %self.bucket,
407            key = %self.key,
408            bytes = data.len(),
409            "S3 pull complete"
410        );
411
412        Ok(())
413    }
414}
415
416/// Create an init action from the spec
417///
418/// # Errors
419/// Returns `InitError::InvalidParams` if required parameters are missing or invalid,
420/// or `InitError::UnknownAction` if the action type is not recognized.
421#[allow(clippy::too_many_lines, clippy::implicit_hasher)]
422pub fn from_spec(
423    action: &str,
424    params: &HashMap<String, serde_json::Value>,
425    _default_timeout: Duration,
426) -> Result<InitAction> {
427    match action {
428        "init.wait_tcp" => {
429            let host = params
430                .get("host")
431                .and_then(|v| v.as_str())
432                .ok_or_else(|| InitError::InvalidParams {
433                    action: action.to_string(),
434                    reason: "missing 'host' parameter".to_string(),
435                })?
436                .to_string();
437
438            #[allow(clippy::cast_possible_truncation)]
439            let port = params
440                .get("port")
441                .and_then(serde_json::Value::as_u64)
442                .ok_or_else(|| InitError::InvalidParams {
443                    action: action.to_string(),
444                    reason: "missing or invalid 'port' parameter".to_string(),
445                })? as u16;
446
447            let timeout_secs = params
448                .get("timeout")
449                .and_then(serde_json::Value::as_u64)
450                .unwrap_or(30);
451
452            Ok(InitAction::WaitTcp(WaitTcp {
453                host,
454                port,
455                timeout: Duration::from_secs(timeout_secs),
456                interval: Duration::from_secs(2),
457            }))
458        }
459
460        "init.wait_http" => {
461            let url = params
462                .get("url")
463                .and_then(|v| v.as_str())
464                .ok_or_else(|| InitError::InvalidParams {
465                    action: action.to_string(),
466                    reason: "missing 'url' parameter".to_string(),
467                })?
468                .to_string();
469
470            #[allow(clippy::cast_possible_truncation)]
471            let expect_status = params
472                .get("expect_status")
473                .and_then(serde_json::Value::as_u64)
474                .map(|v| v as u16);
475
476            let timeout_secs = params
477                .get("timeout")
478                .and_then(serde_json::Value::as_u64)
479                .unwrap_or(30);
480
481            Ok(InitAction::WaitHttp(WaitHttp {
482                url,
483                expect_status,
484                timeout: Duration::from_secs(timeout_secs),
485                interval: Duration::from_secs(2),
486            }))
487        }
488
489        "init.run" => {
490            let command = params
491                .get("command")
492                .and_then(|v| v.as_str())
493                .ok_or_else(|| InitError::InvalidParams {
494                    action: action.to_string(),
495                    reason: "missing 'command' parameter".to_string(),
496                })?
497                .to_string();
498
499            let timeout_secs = params
500                .get("timeout")
501                .and_then(serde_json::Value::as_u64)
502                .unwrap_or(300);
503
504            Ok(InitAction::Run(RunCommand {
505                command,
506                timeout: Duration::from_secs(timeout_secs),
507            }))
508        }
509
510        #[cfg(feature = "s3")]
511        "init.s3_push" => {
512            let source = params
513                .get("source")
514                .and_then(|v| v.as_str())
515                .ok_or_else(|| InitError::InvalidParams {
516                    action: action.to_string(),
517                    reason: "missing 'source' parameter".to_string(),
518                })?
519                .to_string();
520
521            let bucket = params
522                .get("bucket")
523                .and_then(|v| v.as_str())
524                .ok_or_else(|| InitError::InvalidParams {
525                    action: action.to_string(),
526                    reason: "missing 'bucket' parameter".to_string(),
527                })?
528                .to_string();
529
530            let key = params
531                .get("key")
532                .and_then(|v| v.as_str())
533                .ok_or_else(|| InitError::InvalidParams {
534                    action: action.to_string(),
535                    reason: "missing 'key' parameter".to_string(),
536                })?
537                .to_string();
538
539            let endpoint = params
540                .get("endpoint")
541                .and_then(|v| v.as_str())
542                .map(String::from);
543            let region = params
544                .get("region")
545                .and_then(|v| v.as_str())
546                .map(String::from);
547            let timeout_secs = params
548                .get("timeout")
549                .and_then(serde_json::Value::as_u64)
550                .unwrap_or(300);
551
552            Ok(InitAction::S3Push(S3Push {
553                source,
554                bucket,
555                key,
556                endpoint,
557                region,
558                timeout: Duration::from_secs(timeout_secs),
559            }))
560        }
561
562        #[cfg(feature = "s3")]
563        "init.s3_pull" => {
564            let bucket = params
565                .get("bucket")
566                .and_then(|v| v.as_str())
567                .ok_or_else(|| InitError::InvalidParams {
568                    action: action.to_string(),
569                    reason: "missing 'bucket' parameter".to_string(),
570                })?
571                .to_string();
572
573            let key = params
574                .get("key")
575                .and_then(|v| v.as_str())
576                .ok_or_else(|| InitError::InvalidParams {
577                    action: action.to_string(),
578                    reason: "missing 'key' parameter".to_string(),
579                })?
580                .to_string();
581
582            let destination = params
583                .get("destination")
584                .and_then(|v| v.as_str())
585                .ok_or_else(|| InitError::InvalidParams {
586                    action: action.to_string(),
587                    reason: "missing 'destination' parameter".to_string(),
588                })?
589                .to_string();
590
591            let endpoint = params
592                .get("endpoint")
593                .and_then(|v| v.as_str())
594                .map(String::from);
595            let region = params
596                .get("region")
597                .and_then(|v| v.as_str())
598                .map(String::from);
599            let timeout_secs = params
600                .get("timeout")
601                .and_then(serde_json::Value::as_u64)
602                .unwrap_or(300);
603
604            Ok(InitAction::S3Pull(S3Pull {
605                bucket,
606                key,
607                destination,
608                endpoint,
609                region,
610                timeout: Duration::from_secs(timeout_secs),
611            }))
612        }
613
614        _ => Err(InitError::UnknownAction(action.to_string())),
615    }
616}
617
618/// Enum of all init actions
619pub enum InitAction {
620    WaitTcp(WaitTcp),
621    WaitHttp(WaitHttp),
622    Run(RunCommand),
623    #[cfg(feature = "s3")]
624    S3Push(S3Push),
625    #[cfg(feature = "s3")]
626    S3Pull(S3Pull),
627}
628
629impl InitAction {
630    /// # Errors
631    /// Returns an error if the underlying action fails.
632    pub async fn execute(&self) -> Result<()> {
633        match self {
634            InitAction::WaitTcp(a) => a.execute().await,
635            InitAction::WaitHttp(a) => a.execute().await,
636            InitAction::Run(a) => a.execute().await,
637            #[cfg(feature = "s3")]
638            InitAction::S3Push(a) => a.execute().await,
639            #[cfg(feature = "s3")]
640            InitAction::S3Pull(a) => a.execute().await,
641        }
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648
649    #[tokio::test]
650    async fn test_run_command_success() {
651        let action = RunCommand {
652            command: "echo hello".to_string(),
653            timeout: Duration::from_secs(5),
654        };
655        action.execute().await.unwrap();
656    }
657
658    #[tokio::test]
659    async fn test_run_command_failure() {
660        let action = RunCommand {
661            command: "exit 1".to_string(),
662            timeout: Duration::from_secs(5),
663        };
664        let result = action.execute().await;
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn test_from_spec_wait_tcp() {
670        let mut params = HashMap::new();
671        params.insert("host".to_string(), serde_json::json!("localhost"));
672        params.insert("port".to_string(), serde_json::json!(8080));
673
674        let action = from_spec("init.wait_tcp", &params, Duration::from_secs(30)).unwrap();
675        match action {
676            InitAction::WaitTcp(_) => {}
677            _ => panic!("Expected WaitTcp action"),
678        }
679    }
680
681    #[test]
682    fn test_from_spec_unknown() {
683        let params = HashMap::new();
684        let result = from_spec("unknown.action", &params, Duration::from_secs(30));
685        assert!(result.is_err());
686    }
687}