1use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12use super::execution_plan::{PlannedToolCall, TemplateRef};
13use super::tools::McpTool;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PlanValidationError {
17 pub tool_index: usize,
18 pub argument: String,
19 pub template: String,
20 pub error: ValidationErrorKind,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24#[serde(tag = "type", rename_all = "snake_case")]
25pub enum ValidationErrorKind {
26 InvalidTemplateSyntax,
27 IndexOutOfBounds {
28 referenced_index: usize,
29 max_valid_index: usize,
30 },
31 SelfReference,
32 ForwardReference {
33 referenced_index: usize,
34 },
35 FieldNotFound {
36 tool_name: String,
37 field: String,
38 available_fields: Vec<String>,
39 },
40 NoOutputSchema {
41 tool_name: String,
42 },
43}
44
45impl std::fmt::Display for PlanValidationError {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 match &self.error {
48 ValidationErrorKind::InvalidTemplateSyntax => {
49 write!(
50 f,
51 "Tool {}: Invalid template syntax '{}' for argument '{}'",
52 self.tool_index, self.template, self.argument
53 )
54 },
55 ValidationErrorKind::IndexOutOfBounds {
56 referenced_index,
57 max_valid_index,
58 } => {
59 write!(
60 f,
61 "Tool {}: Template '{}' references tool {} but only tools 0-{} are available",
62 self.tool_index, self.template, referenced_index, max_valid_index
63 )
64 },
65 ValidationErrorKind::SelfReference => {
66 write!(
67 f,
68 "Tool {}: Template '{}' cannot reference itself",
69 self.tool_index, self.template
70 )
71 },
72 ValidationErrorKind::ForwardReference { referenced_index } => {
73 write!(
74 f,
75 "Tool {}: Template '{}' references tool {} which hasn't executed yet",
76 self.tool_index, self.template, referenced_index
77 )
78 },
79 ValidationErrorKind::FieldNotFound {
80 tool_name,
81 field,
82 available_fields,
83 } => {
84 write!(
85 f,
86 "Tool {}: Template '{}' references field '{}' but tool '{}' outputs: [{}]",
87 self.tool_index,
88 self.template,
89 field,
90 tool_name,
91 available_fields.join(", ")
92 )
93 },
94 ValidationErrorKind::NoOutputSchema { tool_name } => {
95 write!(
96 f,
97 "Tool {}: Template '{}' references '{}' which has no output schema",
98 self.tool_index, self.template, tool_name
99 )
100 },
101 }
102 }
103}
104
105impl std::error::Error for PlanValidationError {}
106
107#[derive(Debug, Clone, Copy)]
108pub struct TemplateValidator;
109
110impl TemplateValidator {
111 pub fn get_tool_output_schemas(
112 calls: &[PlannedToolCall],
113 tools: &[McpTool],
114 ) -> Vec<(String, Option<Value>)> {
115 calls
116 .iter()
117 .map(|call| {
118 let output_schema = tools
119 .iter()
120 .find(|t| t.name == call.tool_name)
121 .and_then(|t| t.output_schema.clone());
122 (call.tool_name.clone(), output_schema)
123 })
124 .collect()
125 }
126
127 pub fn find_templates_in_value(value: &Value) -> Vec<String> {
128 let mut templates = Vec::new();
129 Self::collect_templates(value, &mut templates);
130 templates
131 }
132
133 fn collect_templates(value: &Value, templates: &mut Vec<String>) {
134 match value {
135 Value::String(s) if s.starts_with('$') && s.contains(".output.") => {
136 templates.push(s.clone());
137 },
138 Value::Array(arr) => {
139 for v in arr {
140 Self::collect_templates(v, templates);
141 }
142 },
143 Value::Object(obj) => {
144 for v in obj.values() {
145 Self::collect_templates(v, templates);
146 }
147 },
148 Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {},
149 }
150 }
151
152 pub fn validate_plan(
153 calls: &[PlannedToolCall],
154 tool_output_schemas: &[(String, Option<Value>)],
155 ) -> Result<(), Vec<PlanValidationError>> {
156 let mut errors = Vec::new();
157
158 for (tool_index, call) in calls.iter().enumerate() {
159 for template in Self::find_templates_in_value(&call.arguments) {
160 if let Some(err) =
161 Self::validate_template(tool_index, call, &template, tool_output_schemas)
162 {
163 errors.push(err);
164 }
165 }
166 }
167
168 if errors.is_empty() {
169 Ok(())
170 } else {
171 Err(errors)
172 }
173 }
174
175 fn validate_template(
176 tool_index: usize,
177 call: &PlannedToolCall,
178 template: &str,
179 tool_output_schemas: &[(String, Option<Value>)],
180 ) -> Option<PlanValidationError> {
181 let make_error = |error: ValidationErrorKind| PlanValidationError {
182 tool_index,
183 argument: Self::find_argument_for_template(&call.arguments, template),
184 template: template.to_owned(),
185 error,
186 };
187
188 let Some(template_ref) = TemplateRef::parse(template) else {
189 return Some(make_error(ValidationErrorKind::InvalidTemplateSyntax));
190 };
191
192 if template_ref.tool_index == tool_index {
193 return Some(make_error(ValidationErrorKind::SelfReference));
194 }
195 if template_ref.tool_index > tool_index {
196 return Some(make_error(ValidationErrorKind::ForwardReference {
197 referenced_index: template_ref.tool_index,
198 }));
199 }
200 if template_ref.tool_index >= tool_output_schemas.len() {
201 return Some(make_error(ValidationErrorKind::IndexOutOfBounds {
202 referenced_index: template_ref.tool_index,
203 max_valid_index: tool_output_schemas.len().saturating_sub(1),
204 }));
205 }
206
207 let (ref_tool_name, ref_output_schema) = &tool_output_schemas[template_ref.tool_index];
208
209 ref_output_schema.as_ref().map_or_else(
210 || {
211 Some(make_error(ValidationErrorKind::NoOutputSchema {
212 tool_name: ref_tool_name.clone(),
213 }))
214 },
215 |schema| {
216 Self::validate_field_access(&template_ref, schema, ref_tool_name).map(make_error)
217 },
218 )
219 }
220
221 fn validate_field_access(
222 template_ref: &TemplateRef,
223 schema: &Value,
224 tool_name: &str,
225 ) -> Option<ValidationErrorKind> {
226 let first_field = template_ref.field_path.first()?;
227 let available_fields = Self::get_schema_fields(schema);
228
229 if available_fields.contains(first_field) {
230 None
231 } else {
232 Some(ValidationErrorKind::FieldNotFound {
233 tool_name: tool_name.to_owned(),
234 field: first_field.clone(),
235 available_fields,
236 })
237 }
238 }
239
240 fn find_argument_for_template(value: &Value, template: &str) -> String {
241 if let Value::Object(obj) = value {
242 for (key, val) in obj {
243 if let Value::String(s) = val {
244 if s == template {
245 return key.clone();
246 }
247 }
248 let nested = Self::find_argument_for_template(val, template);
249 if !nested.is_empty() {
250 return format!("{key}.{nested}");
251 }
252 }
253 }
254 String::new()
255 }
256
257 fn get_schema_fields(schema: &Value) -> Vec<String> {
258 schema
259 .get("properties")
260 .and_then(|p| p.as_object())
261 .map_or_else(Vec::new, |obj| obj.keys().cloned().collect())
262 }
263}