systemprompt_models/ai/
execution_plan.rs1use regex::Regex;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6#[serde(tag = "type", rename_all = "snake_case")]
7pub enum PlanningResult {
8 DirectResponse {
9 content: String,
10 },
11 ToolCalls {
12 reasoning: String,
13 calls: Vec<PlannedToolCall>,
14 },
15}
16
17impl PlanningResult {
18 pub fn direct_response(content: impl Into<String>) -> Self {
19 Self::DirectResponse {
20 content: content.into(),
21 }
22 }
23
24 pub fn tool_calls(reasoning: impl Into<String>, calls: Vec<PlannedToolCall>) -> Self {
25 Self::ToolCalls {
26 reasoning: reasoning.into(),
27 calls,
28 }
29 }
30
31 pub const fn is_direct(&self) -> bool {
32 matches!(self, Self::DirectResponse { .. })
33 }
34
35 pub const fn is_tool_calls(&self) -> bool {
36 matches!(self, Self::ToolCalls { .. })
37 }
38
39 pub fn tool_count(&self) -> usize {
40 match self {
41 Self::DirectResponse { .. } => 0,
42 Self::ToolCalls { calls, .. } => calls.len(),
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct PlannedToolCall {
49 pub tool_name: String,
50 pub arguments: Value,
51}
52
53impl PlannedToolCall {
54 pub fn new(tool_name: impl Into<String>, arguments: Value) -> Self {
55 Self {
56 tool_name: tool_name.into(),
57 arguments,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct ToolCallResult {
64 pub tool_name: String,
65 pub arguments: Value,
66 pub success: bool,
67 pub output: Value,
68 pub error: Option<String>,
69 pub duration_ms: u64,
70}
71
72impl ToolCallResult {
73 pub const fn success(
74 tool_name: String,
75 arguments: Value,
76 output: Value,
77 duration_ms: u64,
78 ) -> Self {
79 Self {
80 tool_name,
81 arguments,
82 success: true,
83 output,
84 error: None,
85 duration_ms,
86 }
87 }
88
89 pub fn failure(
90 tool_name: String,
91 arguments: Value,
92 error: impl Into<String>,
93 duration_ms: u64,
94 ) -> Self {
95 Self {
96 tool_name,
97 arguments,
98 success: false,
99 output: Value::Null,
100 error: Some(error.into()),
101 duration_ms,
102 }
103 }
104}
105
106#[derive(Debug, Clone, Default, Serialize, Deserialize)]
107pub struct ExecutionState {
108 pub results: Vec<ToolCallResult>,
109 pub halted: bool,
110 pub halt_reason: Option<String>,
111}
112
113impl ExecutionState {
114 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn add_result(&mut self, result: ToolCallResult) {
119 if !result.success && !self.halted {
120 self.halted = true;
121 self.halt_reason.clone_from(&result.error);
122 }
123 self.results.push(result);
124 }
125
126 pub fn successful_results(&self) -> Vec<&ToolCallResult> {
127 self.results.iter().filter(|r| r.success).collect()
128 }
129
130 pub fn failed_results(&self) -> Vec<&ToolCallResult> {
131 self.results.iter().filter(|r| !r.success).collect()
132 }
133
134 pub fn total_duration_ms(&self) -> u64 {
135 self.results.iter().map(|r| r.duration_ms).sum()
136 }
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct TemplateRef {
141 pub tool_index: usize,
142 pub field_path: Vec<String>,
143}
144
145impl TemplateRef {
146 pub fn parse(template: &str) -> Option<Self> {
147 let re = Regex::new(r"^\$(\d+)\.output\.(.+)$").ok()?;
148 let caps = re.captures(template)?;
149
150 let tool_index = caps.get(1)?.as_str().parse().ok()?;
151 let path = caps.get(2)?.as_str();
152 let field_path = path.split('.').map(String::from).collect();
153
154 Some(Self {
155 tool_index,
156 field_path,
157 })
158 }
159
160 pub fn format(&self) -> String {
161 format!("${}.output.{}", self.tool_index, self.field_path.join("."))
162 }
163}