sdf_parser_core/config/transform/
code.rs1use std::collections::BTreeMap;
2
3use schemars::JsonSchema;
4use serde::{Serialize, Deserialize};
5
6use crate::config::import::StateImport;
7
8use super::{NamedParameterWrapper, ParameterWrapper};
9
10#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
11#[serde(untagged)]
12pub enum StepInvocationDefinition {
13 Code(Code),
14 Function(FunctionDefinition),
15}
16
17impl StepInvocationDefinition {
18 pub fn extra_deps(&self) -> Vec<Dependency> {
19 match self {
20 StepInvocationDefinition::Code(code) => code.dependencies.clone(),
21 StepInvocationDefinition::Function(function) => function.dependencies.clone(),
22 }
23 }
24
25 pub fn name(&self) -> Option<&str> {
26 match self {
27 StepInvocationDefinition::Code(_) => None,
28 StepInvocationDefinition::Function(function) => Some(&function.uses),
29 }
30 }
31}
32
33#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
34pub struct Dependency {
35 pub name: String,
36 #[serde(flatten)]
37 pub version: DependencyVersion,
38 #[serde(default = "default_features")]
39 pub default_features: bool,
40 #[serde(default)]
41 pub features: Vec<String>,
42}
43
44fn default_features() -> bool {
45 true
46}
47
48impl Dependency {
49 pub fn to_rust_dependency(&self) -> String {
50 match &self.version {
51 DependencyVersion::Version { version } => {
52 if self.default_features && self.features.is_empty() {
53 format!("{} = \"{}\"", self.name, version)
54 } else {
55 let mut dep = format!("{} = {{ version = \"{}\"", self.name, version);
56 if !self.default_features {
57 dep.push_str(", default-features = false");
58 }
59 if !self.features.is_empty() {
60 dep.push_str(&format!(
61 ", features = [\"{}\"]",
62 self.features.join("\", \"")
63 ));
64 }
65 dep.push_str(" }");
66 dep
67 }
68 }
69 DependencyVersion::Path { path } => {
70 let mut dep = format!("{} = {{ path = \"{}\"", self.name, path);
71 if !self.default_features {
72 dep.push_str(", default-features = false");
73 }
74
75 if !self.features.is_empty() {
76 dep.push_str(&format!(
77 ", features = [\"{}\"]",
78 self.features.join("\", \"")
79 ));
80 }
81 dep.push_str(" }");
82 dep
83 }
84 DependencyVersion::Git {
85 git,
86 branch,
87 rev,
88 tag,
89 } => {
90 let mut git = format!("{} = {{ git = \"{}\"", self.name, git);
91 if let Some(branch) = branch {
92 git.push_str(&format!(", branch = \"{}\"", branch));
93 }
94 if let Some(rev) = rev {
95 git.push_str(&format!(", rev = \"{}\"", rev));
96 }
97
98 if let Some(tag) = tag {
99 git.push_str(&format!(", tag = \"{}\"", tag));
100 }
101
102 if !self.default_features {
103 git.push_str(", default-features = false");
104 }
105
106 if !self.features.is_empty() {
107 git.push_str(&format!(
108 ", features = [\"{}\"]",
109 self.features.join("\", \"")
110 ));
111 }
112
113 git.push_str(" }");
114 git
115 }
116 }
117 }
118}
119
120#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
121#[serde(untagged)]
122pub enum DependencyVersion {
123 Version {
124 version: String,
125 },
126 Path {
127 path: String,
128 },
129 Git {
130 git: String,
131 branch: Option<String>,
132 rev: Option<String>,
133 tag: Option<String>,
134 },
135}
136
137#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
139#[serde(rename_all = "kebab-case")]
140pub struct Code {
141 #[serde(alias = "$key$")]
142 pub export_name: Option<String>, #[serde(default)]
144 pub lang: Lang,
145 #[serde(skip_serializing_if = "Vec::is_empty", default)]
146 #[serde(rename = "states")]
147 pub state_imports: Vec<StateImport>,
148 #[serde(default)]
149 pub dependencies: Vec<Dependency>,
150 pub run: String,
151}
152
153#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema)]
155#[serde(rename_all = "kebab-case")]
156pub enum Lang {
157 #[default]
158 Rust,
159}
160
161#[derive(Serialize, Deserialize, Debug, Clone, Default, JsonSchema)]
162#[serde(rename_all = "kebab-case")]
163pub struct FunctionDefinition {
164 #[serde(alias = "$key$")]
165 pub uses: String,
166 #[serde(skip_serializing_if = "Vec::is_empty", default)]
167 #[serde(rename = "states")]
168 pub state_imports: Vec<StateImport>,
169 #[serde(skip_serializing_if = "Vec::is_empty", default)]
170 pub inputs: Vec<NamedParameterWrapper>,
171 #[serde(skip_serializing_if = "Option::is_none", default)]
172 pub output: Option<ParameterWrapper>,
173 #[serde(default)]
174 pub lang: Lang,
175 #[serde(skip_serializing_if = "Vec::is_empty", default)]
176 pub dependencies: Vec<Dependency>,
177 #[serde(default)]
178 pub with: BTreeMap<String, String>,
179}
180
181#[cfg(test)]
182mod tests {
183
184 use super::*;
185
186 #[test]
187 fn test_deserialize_code() {
188 let yaml = r#"
189lang: rust
190run: |
191 fn my_map(my_input: String) -> Result<String, String> {
192 println!("Hello, world!");
193 }
194"#;
195 let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
196
197 let code = match code {
198 StepInvocationDefinition::Code(code) => code,
199 _ => panic!("Invalid parsed code"),
200 };
201
202 assert_eq!(code.lang, Lang::Rust);
203 assert!(code.state_imports.is_empty());
204 }
205
206 #[test]
207 fn test_deserialize_code_with_input_key() {
208 let yaml = r#"
209lang: rust
210run: |
211 fn my_map(key: Option<String>, my_input: String) -> Result<String, String> {
212 todo!()
213 }
214"#;
215 let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
216
217 let code = match code {
218 StepInvocationDefinition::Code(code) => code,
219 _ => panic!("Invalid parsed code"),
220 };
221
222 assert_eq!(code.lang, Lang::Rust);
223 assert!(code.state_imports.is_empty());
224 }
225
226 #[test]
227 fn test_deserialize_code_with_output_key() {
228 let yaml = r#"
229lang: rust
230run: |
231 fn my_map(my_input: String) -> Result<(Option<i32>,String), String> {
232 println!("Hello, world!");
233 }
234"#;
235 let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
236
237 let code = match code {
238 StepInvocationDefinition::Code(code) => code,
239 _ => panic!("Invalid code"),
240 };
241
242 assert_eq!(code.lang, Lang::Rust);
243 assert!(code.state_imports.is_empty());
244 }
245
246 #[test]
247 fn test_deserialize_flat_map_code_with_output_key() {
248 let yaml = r#"
249 lang: rust
250 run: |
251 fn my_flatmap(my_input: String) -> Result<Option<(Option<String>,String)>, String> {
252 println!("Hello, world!");
253 }
254 "#;
255
256 let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
257
258 let code = match code {
259 StepInvocationDefinition::Code(code) => code,
260 _ => panic!("Invalid parsed code"),
261 };
262
263 assert_eq!(code.lang, Lang::Rust);
264 assert!(code.state_imports.is_empty());
265 }
266
267 #[test]
268 fn test_deserialize_function() {
269 let yaml = r#"
270lang: rust
271uses: my-map
272inputs:
273 - name: my-input
274 type: string
275output:
276 type: string
277"#;
278
279 let parsed_code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
280
281 let function = match parsed_code {
282 StepInvocationDefinition::Function(function) => function,
283 _ => panic!("Invalid parsed code"),
284 };
285 assert_eq!(function.uses, "my-map");
286 assert_eq!(function.inputs.len(), 1);
287 assert_eq!(function.inputs[0].name, "my-input");
288 assert_eq!(function.inputs[0].ty.ty(), "string");
289 assert_eq!(function.output.unwrap().ty.ty(), "string");
290 }
291
292 #[test]
293 fn test_deserialize_ambiguous_code_takes_priority() {
294 let yaml = r#"
295lang: rust
296run: |
297 fn my_map(my_input: String) -> Result<String, String> {
298 println!("Hello, world!");
299 }
300uses: my-map
301inputs:
302 - name: my-input
303 type: string
304output:
305 type: string
306"#;
307 let code: StepInvocationDefinition = serde_yaml::from_str(yaml).expect("parse yaml");
308
309 let code = match code {
310 StepInvocationDefinition::Code(code) => code,
311 _ => panic!("Invalid parsed code"),
312 };
313
314 assert_eq!(code.lang, Lang::Rust);
315 assert!(code.state_imports.is_empty());
316 code.run.contains("Hello, world!");
317 }
318}