Skip to main content

zig_core/workflow/
validate.rs

1use std::collections::{HashMap, HashSet};
2
3use regex::Regex;
4
5use crate::error::ZigError;
6use crate::workflow::model::{
7    FailurePolicy, StepCommand, StorageKind, VarType, Variable, Workflow,
8};
9
10/// Validate a parsed workflow for structural correctness.
11///
12/// Checks:
13/// - At least one step exists
14/// - Step names are unique
15/// - `depends_on` references exist
16/// - No dependency cycles
17/// - `next` references exist
18/// - Variable references in prompts refer to declared variables
19/// - `saves` variable names are declared
20/// - Condition variable references are declared
21pub fn validate(workflow: &Workflow) -> Result<(), Vec<ZigError>> {
22    let mut errors = Vec::new();
23
24    if workflow.steps.is_empty() {
25        errors.push(ZigError::Validation(
26            "workflow must have at least one step".into(),
27        ));
28        return Err(errors);
29    }
30
31    let step_names: HashSet<&str> = workflow.steps.iter().map(|s| s.name.as_str()).collect();
32    let var_names: HashSet<&str> = workflow.vars.keys().map(|k| k.as_str()).collect();
33    let role_names: HashSet<&str> = workflow.roles.keys().map(|k| k.as_str()).collect();
34    let storage_names: HashSet<&str> = workflow.storage.keys().map(|k| k.as_str()).collect();
35
36    // Storage-level checks — spec-internal consistency.
37    for (name, spec) in &workflow.storage {
38        if spec.path.trim().is_empty() {
39            errors.push(ZigError::Validation(format!(
40                "storage '{name}' has an empty path"
41            )));
42        }
43        if matches!(spec.kind, StorageKind::File) && !spec.files.is_empty() {
44            errors.push(ZigError::Validation(format!(
45                "storage '{name}' has type = \"file\" but also declares 'files' hints \
46                 (the 'files' subtable is only valid for type = \"folder\")"
47            )));
48        }
49        for file in &spec.files {
50            if file.name.contains('/') || file.name.contains('\\') {
51                errors.push(ZigError::Validation(format!(
52                    "storage '{name}' file hint '{}' must be a bare filename, not a path",
53                    file.name
54                )));
55            }
56        }
57    }
58
59    // Check unique step names
60    let mut seen_names = HashSet::new();
61    for step in &workflow.steps {
62        if !seen_names.insert(&step.name) {
63            errors.push(ZigError::Validation(format!(
64                "duplicate step name: '{}'",
65                step.name
66            )));
67        }
68    }
69
70    for step in &workflow.steps {
71        // Check depends_on references
72        for dep in &step.depends_on {
73            if !step_names.contains(dep.as_str()) {
74                errors.push(ZigError::Validation(format!(
75                    "step '{}' depends on unknown step '{dep}'",
76                    step.name
77                )));
78            }
79            if dep == &step.name {
80                errors.push(ZigError::Validation(format!(
81                    "step '{}' depends on itself",
82                    step.name
83                )));
84            }
85        }
86
87        // Check next references
88        if let Some(next) = &step.next {
89            if !step_names.contains(next.as_str()) {
90                errors.push(ZigError::Validation(format!(
91                    "step '{}' references unknown next step '{next}'",
92                    step.name
93                )));
94            }
95        }
96
97        // Check variable references in prompt
98        for var_ref in extract_var_refs(&step.prompt) {
99            if !var_names.contains(var_ref.as_str()) {
100                errors.push(ZigError::Validation(format!(
101                    "step '{}' prompt references unknown variable '${{{var_ref}}}'",
102                    step.name
103                )));
104            }
105        }
106
107        // Check variable references in system_prompt
108        if let Some(system_prompt) = &step.system_prompt {
109            for var_ref in extract_var_refs(system_prompt) {
110                if !var_names.contains(var_ref.as_str()) {
111                    errors.push(ZigError::Validation(format!(
112                        "step '{}' system_prompt references unknown variable '${{{var_ref}}}'",
113                        step.name
114                    )));
115                }
116            }
117        }
118
119        // Check storage scoping references
120        if let Some(scope) = &step.storage {
121            for name in scope {
122                if !storage_names.contains(name.as_str()) {
123                    errors.push(ZigError::Validation(format!(
124                        "step '{}' storage scope references unknown storage '{name}'",
125                        step.name
126                    )));
127                }
128            }
129        }
130
131        // Check role and system_prompt are mutually exclusive
132        if step.role.is_some() && step.system_prompt.is_some() {
133            errors.push(ZigError::Validation(format!(
134                "step '{}' sets both 'role' and 'system_prompt' (they are mutually exclusive)",
135                step.name
136            )));
137        }
138
139        // Check role references
140        if let Some(role_ref) = &step.role {
141            let var_refs = extract_var_refs(role_ref);
142            if var_refs.is_empty() {
143                // Static role reference — must exist in [roles]
144                if !role_names.contains(role_ref.as_str()) {
145                    errors.push(ZigError::Validation(format!(
146                        "step '{}' role references unknown role '{role_ref}'",
147                        step.name
148                    )));
149                }
150            } else {
151                // Dynamic role reference — validate variable refs
152                for var_ref in var_refs {
153                    if !var_names.contains(var_ref.as_str()) {
154                        errors.push(ZigError::Validation(format!(
155                            "step '{}' role references unknown variable '${{{var_ref}}}'",
156                            step.name
157                        )));
158                    }
159                }
160            }
161        }
162
163        // Check saves reference declared variables
164        for var_name in step.saves.keys() {
165            if !var_names.contains(var_name.as_str()) {
166                errors.push(ZigError::Validation(format!(
167                    "step '{}' saves to unknown variable '{var_name}'",
168                    step.name
169                )));
170            }
171        }
172
173        // Check condition variable references
174        if let Some(cond) = &step.condition {
175            for var_ref in extract_condition_vars(cond) {
176                if !var_names.contains(var_ref.as_str()) && !step_names.contains(var_ref.as_str()) {
177                    errors.push(ZigError::Validation(format!(
178                        "step '{}' condition references unknown variable '{var_ref}'",
179                        step.name
180                    )));
181                }
182            }
183        }
184
185        // Check retry_model requires on_failure = "retry"
186        if step.retry_model.is_some() && step.on_failure.as_ref() != Some(&FailurePolicy::Retry) {
187            errors.push(ZigError::Validation(format!(
188                "step '{}' sets retry_model but on_failure is not 'retry'",
189                step.name
190            )));
191        }
192
193        // Check mcp_config requires claude provider (or no provider specified).
194        // The effective provider considers both step-level and workflow-level defaults.
195        if step.mcp_config.is_some() {
196            let effective_provider = step
197                .provider
198                .as_ref()
199                .or(workflow.workflow.provider.as_ref());
200            if let Some(provider) = effective_provider {
201                if provider != "claude" {
202                    errors.push(ZigError::Validation(format!(
203                        "step '{}' sets mcp_config but provider is '{}' \
204                         (mcp_config is only supported with the claude provider)",
205                        step.name, provider
206                    )));
207                }
208            }
209        }
210
211        // Check output format is a valid value
212        if let Some(ref output) = step.output {
213            let valid_formats = ["text", "json", "json-pretty", "stream-json", "native-json"];
214            if !valid_formats.contains(&output.as_str()) {
215                errors.push(ZigError::Validation(format!(
216                    "step '{}' has invalid output format '{}' \
217                     (must be one of: text, json, json-pretty, stream-json, native-json)",
218                    step.name, output
219                )));
220            }
221        }
222
223        // Check review-only fields require command = "review"
224        let is_review = step.command.as_ref() == Some(&StepCommand::Review);
225        if !is_review {
226            for (field, set) in [("uncommitted", step.uncommitted)] {
227                if set {
228                    errors.push(ZigError::Validation(format!(
229                        "step '{}' sets '{}' but command is not 'review'",
230                        step.name, field
231                    )));
232                }
233            }
234            for (field, set) in [
235                ("base", step.base.is_some()),
236                ("commit", step.commit.is_some()),
237                ("title", step.title.is_some()),
238            ] {
239                if set {
240                    errors.push(ZigError::Validation(format!(
241                        "step '{}' sets '{}' but command is not 'review'",
242                        step.name, field
243                    )));
244                }
245            }
246        }
247
248        // Check plan-only fields require command = "plan"
249        let is_plan = step.command.as_ref() == Some(&StepCommand::Plan);
250        if !is_plan {
251            for (field, set) in [
252                ("plan_output", step.plan_output.is_some()),
253                ("instructions", step.instructions.is_some()),
254            ] {
255                if set {
256                    errors.push(ZigError::Validation(format!(
257                        "step '{}' sets '{}' but command is not 'plan'",
258                        step.name, field
259                    )));
260                }
261            }
262        }
263
264        // Check pipe/collect/summary require depends_on
265        if let Some(ref cmd) = step.command {
266            match cmd {
267                StepCommand::Pipe | StepCommand::Collect | StepCommand::Summary => {
268                    if step.depends_on.is_empty() {
269                        errors.push(ZigError::Validation(format!(
270                            "step '{}' uses command '{}' but has no depends_on \
271                             (pipe/collect/summary operate on prior session outputs)",
272                            step.name,
273                            match cmd {
274                                StepCommand::Pipe => "pipe",
275                                StepCommand::Collect => "collect",
276                                StepCommand::Summary => "summary",
277                                _ => unreachable!(),
278                            }
279                        )));
280                    }
281                }
282                _ => {}
283            }
284        }
285    }
286
287    // Check role definitions
288    for (role_name, role) in &workflow.roles {
289        // system_prompt and system_prompt_file are mutually exclusive
290        if role.system_prompt.is_some() && role.system_prompt_file.is_some() {
291            errors.push(ZigError::Validation(format!(
292                "role '{role_name}' sets both 'system_prompt' and 'system_prompt_file' \
293                 (they are mutually exclusive)"
294            )));
295        }
296
297        // Validate ${var} references in role system_prompt
298        if let Some(ref sp) = role.system_prompt {
299            for var_ref in extract_var_refs(sp) {
300                if !var_names.contains(var_ref.as_str()) {
301                    errors.push(ZigError::Validation(format!(
302                        "role '{role_name}' system_prompt references unknown variable \
303                         '${{{var_ref}}}'"
304                    )));
305                }
306            }
307        }
308    }
309
310    // Check race_group: steps in the same group must not depend on each other
311    let mut race_groups: HashMap<&str, Vec<&str>> = HashMap::new();
312    for step in &workflow.steps {
313        if let Some(ref group) = step.race_group {
314            race_groups
315                .entry(group.as_str())
316                .or_default()
317                .push(step.name.as_str());
318        }
319    }
320    for (group, members) in &race_groups {
321        let member_set: HashSet<&str> = members.iter().copied().collect();
322        for step in &workflow.steps {
323            if step.race_group.as_deref() == Some(*group) {
324                for dep in &step.depends_on {
325                    if member_set.contains(dep.as_str()) {
326                        errors.push(ZigError::Validation(format!(
327                            "step '{}' depends on '{}' but both are in race_group '{}' \
328                             (race members must not depend on each other)",
329                            step.name, dep, group
330                        )));
331                    }
332                }
333            }
334        }
335    }
336
337    // Check variable constraints
338    validate_var_constraints(&workflow.vars, &mut errors);
339
340    // Check for dependency cycles
341    if let Some(cycle) = detect_cycle(&workflow.steps) {
342        errors.push(ZigError::Validation(format!(
343            "dependency cycle detected: {}",
344            cycle.join(" -> ")
345        )));
346    }
347
348    if errors.is_empty() {
349        Ok(())
350    } else {
351        Err(errors)
352    }
353}
354
355/// Validate variable constraint declarations for structural correctness.
356fn validate_var_constraints(vars: &HashMap<String, Variable>, errors: &mut Vec<ZigError>) {
357    let mut prompt_bound_count = 0;
358
359    for (name, var) in vars {
360        // default and default_file are mutually exclusive
361        if var.default.is_some() && var.default_file.is_some() {
362            errors.push(ZigError::Validation(format!(
363                "variable '{name}' sets both 'default' and 'default_file' \
364                 (they are mutually exclusive)"
365            )));
366        }
367
368        // Validate `from` field
369        if let Some(ref from) = var.from {
370            if from != "prompt" {
371                errors.push(ZigError::Validation(format!(
372                    "variable '{name}' has unsupported from value '{from}' (only 'prompt' is supported)"
373                )));
374            } else {
375                prompt_bound_count += 1;
376            }
377        }
378
379        // String-only constraints on non-string types
380        if var.var_type != VarType::String {
381            if var.min_length.is_some() {
382                errors.push(ZigError::Validation(format!(
383                    "variable '{name}' has min_length but type is '{}' (only valid for 'string')",
384                    var.var_type
385                )));
386            }
387            if var.max_length.is_some() {
388                errors.push(ZigError::Validation(format!(
389                    "variable '{name}' has max_length but type is '{}' (only valid for 'string')",
390                    var.var_type
391                )));
392            }
393            if var.pattern.is_some() {
394                errors.push(ZigError::Validation(format!(
395                    "variable '{name}' has pattern but type is '{}' (only valid for 'string')",
396                    var.var_type
397                )));
398            }
399        }
400
401        // Number-only constraints on non-number types
402        if var.var_type != VarType::Number {
403            if var.min.is_some() {
404                errors.push(ZigError::Validation(format!(
405                    "variable '{name}' has min but type is '{}' (only valid for 'number')",
406                    var.var_type
407                )));
408            }
409            if var.max.is_some() {
410                errors.push(ZigError::Validation(format!(
411                    "variable '{name}' has max but type is '{}' (only valid for 'number')",
412                    var.var_type
413                )));
414            }
415        }
416
417        // Range consistency
418        if let (Some(min_len), Some(max_len)) = (var.min_length, var.max_length) {
419            if min_len > max_len {
420                errors.push(ZigError::Validation(format!(
421                    "variable '{name}' has min_length ({min_len}) greater than max_length ({max_len})"
422                )));
423            }
424        }
425        if let (Some(min), Some(max)) = (var.min, var.max) {
426            if min > max {
427                errors.push(ZigError::Validation(format!(
428                    "variable '{name}' has min ({min}) greater than max ({max})"
429                )));
430            }
431        }
432
433        // Validate pattern compiles
434        if let Some(ref pattern) = var.pattern {
435            if Regex::new(pattern).is_err() {
436                errors.push(ZigError::Validation(format!(
437                    "variable '{name}' has invalid regex pattern: '{pattern}'"
438                )));
439            }
440        }
441
442        // Validate allowed_values type compatibility
443        if let Some(ref allowed) = var.allowed_values {
444            for val in allowed {
445                let ok = match var.var_type {
446                    VarType::String => val.is_str(),
447                    VarType::Number => val.is_integer() || val.is_float(),
448                    VarType::Bool => matches!(val, toml::Value::Boolean(_)),
449                    VarType::Json => true,
450                };
451                if !ok {
452                    errors.push(ZigError::Validation(format!(
453                        "variable '{name}' has allowed_values entry {val} incompatible with type '{}'",
454                        var.var_type
455                    )));
456                }
457            }
458        }
459
460        // Validate default satisfies constraints
461        if let Some(ref default) = var.default {
462            let default_str = toml_value_to_string(default);
463            let constraint_errors = check_value_constraints(name, &default_str, var);
464            for msg in constraint_errors {
465                errors.push(ZigError::Validation(format!(
466                    "variable '{name}' default value violates constraint: {msg}"
467                )));
468            }
469        }
470    }
471
472    if prompt_bound_count > 1 {
473        errors.push(ZigError::Validation(
474            "multiple variables have from = \"prompt\" (only one is allowed)".into(),
475        ));
476    }
477}
478
479/// Convert a TOML value to its string representation for constraint checking.
480fn toml_value_to_string(val: &toml::Value) -> String {
481    match val {
482        toml::Value::String(s) => s.clone(),
483        toml::Value::Integer(n) => n.to_string(),
484        toml::Value::Float(f) => f.to_string(),
485        toml::Value::Boolean(b) => b.to_string(),
486        other => other.to_string(),
487    }
488}
489
490/// Check a single value against a variable's constraints.
491/// Returns a list of human-readable violation messages (empty if valid).
492fn check_value_constraints(name: &str, value: &str, var: &Variable) -> Vec<String> {
493    let mut violations = Vec::new();
494
495    if var.required && value.is_empty() {
496        violations.push(format!(
497            "variable '{name}' is required but was not provided"
498        ));
499    }
500
501    // Skip further checks for empty non-required values
502    if value.is_empty() && !var.required {
503        return violations;
504    }
505
506    if let Some(min_len) = var.min_length {
507        let len = value.len() as u32;
508        if len < min_len {
509            violations.push(format!(
510                "variable '{name}' must be at least {min_len} characters (got {len})"
511            ));
512        }
513    }
514
515    if let Some(max_len) = var.max_length {
516        let len = value.len() as u32;
517        if len > max_len {
518            violations.push(format!(
519                "variable '{name}' must be at most {max_len} characters (got {len})"
520            ));
521        }
522    }
523
524    if let Some(min) = var.min {
525        if let Ok(num) = value.parse::<f64>() {
526            if num < min {
527                violations.push(format!(
528                    "variable '{name}' must be at least {min} (got {num})"
529                ));
530            }
531        }
532    }
533
534    if let Some(max) = var.max {
535        if let Ok(num) = value.parse::<f64>() {
536            if num > max {
537                violations.push(format!(
538                    "variable '{name}' must be at most {max} (got {num})"
539                ));
540            }
541        }
542    }
543
544    if let Some(ref pattern) = var.pattern {
545        if let Ok(re) = Regex::new(pattern) {
546            if !re.is_match(value) {
547                violations.push(format!("variable '{name}' must match pattern '{pattern}'"));
548            }
549        }
550    }
551
552    if let Some(ref allowed) = var.allowed_values {
553        let allowed_strs: Vec<String> = allowed.iter().map(toml_value_to_string).collect();
554        if !allowed_strs.iter().any(|a| a == value) {
555            violations.push(format!(
556                "variable '{name}' must be one of: {}",
557                allowed_strs.join(", ")
558            ));
559        }
560    }
561
562    violations
563}
564
565/// Validate variable values against their declared constraints at runtime.
566///
567/// Called after `init_vars` and prompt binding, before step execution begins.
568pub fn validate_var_values(
569    vars: &HashMap<String, String>,
570    declarations: &HashMap<String, Variable>,
571) -> Result<(), Vec<ZigError>> {
572    let mut errors = Vec::new();
573
574    for (name, decl) in declarations {
575        let value = vars.get(name).map(|s| s.as_str()).unwrap_or("");
576        let violations = check_value_constraints(name, value, decl);
577        for msg in violations {
578            errors.push(ZigError::Validation(msg));
579        }
580    }
581
582    if errors.is_empty() {
583        Ok(())
584    } else {
585        Err(errors)
586    }
587}
588
589/// Extract `${var_name}` references from a prompt template.
590fn extract_var_refs(template: &str) -> Vec<String> {
591    let mut refs = Vec::new();
592    let mut rest = template;
593    while let Some(start) = rest.find("${") {
594        let after_start = &rest[start + 2..];
595        if let Some(end) = after_start.find('}') {
596            let var_name = &after_start[..end];
597            // Support dotted paths like ${quality.score} — take the root variable
598            let root = var_name.split('.').next().unwrap_or(var_name);
599            refs.push(root.to_string());
600            rest = &after_start[end + 1..];
601        } else {
602            break;
603        }
604    }
605    refs
606}
607
608/// Extract variable names from a condition expression.
609///
610/// Simple heuristic: split on whitespace and operators, keep identifiers
611/// that are not numeric literals, string literals, or comparison operators.
612fn extract_condition_vars(condition: &str) -> Vec<String> {
613    let operators = ["==", "!=", "<", ">", "<=", ">=", "&&", "||", "!"];
614    let keywords = ["true", "false"];
615
616    condition
617        .split(|c: char| c.is_whitespace() || c == '(' || c == ')')
618        .filter(|token| {
619            !token.is_empty()
620                && !operators.contains(token)
621                && !keywords.contains(token)
622                && !token.starts_with('"')
623                && !token.starts_with('\'')
624                && token.parse::<f64>().is_err()
625        })
626        .map(|token| {
627            // Handle dotted paths: score.value → score
628            token.split('.').next().unwrap_or(token).to_string()
629        })
630        .collect()
631}
632
633/// Detect cycles in the step dependency graph using DFS.
634/// Returns the cycle path if found, or None.
635fn detect_cycle(steps: &[crate::workflow::model::Step]) -> Option<Vec<String>> {
636    let adjacency: HashMap<&str, Vec<&str>> = steps
637        .iter()
638        .map(|s| {
639            (
640                s.name.as_str(),
641                s.depends_on.iter().map(|d| d.as_str()).collect(),
642            )
643        })
644        .collect();
645
646    let mut visited = HashSet::new();
647    let mut in_stack = HashSet::new();
648    let mut path = Vec::new();
649
650    for step in steps {
651        if !visited.contains(step.name.as_str())
652            && dfs_cycle(
653                step.name.as_str(),
654                &adjacency,
655                &mut visited,
656                &mut in_stack,
657                &mut path,
658            )
659        {
660            return Some(path);
661        }
662    }
663    None
664}
665
666fn dfs_cycle<'a>(
667    node: &'a str,
668    adjacency: &HashMap<&'a str, Vec<&'a str>>,
669    visited: &mut HashSet<&'a str>,
670    in_stack: &mut HashSet<&'a str>,
671    path: &mut Vec<String>,
672) -> bool {
673    visited.insert(node);
674    in_stack.insert(node);
675    path.push(node.to_string());
676
677    if let Some(neighbors) = adjacency.get(node) {
678        for &neighbor in neighbors {
679            if !visited.contains(neighbor) {
680                if dfs_cycle(neighbor, adjacency, visited, in_stack, path) {
681                    return true;
682                }
683            } else if in_stack.contains(neighbor) {
684                path.push(neighbor.to_string());
685                return true;
686            }
687        }
688    }
689
690    in_stack.remove(node);
691    path.pop();
692    false
693}
694
695#[cfg(test)]
696#[path = "validate_tests.rs"]
697mod tests;