Skip to main content

systemprompt_models/ai/
template_validation.rs

1//! Static validation of cross-tool template references in an execution plan.
2//!
3//! [`TemplateValidator`] walks each [`PlannedToolCall`]'s arguments for
4//! `$N.output.field` templates and checks them against the referenced tools'
5//! output schemas, surfacing each problem as a [`PlanValidationError`] tagged
6//! with a [`ValidationErrorKind`] (bad syntax, self/forward reference,
7//! out-of-bounds index, or a field absent from the target schema).
8
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use super::execution_plan::{PlannedToolCall, TemplateRef};
13use super::tools::McpTool;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PlanValidationError {
17    pub tool_index: usize,
18    pub argument: String,
19    pub template: String,
20    pub error: ValidationErrorKind,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "type", rename_all = "snake_case")]
25pub enum ValidationErrorKind {
26    InvalidTemplateSyntax,
27    IndexOutOfBounds {
28        referenced_index: usize,
29        max_valid_index: usize,
30    },
31    SelfReference,
32    ForwardReference {
33        referenced_index: usize,
34    },
35    FieldNotFound {
36        tool_name: String,
37        field: String,
38        available_fields: Vec<String>,
39    },
40    NoOutputSchema {
41        tool_name: String,
42    },
43}
44
45impl std::fmt::Display for PlanValidationError {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        match &self.error {
48            ValidationErrorKind::InvalidTemplateSyntax => {
49                write!(
50                    f,
51                    "Tool {}: Invalid template syntax '{}' for argument '{}'",
52                    self.tool_index, self.template, self.argument
53                )
54            },
55            ValidationErrorKind::IndexOutOfBounds {
56                referenced_index,
57                max_valid_index,
58            } => {
59                write!(
60                    f,
61                    "Tool {}: Template '{}' references tool {} but only tools 0-{} are available",
62                    self.tool_index, self.template, referenced_index, max_valid_index
63                )
64            },
65            ValidationErrorKind::SelfReference => {
66                write!(
67                    f,
68                    "Tool {}: Template '{}' cannot reference itself",
69                    self.tool_index, self.template
70                )
71            },
72            ValidationErrorKind::ForwardReference { referenced_index } => {
73                write!(
74                    f,
75                    "Tool {}: Template '{}' references tool {} which hasn't executed yet",
76                    self.tool_index, self.template, referenced_index
77                )
78            },
79            ValidationErrorKind::FieldNotFound {
80                tool_name,
81                field,
82                available_fields,
83            } => {
84                write!(
85                    f,
86                    "Tool {}: Template '{}' references field '{}' but tool '{}' outputs: [{}]",
87                    self.tool_index,
88                    self.template,
89                    field,
90                    tool_name,
91                    available_fields.join(", ")
92                )
93            },
94            ValidationErrorKind::NoOutputSchema { tool_name } => {
95                write!(
96                    f,
97                    "Tool {}: Template '{}' references '{}' which has no output schema",
98                    self.tool_index, self.template, tool_name
99                )
100            },
101        }
102    }
103}
104
105impl std::error::Error for PlanValidationError {}
106
107#[derive(Debug, Clone, Copy)]
108pub struct TemplateValidator;
109
110impl TemplateValidator {
111    pub fn get_tool_output_schemas(
112        calls: &[PlannedToolCall],
113        tools: &[McpTool],
114    ) -> Vec<(String, Option<Value>)> {
115        calls
116            .iter()
117            .map(|call| {
118                let output_schema = tools
119                    .iter()
120                    .find(|t| t.name == call.tool_name)
121                    .and_then(|t| t.output_schema.clone());
122                (call.tool_name.clone(), output_schema)
123            })
124            .collect()
125    }
126
127    pub fn find_templates_in_value(value: &Value) -> Vec<String> {
128        let mut templates = Vec::new();
129        Self::collect_templates(value, &mut templates);
130        templates
131    }
132
133    fn collect_templates(value: &Value, templates: &mut Vec<String>) {
134        match value {
135            Value::String(s) if s.starts_with('$') && s.contains(".output.") => {
136                templates.push(s.clone());
137            },
138            Value::Array(arr) => {
139                for v in arr {
140                    Self::collect_templates(v, templates);
141                }
142            },
143            Value::Object(obj) => {
144                for v in obj.values() {
145                    Self::collect_templates(v, templates);
146                }
147            },
148            Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {},
149        }
150    }
151
152    pub fn validate_plan(
153        calls: &[PlannedToolCall],
154        tool_output_schemas: &[(String, Option<Value>)],
155    ) -> Result<(), Vec<PlanValidationError>> {
156        let mut errors = Vec::new();
157
158        for (tool_index, call) in calls.iter().enumerate() {
159            for template in Self::find_templates_in_value(&call.arguments) {
160                if let Some(err) =
161                    Self::validate_template(tool_index, call, &template, tool_output_schemas)
162                {
163                    errors.push(err);
164                }
165            }
166        }
167
168        if errors.is_empty() {
169            Ok(())
170        } else {
171            Err(errors)
172        }
173    }
174
175    fn validate_template(
176        tool_index: usize,
177        call: &PlannedToolCall,
178        template: &str,
179        tool_output_schemas: &[(String, Option<Value>)],
180    ) -> Option<PlanValidationError> {
181        let make_error = |error: ValidationErrorKind| PlanValidationError {
182            tool_index,
183            argument: Self::find_argument_for_template(&call.arguments, template),
184            template: template.to_owned(),
185            error,
186        };
187
188        let Some(template_ref) = TemplateRef::parse(template) else {
189            return Some(make_error(ValidationErrorKind::InvalidTemplateSyntax));
190        };
191
192        if template_ref.tool_index == tool_index {
193            return Some(make_error(ValidationErrorKind::SelfReference));
194        }
195        if template_ref.tool_index > tool_index {
196            return Some(make_error(ValidationErrorKind::ForwardReference {
197                referenced_index: template_ref.tool_index,
198            }));
199        }
200        if template_ref.tool_index >= tool_output_schemas.len() {
201            return Some(make_error(ValidationErrorKind::IndexOutOfBounds {
202                referenced_index: template_ref.tool_index,
203                max_valid_index: tool_output_schemas.len().saturating_sub(1),
204            }));
205        }
206
207        let (ref_tool_name, ref_output_schema) = &tool_output_schemas[template_ref.tool_index];
208
209        ref_output_schema.as_ref().map_or_else(
210            || {
211                Some(make_error(ValidationErrorKind::NoOutputSchema {
212                    tool_name: ref_tool_name.clone(),
213                }))
214            },
215            |schema| {
216                Self::validate_field_access(&template_ref, schema, ref_tool_name).map(make_error)
217            },
218        )
219    }
220
221    fn validate_field_access(
222        template_ref: &TemplateRef,
223        schema: &Value,
224        tool_name: &str,
225    ) -> Option<ValidationErrorKind> {
226        let first_field = template_ref.field_path.first()?;
227        let available_fields = Self::get_schema_fields(schema);
228
229        if available_fields.contains(first_field) {
230            None
231        } else {
232            Some(ValidationErrorKind::FieldNotFound {
233                tool_name: tool_name.to_owned(),
234                field: first_field.clone(),
235                available_fields,
236            })
237        }
238    }
239
240    fn find_argument_for_template(value: &Value, template: &str) -> String {
241        if let Value::Object(obj) = value {
242            for (key, val) in obj {
243                if let Value::String(s) = val {
244                    if s == template {
245                        return key.clone();
246                    }
247                }
248                let nested = Self::find_argument_for_template(val, template);
249                if !nested.is_empty() {
250                    return format!("{key}.{nested}");
251                }
252            }
253        }
254        String::new()
255    }
256
257    fn get_schema_fields(schema: &Value) -> Vec<String> {
258        schema
259            .get("properties")
260            .and_then(|p| p.as_object())
261            .map_or_else(Vec::new, |obj| obj.keys().cloned().collect())
262    }
263}