Skip to main content

stepflow_flow/workflow/
variable_schema.rs

1// Copyright 2025 DataStax Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the License
9// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
10// or implied. See the License for the specific language governing permissions and limitations under
11// the License.
12
13use crate::{
14    schema::SchemaRef,
15    values::{Secrets, ValueRef},
16};
17use log::debug;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet};
20
21/// Variable schema for workflow variables using JSON Schema format.
22///
23/// This allows flows to declare required variables with their types,
24/// default values, descriptions, secret annotations, and environment
25/// variable mappings.
26///
27/// Example:
28/// ```yaml
29/// variables:
30///   type: object
31///   properties:
32///     api_key:
33///       type: string
34///       is_secret: true
35///       env_var: "OPENAI_API_KEY"
36///       description: "OpenAI API key"
37///     temperature:
38///       type: number
39///       default: 0.7
40///       minimum: 0
41///       maximum: 2
42///   required: ["api_key"]
43/// ```
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
45#[serde(from = "SchemaRef", into = "SchemaRef")]
46pub struct VariableSchema {
47    schema: SchemaRef,
48    variables: Vec<String>,
49    defaults: HashMap<String, ValueRef>,
50    secrets: Secrets,
51    required: HashSet<String>,
52    /// Mapping from variable name to the environment variable name
53    /// that should be used to populate it when `--env-variables` is enabled.
54    env_vars: HashMap<String, String>,
55}
56
57impl schemars::JsonSchema for VariableSchema {
58    fn schema_name() -> std::borrow::Cow<'static, str> {
59        <crate::schema::SchemaRef as schemars::JsonSchema>::schema_name()
60    }
61
62    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
63        <crate::schema::SchemaRef as schemars::JsonSchema>::json_schema(generator)
64    }
65}
66
67impl From<SchemaRef> for VariableSchema {
68    fn from(schema: SchemaRef) -> Self {
69        Self::new(schema)
70    }
71}
72
73impl From<VariableSchema> for SchemaRef {
74    fn from(var_schema: VariableSchema) -> Self {
75        var_schema.schema
76    }
77}
78
79impl VariableSchema {
80    /// Create a new variable schema from a JSON Schema.
81    pub fn new(schema: SchemaRef) -> Self {
82        let schema_value = schema.as_value();
83
84        let mut required = HashSet::new();
85        if let Some(required_array) = schema_value.get("required").and_then(|r| r.as_array()) {
86            for req in required_array {
87                if let Some(req_str) = req.as_str() {
88                    required.insert(req_str.to_string());
89                }
90            }
91        }
92
93        let mut variables = Vec::new();
94        let mut defaults = HashMap::new();
95        let mut env_vars = HashMap::new();
96        if let Some(properties) = schema_value.get("properties").and_then(|p| p.as_object()) {
97            for (var_name, var_schema) in properties {
98                variables.push(var_name.clone());
99
100                // Parse env_var annotation
101                if let Some(env_var) = var_schema.get("env_var").and_then(|v| v.as_str()) {
102                    env_vars.insert(var_name.clone(), env_var.to_string());
103                }
104
105                let var_type = var_schema.get("type");
106                let var_default = if let Some(default_value) = var_schema.get("default") {
107                    Some(default_value.clone())
108                } else if !required.contains(var_name) {
109                    match var_type {
110                        Some(serde_json::Value::String(type_str)) => match type_str.as_str() {
111                            "string" => Some(serde_json::Value::String("".to_string())),
112                            "number" | "integer" => Some(serde_json::Value::Number(0.into())),
113                            "boolean" => Some(serde_json::Value::Bool(false)),
114                            _ => None,
115                        },
116                        Some(serde_json::Value::Array(type_array)) => {
117                            if type_array
118                                .iter()
119                                .any(|t| t.as_str().is_some_and(|t| t == "null"))
120                            {
121                                Some(serde_json::Value::Null)
122                            } else {
123                                None
124                            }
125                        }
126                        _ => None,
127                    }
128                } else {
129                    None
130                };
131
132                if let Some(var_default) = var_default {
133                    defaults.insert(var_name.clone(), ValueRef::new(var_default));
134                } else {
135                    debug!(
136                        "Variable '{}' has no default and is not required; no default value inferred.",
137                        var_name
138                    );
139                }
140            }
141        }
142
143        let secrets = Secrets::from_schema(schema_value);
144        Self {
145            schema,
146            variables,
147            defaults,
148            secrets,
149            required,
150            env_vars,
151        }
152    }
153
154    pub fn secrets(&self) -> &Secrets {
155        &self.secrets
156    }
157
158    /// Return variable names from the schema properties.
159    pub fn variables(&self) -> &'_ [String] {
160        &self.variables
161    }
162
163    /// Get the environment variable name for a given variable, if annotated.
164    pub fn env_var_name(&self, variable_name: &str) -> Option<&str> {
165        self.env_vars.get(variable_name).map(|s| s.as_str())
166    }
167
168    /// Get the full mapping of variable names to environment variable names.
169    pub fn env_var_map(&self) -> &HashMap<String, String> {
170        &self.env_vars
171    }
172
173    /// Get the list of required variables.
174    pub fn required_variables(&self) -> impl Iterator<Item = &'_ str> + '_ {
175        // Variables in `required` that don't have a default value.
176        self.required.iter().map(|s| s.as_ref())
177    }
178
179    /// Get the default value for a variable, if specified.
180    pub fn default_value(&self, variable_name: &str) -> Option<ValueRef> {
181        self.defaults.get(variable_name).cloned()
182    }
183
184    /// Validate that provided variable values match the schema requirements.
185    pub fn validate_variables(
186        &self,
187        variables: &HashMap<String, serde_json::Value>,
188    ) -> Result<(), VariableValidationError> {
189        // Check that all required variables are provided
190        for required_var in self.required_variables() {
191            if !variables.contains_key(required_var) {
192                return Err(VariableValidationError::MissingVariable(
193                    required_var.to_string(),
194                ));
195            }
196        }
197
198        // TODO: Add JSON Schema validation for variable values
199        // This would require integrating with a JSON Schema validation library
200
201        Ok(())
202    }
203}
204
205/// Errors that can occur during variable validation.
206#[derive(Debug, thiserror::Error, PartialEq)]
207pub enum VariableValidationError {
208    #[error("Missing required variable: {0}")]
209    MissingVariable(String),
210    #[error("Invalid variable value for '{variable}': {message}")]
211    InvalidValue { variable: String, message: String },
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use serde_json::json;
218
219    #[test]
220    fn test_variable_schema_creation() {
221        let schema_json = json!({
222            "type": "object",
223            "properties": {
224                "api_key": {
225                    "type": "string",
226                    "is_secret": true,
227                    "description": "API key for external service"
228                },
229                "temperature": {
230                    "type": "number",
231                    "default": 0.7,
232                    "minimum": 0,
233                    "maximum": 2
234                }
235            },
236            "required": ["api_key"]
237        });
238
239        let schema = SchemaRef::parse_json(&schema_json.to_string()).unwrap();
240        let var_schema = VariableSchema::new(schema);
241
242        let variable_names = var_schema.variables();
243        assert_eq!(variable_names.len(), 2);
244        assert!(variable_names.contains(&"api_key".to_string()));
245        assert!(variable_names.contains(&"temperature".to_string()));
246
247        let required: Vec<_> = var_schema.required_variables().collect();
248        assert_eq!(required, vec!["api_key"]);
249
250        assert!(var_schema.secrets.field("api_key").is_secret());
251        assert!(!var_schema.secrets.field("temperature").is_secret());
252
253        assert_eq!(
254            var_schema
255                .default_value("temperature")
256                .map(|v| v.clone_value()),
257            Some(json!(0.7))
258        );
259        assert_eq!(var_schema.default_value("api_key"), None);
260    }
261
262    #[test]
263    fn test_env_var_annotation() {
264        let schema_json = json!({
265            "type": "object",
266            "properties": {
267                "api_key": {
268                    "type": "string",
269                    "is_secret": true,
270                    "env_var": "OPENAI_API_KEY"
271                },
272                "temperature": {
273                    "type": "number",
274                    "default": 0.7
275                },
276                "db_url": {
277                    "type": "string",
278                    "env_var": "DATABASE_URL"
279                }
280            },
281            "required": ["api_key"]
282        });
283
284        let schema = SchemaRef::parse_json(&schema_json.to_string()).unwrap();
285        let var_schema = VariableSchema::new(schema);
286
287        assert_eq!(var_schema.env_var_name("api_key"), Some("OPENAI_API_KEY"));
288        assert_eq!(var_schema.env_var_name("temperature"), None);
289        assert_eq!(var_schema.env_var_name("db_url"), Some("DATABASE_URL"));
290        assert_eq!(var_schema.env_var_name("nonexistent"), None);
291
292        let env_map = var_schema.env_var_map();
293        assert_eq!(env_map.len(), 2);
294        assert_eq!(env_map.get("api_key").unwrap(), "OPENAI_API_KEY");
295        assert_eq!(env_map.get("db_url").unwrap(), "DATABASE_URL");
296    }
297
298    #[test]
299    fn test_variable_validation() {
300        let schema_json = json!({
301            "type": "object",
302            "properties": {
303                "api_key": { "type": "string" },
304                "temperature": { "type": "number", "default": 0.7 }
305            },
306            "required": ["api_key"]
307        });
308
309        let schema = SchemaRef::parse_json(&schema_json.to_string()).unwrap();
310        let var_schema = VariableSchema::new(schema);
311
312        // Valid variables
313        let mut variables = HashMap::new();
314        variables.insert("api_key".to_string(), json!("test-key"));
315        variables.insert("temperature".to_string(), json!(0.8));
316        assert!(var_schema.validate_variables(&variables).is_ok());
317
318        // Missing required variable
319        let mut missing_required = HashMap::new();
320        missing_required.insert("temperature".to_string(), json!(0.8));
321        match var_schema.validate_variables(&missing_required) {
322            Err(VariableValidationError::MissingVariable(var)) => {
323                assert_eq!(var, "api_key");
324            }
325            _ => panic!("Expected missing variable error"),
326        }
327
328        // Optional variable missing is OK
329        let mut only_required = HashMap::new();
330        only_required.insert("api_key".to_string(), json!("test-key"));
331        assert!(var_schema.validate_variables(&only_required).is_ok());
332    }
333
334    #[test]
335    fn test_default_variable_schema() {
336        let default_schema = VariableSchema::default();
337        assert!(default_schema.variables().is_empty());
338        assert_eq!(default_schema.required_variables().count(), 0);
339    }
340
341    #[test]
342    fn test_default_value() {
343        let schema_json = json!({
344            "type": "object",
345            "properties": {
346                "default_bool": { "type": "boolean", "default": true },
347                "default_str": { "type": "string", "default": "hello" },
348                "default_num": { "type": "number", "default": 3.15 },
349                "optional_bool": { "type": "boolean" },
350                "optional_str": { "type": "string" },
351                "optional_num": { "type": "number" },
352                "optional_str_or_none": { "type": ["string", "null"] },
353                "required_str": { "type": "string" },
354            },
355            "required": ["required_str"]
356        });
357
358        let schema = SchemaRef::parse_json(&schema_json.to_string()).unwrap();
359        let variable_schema = VariableSchema::new(schema);
360
361        assert_eq!(
362            variable_schema
363                .default_value("default_bool")
364                .map(|v| v.clone_value()),
365            Some(json!(true))
366        );
367        assert_eq!(
368            variable_schema
369                .default_value("default_str")
370                .map(|v| v.clone_value()),
371            Some(json!("hello"))
372        );
373        assert_eq!(
374            variable_schema
375                .default_value("default_num")
376                .map(|v| v.clone_value()),
377            Some(json!(3.15))
378        );
379        assert_eq!(
380            variable_schema
381                .default_value("optional_bool")
382                .map(|v| v.clone_value()),
383            Some(json!(false))
384        );
385        assert_eq!(
386            variable_schema
387                .default_value("optional_str")
388                .map(|v| v.clone_value()),
389            Some(json!(""))
390        );
391        assert_eq!(
392            variable_schema
393                .default_value("optional_num")
394                .map(|v| v.clone_value()),
395            Some(json!(0))
396        );
397        assert_eq!(
398            variable_schema
399                .default_value("optional_str_or_none")
400                .map(|v| v.clone_value()),
401            Some(serde_json::Value::Null)
402        );
403        assert_eq!(variable_schema.default_value("required_str"), None);
404    }
405}