1#![allow(clippy::unwrap_used)]
2use anyhow::Result;
3use rand::RngExt;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, Mutex as StdMutex};
9use tokio::sync::Mutex;
10use tracing::info;
11
12use crate::state::persistence::{
13 delete_bash_state as delete_state_file, load_bash_state as load_state_file,
14 save_bash_state as save_state_file, BashStateSnapshot,
15};
16use crate::state::pty::PtyShell;
17use crate::state::terminal::{
18 incremental_text, TerminalEmulator, TerminalOutputDiff, DEFAULT_MAX_SCREEN_LINES,
19 MAX_OUTPUT_SIZE as TERMINAL_MAX_OUTPUT_SIZE,
20};
21use crate::types::{
22 AllowedCommands, AllowedGlobs, BashCommandMode, BashMode, FileEditMode, Modes, WriteIfEmptyMode,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct FileWhitelistData {
27 pub file_hash: String,
28 pub line_ranges_read: Vec<(usize, usize)>,
29 pub total_lines: usize,
30}
31
32impl FileWhitelistData {
33 pub fn new(
34 file_hash: String,
35 line_ranges_read: Vec<(usize, usize)>,
36 total_lines: usize,
37 ) -> Self {
38 Self { file_hash, line_ranges_read, total_lines }
39 }
40
41 pub fn is_read_enough(&self) -> bool {
42 self.get_percentage_read() >= 99.0
43 }
44
45 pub fn get_percentage_read(&self) -> f64 {
46 if self.total_lines == 0 {
47 return 100.0;
48 }
49 let mut lines_read = std::collections::HashSet::new();
50 for (start, end) in &self.line_ranges_read {
51 for line in *start..=*end {
52 lines_read.insert(line);
53 }
54 }
55 (lines_read.len() as f64 / self.total_lines as f64) * 100.0
56 }
57
58 pub fn get_unread_ranges(&self) -> Vec<(usize, usize)> {
59 if self.total_lines == 0 {
60 return vec![];
61 }
62 let mut lines_read = std::collections::HashSet::new();
63 for (start, end) in &self.line_ranges_read {
64 for line in *start..=*end {
65 lines_read.insert(line);
66 }
67 }
68 let mut unread = vec![];
69 let mut start_range = None;
70 for i in 1..=self.total_lines {
71 if !lines_read.contains(&i) {
72 if start_range.is_none() {
73 start_range = Some(i);
74 }
75 } else if let Some(start) = start_range {
76 unread.push((start, i - 1));
77 start_range = None;
78 }
79 }
80 if let Some(start) = start_range {
81 unread.push((start, self.total_lines));
82 }
83 unread
84 }
85
86 pub fn add_range(&mut self, start: usize, end: usize) {
87 self.line_ranges_read.push((start, end));
88 }
89
90 pub fn get_read_error_message(&self, file_path: &Path) -> String {
91 format!(
92 "File {} needs more reading. Coverage: {:.1}%",
93 file_path.display(),
94 self.get_percentage_read()
95 )
96 }
97
98 pub fn needs_more_reading(&self) -> bool {
99 !self.is_read_enough()
100 }
101}
102
103#[derive(Debug, Clone)]
104pub struct TerminalState {
105 pub last_command: String,
106 pub last_pending_output: String,
107 pub command_running: bool,
108 pub terminal_emulator: Arc<StdMutex<TerminalEmulator>>,
109 pub diff_detector: Option<TerminalOutputDiff>,
110 pub limit_buffer: bool,
111 pub max_buffer_lines: usize,
112}
113
114impl Default for TerminalState {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl TerminalState {
121 pub fn new() -> Self {
122 Self {
123 last_command: String::new(),
124 last_pending_output: String::new(),
125 command_running: false,
126 terminal_emulator: Arc::new(StdMutex::new(TerminalEmulator::new(160))),
127 diff_detector: Some(TerminalOutputDiff::new()),
128 limit_buffer: false,
129 max_buffer_lines: DEFAULT_MAX_SCREEN_LINES,
130 }
131 }
132
133 pub fn process_output(&mut self, output: &str) -> String {
134 self.last_pending_output = output.to_string();
135 if let Ok(mut emulator) = self.terminal_emulator.lock() {
136 emulator.process(output);
137 emulator.display().join("\n")
138 } else {
139 output.to_string()
140 }
141 }
142
143 pub fn get_incremental_output(&mut self, output: &str) -> String {
144 let result = incremental_text(output, &self.last_pending_output);
145 self.last_pending_output = output.to_string();
146 result
147 }
148
149 pub fn smart_truncate(&mut self, max_size: usize) {
150 if let Ok(screen) = self.terminal_emulator.lock() {
151 if let Ok(mut screen_guard) = screen.get_screen().lock() {
152 screen_guard.smart_truncate(max_size);
153 }
154 }
155 }
156}
157
158#[derive(Debug, Clone)]
159pub struct BashState {
160 pub cwd: PathBuf,
161 pub workspace_root: PathBuf,
162 pub current_thread_id: String,
163 pub mode: Modes,
164 pub bash_command_mode: BashCommandMode,
165 pub file_edit_mode: FileEditMode,
166 pub write_if_empty_mode: WriteIfEmptyMode,
167 pub whitelist_for_overwrite: HashMap<String, FileWhitelistData>,
168 pub terminal_state: TerminalState,
169 pub pty_shell: Arc<Mutex<Option<PtyShell>>>,
170 pub initialized: bool,
171}
172
173impl Default for BashState {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179impl BashState {
180 pub fn new() -> Self {
181 let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp"));
182 Self {
183 cwd: cwd.clone(),
184 workspace_root: cwd,
185 current_thread_id: generate_thread_id(),
186 mode: Modes::Wcgw,
187 bash_command_mode: BashCommandMode {
188 bash_mode: BashMode::NormalMode,
189 allowed_commands: AllowedCommands::All("all".to_string()),
190 },
191 file_edit_mode: FileEditMode { allowed_globs: AllowedGlobs::All("all".to_string()) },
192 write_if_empty_mode: WriteIfEmptyMode {
193 allowed_globs: AllowedGlobs::All("all".to_string()),
194 },
195 whitelist_for_overwrite: HashMap::new(),
196 terminal_state: TerminalState::new(),
197 pty_shell: Arc::new(Mutex::new(None)),
198 initialized: false,
199 }
200 }
201
202 pub async fn init_pty_shell(&mut self) -> Result<()> {
203 let shell =
204 PtyShell::new(&self.cwd, self.bash_command_mode.bash_mode == BashMode::RestrictedMode)?;
205 *self.pty_shell.lock().await = Some(shell);
206 Ok(())
207 }
208
209 pub fn update_cwd(&mut self, path: &Path) -> Result<()> {
210 self.cwd = path.to_path_buf();
211 Ok(())
212 }
213
214 pub fn update_workspace_root(&mut self, path: &Path) -> Result<()> {
215 self.workspace_root = path.to_path_buf();
216 Ok(())
217 }
218
219 pub fn is_command_allowed(&self, command: &str) -> bool {
220 self.bash_command_mode.allowed_commands.is_allowed(command)
221 }
222
223 pub fn is_file_edit_allowed(&self, path: &str) -> bool {
224 self.file_edit_mode.allowed_globs.is_allowed(path)
225 }
226
227 pub fn is_file_write_allowed(&self, path: &str) -> bool {
228 self.write_if_empty_mode.allowed_globs.is_allowed(path)
229 }
230 pub fn get_mode_violation_message(&self, op: &str, _target: &str) -> String {
231 format!("Operation {op} not allowed")
232 }
233
234 pub fn save_state_to_disk(&self) -> Result<()> {
235 let snapshot = BashStateSnapshot::from_state(
236 &self.cwd.to_string_lossy(),
237 &self.workspace_root.to_string_lossy(),
238 &self.mode,
239 &self.bash_command_mode,
240 &self.file_edit_mode,
241 &self.write_if_empty_mode,
242 &self.whitelist_for_overwrite,
243 &self.current_thread_id,
244 );
245 save_state_file(&self.current_thread_id, &snapshot)?;
246 Ok(())
247 }
248
249 pub fn load_state_from_disk(&mut self, thread_id: &str) -> Result<bool> {
250 if let Some(snapshot) = load_state_file(thread_id)? {
251 let (cwd, root, mode, bmode, emode, wmode, whitelist, tid) =
252 snapshot.to_state_components();
253
254 self.cwd = PathBuf::from(cwd);
255
256 self.workspace_root = PathBuf::from(root);
257
258 self.mode = mode;
259
260 self.bash_command_mode = bmode;
261
262 self.file_edit_mode = emode;
263
264 self.write_if_empty_mode = wmode;
265
266 self.whitelist_for_overwrite = whitelist;
267
268 self.current_thread_id = tid;
269
270 self.initialized = true;
271
272 Ok(true)
273 } else {
274 Ok(false)
275 }
276 }
277
278 pub fn new_with_thread_id(thread_id: Option<&str>) -> Self {
279 let mut state = Self::new();
280
281 if let Some(tid) = thread_id {
282 if !tid.is_empty() {
283 if let Ok(true) = state.load_state_from_disk(tid) {
284 info!("Loaded state for thread_id '{}'", tid);
285 } else {
286 state.current_thread_id = tid.to_string();
287 }
288 }
289 }
290
291 state
292 }
293}
294
295pub fn generate_thread_id() -> String {
296 let mut rng = rand::rng();
297 format!("tid_{:x}", rng.random::<u64>())
298}