1use crate::error::{InitError, Result};
4use std::collections::HashMap;
5use std::time::Duration;
6use tokio::process::Command;
7use tokio::time::{sleep, timeout};
8
9pub struct WaitTcp {
11 pub host: String,
12 pub port: u16,
13 pub timeout: Duration,
14 pub interval: Duration,
15}
16
17impl WaitTcp {
18 pub async fn execute(&self) -> Result<()> {
21 let start = std::time::Instant::now();
22
23 loop {
24 if tokio::net::TcpStream::connect(&format!("{}:{}", self.host, self.port))
25 .await
26 .is_ok()
27 {
28 return Ok(());
29 }
30
31 if start.elapsed() >= self.timeout {
32 return Err(InitError::TcpFailed {
33 host: self.host.clone(),
34 port: self.port,
35 reason: format!("timeout after {:?}", self.timeout),
36 });
37 }
38
39 sleep(self.interval).await;
40 }
41 }
42}
43
44pub struct WaitHttp {
46 pub url: String,
47 pub expect_status: Option<u16>,
48 pub timeout: Duration,
49 pub interval: Duration,
50}
51
52impl WaitHttp {
53 pub async fn execute(&self) -> Result<()> {
56 let start = std::time::Instant::now();
57 let client = reqwest::Client::builder()
58 .timeout(Duration::from_secs(5))
59 .build()
60 .map_err(|e| InitError::HttpFailed {
61 url: self.url.clone(),
62 reason: format!("failed to create client: {e}"),
63 })?;
64
65 loop {
66 let response = client.get(&self.url).send().await;
67
68 if let Ok(resp) = response {
69 let status = resp.status().as_u16();
70
71 if let Some(expected) = self.expect_status {
72 if status == expected {
73 return Ok(());
74 }
75 } else if (200..300).contains(&status) {
76 return Ok(());
77 }
78 }
79
80 if start.elapsed() >= self.timeout {
81 return Err(InitError::HttpFailed {
82 url: self.url.clone(),
83 reason: format!("timeout after {:?}", self.timeout),
84 });
85 }
86
87 sleep(self.interval).await;
88 }
89 }
90}
91
92pub struct RunCommand {
94 pub command: String,
95 pub timeout: Duration,
96}
97
98#[cfg(unix)]
99fn build_shell_command(cmd: &str) -> Command {
100 let mut c = Command::new("sh");
101 c.arg("-c").arg(cmd);
102 c
103}
104
105#[cfg(windows)]
106fn build_shell_command(cmd: &str) -> Command {
107 let mut c = Command::new("cmd");
108 c.arg("/C").arg(cmd);
109 c
110}
111
112impl RunCommand {
113 pub async fn execute(&self) -> Result<()> {
116 match timeout(self.timeout, build_shell_command(&self.command).output()).await {
117 Ok(Ok(output)) => {
118 if output.status.success() {
119 Ok(())
120 } else {
121 Err(InitError::CommandFailed {
122 command: self.command.clone(),
123 code: output.status.code().unwrap_or(-1),
124 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
125 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
126 })
127 }
128 }
129 Ok(Err(_)) => Err(InitError::CommandFailed {
130 command: self.command.clone(),
131 code: -1,
132 stdout: String::new(),
133 stderr: "timeout".to_string(),
134 }),
135 Err(_) => Err(InitError::Timeout {
136 timeout: self.timeout,
137 }),
138 }
139 }
140}
141
142#[cfg(feature = "s3")]
144pub struct S3Push {
145 pub source: String,
147 pub bucket: String,
149 pub key: String,
151 pub endpoint: Option<String>,
153 pub region: Option<String>,
155 pub timeout: Duration,
157}
158
159#[cfg(feature = "s3")]
160impl S3Push {
161 pub async fn execute(&self) -> Result<()> {
168 use aws_sdk_s3::Client;
169
170 let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
172 if let Some(ref region) = self.region {
173 config_loader = config_loader.region(aws_config::Region::new(region.clone()));
174 }
175 let sdk_config = config_loader.load().await;
176
177 let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
179 if let Some(ref endpoint) = self.endpoint {
180 s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
181 }
182 let client = Client::from_conf(s3_config.build());
183
184 let source_path = std::path::Path::new(&self.source);
185
186 if source_path.is_file() {
187 self.upload_file(&client, source_path, &self.key).await?;
189 } else if source_path.is_dir() {
190 self.upload_directory(&client, source_path, &self.key)
192 .await?;
193 } else {
194 return Err(InitError::S3Failed {
195 bucket: self.bucket.clone(),
196 key: self.key.clone(),
197 reason: format!("source path '{}' does not exist", self.source),
198 });
199 }
200
201 Ok(())
202 }
203
204 #[cfg(feature = "s3")]
205 async fn upload_file(
206 &self,
207 client: &aws_sdk_s3::Client,
208 path: &std::path::Path,
209 key: &str,
210 ) -> Result<()> {
211 use aws_sdk_s3::primitives::ByteStream;
212
213 tracing::info!(
214 bucket = %self.bucket,
215 key = %key,
216 source = %path.display(),
217 "pushing file to S3"
218 );
219
220 let data = tokio::fs::read(path)
221 .await
222 .map_err(|e| InitError::S3Failed {
223 bucket: self.bucket.clone(),
224 key: key.to_string(),
225 reason: format!("failed to read file: {e}"),
226 })?;
227
228 tokio::time::timeout(
229 self.timeout,
230 client
231 .put_object()
232 .bucket(&self.bucket)
233 .key(key)
234 .body(ByteStream::from(data))
235 .content_type("application/octet-stream")
236 .send(),
237 )
238 .await
239 .map_err(|_| InitError::Timeout {
240 timeout: self.timeout,
241 })?
242 .map_err(|e| InitError::S3Failed {
243 bucket: self.bucket.clone(),
244 key: key.to_string(),
245 reason: format!("put_object failed: {e}"),
246 })?;
247
248 tracing::info!(bucket = %self.bucket, key = %key, "S3 push complete");
249 Ok(())
250 }
251
252 #[cfg(feature = "s3")]
253 async fn upload_directory(
254 &self,
255 client: &aws_sdk_s3::Client,
256 dir: &std::path::Path,
257 prefix: &str,
258 ) -> Result<()> {
259 let mut entries = tokio::fs::read_dir(dir)
260 .await
261 .map_err(|e| InitError::S3Failed {
262 bucket: self.bucket.clone(),
263 key: prefix.to_string(),
264 reason: format!("failed to read directory: {e}"),
265 })?;
266
267 while let Some(entry) = entries
268 .next_entry()
269 .await
270 .map_err(|e| InitError::S3Failed {
271 bucket: self.bucket.clone(),
272 key: prefix.to_string(),
273 reason: format!("failed to read directory entry: {e}"),
274 })?
275 {
276 let path = entry.path();
277 let file_name = entry.file_name();
278 let key = format!(
279 "{}/{}",
280 prefix.trim_end_matches('/'),
281 file_name.to_string_lossy()
282 );
283
284 if path.is_file() {
285 self.upload_file(client, &path, &key).await?;
286 } else if path.is_dir() {
287 Box::pin(self.upload_directory(client, &path, &key)).await?;
289 }
290 }
291
292 Ok(())
293 }
294}
295
296#[cfg(feature = "s3")]
298pub struct S3Pull {
299 pub bucket: String,
301 pub key: String,
303 pub destination: String,
305 pub endpoint: Option<String>,
307 pub region: Option<String>,
309 pub timeout: Duration,
311}
312
313#[cfg(feature = "s3")]
314impl S3Pull {
315 pub async fn execute(&self) -> Result<()> {
322 use aws_sdk_s3::Client;
323 use tokio::io::AsyncWriteExt;
324
325 let mut config_loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
327 if let Some(ref region) = self.region {
328 config_loader = config_loader.region(aws_config::Region::new(region.clone()));
329 }
330 let sdk_config = config_loader.load().await;
331
332 let mut s3_config = aws_sdk_s3::config::Builder::from(&sdk_config);
334 if let Some(ref endpoint) = self.endpoint {
335 s3_config = s3_config.endpoint_url(endpoint).force_path_style(true);
336 }
337 let client = Client::from_conf(s3_config.build());
338
339 tracing::info!(
340 bucket = %self.bucket,
341 key = %self.key,
342 destination = %self.destination,
343 "pulling from S3"
344 );
345
346 let result = tokio::time::timeout(
348 self.timeout,
349 client
350 .get_object()
351 .bucket(&self.bucket)
352 .key(&self.key)
353 .send(),
354 )
355 .await
356 .map_err(|_| InitError::Timeout {
357 timeout: self.timeout,
358 })?
359 .map_err(|e| InitError::S3Failed {
360 bucket: self.bucket.clone(),
361 key: self.key.clone(),
362 reason: format!("get_object failed: {e}"),
363 })?;
364
365 let data = result
367 .body
368 .collect()
369 .await
370 .map_err(|e| InitError::S3Failed {
371 bucket: self.bucket.clone(),
372 key: self.key.clone(),
373 reason: format!("failed to read body: {e}"),
374 })?
375 .into_bytes();
376
377 let dest_path = std::path::Path::new(&self.destination);
379 if let Some(parent) = dest_path.parent() {
380 tokio::fs::create_dir_all(parent)
381 .await
382 .map_err(|e| InitError::S3Failed {
383 bucket: self.bucket.clone(),
384 key: self.key.clone(),
385 reason: format!("failed to create destination directory: {e}"),
386 })?;
387 }
388
389 let mut file = tokio::fs::File::create(&self.destination)
390 .await
391 .map_err(|e| InitError::S3Failed {
392 bucket: self.bucket.clone(),
393 key: self.key.clone(),
394 reason: format!("failed to create file: {e}"),
395 })?;
396
397 file.write_all(&data)
398 .await
399 .map_err(|e| InitError::S3Failed {
400 bucket: self.bucket.clone(),
401 key: self.key.clone(),
402 reason: format!("failed to write file: {e}"),
403 })?;
404
405 tracing::info!(
406 bucket = %self.bucket,
407 key = %self.key,
408 bytes = data.len(),
409 "S3 pull complete"
410 );
411
412 Ok(())
413 }
414}
415
416#[allow(clippy::too_many_lines, clippy::implicit_hasher)]
422pub fn from_spec(
423 action: &str,
424 params: &HashMap<String, serde_json::Value>,
425 _default_timeout: Duration,
426) -> Result<InitAction> {
427 match action {
428 "init.wait_tcp" => {
429 let host = params
430 .get("host")
431 .and_then(|v| v.as_str())
432 .ok_or_else(|| InitError::InvalidParams {
433 action: action.to_string(),
434 reason: "missing 'host' parameter".to_string(),
435 })?
436 .to_string();
437
438 #[allow(clippy::cast_possible_truncation)]
439 let port = params
440 .get("port")
441 .and_then(serde_json::Value::as_u64)
442 .ok_or_else(|| InitError::InvalidParams {
443 action: action.to_string(),
444 reason: "missing or invalid 'port' parameter".to_string(),
445 })? as u16;
446
447 let timeout_secs = params
448 .get("timeout")
449 .and_then(serde_json::Value::as_u64)
450 .unwrap_or(30);
451
452 Ok(InitAction::WaitTcp(WaitTcp {
453 host,
454 port,
455 timeout: Duration::from_secs(timeout_secs),
456 interval: Duration::from_secs(2),
457 }))
458 }
459
460 "init.wait_http" => {
461 let url = params
462 .get("url")
463 .and_then(|v| v.as_str())
464 .ok_or_else(|| InitError::InvalidParams {
465 action: action.to_string(),
466 reason: "missing 'url' parameter".to_string(),
467 })?
468 .to_string();
469
470 #[allow(clippy::cast_possible_truncation)]
471 let expect_status = params
472 .get("expect_status")
473 .and_then(serde_json::Value::as_u64)
474 .map(|v| v as u16);
475
476 let timeout_secs = params
477 .get("timeout")
478 .and_then(serde_json::Value::as_u64)
479 .unwrap_or(30);
480
481 Ok(InitAction::WaitHttp(WaitHttp {
482 url,
483 expect_status,
484 timeout: Duration::from_secs(timeout_secs),
485 interval: Duration::from_secs(2),
486 }))
487 }
488
489 "init.run" => {
490 let command = params
491 .get("command")
492 .and_then(|v| v.as_str())
493 .ok_or_else(|| InitError::InvalidParams {
494 action: action.to_string(),
495 reason: "missing 'command' parameter".to_string(),
496 })?
497 .to_string();
498
499 let timeout_secs = params
500 .get("timeout")
501 .and_then(serde_json::Value::as_u64)
502 .unwrap_or(300);
503
504 Ok(InitAction::Run(RunCommand {
505 command,
506 timeout: Duration::from_secs(timeout_secs),
507 }))
508 }
509
510 #[cfg(feature = "s3")]
511 "init.s3_push" => {
512 let source = params
513 .get("source")
514 .and_then(|v| v.as_str())
515 .ok_or_else(|| InitError::InvalidParams {
516 action: action.to_string(),
517 reason: "missing 'source' parameter".to_string(),
518 })?
519 .to_string();
520
521 let bucket = params
522 .get("bucket")
523 .and_then(|v| v.as_str())
524 .ok_or_else(|| InitError::InvalidParams {
525 action: action.to_string(),
526 reason: "missing 'bucket' parameter".to_string(),
527 })?
528 .to_string();
529
530 let key = params
531 .get("key")
532 .and_then(|v| v.as_str())
533 .ok_or_else(|| InitError::InvalidParams {
534 action: action.to_string(),
535 reason: "missing 'key' parameter".to_string(),
536 })?
537 .to_string();
538
539 let endpoint = params
540 .get("endpoint")
541 .and_then(|v| v.as_str())
542 .map(String::from);
543 let region = params
544 .get("region")
545 .and_then(|v| v.as_str())
546 .map(String::from);
547 let timeout_secs = params
548 .get("timeout")
549 .and_then(serde_json::Value::as_u64)
550 .unwrap_or(300);
551
552 Ok(InitAction::S3Push(S3Push {
553 source,
554 bucket,
555 key,
556 endpoint,
557 region,
558 timeout: Duration::from_secs(timeout_secs),
559 }))
560 }
561
562 #[cfg(feature = "s3")]
563 "init.s3_pull" => {
564 let bucket = params
565 .get("bucket")
566 .and_then(|v| v.as_str())
567 .ok_or_else(|| InitError::InvalidParams {
568 action: action.to_string(),
569 reason: "missing 'bucket' parameter".to_string(),
570 })?
571 .to_string();
572
573 let key = params
574 .get("key")
575 .and_then(|v| v.as_str())
576 .ok_or_else(|| InitError::InvalidParams {
577 action: action.to_string(),
578 reason: "missing 'key' parameter".to_string(),
579 })?
580 .to_string();
581
582 let destination = params
583 .get("destination")
584 .and_then(|v| v.as_str())
585 .ok_or_else(|| InitError::InvalidParams {
586 action: action.to_string(),
587 reason: "missing 'destination' parameter".to_string(),
588 })?
589 .to_string();
590
591 let endpoint = params
592 .get("endpoint")
593 .and_then(|v| v.as_str())
594 .map(String::from);
595 let region = params
596 .get("region")
597 .and_then(|v| v.as_str())
598 .map(String::from);
599 let timeout_secs = params
600 .get("timeout")
601 .and_then(serde_json::Value::as_u64)
602 .unwrap_or(300);
603
604 Ok(InitAction::S3Pull(S3Pull {
605 bucket,
606 key,
607 destination,
608 endpoint,
609 region,
610 timeout: Duration::from_secs(timeout_secs),
611 }))
612 }
613
614 _ => Err(InitError::UnknownAction(action.to_string())),
615 }
616}
617
618pub enum InitAction {
620 WaitTcp(WaitTcp),
621 WaitHttp(WaitHttp),
622 Run(RunCommand),
623 #[cfg(feature = "s3")]
624 S3Push(S3Push),
625 #[cfg(feature = "s3")]
626 S3Pull(S3Pull),
627}
628
629impl InitAction {
630 pub async fn execute(&self) -> Result<()> {
633 match self {
634 InitAction::WaitTcp(a) => a.execute().await,
635 InitAction::WaitHttp(a) => a.execute().await,
636 InitAction::Run(a) => a.execute().await,
637 #[cfg(feature = "s3")]
638 InitAction::S3Push(a) => a.execute().await,
639 #[cfg(feature = "s3")]
640 InitAction::S3Pull(a) => a.execute().await,
641 }
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[tokio::test]
650 async fn test_run_command_success() {
651 let action = RunCommand {
652 command: "echo hello".to_string(),
653 timeout: Duration::from_secs(5),
654 };
655 action.execute().await.unwrap();
656 }
657
658 #[tokio::test]
659 async fn test_run_command_failure() {
660 let action = RunCommand {
661 command: "exit 1".to_string(),
662 timeout: Duration::from_secs(5),
663 };
664 let result = action.execute().await;
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn test_from_spec_wait_tcp() {
670 let mut params = HashMap::new();
671 params.insert("host".to_string(), serde_json::json!("localhost"));
672 params.insert("port".to_string(), serde_json::json!(8080));
673
674 let action = from_spec("init.wait_tcp", ¶ms, Duration::from_secs(30)).unwrap();
675 match action {
676 InitAction::WaitTcp(_) => {}
677 _ => panic!("Expected WaitTcp action"),
678 }
679 }
680
681 #[test]
682 fn test_from_spec_unknown() {
683 let params = HashMap::new();
684 let result = from_spec("unknown.action", ¶ms, Duration::from_secs(30));
685 assert!(result.is_err());
686 }
687}