Skip to main content

winx_code_agent/state/
bash_state.rs

1#![allow(clippy::unwrap_used)]
2#![allow(clippy::expect_used)]
3use anyhow::{anyhow, Context as AnyhowContext, Result};
4use glob;
5use lazy_static::lazy_static;
6use rand::RngExt;
7use regex::Regex;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::collections::HashMap;
11use std::io::{BufReader, Read, Write};
12use std::path::{Path, PathBuf};
13use std::process::{Child, Command, Stdio};
14use std::sync::{Arc, Mutex};
15use std::time::{Duration, Instant};
16use tracing::{debug, info, warn};
17
18use crate::state::persistence::{
19    delete_bash_state as delete_state_file, load_bash_state as load_state_file,
20    save_bash_state as save_state_file, BashStateSnapshot,
21};
22use crate::state::pty::PtyShell;
23use crate::state::terminal::{
24    incremental_text, TerminalEmulator, TerminalOutputDiff, DEFAULT_MAX_SCREEN_LINES,
25    MAX_OUTPUT_SIZE as TERMINAL_MAX_OUTPUT_SIZE,
26};
27use crate::types::{
28    AllowedCommands, AllowedGlobs, BashCommandMode, BashMode, FileEditMode, Modes, WriteIfEmptyMode,
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct FileWhitelistData {
33    pub file_hash: String,
34    pub line_ranges_read: Vec<(usize, usize)>,
35    pub total_lines: usize,
36}
37
38impl FileWhitelistData {
39    pub fn new(
40        file_hash: String,
41        line_ranges_read: Vec<(usize, usize)>,
42        total_lines: usize,
43    ) -> Self {
44        Self { file_hash, line_ranges_read, total_lines }
45    }
46
47    pub fn is_read_enough(&self) -> bool {
48        self.get_percentage_read() >= 99.0
49    }
50
51    pub fn get_percentage_read(&self) -> f64 {
52        if self.total_lines == 0 {
53            return 100.0;
54        }
55        let mut lines_read = std::collections::HashSet::new();
56        for (start, end) in &self.line_ranges_read {
57            for line in *start..=*end {
58                lines_read.insert(line);
59            }
60        }
61        (lines_read.len() as f64 / self.total_lines as f64) * 100.0
62    }
63
64    pub fn get_unread_ranges(&self) -> Vec<(usize, usize)> {
65        if self.total_lines == 0 {
66            return vec![];
67        }
68        let mut lines_read = std::collections::HashSet::new();
69        for (start, end) in &self.line_ranges_read {
70            for line in *start..=*end {
71                lines_read.insert(line);
72            }
73        }
74        let mut unread = vec![];
75        let mut start_range = None;
76        for i in 1..=self.total_lines {
77            if !lines_read.contains(&i) {
78                if start_range.is_none() {
79                    start_range = Some(i);
80                }
81            } else if let Some(start) = start_range {
82                unread.push((start, i - 1));
83                start_range = None;
84            }
85        }
86        if let Some(start) = start_range {
87            unread.push((start, self.total_lines));
88        }
89        unread
90    }
91
92    pub fn add_range(&mut self, start: usize, end: usize) {
93        self.line_ranges_read.push((start, end));
94    }
95
96    pub fn get_read_error_message(&self, file_path: &Path) -> String {
97        format!(
98            "File {} needs more reading. Coverage: {:.1}%",
99            file_path.display(),
100            self.get_percentage_read()
101        )
102    }
103
104    pub fn needs_more_reading(&self) -> bool {
105        !self.is_read_enough()
106    }
107}
108
109#[derive(Debug, Clone)]
110pub struct TerminalState {
111    pub last_command: String,
112    pub last_pending_output: String,
113    pub command_running: bool,
114    pub terminal_emulator: Arc<Mutex<TerminalEmulator>>,
115    pub diff_detector: Option<TerminalOutputDiff>,
116    pub limit_buffer: bool,
117    pub max_buffer_lines: usize,
118}
119
120impl Default for TerminalState {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl TerminalState {
127    pub fn new() -> Self {
128        Self {
129            last_command: String::new(),
130            last_pending_output: String::new(),
131            command_running: false,
132            terminal_emulator: Arc::new(Mutex::new(TerminalEmulator::new(160))),
133            diff_detector: Some(TerminalOutputDiff::new()),
134            limit_buffer: false,
135            max_buffer_lines: DEFAULT_MAX_SCREEN_LINES,
136        }
137    }
138
139    pub fn process_output(&mut self, output: &str) -> String {
140        self.last_pending_output = output.to_string();
141        if let Ok(mut emulator) = self.terminal_emulator.lock() {
142            emulator.process(output);
143            emulator.display().join("\n")
144        } else {
145            output.to_string()
146        }
147    }
148
149    pub fn get_incremental_output(&mut self, output: &str) -> String {
150        let result = incremental_text(output, &self.last_pending_output);
151        self.last_pending_output = output.to_string();
152        result
153    }
154
155    pub fn smart_truncate(&mut self, max_size: usize) {
156        if let Ok(screen) = self.terminal_emulator.lock() {
157            if let Ok(mut screen_guard) = screen.get_screen().lock() {
158                screen_guard.smart_truncate(max_size);
159            }
160        }
161    }
162}
163
164const WCGW_PROMPT_PATTERN: &str = r"◉ ([^\n]*)──➤";
165const WCGW_PROMPT_COMMAND: &str = r#"printf '◉ "$(pwd)"──➤ '"#;
166const BASH_PROMPT_STATEMENT: &str =
167    r#"export GIT_PAGER=cat PAGER=cat PROMPT_COMMAND='printf \"◉ $(pwd)──➤ \"'"#;
168
169lazy_static! {
170    static ref PROMPT_REGEX: Regex = Regex::new(WCGW_PROMPT_PATTERN).expect("Invalid prompt regex");
171}
172
173fn contains_wcgw_prompt(text: &str) -> bool {
174    PROMPT_REGEX.is_match(text)
175}
176
177const MAX_OUTPUT_SIZE: usize = 1_000_000;
178const MAX_COMMAND_TIMEOUT: f32 = 60.0;
179const DEFAULT_BUFFER_SIZE: usize = 8192;
180
181#[derive(Debug, Clone, PartialEq)]
182pub enum CommandState {
183    Idle,
184    Running { start_time: std::time::SystemTime, command: String },
185}
186
187#[derive(Debug, Clone)]
188pub struct BashState {
189    pub cwd: PathBuf,
190    pub workspace_root: PathBuf,
191    pub current_thread_id: String,
192    pub mode: Modes,
193    pub bash_command_mode: BashCommandMode,
194    pub file_edit_mode: FileEditMode,
195    pub write_if_empty_mode: WriteIfEmptyMode,
196    pub whitelist_for_overwrite: HashMap<String, FileWhitelistData>,
197    pub terminal_state: TerminalState,
198    pub interactive_bash: Arc<Mutex<Option<InteractiveBash>>>,
199    pub pty_shell: Arc<Mutex<Option<PtyShell>>>,
200    pub initialized: bool,
201}
202
203#[derive(Debug)]
204pub struct InteractiveBash {
205    pub process: Child,
206    pub last_command: String,
207    pub last_output: String,
208    pub output_buffer: String,
209    pub command_state: CommandState,
210    pub max_output_size: usize,
211    pub output_truncated: bool,
212    pub output_chunks: Vec<String>,
213    initial_dir: PathBuf,
214    restricted_mode: bool,
215}
216
217impl InteractiveBash {
218    pub fn is_alive(&mut self) -> bool {
219        matches!(self.process.try_wait(), Ok(None))
220    }
221
222    pub fn reinit(&mut self) -> Result<()> {
223        let mut cmd = Command::new("bash");
224        cmd.arg("-i");
225        if self.restricted_mode {
226            cmd.arg("-r");
227        }
228        let mut process = cmd
229            .env("PAGER", "cat")
230            .env("GIT_PAGER", "cat")
231            .env("PROMPT_COMMAND", WCGW_PROMPT_COMMAND)
232            .env("TERM", "xterm-256color")
233            .current_dir(&self.initial_dir)
234            .stdin(Stdio::piped())
235            .stdout(Stdio::piped())
236            .stderr(Stdio::piped())
237            .spawn()?;
238        let mut stdin = process.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
239        writeln!(stdin, "{BASH_PROMPT_STATEMENT}")?;
240        stdin.flush()?;
241        process.stdin = Some(stdin);
242        self.process = process;
243        self.command_state = CommandState::Idle;
244        Ok(())
245    }
246
247    pub fn ensure_alive(&mut self) -> Result<()> {
248        if !self.is_alive() {
249            self.reinit()?;
250        }
251        Ok(())
252    }
253
254    pub fn new(initial_dir: &Path, restricted_mode: bool) -> Result<Self> {
255        let mut cmd = Command::new("bash");
256        cmd.arg("-i");
257        if restricted_mode {
258            cmd.arg("-r");
259        }
260        let mut process = cmd
261            .env("PAGER", "cat")
262            .env("GIT_PAGER", "cat")
263            .env("PROMPT_COMMAND", WCGW_PROMPT_COMMAND)
264            .env("TERM", "xterm-256color")
265            .current_dir(initial_dir)
266            .stdin(Stdio::piped())
267            .stdout(Stdio::piped())
268            .stderr(Stdio::piped())
269            .spawn()?;
270        let mut stdin = process.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
271        writeln!(stdin, "{BASH_PROMPT_STATEMENT}")?;
272        stdin.flush()?;
273        process.stdin = Some(stdin);
274        Ok(Self {
275            process,
276            last_command: String::new(),
277            last_output: String::new(),
278            output_buffer: String::new(),
279            command_state: CommandState::Idle,
280            max_output_size: MAX_OUTPUT_SIZE,
281            output_truncated: false,
282            output_chunks: Vec::new(),
283            initial_dir: initial_dir.to_path_buf(),
284            restricted_mode,
285        })
286    }
287
288    pub fn send_command(&mut self, command: &str) -> Result<()> {
289        self.ensure_alive()?;
290        let mut stdin = self.process.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
291        writeln!(stdin, "{command}")?;
292        stdin.flush()?;
293        self.process.stdin = Some(stdin);
294        self.last_command = command.to_string();
295        self.command_state = CommandState::Running {
296            start_time: std::time::SystemTime::now(),
297            command: command.to_string(),
298        };
299        Ok(())
300    }
301
302    pub fn read_output(&mut self, timeout_secs: f32) -> Result<(String, bool)> {
303        let timeout = Duration::from_secs_f32(timeout_secs.clamp(0.1, MAX_COMMAND_TIMEOUT));
304        let start = Instant::now();
305        let mut new_output = String::new();
306        let mut complete = false;
307        let mut full_output = self.last_output.clone();
308
309        while start.elapsed() < timeout {
310            let mut buf = vec![0; DEFAULT_BUFFER_SIZE];
311            if let Some(stdout) = self.process.stdout.as_mut() {
312                if let Ok(n) = stdout.read(&mut buf) {
313                    if n > 0 {
314                        let chunk = String::from_utf8_lossy(&buf[..n]);
315                        full_output.push_str(&chunk);
316                        new_output.push_str(&chunk);
317                        if contains_wcgw_prompt(&full_output) {
318                            complete = true;
319                            break;
320                        }
321                    }
322                }
323            }
324            std::thread::sleep(Duration::from_millis(10));
325        }
326
327        if complete {
328            self.command_state = CommandState::Idle;
329        }
330        self.last_output.clone_from(&full_output);
331        Ok((full_output, complete))
332    }
333
334    pub fn send_interrupt(&mut self) -> Result<()> {
335        #[cfg(unix)]
336        {
337            let pid = self.process.id() as i32;
338            unsafe {
339                libc::kill(pid, libc::SIGINT);
340            }
341        }
342        Ok(())
343    }
344}
345
346impl Default for BashState {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352impl BashState {
353    pub fn new() -> Self {
354        let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("/tmp"));
355        Self {
356            cwd: cwd.clone(),
357            workspace_root: cwd,
358            current_thread_id: generate_thread_id(),
359            mode: Modes::Wcgw,
360            bash_command_mode: BashCommandMode {
361                bash_mode: BashMode::NormalMode,
362                allowed_commands: AllowedCommands::All("all".to_string()),
363            },
364            file_edit_mode: FileEditMode { allowed_globs: AllowedGlobs::All("all".to_string()) },
365            write_if_empty_mode: WriteIfEmptyMode {
366                allowed_globs: AllowedGlobs::All("all".to_string()),
367            },
368            whitelist_for_overwrite: HashMap::new(),
369            terminal_state: TerminalState::new(),
370            interactive_bash: Arc::new(Mutex::new(None)),
371            pty_shell: Arc::new(Mutex::new(None)),
372            initialized: false,
373        }
374    }
375
376    pub fn init_interactive_bash(&mut self) -> Result<()> {
377        let bash = InteractiveBash::new(
378            &self.cwd,
379            self.bash_command_mode.bash_mode == BashMode::RestrictedMode,
380        )?;
381        *self.interactive_bash.lock().unwrap() = Some(bash);
382        Ok(())
383    }
384
385    pub fn update_cwd(&mut self, path: &Path) -> Result<()> {
386        self.cwd = path.to_path_buf();
387        Ok(())
388    }
389
390    pub fn update_workspace_root(&mut self, path: &Path) -> Result<()> {
391        self.workspace_root = path.to_path_buf();
392        Ok(())
393    }
394
395    pub fn is_command_allowed(&self, command: &str) -> bool {
396        self.bash_command_mode.allowed_commands.is_allowed(command)
397    }
398
399    pub fn is_file_edit_allowed(&self, path: &str) -> bool {
400        self.file_edit_mode.allowed_globs.is_allowed(path)
401    }
402
403    pub fn is_file_write_allowed(&self, path: &str) -> bool {
404        self.write_if_empty_mode.allowed_globs.is_allowed(path)
405    }
406    pub fn get_mode_violation_message(&self, op: &str, _target: &str) -> String {
407        format!("Operation {op} not allowed")
408    }
409
410    pub fn save_state_to_disk(&self) -> Result<()> {
411        let snapshot = BashStateSnapshot::from_state(
412            &self.cwd.to_string_lossy(),
413            &self.workspace_root.to_string_lossy(),
414            &self.mode,
415            &self.bash_command_mode,
416            &self.file_edit_mode,
417            &self.write_if_empty_mode,
418            &self.whitelist_for_overwrite,
419            &self.current_thread_id,
420        );
421        save_state_file(&self.current_thread_id, &snapshot)?;
422        Ok(())
423    }
424
425    pub fn load_state_from_disk(&mut self, thread_id: &str) -> Result<bool> {
426        if let Some(snapshot) = load_state_file(thread_id)? {
427            let (cwd, root, mode, bmode, emode, wmode, whitelist, tid) =
428                snapshot.to_state_components();
429
430            self.cwd = PathBuf::from(cwd);
431
432            self.workspace_root = PathBuf::from(root);
433
434            self.mode = mode;
435
436            self.bash_command_mode = bmode;
437
438            self.file_edit_mode = emode;
439
440            self.write_if_empty_mode = wmode;
441
442            self.whitelist_for_overwrite = whitelist;
443
444            self.current_thread_id = tid;
445
446            self.initialized = true;
447
448            Ok(true)
449        } else {
450            Ok(false)
451        }
452    }
453
454    pub fn new_with_thread_id(thread_id: Option<&str>) -> Self {
455        let mut state = Self::new();
456
457        if let Some(tid) = thread_id {
458            if !tid.is_empty() {
459                if let Ok(true) = state.load_state_from_disk(tid) {
460                    info!("Loaded state for thread_id '{}'", tid);
461                } else {
462                    state.current_thread_id = tid.to_string();
463                }
464            }
465        }
466
467        state
468    }
469}
470
471pub fn generate_thread_id() -> String {
472    let mut rng = rand::rng();
473    format!("tid_{:x}", rng.random::<u64>())
474}