Skip to main content

systemprompt_models/ai/
template_validation.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use super::execution_plan::{PlannedToolCall, TemplateRef};
5use super::tools::McpTool;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PlanValidationError {
9    pub tool_index: usize,
10    pub argument: String,
11    pub template: String,
12    pub error: ValidationErrorKind,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16#[serde(tag = "type", rename_all = "snake_case")]
17pub enum ValidationErrorKind {
18    InvalidTemplateSyntax,
19    IndexOutOfBounds {
20        referenced_index: usize,
21        max_valid_index: usize,
22    },
23    SelfReference,
24    ForwardReference {
25        referenced_index: usize,
26    },
27    FieldNotFound {
28        tool_name: String,
29        field: String,
30        available_fields: Vec<String>,
31    },
32    NoOutputSchema {
33        tool_name: String,
34    },
35}
36
37impl std::fmt::Display for PlanValidationError {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match &self.error {
40            ValidationErrorKind::InvalidTemplateSyntax => {
41                write!(
42                    f,
43                    "Tool {}: Invalid template syntax '{}' for argument '{}'",
44                    self.tool_index, self.template, self.argument
45                )
46            },
47            ValidationErrorKind::IndexOutOfBounds {
48                referenced_index,
49                max_valid_index,
50            } => {
51                write!(
52                    f,
53                    "Tool {}: Template '{}' references tool {} but only tools 0-{} are available",
54                    self.tool_index, self.template, referenced_index, max_valid_index
55                )
56            },
57            ValidationErrorKind::SelfReference => {
58                write!(
59                    f,
60                    "Tool {}: Template '{}' cannot reference itself",
61                    self.tool_index, self.template
62                )
63            },
64            ValidationErrorKind::ForwardReference { referenced_index } => {
65                write!(
66                    f,
67                    "Tool {}: Template '{}' references tool {} which hasn't executed yet",
68                    self.tool_index, self.template, referenced_index
69                )
70            },
71            ValidationErrorKind::FieldNotFound {
72                tool_name,
73                field,
74                available_fields,
75            } => {
76                write!(
77                    f,
78                    "Tool {}: Template '{}' references field '{}' but tool '{}' outputs: [{}]",
79                    self.tool_index,
80                    self.template,
81                    field,
82                    tool_name,
83                    available_fields.join(", ")
84                )
85            },
86            ValidationErrorKind::NoOutputSchema { tool_name } => {
87                write!(
88                    f,
89                    "Tool {}: Template '{}' references '{}' which has no output schema",
90                    self.tool_index, self.template, tool_name
91                )
92            },
93        }
94    }
95}
96
97impl std::error::Error for PlanValidationError {}
98
99#[derive(Debug, Clone, Copy)]
100pub struct TemplateValidator;
101
102impl TemplateValidator {
103    pub fn get_tool_output_schemas(
104        calls: &[PlannedToolCall],
105        tools: &[McpTool],
106    ) -> Vec<(String, Option<Value>)> {
107        calls
108            .iter()
109            .map(|call| {
110                let output_schema = tools
111                    .iter()
112                    .find(|t| t.name == call.tool_name)
113                    .and_then(|t| t.output_schema.clone());
114                (call.tool_name.clone(), output_schema)
115            })
116            .collect()
117    }
118
119    pub fn find_templates_in_value(value: &Value) -> Vec<String> {
120        let mut templates = Vec::new();
121        Self::collect_templates(value, &mut templates);
122        templates
123    }
124
125    fn collect_templates(value: &Value, templates: &mut Vec<String>) {
126        match value {
127            Value::String(s) if s.starts_with('$') && s.contains(".output.") => {
128                templates.push(s.clone());
129            },
130            Value::Array(arr) => {
131                for v in arr {
132                    Self::collect_templates(v, templates);
133                }
134            },
135            Value::Object(obj) => {
136                for v in obj.values() {
137                    Self::collect_templates(v, templates);
138                }
139            },
140            Value::Null | Value::Bool(_) | Value::Number(_) | Value::String(_) => {},
141        }
142    }
143
144    pub fn validate_plan(
145        calls: &[PlannedToolCall],
146        tool_output_schemas: &[(String, Option<Value>)],
147    ) -> Result<(), Vec<PlanValidationError>> {
148        let mut errors = Vec::new();
149
150        for (tool_index, call) in calls.iter().enumerate() {
151            for template in Self::find_templates_in_value(&call.arguments) {
152                if let Some(err) =
153                    Self::validate_template(tool_index, call, &template, tool_output_schemas)
154                {
155                    errors.push(err);
156                }
157            }
158        }
159
160        if errors.is_empty() {
161            Ok(())
162        } else {
163            Err(errors)
164        }
165    }
166
167    fn validate_template(
168        tool_index: usize,
169        call: &PlannedToolCall,
170        template: &str,
171        tool_output_schemas: &[(String, Option<Value>)],
172    ) -> Option<PlanValidationError> {
173        let make_error = |error: ValidationErrorKind| PlanValidationError {
174            tool_index,
175            argument: Self::find_argument_for_template(&call.arguments, template),
176            template: template.to_string(),
177            error,
178        };
179
180        let Some(template_ref) = TemplateRef::parse(template) else {
181            return Some(make_error(ValidationErrorKind::InvalidTemplateSyntax));
182        };
183
184        if template_ref.tool_index == tool_index {
185            return Some(make_error(ValidationErrorKind::SelfReference));
186        }
187        if template_ref.tool_index > tool_index {
188            return Some(make_error(ValidationErrorKind::ForwardReference {
189                referenced_index: template_ref.tool_index,
190            }));
191        }
192        if template_ref.tool_index >= tool_output_schemas.len() {
193            return Some(make_error(ValidationErrorKind::IndexOutOfBounds {
194                referenced_index: template_ref.tool_index,
195                max_valid_index: tool_output_schemas.len().saturating_sub(1),
196            }));
197        }
198
199        let (ref_tool_name, ref_output_schema) = &tool_output_schemas[template_ref.tool_index];
200
201        ref_output_schema.as_ref().map_or_else(
202            || {
203                Some(make_error(ValidationErrorKind::NoOutputSchema {
204                    tool_name: ref_tool_name.clone(),
205                }))
206            },
207            |schema| {
208                Self::validate_field_access(&template_ref, schema, ref_tool_name).map(make_error)
209            },
210        )
211    }
212
213    fn validate_field_access(
214        template_ref: &TemplateRef,
215        schema: &Value,
216        tool_name: &str,
217    ) -> Option<ValidationErrorKind> {
218        let first_field = template_ref.field_path.first()?;
219        let available_fields = Self::get_schema_fields(schema);
220
221        if available_fields.contains(first_field) {
222            None
223        } else {
224            Some(ValidationErrorKind::FieldNotFound {
225                tool_name: tool_name.to_string(),
226                field: first_field.clone(),
227                available_fields,
228            })
229        }
230    }
231
232    fn find_argument_for_template(value: &Value, template: &str) -> String {
233        if let Value::Object(obj) = value {
234            for (key, val) in obj {
235                if let Value::String(s) = val {
236                    if s == template {
237                        return key.clone();
238                    }
239                }
240                let nested = Self::find_argument_for_template(val, template);
241                if !nested.is_empty() {
242                    return format!("{key}.{nested}");
243                }
244            }
245        }
246        String::new()
247    }
248
249    fn get_schema_fields(schema: &Value) -> Vec<String> {
250        schema
251            .get("properties")
252            .and_then(|p| p.as_object())
253            .map(|obj| obj.keys().cloned().collect())
254            .unwrap_or_default()
255    }
256}