Skip to main content

zig_core/workflow/
validate.rs

1use std::collections::{HashMap, HashSet};
2
3use regex::Regex;
4
5use crate::error::ZigError;
6use crate::workflow::model::{FailurePolicy, StepCommand, VarType, Variable, Workflow};
7
8/// Validate a parsed workflow for structural correctness.
9///
10/// Checks:
11/// - At least one step exists
12/// - Step names are unique
13/// - `depends_on` references exist
14/// - No dependency cycles
15/// - `next` references exist
16/// - Variable references in prompts refer to declared variables
17/// - `saves` variable names are declared
18/// - Condition variable references are declared
19pub fn validate(workflow: &Workflow) -> Result<(), Vec<ZigError>> {
20    let mut errors = Vec::new();
21
22    if workflow.steps.is_empty() {
23        errors.push(ZigError::Validation(
24            "workflow must have at least one step".into(),
25        ));
26        return Err(errors);
27    }
28
29    let step_names: HashSet<&str> = workflow.steps.iter().map(|s| s.name.as_str()).collect();
30    let var_names: HashSet<&str> = workflow.vars.keys().map(|k| k.as_str()).collect();
31    let role_names: HashSet<&str> = workflow.roles.keys().map(|k| k.as_str()).collect();
32
33    // Check unique step names
34    let mut seen_names = HashSet::new();
35    for step in &workflow.steps {
36        if !seen_names.insert(&step.name) {
37            errors.push(ZigError::Validation(format!(
38                "duplicate step name: '{}'",
39                step.name
40            )));
41        }
42    }
43
44    for step in &workflow.steps {
45        // Check depends_on references
46        for dep in &step.depends_on {
47            if !step_names.contains(dep.as_str()) {
48                errors.push(ZigError::Validation(format!(
49                    "step '{}' depends on unknown step '{dep}'",
50                    step.name
51                )));
52            }
53            if dep == &step.name {
54                errors.push(ZigError::Validation(format!(
55                    "step '{}' depends on itself",
56                    step.name
57                )));
58            }
59        }
60
61        // Check next references
62        if let Some(next) = &step.next {
63            if !step_names.contains(next.as_str()) {
64                errors.push(ZigError::Validation(format!(
65                    "step '{}' references unknown next step '{next}'",
66                    step.name
67                )));
68            }
69        }
70
71        // Check variable references in prompt
72        for var_ref in extract_var_refs(&step.prompt) {
73            if !var_names.contains(var_ref.as_str()) {
74                errors.push(ZigError::Validation(format!(
75                    "step '{}' prompt references unknown variable '${{{var_ref}}}'",
76                    step.name
77                )));
78            }
79        }
80
81        // Check variable references in system_prompt
82        if let Some(system_prompt) = &step.system_prompt {
83            for var_ref in extract_var_refs(system_prompt) {
84                if !var_names.contains(var_ref.as_str()) {
85                    errors.push(ZigError::Validation(format!(
86                        "step '{}' system_prompt references unknown variable '${{{var_ref}}}'",
87                        step.name
88                    )));
89                }
90            }
91        }
92
93        // Check role and system_prompt are mutually exclusive
94        if step.role.is_some() && step.system_prompt.is_some() {
95            errors.push(ZigError::Validation(format!(
96                "step '{}' sets both 'role' and 'system_prompt' (they are mutually exclusive)",
97                step.name
98            )));
99        }
100
101        // Check role references
102        if let Some(role_ref) = &step.role {
103            let var_refs = extract_var_refs(role_ref);
104            if var_refs.is_empty() {
105                // Static role reference — must exist in [roles]
106                if !role_names.contains(role_ref.as_str()) {
107                    errors.push(ZigError::Validation(format!(
108                        "step '{}' role references unknown role '{role_ref}'",
109                        step.name
110                    )));
111                }
112            } else {
113                // Dynamic role reference — validate variable refs
114                for var_ref in var_refs {
115                    if !var_names.contains(var_ref.as_str()) {
116                        errors.push(ZigError::Validation(format!(
117                            "step '{}' role references unknown variable '${{{var_ref}}}'",
118                            step.name
119                        )));
120                    }
121                }
122            }
123        }
124
125        // Check saves reference declared variables
126        for var_name in step.saves.keys() {
127            if !var_names.contains(var_name.as_str()) {
128                errors.push(ZigError::Validation(format!(
129                    "step '{}' saves to unknown variable '{var_name}'",
130                    step.name
131                )));
132            }
133        }
134
135        // Check condition variable references
136        if let Some(cond) = &step.condition {
137            for var_ref in extract_condition_vars(cond) {
138                if !var_names.contains(var_ref.as_str()) && !step_names.contains(var_ref.as_str()) {
139                    errors.push(ZigError::Validation(format!(
140                        "step '{}' condition references unknown variable '{var_ref}'",
141                        step.name
142                    )));
143                }
144            }
145        }
146
147        // Check retry_model requires on_failure = "retry"
148        if step.retry_model.is_some() && step.on_failure.as_ref() != Some(&FailurePolicy::Retry) {
149            errors.push(ZigError::Validation(format!(
150                "step '{}' sets retry_model but on_failure is not 'retry'",
151                step.name
152            )));
153        }
154
155        // Check mcp_config requires claude provider (or no provider specified).
156        // The effective provider considers both step-level and workflow-level defaults.
157        if step.mcp_config.is_some() {
158            let effective_provider = step
159                .provider
160                .as_ref()
161                .or(workflow.workflow.provider.as_ref());
162            if let Some(provider) = effective_provider {
163                if provider != "claude" {
164                    errors.push(ZigError::Validation(format!(
165                        "step '{}' sets mcp_config but provider is '{}' \
166                         (mcp_config is only supported with the claude provider)",
167                        step.name, provider
168                    )));
169                }
170            }
171        }
172
173        // Check output format is a valid value
174        if let Some(ref output) = step.output {
175            let valid_formats = ["text", "json", "json-pretty", "stream-json", "native-json"];
176            if !valid_formats.contains(&output.as_str()) {
177                errors.push(ZigError::Validation(format!(
178                    "step '{}' has invalid output format '{}' \
179                     (must be one of: text, json, json-pretty, stream-json, native-json)",
180                    step.name, output
181                )));
182            }
183        }
184
185        // Check review-only fields require command = "review"
186        let is_review = step.command.as_ref() == Some(&StepCommand::Review);
187        if !is_review {
188            for (field, set) in [("uncommitted", step.uncommitted)] {
189                if set {
190                    errors.push(ZigError::Validation(format!(
191                        "step '{}' sets '{}' but command is not 'review'",
192                        step.name, field
193                    )));
194                }
195            }
196            for (field, set) in [
197                ("base", step.base.is_some()),
198                ("commit", step.commit.is_some()),
199                ("title", step.title.is_some()),
200            ] {
201                if set {
202                    errors.push(ZigError::Validation(format!(
203                        "step '{}' sets '{}' but command is not 'review'",
204                        step.name, field
205                    )));
206                }
207            }
208        }
209
210        // Check plan-only fields require command = "plan"
211        let is_plan = step.command.as_ref() == Some(&StepCommand::Plan);
212        if !is_plan {
213            for (field, set) in [
214                ("plan_output", step.plan_output.is_some()),
215                ("instructions", step.instructions.is_some()),
216            ] {
217                if set {
218                    errors.push(ZigError::Validation(format!(
219                        "step '{}' sets '{}' but command is not 'plan'",
220                        step.name, field
221                    )));
222                }
223            }
224        }
225
226        // Check pipe/collect/summary require depends_on
227        if let Some(ref cmd) = step.command {
228            match cmd {
229                StepCommand::Pipe | StepCommand::Collect | StepCommand::Summary => {
230                    if step.depends_on.is_empty() {
231                        errors.push(ZigError::Validation(format!(
232                            "step '{}' uses command '{}' but has no depends_on \
233                             (pipe/collect/summary operate on prior session outputs)",
234                            step.name,
235                            match cmd {
236                                StepCommand::Pipe => "pipe",
237                                StepCommand::Collect => "collect",
238                                StepCommand::Summary => "summary",
239                                _ => unreachable!(),
240                            }
241                        )));
242                    }
243                }
244                _ => {}
245            }
246        }
247    }
248
249    // Check role definitions
250    for (role_name, role) in &workflow.roles {
251        // system_prompt and system_prompt_file are mutually exclusive
252        if role.system_prompt.is_some() && role.system_prompt_file.is_some() {
253            errors.push(ZigError::Validation(format!(
254                "role '{role_name}' sets both 'system_prompt' and 'system_prompt_file' \
255                 (they are mutually exclusive)"
256            )));
257        }
258
259        // Validate ${var} references in role system_prompt
260        if let Some(ref sp) = role.system_prompt {
261            for var_ref in extract_var_refs(sp) {
262                if !var_names.contains(var_ref.as_str()) {
263                    errors.push(ZigError::Validation(format!(
264                        "role '{role_name}' system_prompt references unknown variable \
265                         '${{{var_ref}}}'"
266                    )));
267                }
268            }
269        }
270    }
271
272    // Check race_group: steps in the same group must not depend on each other
273    let mut race_groups: HashMap<&str, Vec<&str>> = HashMap::new();
274    for step in &workflow.steps {
275        if let Some(ref group) = step.race_group {
276            race_groups
277                .entry(group.as_str())
278                .or_default()
279                .push(step.name.as_str());
280        }
281    }
282    for (group, members) in &race_groups {
283        let member_set: HashSet<&str> = members.iter().copied().collect();
284        for step in &workflow.steps {
285            if step.race_group.as_deref() == Some(*group) {
286                for dep in &step.depends_on {
287                    if member_set.contains(dep.as_str()) {
288                        errors.push(ZigError::Validation(format!(
289                            "step '{}' depends on '{}' but both are in race_group '{}' \
290                             (race members must not depend on each other)",
291                            step.name, dep, group
292                        )));
293                    }
294                }
295            }
296        }
297    }
298
299    // Check variable constraints
300    validate_var_constraints(&workflow.vars, &mut errors);
301
302    // Check for dependency cycles
303    if let Some(cycle) = detect_cycle(&workflow.steps) {
304        errors.push(ZigError::Validation(format!(
305            "dependency cycle detected: {}",
306            cycle.join(" -> ")
307        )));
308    }
309
310    if errors.is_empty() {
311        Ok(())
312    } else {
313        Err(errors)
314    }
315}
316
317/// Validate variable constraint declarations for structural correctness.
318fn validate_var_constraints(vars: &HashMap<String, Variable>, errors: &mut Vec<ZigError>) {
319    let mut prompt_bound_count = 0;
320
321    for (name, var) in vars {
322        // default and default_file are mutually exclusive
323        if var.default.is_some() && var.default_file.is_some() {
324            errors.push(ZigError::Validation(format!(
325                "variable '{name}' sets both 'default' and 'default_file' \
326                 (they are mutually exclusive)"
327            )));
328        }
329
330        // Validate `from` field
331        if let Some(ref from) = var.from {
332            if from != "prompt" {
333                errors.push(ZigError::Validation(format!(
334                    "variable '{name}' has unsupported from value '{from}' (only 'prompt' is supported)"
335                )));
336            } else {
337                prompt_bound_count += 1;
338            }
339        }
340
341        // String-only constraints on non-string types
342        if var.var_type != VarType::String {
343            if var.min_length.is_some() {
344                errors.push(ZigError::Validation(format!(
345                    "variable '{name}' has min_length but type is '{}' (only valid for 'string')",
346                    var.var_type
347                )));
348            }
349            if var.max_length.is_some() {
350                errors.push(ZigError::Validation(format!(
351                    "variable '{name}' has max_length but type is '{}' (only valid for 'string')",
352                    var.var_type
353                )));
354            }
355            if var.pattern.is_some() {
356                errors.push(ZigError::Validation(format!(
357                    "variable '{name}' has pattern but type is '{}' (only valid for 'string')",
358                    var.var_type
359                )));
360            }
361        }
362
363        // Number-only constraints on non-number types
364        if var.var_type != VarType::Number {
365            if var.min.is_some() {
366                errors.push(ZigError::Validation(format!(
367                    "variable '{name}' has min but type is '{}' (only valid for 'number')",
368                    var.var_type
369                )));
370            }
371            if var.max.is_some() {
372                errors.push(ZigError::Validation(format!(
373                    "variable '{name}' has max but type is '{}' (only valid for 'number')",
374                    var.var_type
375                )));
376            }
377        }
378
379        // Range consistency
380        if let (Some(min_len), Some(max_len)) = (var.min_length, var.max_length) {
381            if min_len > max_len {
382                errors.push(ZigError::Validation(format!(
383                    "variable '{name}' has min_length ({min_len}) greater than max_length ({max_len})"
384                )));
385            }
386        }
387        if let (Some(min), Some(max)) = (var.min, var.max) {
388            if min > max {
389                errors.push(ZigError::Validation(format!(
390                    "variable '{name}' has min ({min}) greater than max ({max})"
391                )));
392            }
393        }
394
395        // Validate pattern compiles
396        if let Some(ref pattern) = var.pattern {
397            if Regex::new(pattern).is_err() {
398                errors.push(ZigError::Validation(format!(
399                    "variable '{name}' has invalid regex pattern: '{pattern}'"
400                )));
401            }
402        }
403
404        // Validate allowed_values type compatibility
405        if let Some(ref allowed) = var.allowed_values {
406            for val in allowed {
407                let ok = match var.var_type {
408                    VarType::String => val.is_str(),
409                    VarType::Number => val.is_integer() || val.is_float(),
410                    VarType::Bool => matches!(val, toml::Value::Boolean(_)),
411                    VarType::Json => true,
412                };
413                if !ok {
414                    errors.push(ZigError::Validation(format!(
415                        "variable '{name}' has allowed_values entry {val} incompatible with type '{}'",
416                        var.var_type
417                    )));
418                }
419            }
420        }
421
422        // Validate default satisfies constraints
423        if let Some(ref default) = var.default {
424            let default_str = toml_value_to_string(default);
425            let constraint_errors = check_value_constraints(name, &default_str, var);
426            for msg in constraint_errors {
427                errors.push(ZigError::Validation(format!(
428                    "variable '{name}' default value violates constraint: {msg}"
429                )));
430            }
431        }
432    }
433
434    if prompt_bound_count > 1 {
435        errors.push(ZigError::Validation(
436            "multiple variables have from = \"prompt\" (only one is allowed)".into(),
437        ));
438    }
439}
440
441/// Convert a TOML value to its string representation for constraint checking.
442fn toml_value_to_string(val: &toml::Value) -> String {
443    match val {
444        toml::Value::String(s) => s.clone(),
445        toml::Value::Integer(n) => n.to_string(),
446        toml::Value::Float(f) => f.to_string(),
447        toml::Value::Boolean(b) => b.to_string(),
448        other => other.to_string(),
449    }
450}
451
452/// Check a single value against a variable's constraints.
453/// Returns a list of human-readable violation messages (empty if valid).
454fn check_value_constraints(name: &str, value: &str, var: &Variable) -> Vec<String> {
455    let mut violations = Vec::new();
456
457    if var.required && value.is_empty() {
458        violations.push(format!(
459            "variable '{name}' is required but was not provided"
460        ));
461    }
462
463    // Skip further checks for empty non-required values
464    if value.is_empty() && !var.required {
465        return violations;
466    }
467
468    if let Some(min_len) = var.min_length {
469        let len = value.len() as u32;
470        if len < min_len {
471            violations.push(format!(
472                "variable '{name}' must be at least {min_len} characters (got {len})"
473            ));
474        }
475    }
476
477    if let Some(max_len) = var.max_length {
478        let len = value.len() as u32;
479        if len > max_len {
480            violations.push(format!(
481                "variable '{name}' must be at most {max_len} characters (got {len})"
482            ));
483        }
484    }
485
486    if let Some(min) = var.min {
487        if let Ok(num) = value.parse::<f64>() {
488            if num < min {
489                violations.push(format!(
490                    "variable '{name}' must be at least {min} (got {num})"
491                ));
492            }
493        }
494    }
495
496    if let Some(max) = var.max {
497        if let Ok(num) = value.parse::<f64>() {
498            if num > max {
499                violations.push(format!(
500                    "variable '{name}' must be at most {max} (got {num})"
501                ));
502            }
503        }
504    }
505
506    if let Some(ref pattern) = var.pattern {
507        if let Ok(re) = Regex::new(pattern) {
508            if !re.is_match(value) {
509                violations.push(format!("variable '{name}' must match pattern '{pattern}'"));
510            }
511        }
512    }
513
514    if let Some(ref allowed) = var.allowed_values {
515        let allowed_strs: Vec<String> = allowed.iter().map(toml_value_to_string).collect();
516        if !allowed_strs.iter().any(|a| a == value) {
517            violations.push(format!(
518                "variable '{name}' must be one of: {}",
519                allowed_strs.join(", ")
520            ));
521        }
522    }
523
524    violations
525}
526
527/// Validate variable values against their declared constraints at runtime.
528///
529/// Called after `init_vars` and prompt binding, before step execution begins.
530pub fn validate_var_values(
531    vars: &HashMap<String, String>,
532    declarations: &HashMap<String, Variable>,
533) -> Result<(), Vec<ZigError>> {
534    let mut errors = Vec::new();
535
536    for (name, decl) in declarations {
537        let value = vars.get(name).map(|s| s.as_str()).unwrap_or("");
538        let violations = check_value_constraints(name, value, decl);
539        for msg in violations {
540            errors.push(ZigError::Validation(msg));
541        }
542    }
543
544    if errors.is_empty() {
545        Ok(())
546    } else {
547        Err(errors)
548    }
549}
550
551/// Extract `${var_name}` references from a prompt template.
552fn extract_var_refs(template: &str) -> Vec<String> {
553    let mut refs = Vec::new();
554    let mut rest = template;
555    while let Some(start) = rest.find("${") {
556        let after_start = &rest[start + 2..];
557        if let Some(end) = after_start.find('}') {
558            let var_name = &after_start[..end];
559            // Support dotted paths like ${quality.score} — take the root variable
560            let root = var_name.split('.').next().unwrap_or(var_name);
561            refs.push(root.to_string());
562            rest = &after_start[end + 1..];
563        } else {
564            break;
565        }
566    }
567    refs
568}
569
570/// Extract variable names from a condition expression.
571///
572/// Simple heuristic: split on whitespace and operators, keep identifiers
573/// that are not numeric literals, string literals, or comparison operators.
574fn extract_condition_vars(condition: &str) -> Vec<String> {
575    let operators = ["==", "!=", "<", ">", "<=", ">=", "&&", "||", "!"];
576    let keywords = ["true", "false"];
577
578    condition
579        .split(|c: char| c.is_whitespace() || c == '(' || c == ')')
580        .filter(|token| {
581            !token.is_empty()
582                && !operators.contains(token)
583                && !keywords.contains(token)
584                && !token.starts_with('"')
585                && !token.starts_with('\'')
586                && token.parse::<f64>().is_err()
587        })
588        .map(|token| {
589            // Handle dotted paths: score.value → score
590            token.split('.').next().unwrap_or(token).to_string()
591        })
592        .collect()
593}
594
595/// Detect cycles in the step dependency graph using DFS.
596/// Returns the cycle path if found, or None.
597fn detect_cycle(steps: &[crate::workflow::model::Step]) -> Option<Vec<String>> {
598    let adjacency: HashMap<&str, Vec<&str>> = steps
599        .iter()
600        .map(|s| {
601            (
602                s.name.as_str(),
603                s.depends_on.iter().map(|d| d.as_str()).collect(),
604            )
605        })
606        .collect();
607
608    let mut visited = HashSet::new();
609    let mut in_stack = HashSet::new();
610    let mut path = Vec::new();
611
612    for step in steps {
613        if !visited.contains(step.name.as_str())
614            && dfs_cycle(
615                step.name.as_str(),
616                &adjacency,
617                &mut visited,
618                &mut in_stack,
619                &mut path,
620            )
621        {
622            return Some(path);
623        }
624    }
625    None
626}
627
628fn dfs_cycle<'a>(
629    node: &'a str,
630    adjacency: &HashMap<&'a str, Vec<&'a str>>,
631    visited: &mut HashSet<&'a str>,
632    in_stack: &mut HashSet<&'a str>,
633    path: &mut Vec<String>,
634) -> bool {
635    visited.insert(node);
636    in_stack.insert(node);
637    path.push(node.to_string());
638
639    if let Some(neighbors) = adjacency.get(node) {
640        for &neighbor in neighbors {
641            if !visited.contains(neighbor) {
642                if dfs_cycle(neighbor, adjacency, visited, in_stack, path) {
643                    return true;
644                }
645            } else if in_stack.contains(neighbor) {
646                path.push(neighbor.to_string());
647                return true;
648            }
649        }
650    }
651
652    in_stack.remove(node);
653    path.pop();
654    false
655}
656
657#[cfg(test)]
658#[path = "validate_tests.rs"]
659mod tests;