1use crate::agent::{AgentMessage, AgentTool};
4use crate::thinking::ThinkingLevel;
5use crate::types::Model;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashSet;
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10
11#[derive(Debug)]
17pub struct AgentState {
18 pub system_prompt: RwLock<String>,
20 pub tools: RwLock<Vec<AgentTool>>,
22 pub messages: RwLock<Vec<AgentMessage>>,
24 pub is_streaming: AtomicBool,
26 pub stream_message: RwLock<Option<AgentMessage>>,
28 pub pending_tool_calls: RwLock<HashSet<String>>,
30 pub error: RwLock<Option<String>>,
32 pub max_messages: AtomicUsize,
35}
36
37impl AgentState {
38 pub fn new() -> Self {
40 Self {
41 system_prompt: RwLock::new(String::new()),
42 tools: RwLock::new(Vec::new()),
43 messages: RwLock::new(Vec::new()),
44 is_streaming: AtomicBool::new(false),
45 stream_message: RwLock::new(None),
46 pending_tool_calls: RwLock::new(HashSet::new()),
47 error: RwLock::new(None),
48 max_messages: AtomicUsize::new(0), }
50 }
51
52 pub fn set_system_prompt(&self, prompt: impl Into<String>) {
54 *self.system_prompt.write() = prompt.into();
55 }
56
57 pub fn set_tools(&self, tools: Vec<AgentTool>) {
59 *self.tools.write() = tools;
60 }
61
62 pub fn add_message(&self, message: AgentMessage) {
65 let mut msgs = self.messages.write();
66 msgs.push(message);
67 let max = self.max_messages.load(Ordering::SeqCst);
68 if max > 0 && msgs.len() > max {
69 let excess = msgs.len() - max;
70 msgs.drain(..excess);
71 }
72 }
73
74 pub fn set_max_messages(&self, max: usize) {
77 self.max_messages.store(max, Ordering::SeqCst);
78 if max > 0 {
80 let mut msgs = self.messages.write();
81 if msgs.len() > max {
82 let excess = msgs.len() - max;
83 msgs.drain(..excess);
84 }
85 }
86 }
87
88 pub fn get_max_messages(&self) -> usize {
90 self.max_messages.load(Ordering::SeqCst)
91 }
92
93 pub fn replace_messages(&self, messages: Vec<AgentMessage>) {
95 *self.messages.write() = messages;
96 }
97
98 pub fn clear_messages(&self) {
100 self.messages.write().clear();
101 }
102
103 pub fn reset(&self) {
107 *self.system_prompt.write() = String::new();
108 *self.tools.write() = Vec::new();
109 self.messages.write().clear();
110 self.is_streaming.store(false, Ordering::SeqCst);
111 *self.stream_message.write() = None;
112 self.pending_tool_calls.write().clear();
113 *self.error.write() = None;
114 }
115
116 pub fn is_streaming(&self) -> bool {
118 self.is_streaming.load(Ordering::SeqCst)
119 }
120
121 pub fn set_streaming(&self, value: bool) {
123 self.is_streaming.store(value, Ordering::SeqCst);
124 }
125
126 pub fn message_count(&self) -> usize {
128 self.messages.read().len()
129 }
130}
131
132impl Default for AgentState {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138impl Clone for AgentState {
142 fn clone(&self) -> Self {
143 Self {
144 system_prompt: RwLock::new(self.system_prompt.read().clone()),
145 tools: RwLock::new(self.tools.read().clone()),
146 messages: RwLock::new(self.messages.read().clone()),
147 is_streaming: AtomicBool::new(self.is_streaming.load(Ordering::SeqCst)),
148 stream_message: RwLock::new(self.stream_message.read().clone()),
149 pending_tool_calls: RwLock::new(self.pending_tool_calls.read().clone()),
150 error: RwLock::new(self.error.read().clone()),
151 max_messages: AtomicUsize::new(self.max_messages.load(Ordering::SeqCst)),
152 }
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct AgentStateSnapshot {
168 pub system_prompt: String,
170 pub model: Model,
172 pub thinking_level: ThinkingLevel,
174 pub messages: Vec<AgentMessage>,
176 pub is_streaming: bool,
178 pub stream_message: Option<AgentMessage>,
180 pub pending_tool_calls: HashSet<String>,
182 pub error: Option<String>,
184 pub message_count: usize,
186 pub max_messages: usize,
188}