sk_core/
hooks.rs

1use std::fs;
2use std::process::Stdio;
3
4use anyhow::{
5    anyhow,
6    bail,
7};
8use sk_api::v1::SimulationHooksConfig;
9use tokio::io::{
10    AsyncWriteExt,
11    BufWriter,
12};
13use tokio::process::Command;
14use tracing::*;
15
16use crate::prelude::*;
17
18#[derive(Debug)]
19pub enum Type {
20    PreStart,
21    PreRun,
22    PostRun,
23    PostStop,
24}
25
26pub fn merge_hooks(maybe_files: &Option<Vec<String>>) -> anyhow::Result<Option<SimulationHooksConfig>> {
27    let Some(files) = maybe_files else {
28        return Ok(None);
29    };
30    if files.is_empty() {
31        return Ok(None);
32    }
33
34    Some(files.iter().try_fold(SimulationHooksConfig::default(), |mut merged_hooks, f| {
35        let next = serde_yaml::from_slice::<SimulationHooksConfig>(
36            &fs::read(f).map_err(|e| anyhow!("error reading hook {f}: {e}"))?,
37        )
38        .map_err(|e| anyhow!("error parsing hook {f}: {e}"))?;
39        merge_vecs(&mut merged_hooks.pre_start_hooks, next.pre_start_hooks);
40        merge_vecs(&mut merged_hooks.pre_run_hooks, next.pre_run_hooks);
41        merge_vecs(&mut merged_hooks.post_run_hooks, next.post_run_hooks);
42        merge_vecs(&mut merged_hooks.post_stop_hooks, next.post_stop_hooks);
43        Ok(merged_hooks)
44    }))
45    .transpose()
46}
47
48pub async fn execute(sim: &Simulation, type_: Type) -> EmptyResult {
49    let maybe_hooks = match &sim.spec.hooks {
50        Some(hooks_config) => match type_ {
51            Type::PreStart => hooks_config.pre_start_hooks.as_ref(),
52            Type::PreRun => hooks_config.pre_run_hooks.as_ref(),
53            Type::PostRun => hooks_config.post_run_hooks.as_ref(),
54            Type::PostStop => hooks_config.post_stop_hooks.as_ref(),
55        },
56        _ => None,
57    };
58
59    if let Some(hooks) = maybe_hooks {
60        info!("Executing {:?} hooks", type_);
61
62        for hook in hooks {
63            info!("Running `{}` with args {:?}", hook.cmd, hook.args);
64            let mut child = Command::new(hook.cmd.clone())
65                .args(hook.args.clone().unwrap_or_default())
66                .stdin(Stdio::piped())
67                .stdout(Stdio::piped())
68                .stderr(Stdio::piped())
69                .spawn()?;
70            if let Some(true) = hook.send_sim {
71                let mut stdin = BufWriter::new(child.stdin.take().ok_or(anyhow!("could not take stdin"))?);
72                stdin.write_all(&serde_json::to_vec(sim)?).await?;
73                stdin.flush().await?;
74            }
75            let output = child.wait_with_output().await?;
76            info!("Hook output: {:?}", output);
77            match hook.ignore_failure {
78                Some(true) => (),
79                _ => {
80                    if !output.status.success() {
81                        bail!("hook failed");
82                    }
83                },
84            }
85        }
86        info!("Done executing {:?} hooks", type_);
87    };
88
89    Ok(())
90}
91
92fn merge_vecs<T>(maybe_v1: &mut Option<Vec<T>>, maybe_v2: Option<Vec<T>>) {
93    if let Some(v2) = maybe_v2 {
94        if let Some(v1) = maybe_v1 { v1.extend(v2) } else { *maybe_v1 = Some(v2) }
95    }
96}
97
98#[cfg(test)]
99#[cfg_attr(coverage, coverage(off))]
100mod test {
101    use assert_fs::prelude::*;
102    use sk_testutils::*;
103
104    use super::*;
105
106    const HOOK1: &str = r#"
107---
108preStartHooks:
109  - cmd: prestart1
110    args:
111      - prestart-arg1
112      - prestart-arg2
113preRunHooks:
114  - cmd: prerun1
115    args:
116      - prerun-arg1
117postRunHooks:
118  - cmd: postrun1
119    args:
120      - postrun-arg1
121"#;
122
123    const HOOK2: &str = r#"
124---
125preStartHooks:
126  - cmd: prestart2
127  - cmd: prestart3
128preRunHooks:
129  - cmd: prerun2
130    args:
131      - prerun-arg2
132postStopHooks:
133  - cmd: poststop1
134    args:
135      - poststop-arg1
136"#;
137
138    const HOOK3: &str = r#"
139---
140preRunHooks:
141  - cmd: prerun3
142    args:
143      - prerun-arg3
144postRunHooks:
145  - cmd: postrun2
146    args:
147      - prerun-arg2
148"#;
149
150    const EXPECTED_MERGED: &str = r#"
151---
152preStartHooks:
153  - cmd: prestart1
154    args:
155      - prestart-arg1
156      - prestart-arg2
157  - cmd: prestart2
158  - cmd: prestart3
159preRunHooks:
160  - cmd: prerun1
161    args:
162      - prerun-arg1
163  - cmd: prerun2
164    args:
165      - prerun-arg2
166  - cmd: prerun3
167    args:
168      - prerun-arg3
169postRunHooks:
170  - cmd: postrun1
171    args:
172      - postrun-arg1
173  - cmd: postrun2
174    args:
175      - prerun-arg2
176postStopHooks:
177  - cmd: poststop1
178    args:
179      - poststop-arg1
180"#;
181
182    #[rstest]
183    fn test_merge_hooks() {
184        let temp = assert_fs::TempDir::new().unwrap();
185        let hook1 = temp.child("hook1.yml");
186        hook1.write_str(HOOK1).unwrap();
187        let hook2 = temp.child("hook2.yml");
188        hook2.write_str(HOOK2).unwrap();
189        let hook3 = temp.child("hook3.yml");
190        hook3.write_str(HOOK3).unwrap();
191
192        let merged_config = merge_hooks(&Some(vec![
193            hook1.path().to_str().unwrap().into(),
194            hook2.path().to_str().unwrap().into(),
195            hook3.path().to_str().unwrap().into(),
196        ]))
197        .unwrap()
198        .unwrap();
199        assert_eq!(merged_config, serde_yaml::from_str(EXPECTED_MERGED).unwrap());
200    }
201
202    #[rstest(tokio::test)]
203    async fn test_execute_hooks(test_sim: Simulation) {
204        // Should print "foo"
205        let res = execute(&test_sim, Type::PreStart).await;
206        assert!(res.is_ok());
207
208        // No PreStop hook defined
209        let res = execute(&test_sim, Type::PostStop).await;
210        assert!(res.is_ok());
211
212        // PreRun hook calls bad command
213        let res = execute(&test_sim, Type::PreRun).await;
214        assert!(res.is_err());
215    }
216}