rust_expect/auto_config/
prompt.rs

1//! Prompt detection and configuration.
2
3use std::sync::LazyLock;
4
5use regex::Regex;
6
7/// Common prompt patterns.
8/// Order matters: more specific patterns must come before generic ones.
9static PROMPT_PATTERNS: LazyLock<Vec<(&'static str, Regex)>> = LazyLock::new(|| {
10    vec![
11        // Most specific patterns first
12        (
13            "python",
14            Regex::new(r">>>\s*$").expect("Python prompt pattern is a valid regex"),
15        ),
16        (
17            "irb",
18            Regex::new(r"irb\([^)]*\):\d+:\d+[>*]\s*$")
19                .expect("IRB prompt pattern is a valid regex"),
20        ),
21        (
22            "powershell",
23            Regex::new(r"PS[^>]*>\s*$").expect("PowerShell prompt pattern is a valid regex"),
24        ),
25        (
26            "mysql",
27            Regex::new(r"mysql>\s*$").expect("MySQL prompt pattern is a valid regex"),
28        ),
29        (
30            "postgres",
31            Regex::new(r"[a-z_]+[=#]\s*$").expect("PostgreSQL prompt pattern is a valid regex"),
32        ),
33        // Root before bash/zsh (# is in both [$#] and [%#$>])
34        (
35            "root",
36            Regex::new(r"^root@[^#]*#\s*$").expect("Root prompt pattern is a valid regex"),
37        ),
38        // General shell patterns
39        (
40            "bash",
41            Regex::new(r"[$#]\s*$").expect("Bash prompt pattern is a valid regex"),
42        ),
43        (
44            "zsh",
45            Regex::new(r"%\s*$").expect("Zsh prompt pattern is a valid regex"),
46        ),
47        (
48            "fish",
49            Regex::new(r"[^>]>\s*$").expect("Fish prompt pattern is a valid regex"),
50        ),
51        (
52            "cmd",
53            Regex::new(r"[^>]>\s*$").expect("CMD prompt pattern is a valid regex"),
54        ),
55        (
56            "node",
57            Regex::new(r"[^>]>\s*$").expect("Node prompt pattern is a valid regex"),
58        ),
59    ]
60});
61
62/// Prompt detection result.
63#[derive(Debug, Clone)]
64pub struct PromptInfo {
65    /// Detected prompt type.
66    pub prompt_type: String,
67    /// Matched prompt text.
68    pub matched: String,
69    /// Position in buffer.
70    pub position: usize,
71    /// Confidence (0.0-1.0).
72    pub confidence: f32,
73}
74
75/// Detect prompt in text.
76#[must_use]
77pub fn detect_prompt(text: &str) -> Option<PromptInfo> {
78    // Look at last few lines
79    let lines: Vec<&str> = text.lines().collect();
80    let last_lines: String = lines
81        .iter()
82        .rev()
83        .take(3)
84        .collect::<Vec<_>>()
85        .into_iter()
86        .rev()
87        .copied()
88        .collect::<Vec<_>>()
89        .join("\n");
90
91    for (name, pattern) in PROMPT_PATTERNS.iter() {
92        if let Some(m) = pattern.find(&last_lines) {
93            return Some(PromptInfo {
94                prompt_type: (*name).to_string(),
95                matched: m.as_str().to_string(),
96                position: text.len() - (last_lines.len() - m.start()),
97                confidence: 0.8,
98            });
99        }
100    }
101
102    None
103}
104
105/// Check if text ends with a prompt.
106#[must_use]
107pub fn ends_with_prompt(text: &str) -> bool {
108    detect_prompt(text).is_some()
109}
110
111/// Prompt configuration.
112#[derive(Debug, Clone)]
113pub struct PromptConfig {
114    /// Custom prompt pattern.
115    pub pattern: Option<String>,
116    /// Compiled regex.
117    regex: Option<Regex>,
118    /// Wait for prompt after commands.
119    pub wait_for_prompt: bool,
120    /// Timeout for prompt detection.
121    pub timeout_ms: u64,
122}
123
124impl Default for PromptConfig {
125    fn default() -> Self {
126        Self {
127            pattern: None,
128            regex: None,
129            wait_for_prompt: true,
130            timeout_ms: 5000,
131        }
132    }
133}
134
135impl PromptConfig {
136    /// Create new prompt config.
137    #[must_use]
138    pub fn new() -> Self {
139        Self::default()
140    }
141
142    /// Set custom prompt pattern.
143    #[must_use]
144    pub fn with_pattern(mut self, pattern: &str) -> Self {
145        self.pattern = Some(pattern.to_string());
146        self.regex = Regex::new(pattern).ok();
147        self
148    }
149
150    /// Set wait for prompt.
151    #[must_use]
152    pub const fn with_wait(mut self, wait: bool) -> Self {
153        self.wait_for_prompt = wait;
154        self
155    }
156
157    /// Set timeout.
158    #[must_use]
159    pub const fn with_timeout(mut self, timeout_ms: u64) -> Self {
160        self.timeout_ms = timeout_ms;
161        self
162    }
163
164    /// Check if text matches prompt.
165    #[must_use]
166    pub fn matches(&self, text: &str) -> bool {
167        if let Some(ref regex) = self.regex {
168            regex.is_match(text)
169        } else {
170            ends_with_prompt(text)
171        }
172    }
173
174    /// Find prompt in text.
175    #[must_use]
176    pub fn find(&self, text: &str) -> Option<PromptInfo> {
177        if let Some(ref regex) = self.regex {
178            regex.find(text).map(|m| PromptInfo {
179                prompt_type: "custom".to_string(),
180                matched: m.as_str().to_string(),
181                position: m.start(),
182                confidence: 1.0,
183            })
184        } else {
185            detect_prompt(text)
186        }
187    }
188}
189
190/// Generate a unique prompt marker.
191#[must_use]
192pub fn generate_prompt_marker() -> String {
193    use std::time::{SystemTime, UNIX_EPOCH};
194    let timestamp = SystemTime::now()
195        .duration_since(UNIX_EPOCH)
196        .map(|d| d.as_nanos())
197        .unwrap_or(0);
198    format!("__EXPECT_PROMPT_{timestamp}__")
199}
200
201/// Create a command that sets a unique prompt.
202#[must_use]
203pub fn set_prompt_command(marker: &str) -> String {
204    format!("PS1='{marker} '")
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn detect_bash_prompt() {
213        let text = "user@host:~$ ";
214        let info = detect_prompt(text);
215        assert!(info.is_some());
216    }
217
218    #[test]
219    fn detect_root_prompt() {
220        let text = "root@host:/# ";
221        let info = detect_prompt(text);
222        assert!(info.is_some());
223        assert_eq!(info.unwrap().prompt_type, "root");
224    }
225
226    #[test]
227    fn detect_python_prompt() {
228        let text = ">>> ";
229        let info = detect_prompt(text);
230        assert!(info.is_some());
231        assert_eq!(info.unwrap().prompt_type, "python");
232    }
233
234    #[test]
235    fn prompt_config_custom() {
236        let config = PromptConfig::new().with_pattern(r"myhost>\s*$");
237        assert!(config.matches("myhost> "));
238        assert!(!config.matches("other> "));
239    }
240
241    #[test]
242    fn prompt_marker() {
243        let marker = generate_prompt_marker();
244        assert!(marker.starts_with("__EXPECT_PROMPT_"));
245        assert!(marker.ends_with("__"));
246    }
247}