systemprompt_models/ai/
execution_plan.rs1use regex::Regex;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum PlanningResult {
16 DirectResponse {
17 content: String,
18 },
19 ToolCalls {
20 reasoning: String,
21 calls: Vec<PlannedToolCall>,
22 },
23}
24
25impl PlanningResult {
26 pub fn direct_response(content: impl Into<String>) -> Self {
27 Self::DirectResponse {
28 content: content.into(),
29 }
30 }
31
32 pub fn tool_calls(reasoning: impl Into<String>, calls: Vec<PlannedToolCall>) -> Self {
33 Self::ToolCalls {
34 reasoning: reasoning.into(),
35 calls,
36 }
37 }
38
39 pub const fn is_direct(&self) -> bool {
40 matches!(self, Self::DirectResponse { .. })
41 }
42
43 pub const fn is_tool_calls(&self) -> bool {
44 matches!(self, Self::ToolCalls { .. })
45 }
46
47 pub fn tool_count(&self) -> usize {
48 match self {
49 Self::DirectResponse { .. } => 0,
50 Self::ToolCalls { calls, .. } => calls.len(),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct PlannedToolCall {
57 pub tool_name: String,
58 pub arguments: Value,
59}
60
61impl PlannedToolCall {
62 pub fn new(tool_name: impl Into<String>, arguments: Value) -> Self {
63 Self {
64 tool_name: tool_name.into(),
65 arguments,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ToolCallResult {
72 pub tool_name: String,
73 pub arguments: Value,
74 pub success: bool,
75 pub output: Value,
76 pub error: Option<String>,
77 pub duration_ms: u64,
78}
79
80impl ToolCallResult {
81 pub const fn success(
82 tool_name: String,
83 arguments: Value,
84 output: Value,
85 duration_ms: u64,
86 ) -> Self {
87 Self {
88 tool_name,
89 arguments,
90 success: true,
91 output,
92 error: None,
93 duration_ms,
94 }
95 }
96
97 pub fn failure(
98 tool_name: String,
99 arguments: Value,
100 error: impl Into<String>,
101 duration_ms: u64,
102 ) -> Self {
103 Self {
104 tool_name,
105 arguments,
106 success: false,
107 output: Value::Null,
108 error: Some(error.into()),
109 duration_ms,
110 }
111 }
112}
113
114#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115pub struct ExecutionState {
116 pub results: Vec<ToolCallResult>,
117 pub halted: bool,
118 pub halt_reason: Option<String>,
119}
120
121impl ExecutionState {
122 pub fn new() -> Self {
123 Self::default()
124 }
125
126 pub fn add_result(&mut self, result: ToolCallResult) {
127 if !result.success && !self.halted {
128 self.halted = true;
129 self.halt_reason.clone_from(&result.error);
130 }
131 self.results.push(result);
132 }
133
134 pub fn successful_results(&self) -> Vec<&ToolCallResult> {
135 self.results.iter().filter(|r| r.success).collect()
136 }
137
138 pub fn failed_results(&self) -> Vec<&ToolCallResult> {
139 self.results.iter().filter(|r| !r.success).collect()
140 }
141
142 pub fn total_duration_ms(&self) -> u64 {
143 self.results.iter().map(|r| r.duration_ms).sum()
144 }
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct TemplateRef {
149 pub tool_index: usize,
150 pub field_path: Vec<String>,
151}
152
153impl TemplateRef {
154 pub fn parse(template: &str) -> Option<Self> {
155 let re = Regex::new(r"^\$(\d+)\.output\.(.+)$").ok()?;
156 let caps = re.captures(template)?;
157
158 let tool_index = caps.get(1)?.as_str().parse().ok()?;
159 let path = caps.get(2)?.as_str();
160 let field_path = path.split('.').map(String::from).collect();
161
162 Some(Self {
163 tool_index,
164 field_path,
165 })
166 }
167
168 pub fn format(&self) -> String {
169 format!("${}.output.{}", self.tool_index, self.field_path.join("."))
170 }
171}