sdf_parser_core/config/transform/
code.rs

1use serde::{Serialize, Deserialize};
2
3use crate::config::import::StateImport;
4
5use super::{NamedParameterWrapper, ParameterWrapper};
6
7#[derive(Serialize, Deserialize, Debug, Clone)]
8#[serde(untagged)]
9pub enum StepInvocationDefinition {
10    Code(Code),
11    Function(FunctionDefinition),
12}
13
14impl StepInvocationDefinition {
15    pub fn extra_deps(&self) -> Vec<Dependency> {
16        match self {
17            StepInvocationDefinition::Code(code) => code.dependencies.clone(),
18            StepInvocationDefinition::Function(function) => function.dependencies.clone(),
19        }
20    }
21
22    pub fn name(&self) -> Option<&str> {
23        match self {
24            StepInvocationDefinition::Code(_) => None,
25            StepInvocationDefinition::Function(function) => Some(&function.uses),
26        }
27    }
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31pub struct Dependency {
32    pub name: String,
33    #[serde(flatten)]
34    pub version: DependencyVersion,
35    #[serde(default = "default_features")]
36    pub default_features: bool,
37    #[serde(default)]
38    pub features: Vec<String>,
39}
40
41fn default_features() -> bool {
42    true
43}
44
45impl Dependency {
46    pub fn to_rust_dependency(&self) -> String {
47        match &self.version {
48            DependencyVersion::Version { version } => {
49                if self.default_features && self.features.is_empty() {
50                    format!("{} = \"{}\"", self.name, version)
51                } else {
52                    let mut dep = format!("{} = {{ version = \"{}\"", self.name, version);
53                    if !self.default_features {
54                        dep.push_str(", default-features = false");
55                    }
56                    if !self.features.is_empty() {
57                        dep.push_str(&format!(
58                            ", features = [\"{}\"]",
59                            self.features.join("\", \"")
60                        ));
61                    }
62                    dep.push_str(" }");
63                    dep
64                }
65            }
66            DependencyVersion::Path { path } => {
67                let mut dep = format!("{} = {{ path = \"{}\"", self.name, path);
68                if !self.default_features {
69                    dep.push_str(", default-features = false");
70                }
71
72                if !self.features.is_empty() {
73                    dep.push_str(&format!(
74                        ", features = [\"{}\"]",
75                        self.features.join("\", \"")
76                    ));
77                }
78                dep.push_str(" }");
79                dep
80            }
81            DependencyVersion::Git {
82                git,
83                branch,
84                rev,
85                tag,
86            } => {
87                let mut git = format!("{} = {{ git = \"{}\"", self.name, git);
88                if let Some(branch) = branch {
89                    git.push_str(&format!(", branch = \"{}\"", branch));
90                }
91                if let Some(rev) = rev {
92                    git.push_str(&format!(", rev = \"{}\"", rev));
93                }
94
95                if let Some(tag) = tag {
96                    git.push_str(&format!(", tag = \"{}\"", tag));
97                }
98
99                if !self.default_features {
100                    git.push_str(", default-features = false");
101                }
102
103                if !self.features.is_empty() {
104                    git.push_str(&format!(
105                        ", features = [\"{}\"]",
106                        self.features.join("\", \"")
107                    ));
108                }
109
110                git.push_str(" }");
111                git
112            }
113        }
114    }
115}
116
117#[derive(Serialize, Deserialize, Debug, Clone)]
118#[serde(untagged)]
119pub enum DependencyVersion {
120    Version {
121        version: String,
122    },
123    Path {
124        path: String,
125    },
126    Git {
127        git: String,
128        branch: Option<String>,
129        rev: Option<String>,
130        tag: Option<String>,
131    },
132}
133
134/// Serialization representation of the code
135#[derive(Serialize, Deserialize, Debug, Clone)]
136#[serde(rename_all = "kebab-case")]
137pub struct Code {
138    #[serde(alias = "$key$")]
139    pub export_name: Option<String>, // used to validate pkg exports
140    #[serde(default)]
141    pub lang: Lang,
142    #[serde(skip_serializing_if = "Vec::is_empty", default)]
143    #[serde(rename = "states")]
144    pub state_imports: Vec<StateImport>,
145    #[serde(default)]
146    pub dependencies: Vec<Dependency>,
147    pub run: String,
148}
149
150/// Supported lang in sdf for build and generation
151#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
152#[serde(rename_all = "kebab-case")]
153pub enum Lang {
154    #[default]
155    Rust,
156}
157
158#[derive(Serialize, Deserialize, Debug, Clone, Default)]
159#[serde(rename_all = "kebab-case")]
160pub struct FunctionDefinition {
161    #[serde(alias = "$key$")]
162    pub uses: String,
163    #[serde(skip_serializing_if = "Vec::is_empty", default)]
164    #[serde(rename = "states")]
165    pub state_imports: Vec<StateImport>,
166    #[serde(skip_serializing_if = "Vec::is_empty", default)]
167    pub inputs: Vec<NamedParameterWrapper>,
168    #[serde(skip_serializing_if = "Option::is_none", default)]
169    pub output: Option<ParameterWrapper>,
170    #[serde(default)]
171    pub lang: Lang,
172    #[serde(skip_serializing_if = "Vec::is_empty", default)]
173    pub dependencies: Vec<Dependency>,
174}
175
176#[cfg(test)]
177mod tests {
178
179    use super::*;
180
181    #[test]
182    fn test_deserialize_code() {
183        let yaml = r#"
184lang: rust
185run: |
186    fn my_map(my_input: String) -> Result<String, String> {
187        println!("Hello, world!");
188    }
189"#;
190        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
191
192        let code = match code {
193            StepInvocationDefinition::Code(code) => code,
194            _ => panic!("Invalid parsed code"),
195        };
196
197        assert_eq!(code.lang, Lang::Rust);
198        assert!(code.state_imports.is_empty());
199    }
200
201    #[test]
202    fn test_deserialize_code_with_input_key() {
203        let yaml = r#"
204lang: rust
205run: |
206    fn my_map(key: Option<String>, my_input: String) -> Result<String, String> {
207        todo!()
208    }
209"#;
210        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
211
212        let code = match code {
213            StepInvocationDefinition::Code(code) => code,
214            _ => panic!("Invalid parsed code"),
215        };
216
217        assert_eq!(code.lang, Lang::Rust);
218        assert!(code.state_imports.is_empty());
219    }
220
221    #[test]
222    fn test_deserialize_code_with_output_key() {
223        let yaml = r#"
224lang: rust
225run: |
226    fn my_map(my_input: String) -> Result<(Option<i32>,String), String> {
227        println!("Hello, world!");
228    }
229"#;
230        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
231
232        let code = match code {
233            StepInvocationDefinition::Code(code) => code,
234            _ => panic!("Invalid code"),
235        };
236
237        assert_eq!(code.lang, Lang::Rust);
238        assert!(code.state_imports.is_empty());
239    }
240
241    #[test]
242    fn test_deserialize_flat_map_code_with_output_key() {
243        let yaml = r#"
244        lang: rust
245        run: |
246            fn my_flatmap(my_input: String) -> Result<Option<(Option<String>,String)>, String> {
247                println!("Hello, world!");
248            }
249        "#;
250
251        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
252
253        let code = match code {
254            StepInvocationDefinition::Code(code) => code,
255            _ => panic!("Invalid parsed code"),
256        };
257
258        assert_eq!(code.lang, Lang::Rust);
259        assert!(code.state_imports.is_empty());
260    }
261
262    #[test]
263    fn test_deserialize_function() {
264        let yaml = r#"
265lang: rust
266uses: my-map
267inputs:
268  - name: my-input
269    type: string
270output:
271    type: string
272"#;
273
274        let parsed_code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
275
276        let function = match parsed_code {
277            StepInvocationDefinition::Function(function) => function,
278            _ => panic!("Invalid parsed code"),
279        };
280        assert_eq!(function.uses, "my-map");
281        assert_eq!(function.inputs.len(), 1);
282        assert_eq!(function.inputs[0].name, "my-input");
283        assert_eq!(function.inputs[0].ty.ty(), "string");
284        assert_eq!(function.output.unwrap().ty.ty(), "string");
285    }
286
287    #[test]
288    fn test_deserialize_ambiguous_code_takes_priority() {
289        let yaml = r#"
290lang: rust
291run: |
292    fn my_map(my_input: String) -> Result<String, String> {
293        println!("Hello, world!");
294    }
295uses: my-map
296inputs:
297  - name: my-input
298    type: string
299output:
300    type: string
301"#;
302        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
303
304        let code = match code {
305            StepInvocationDefinition::Code(code) => code,
306            _ => panic!("Invalid parsed code"),
307        };
308
309        assert_eq!(code.lang, Lang::Rust);
310        assert!(code.state_imports.is_empty());
311        code.run.contains("Hello, world!");
312    }
313}