sdf_parser_core/config/transform/
code.rs

1use std::collections::BTreeMap;
2
3use schemars::JsonSchema;
4use serde::{Serialize, Deserialize};
5
6use crate::config::import::StateImport;
7
8use super::{NamedParameterWrapper, ParameterWrapper};
9
10#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
11#[serde(untagged)]
12pub enum StepInvocationDefinition {
13    Code(Code),
14    Function(FunctionDefinition),
15}
16
17impl StepInvocationDefinition {
18    pub fn extra_deps(&self) -> Vec<Dependency> {
19        match self {
20            StepInvocationDefinition::Code(code) => code.dependencies.clone(),
21            StepInvocationDefinition::Function(function) => function.dependencies.clone(),
22        }
23    }
24
25    pub fn name(&self) -> Option<&str> {
26        match self {
27            StepInvocationDefinition::Code(_) => None,
28            StepInvocationDefinition::Function(function) => Some(&function.uses),
29        }
30    }
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
34pub struct Dependency {
35    pub name: String,
36    #[serde(flatten)]
37    pub version: DependencyVersion,
38    #[serde(default = "default_features")]
39    pub default_features: bool,
40    #[serde(default)]
41    pub features: Vec<String>,
42}
43
44fn default_features() -> bool {
45    true
46}
47
48impl Dependency {
49    pub fn to_rust_dependency(&self) -> String {
50        match &self.version {
51            DependencyVersion::Version { version } => {
52                if self.default_features && self.features.is_empty() {
53                    format!("{} = \"{}\"", self.name, version)
54                } else {
55                    let mut dep = format!("{} = {{ version = \"{}\"", self.name, version);
56                    if !self.default_features {
57                        dep.push_str(", default-features = false");
58                    }
59                    if !self.features.is_empty() {
60                        dep.push_str(&format!(
61                            ", features = [\"{}\"]",
62                            self.features.join("\", \"")
63                        ));
64                    }
65                    dep.push_str(" }");
66                    dep
67                }
68            }
69            DependencyVersion::Path { path } => {
70                let mut dep = format!("{} = {{ path = \"{}\"", self.name, path);
71                if !self.default_features {
72                    dep.push_str(", default-features = false");
73                }
74
75                if !self.features.is_empty() {
76                    dep.push_str(&format!(
77                        ", features = [\"{}\"]",
78                        self.features.join("\", \"")
79                    ));
80                }
81                dep.push_str(" }");
82                dep
83            }
84            DependencyVersion::Git {
85                git,
86                branch,
87                rev,
88                tag,
89            } => {
90                let mut git = format!("{} = {{ git = \"{}\"", self.name, git);
91                if let Some(branch) = branch {
92                    git.push_str(&format!(", branch = \"{}\"", branch));
93                }
94                if let Some(rev) = rev {
95                    git.push_str(&format!(", rev = \"{}\"", rev));
96                }
97
98                if let Some(tag) = tag {
99                    git.push_str(&format!(", tag = \"{}\"", tag));
100                }
101
102                if !self.default_features {
103                    git.push_str(", default-features = false");
104                }
105
106                if !self.features.is_empty() {
107                    git.push_str(&format!(
108                        ", features = [\"{}\"]",
109                        self.features.join("\", \"")
110                    ));
111                }
112
113                git.push_str(" }");
114                git
115            }
116        }
117    }
118}
119
120#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
121#[serde(untagged)]
122pub enum DependencyVersion {
123    Version {
124        version: String,
125    },
126    Path {
127        path: String,
128    },
129    Git {
130        git: String,
131        branch: Option<String>,
132        rev: Option<String>,
133        tag: Option<String>,
134    },
135}
136
137/// Serialization representation of the code
138#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
139#[serde(rename_all = "kebab-case")]
140pub struct Code {
141    #[serde(alias = "$key$")]
142    pub export_name: Option<String>, // used to validate pkg exports
143    #[serde(default)]
144    pub lang: Lang,
145    #[serde(skip_serializing_if = "Vec::is_empty", default)]
146    #[serde(rename = "states")]
147    pub state_imports: Vec<StateImport>,
148    #[serde(default)]
149    pub dependencies: Vec<Dependency>,
150    pub run: String,
151}
152
153/// Supported lang in sdf for build and generation
154#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema)]
155#[serde(rename_all = "kebab-case")]
156pub enum Lang {
157    #[default]
158    Rust,
159}
160
161#[derive(Serialize, Deserialize, Debug, Clone, Default, JsonSchema)]
162#[serde(rename_all = "kebab-case")]
163pub struct FunctionDefinition {
164    #[serde(alias = "$key$")]
165    pub uses: String,
166    #[serde(skip_serializing_if = "Vec::is_empty", default)]
167    #[serde(rename = "states")]
168    pub state_imports: Vec<StateImport>,
169    #[serde(skip_serializing_if = "Vec::is_empty", default)]
170    pub inputs: Vec<NamedParameterWrapper>,
171    #[serde(skip_serializing_if = "Option::is_none", default)]
172    pub output: Option<ParameterWrapper>,
173    #[serde(default)]
174    pub lang: Lang,
175    #[serde(skip_serializing_if = "Vec::is_empty", default)]
176    pub dependencies: Vec<Dependency>,
177    #[serde(default)]
178    pub with: BTreeMap<String, String>,
179}
180
181#[cfg(test)]
182mod tests {
183
184    use super::*;
185
186    #[test]
187    fn test_deserialize_code() {
188        let yaml = r#"
189lang: rust
190run: |
191    fn my_map(my_input: String) -> Result<String, String> {
192        println!("Hello, world!");
193    }
194"#;
195        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
196
197        let code = match code {
198            StepInvocationDefinition::Code(code) => code,
199            _ => panic!("Invalid parsed code"),
200        };
201
202        assert_eq!(code.lang, Lang::Rust);
203        assert!(code.state_imports.is_empty());
204    }
205
206    #[test]
207    fn test_deserialize_code_with_input_key() {
208        let yaml = r#"
209lang: rust
210run: |
211    fn my_map(key: Option<String>, my_input: String) -> Result<String, String> {
212        todo!()
213    }
214"#;
215        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
216
217        let code = match code {
218            StepInvocationDefinition::Code(code) => code,
219            _ => panic!("Invalid parsed code"),
220        };
221
222        assert_eq!(code.lang, Lang::Rust);
223        assert!(code.state_imports.is_empty());
224    }
225
226    #[test]
227    fn test_deserialize_code_with_output_key() {
228        let yaml = r#"
229lang: rust
230run: |
231    fn my_map(my_input: String) -> Result<(Option<i32>,String), String> {
232        println!("Hello, world!");
233    }
234"#;
235        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
236
237        let code = match code {
238            StepInvocationDefinition::Code(code) => code,
239            _ => panic!("Invalid code"),
240        };
241
242        assert_eq!(code.lang, Lang::Rust);
243        assert!(code.state_imports.is_empty());
244    }
245
246    #[test]
247    fn test_deserialize_flat_map_code_with_output_key() {
248        let yaml = r#"
249        lang: rust
250        run: |
251            fn my_flatmap(my_input: String) -> Result<Option<(Option<String>,String)>, String> {
252                println!("Hello, world!");
253            }
254        "#;
255
256        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
257
258        let code = match code {
259            StepInvocationDefinition::Code(code) => code,
260            _ => panic!("Invalid parsed code"),
261        };
262
263        assert_eq!(code.lang, Lang::Rust);
264        assert!(code.state_imports.is_empty());
265    }
266
267    #[test]
268    fn test_deserialize_function() {
269        let yaml = r#"
270lang: rust
271uses: my-map
272inputs:
273  - name: my-input
274    type: string
275output:
276    type: string
277"#;
278
279        let parsed_code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
280
281        let function = match parsed_code {
282            StepInvocationDefinition::Function(function) => function,
283            _ => panic!("Invalid parsed code"),
284        };
285        assert_eq!(function.uses, "my-map");
286        assert_eq!(function.inputs.len(), 1);
287        assert_eq!(function.inputs[0].name, "my-input");
288        assert_eq!(function.inputs[0].ty.ty(), "string");
289        assert_eq!(function.output.unwrap().ty.ty(), "string");
290    }
291
292    #[test]
293    fn test_deserialize_ambiguous_code_takes_priority() {
294        let yaml = r#"
295lang: rust
296run: |
297    fn my_map(my_input: String) -> Result<String, String> {
298        println!("Hello, world!");
299    }
300uses: my-map
301inputs:
302  - name: my-input
303    type: string
304output:
305    type: string
306"#;
307        let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
308
309        let code = match code {
310            StepInvocationDefinition::Code(code) => code,
311            _ => panic!("Invalid parsed code"),
312        };
313
314        assert_eq!(code.lang, Lang::Rust);
315        assert!(code.state_imports.is_empty());
316        code.run.contains("Hello, world!");
317    }
318}