systemprompt_models/ai/
execution_plan.rs

1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6#[serde(tag = "type", rename_all = "snake_case")]
7pub enum PlanningResult {
8    DirectResponse {
9        content: String,
10    },
11    ToolCalls {
12        reasoning: String,
13        calls: Vec<PlannedToolCall>,
14    },
15}
16
17impl PlanningResult {
18    pub fn direct_response(content: impl Into<String>) -> Self {
19        Self::DirectResponse {
20            content: content.into(),
21        }
22    }
23
24    pub fn tool_calls(reasoning: impl Into<String>, calls: Vec<PlannedToolCall>) -> Self {
25        Self::ToolCalls {
26            reasoning: reasoning.into(),
27            calls,
28        }
29    }
30
31    pub const fn is_direct(&self) -> bool {
32        matches!(self, Self::DirectResponse { .. })
33    }
34
35    pub const fn is_tool_calls(&self) -> bool {
36        matches!(self, Self::ToolCalls { .. })
37    }
38
39    pub fn tool_count(&self) -> usize {
40        match self {
41            Self::DirectResponse { .. } => 0,
42            Self::ToolCalls { calls, .. } => calls.len(),
43        }
44    }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct PlannedToolCall {
49    pub tool_name: String,
50    pub arguments: Value,
51}
52
53impl PlannedToolCall {
54    pub fn new(tool_name: impl Into<String>, arguments: Value) -> Self {
55        Self {
56            tool_name: tool_name.into(),
57            arguments,
58        }
59    }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ToolCallResult {
64    pub tool_name: String,
65    pub arguments: Value,
66    pub success: bool,
67    pub output: Value,
68    pub error: Option<String>,
69    pub duration_ms: u64,
70}
71
72impl ToolCallResult {
73    pub const fn success(
74        tool_name: String,
75        arguments: Value,
76        output: Value,
77        duration_ms: u64,
78    ) -> Self {
79        Self {
80            tool_name,
81            arguments,
82            success: true,
83            output,
84            error: None,
85            duration_ms,
86        }
87    }
88
89    pub fn failure(
90        tool_name: String,
91        arguments: Value,
92        error: impl Into<String>,
93        duration_ms: u64,
94    ) -> Self {
95        Self {
96            tool_name,
97            arguments,
98            success: false,
99            output: Value::Null,
100            error: Some(error.into()),
101            duration_ms,
102        }
103    }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
107pub struct ExecutionState {
108    pub results: Vec<ToolCallResult>,
109    pub halted: bool,
110    pub halt_reason: Option<String>,
111}
112
113impl ExecutionState {
114    pub fn new() -> Self {
115        Self::default()
116    }
117
118    pub fn add_result(&mut self, result: ToolCallResult) {
119        if !result.success && !self.halted {
120            self.halted = true;
121            self.halt_reason.clone_from(&result.error);
122        }
123        self.results.push(result);
124    }
125
126    pub fn successful_results(&self) -> Vec<&ToolCallResult> {
127        self.results.iter().filter(|r| r.success).collect()
128    }
129
130    pub fn failed_results(&self) -> Vec<&ToolCallResult> {
131        self.results.iter().filter(|r| !r.success).collect()
132    }
133
134    pub fn total_duration_ms(&self) -> u64 {
135        self.results.iter().map(|r| r.duration_ms).sum()
136    }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct TemplateRef {
141    pub tool_index: usize,
142    pub field_path: Vec<String>,
143}
144
145impl TemplateRef {
146    pub fn parse(template: &str) -> Option<Self> {
147        let re = Regex::new(r"^\$(\d+)\.output\.(.+)$").ok()?;
148        let caps = re.captures(template)?;
149
150        let tool_index = caps.get(1)?.as_str().parse().ok()?;
151        let path = caps.get(2)?.as_str();
152        let field_path = path.split('.').map(String::from).collect();
153
154        Some(Self {
155            tool_index,
156            field_path,
157        })
158    }
159
160    pub fn format(&self) -> String {
161        format!("${}.output.{}", self.tool_index, self.field_path.join("."))
162    }
163}