1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use super::execution_plan::{PlannedToolCall, TemplateRef};
5use super::tools::McpTool;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PlanValidationError {
9 pub tool_index: usize,
10 pub argument: String,
11 pub template: String,
12 pub error: ValidationErrorKind,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum ValidationErrorKind {
18 InvalidTemplateSyntax,
19 IndexOutOfBounds {
20 referenced_index: usize,
21 max_valid_index: usize,
22 },
23 SelfReference,
24 ForwardReference {
25 referenced_index: usize,
26 },
27 FieldNotFound {
28 tool_name: String,
29 field: String,
30 available_fields: Vec<String>,
31 },
32 NoOutputSchema {
33 tool_name: String,
34 },
35}
36
37impl std::fmt::Display for PlanValidationError {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match &self.error {
40 ValidationErrorKind::InvalidTemplateSyntax => {
41 write!(
42 f,
43 "Tool {}: Invalid template syntax '{}' for argument '{}'",
44 self.tool_index, self.template, self.argument
45 )
46 },
47 ValidationErrorKind::IndexOutOfBounds {
48 referenced_index,
49 max_valid_index,
50 } => {
51 write!(
52 f,
53 "Tool {}: Template '{}' references tool {} but only tools 0-{} are available",
54 self.tool_index, self.template, referenced_index, max_valid_index
55 )
56 },
57 ValidationErrorKind::SelfReference => {
58 write!(
59 f,
60 "Tool {}: Template '{}' cannot reference itself",
61 self.tool_index, self.template
62 )
63 },
64 ValidationErrorKind::ForwardReference { referenced_index } => {
65 write!(
66 f,
67 "Tool {}: Template '{}' references tool {} which hasn't executed yet",
68 self.tool_index, self.template, referenced_index
69 )
70 },
71 ValidationErrorKind::FieldNotFound {
72 tool_name,
73 field,
74 available_fields,
75 } => {
76 write!(
77 f,
78 "Tool {}: Template '{}' references field '{}' but tool '{}' outputs: [{}]",
79 self.tool_index,
80 self.template,
81 field,
82 tool_name,
83 available_fields.join(", ")
84 )
85 },
86 ValidationErrorKind::NoOutputSchema { tool_name } => {
87 write!(
88 f,
89 "Tool {}: Template '{}' references '{}' which has no output schema",
90 self.tool_index, self.template, tool_name
91 )
92 },
93 }
94 }
95}
96
97impl std::error::Error for PlanValidationError {}
98
99#[derive(Debug, Clone, Copy)]
100pub struct TemplateValidator;
101
102impl TemplateValidator {
103 pub fn get_tool_output_schemas(
104 calls: &[PlannedToolCall],
105 tools: &[McpTool],
106 ) -> Vec<(String, Option<Value>)> {
107 calls
108 .iter()
109 .map(|call| {
110 let output_schema = tools
111 .iter()
112 .find(|t| t.name == call.tool_name)
113 .and_then(|t| t.output_schema.clone());
114 (call.tool_name.clone(), output_schema)
115 })
116 .collect()
117 }
118
119 pub fn find_templates_in_value(value: &Value) -> Vec<String> {
120 let mut templates = Vec::new();
121 Self::collect_templates(value, &mut templates);
122 templates
123 }
124
125 fn collect_templates(value: &Value, templates: &mut Vec<String>) {
126 match value {
127 Value::String(s) if s.starts_with('$') && s.contains(".output.") => {
128 templates.push(s.clone());
129 },
130 Value::Array(arr) => {
131 for v in arr {
132 Self::collect_templates(v, templates);
133 }
134 },
135 Value::Object(obj) => {
136 for v in obj.values() {
137 Self::collect_templates(v, templates);
138 }
139 },
140 Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {},
141 }
142 }
143
144 pub fn validate_plan(
145 calls: &[PlannedToolCall],
146 tool_output_schemas: &[(String, Option<Value>)],
147 ) -> Result<(), Vec<PlanValidationError>> {
148 let mut errors = Vec::new();
149
150 for (tool_index, call) in calls.iter().enumerate() {
151 for template in Self::find_templates_in_value(&call.arguments) {
152 if let Some(err) =
153 Self::validate_template(tool_index, call, &template, tool_output_schemas)
154 {
155 errors.push(err);
156 }
157 }
158 }
159
160 if errors.is_empty() {
161 Ok(())
162 } else {
163 Err(errors)
164 }
165 }
166
167 fn validate_template(
168 tool_index: usize,
169 call: &PlannedToolCall,
170 template: &str,
171 tool_output_schemas: &[(String, Option<Value>)],
172 ) -> Option<PlanValidationError> {
173 let make_error = |error: ValidationErrorKind| PlanValidationError {
174 tool_index,
175 argument: Self::find_argument_for_template(&call.arguments, template),
176 template: template.to_string(),
177 error,
178 };
179
180 let Some(template_ref) = TemplateRef::parse(template) else {
181 return Some(make_error(ValidationErrorKind::InvalidTemplateSyntax));
182 };
183
184 if template_ref.tool_index == tool_index {
185 return Some(make_error(ValidationErrorKind::SelfReference));
186 }
187 if template_ref.tool_index > tool_index {
188 return Some(make_error(ValidationErrorKind::ForwardReference {
189 referenced_index: template_ref.tool_index,
190 }));
191 }
192 if template_ref.tool_index >= tool_output_schemas.len() {
193 return Some(make_error(ValidationErrorKind::IndexOutOfBounds {
194 referenced_index: template_ref.tool_index,
195 max_valid_index: tool_output_schemas.len().saturating_sub(1),
196 }));
197 }
198
199 let (ref_tool_name, ref_output_schema) = &tool_output_schemas[template_ref.tool_index];
200
201 ref_output_schema.as_ref().map_or_else(
202 || {
203 Some(make_error(ValidationErrorKind::NoOutputSchema {
204 tool_name: ref_tool_name.clone(),
205 }))
206 },
207 |schema| {
208 Self::validate_field_access(&template_ref, schema, ref_tool_name).map(make_error)
209 },
210 )
211 }
212
213 fn validate_field_access(
214 template_ref: &TemplateRef,
215 schema: &Value,
216 tool_name: &str,
217 ) -> Option<ValidationErrorKind> {
218 let first_field = template_ref.field_path.first()?;
219 let available_fields = Self::get_schema_fields(schema);
220
221 if available_fields.contains(first_field) {
222 None
223 } else {
224 Some(ValidationErrorKind::FieldNotFound {
225 tool_name: tool_name.to_string(),
226 field: first_field.clone(),
227 available_fields,
228 })
229 }
230 }
231
232 fn find_argument_for_template(value: &Value, template: &str) -> String {
233 if let Value::Object(obj) = value {
234 for (key, val) in obj {
235 if let Value::String(s) = val {
236 if s == template {
237 return key.clone();
238 }
239 }
240 let nested = Self::find_argument_for_template(val, template);
241 if !nested.is_empty() {
242 return format!("{key}.{nested}");
243 }
244 }
245 }
246 String::new()
247 }
248
249 fn get_schema_fields(schema: &Value) -> Vec<String> {
250 schema
251 .get("properties")
252 .and_then(|p| p.as_object())
253 .map(|obj| obj.keys().cloned().collect())
254 .unwrap_or_default()
255 }
256}