Skip to main content

winx_code_agent/state/
bash_state.rs

1#![allow(clippy::unwrap_used)]
2use anyhow::Result;
3use rand::RngExt;
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, Mutex as StdMutex};
9use tokio::sync::Mutex;
10use tracing::info;
11
12use crate::state::persistence::{
13    delete_bash_state as delete_state_file, load_bash_state as load_state_file,
14    save_bash_state as save_state_file, BashStateSnapshot,
15};
16use crate::state::pty::PtyShell;
17use crate::state::terminal::{
18    incremental_text, TerminalEmulator, TerminalOutputDiff, DEFAULT_MAX_SCREEN_LINES,
19    MAX_OUTPUT_SIZE as TERMINAL_MAX_OUTPUT_SIZE,
20};
21use crate::types::{
22    AllowedCommands, AllowedGlobs, BashCommandMode, BashMode, FileEditMode, Modes, WriteIfEmptyMode,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct FileWhitelistData {
27    pub file_hash: String,
28    pub line_ranges_read: Vec<(usize, usize)>,
29    pub total_lines: usize,
30}
31
32impl FileWhitelistData {
33    pub fn new(
34        file_hash: String,
35        line_ranges_read: Vec<(usize, usize)>,
36        total_lines: usize,
37    ) -> Self {
38        Self { file_hash, line_ranges_read, total_lines }
39    }
40
41    pub fn is_read_enough(&self) -> bool {
42        self.get_percentage_read() >= 99.0
43    }
44
45    pub fn get_percentage_read(&self) -> f64 {
46        if self.total_lines == 0 {
47            return 100.0;
48        }
49        let mut lines_read = std::collections::HashSet::new();
50        for (start, end) in &self.line_ranges_read {
51            for line in *start..=*end {
52                lines_read.insert(line);
53            }
54        }
55        (lines_read.len() as f64 / self.total_lines as f64) * 100.0
56    }
57
58    pub fn get_unread_ranges(&self) -> Vec<(usize, usize)> {
59        if self.total_lines == 0 {
60            return vec![];
61        }
62        let mut lines_read = std::collections::HashSet::new();
63        for (start, end) in &self.line_ranges_read {
64            for line in *start..=*end {
65                lines_read.insert(line);
66            }
67        }
68        let mut unread = vec![];
69        let mut start_range = None;
70        for i in 1..=self.total_lines {
71            if !lines_read.contains(&i) {
72                if start_range.is_none() {
73                    start_range = Some(i);
74                }
75            } else if let Some(start) = start_range {
76                unread.push((start, i - 1));
77                start_range = None;
78            }
79        }
80        if let Some(start) = start_range {
81            unread.push((start, self.total_lines));
82        }
83        unread
84    }
85
86    pub fn add_range(&mut self, start: usize, end: usize) {
87        self.line_ranges_read.push((start, end));
88    }
89
90    pub fn get_read_error_message(&self, file_path: &Path) -> String {
91        format!(
92            "File {} needs more reading. Coverage: {:.1}%",
93            file_path.display(),
94            self.get_percentage_read()
95        )
96    }
97
98    pub fn needs_more_reading(&self) -> bool {
99        !self.is_read_enough()
100    }
101}
102
103#[derive(Debug, Clone)]
104pub struct TerminalState {
105    pub last_command: String,
106    pub last_pending_output: String,
107    pub command_running: bool,
108    pub terminal_emulator: Arc<StdMutex<TerminalEmulator>>,
109    pub diff_detector: Option<TerminalOutputDiff>,
110    pub limit_buffer: bool,
111    pub max_buffer_lines: usize,
112}
113
114impl Default for TerminalState {
115    fn default() -> Self {
116        Self::new()
117    }
118}
119
120impl TerminalState {
121    pub fn new() -> Self {
122        Self {
123            last_command: String::new(),
124            last_pending_output: String::new(),
125            command_running: false,
126            terminal_emulator: Arc::new(StdMutex::new(TerminalEmulator::new(160))),
127            diff_detector: Some(TerminalOutputDiff::new()),
128            limit_buffer: false,
129            max_buffer_lines: DEFAULT_MAX_SCREEN_LINES,
130        }
131    }
132
133    pub fn process_output(&mut self, output: &str) -> String {
134        self.last_pending_output = output.to_string();
135        if let Ok(mut emulator) = self.terminal_emulator.lock() {
136            emulator.process(output);
137            emulator.display().join("\n")
138        } else {
139            output.to_string()
140        }
141    }
142
143    pub fn get_incremental_output(&mut self, output: &str) -> String {
144        let result = incremental_text(output, &self.last_pending_output);
145        self.last_pending_output = output.to_string();
146        result
147    }
148
149    pub fn smart_truncate(&mut self, max_size: usize) {
150        if let Ok(screen) = self.terminal_emulator.lock() {
151            if let Ok(mut screen_guard) = screen.get_screen().lock() {
152                screen_guard.smart_truncate(max_size);
153            }
154        }
155    }
156}
157
158#[derive(Debug, Clone)]
159pub struct BashState {
160    pub cwd: PathBuf,
161    pub workspace_root: PathBuf,
162    pub current_thread_id: String,
163    pub mode: Modes,
164    pub bash_command_mode: BashCommandMode,
165    pub file_edit_mode: FileEditMode,
166    pub write_if_empty_mode: WriteIfEmptyMode,
167    pub whitelist_for_overwrite: HashMap<String, FileWhitelistData>,
168    pub terminal_state: TerminalState,
169    pub pty_shell: Arc<Mutex<Option<PtyShell>>>,
170    pub initialized: bool,
171}
172
173impl Default for BashState {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179impl BashState {
180    pub fn new() -> Self {
181        let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp"));
182        Self {
183            cwd: cwd.clone(),
184            workspace_root: cwd,
185            current_thread_id: generate_thread_id(),
186            mode: Modes::Wcgw,
187            bash_command_mode: BashCommandMode {
188                bash_mode: BashMode::NormalMode,
189                allowed_commands: AllowedCommands::All("all".to_string()),
190            },
191            file_edit_mode: FileEditMode { allowed_globs: AllowedGlobs::All("all".to_string()) },
192            write_if_empty_mode: WriteIfEmptyMode {
193                allowed_globs: AllowedGlobs::All("all".to_string()),
194            },
195            whitelist_for_overwrite: HashMap::new(),
196            terminal_state: TerminalState::new(),
197            pty_shell: Arc::new(Mutex::new(None)),
198            initialized: false,
199        }
200    }
201
202    pub async fn init_pty_shell(&mut self) -> Result<()> {
203        let shell =
204            PtyShell::new(&self.cwd, self.bash_command_mode.bash_mode == BashMode::RestrictedMode)?;
205        *self.pty_shell.lock().await = Some(shell);
206        Ok(())
207    }
208
209    pub fn update_cwd(&mut self, path: &Path) -> Result<()> {
210        self.cwd = path.to_path_buf();
211        Ok(())
212    }
213
214    pub fn update_workspace_root(&mut self, path: &Path) -> Result<()> {
215        self.workspace_root = path.to_path_buf();
216        Ok(())
217    }
218
219    pub fn is_command_allowed(&self, command: &str) -> bool {
220        self.bash_command_mode.allowed_commands.is_allowed(command)
221    }
222
223    pub fn is_file_edit_allowed(&self, path: &str) -> bool {
224        self.file_edit_mode.allowed_globs.is_allowed(path)
225    }
226
227    pub fn is_file_write_allowed(&self, path: &str) -> bool {
228        self.write_if_empty_mode.allowed_globs.is_allowed(path)
229    }
230    pub fn get_mode_violation_message(&self, op: &str, _target: &str) -> String {
231        format!("Operation {op} not allowed")
232    }
233
234    pub fn save_state_to_disk(&self) -> Result<()> {
235        let snapshot = BashStateSnapshot::from_state(
236            &self.cwd.to_string_lossy(),
237            &self.workspace_root.to_string_lossy(),
238            &self.mode,
239            &self.bash_command_mode,
240            &self.file_edit_mode,
241            &self.write_if_empty_mode,
242            &self.whitelist_for_overwrite,
243            &self.current_thread_id,
244        );
245        save_state_file(&self.current_thread_id, &snapshot)?;
246        Ok(())
247    }
248
249    pub fn load_state_from_disk(&mut self, thread_id: &str) -> Result<bool> {
250        if let Some(snapshot) = load_state_file(thread_id)? {
251            let (cwd, root, mode, bmode, emode, wmode, whitelist, tid) =
252                snapshot.to_state_components();
253
254            self.cwd = PathBuf::from(cwd);
255
256            self.workspace_root = PathBuf::from(root);
257
258            self.mode = mode;
259
260            self.bash_command_mode = bmode;
261
262            self.file_edit_mode = emode;
263
264            self.write_if_empty_mode = wmode;
265
266            self.whitelist_for_overwrite = whitelist;
267
268            self.current_thread_id = tid;
269
270            self.initialized = true;
271
272            Ok(true)
273        } else {
274            Ok(false)
275        }
276    }
277
278    pub fn new_with_thread_id(thread_id: Option<&str>) -> Self {
279        let mut state = Self::new();
280
281        if let Some(tid) = thread_id {
282            if !tid.is_empty() {
283                if let Ok(true) = state.load_state_from_disk(tid) {
284                    info!("Loaded state for thread_id '{}'", tid);
285                } else {
286                    state.current_thread_id = tid.to_string();
287                }
288            }
289        }
290
291        state
292    }
293}
294
295pub fn generate_thread_id() -> String {
296    let mut rng = rand::rng();
297    format!("tid_{:x}", rng.random::<u64>())
298}