rust_expect/auto_config/
prompt.rs1use std::sync::LazyLock;
4
5use regex::Regex;
6
7static PROMPT_PATTERNS: LazyLock<Vec<(&'static str, Regex)>> = LazyLock::new(|| {
10 vec![
11 (
13 "python",
14 Regex::new(r">>>\s*$").expect("Python prompt pattern is a valid regex"),
15 ),
16 (
17 "irb",
18 Regex::new(r"irb\([^)]*\):\d+:\d+[>*]\s*$")
19 .expect("IRB prompt pattern is a valid regex"),
20 ),
21 (
22 "powershell",
23 Regex::new(r"PS[^>]*>\s*$").expect("PowerShell prompt pattern is a valid regex"),
24 ),
25 (
26 "mysql",
27 Regex::new(r"mysql>\s*$").expect("MySQL prompt pattern is a valid regex"),
28 ),
29 (
30 "postgres",
31 Regex::new(r"[a-z_]+[=#]\s*$").expect("PostgreSQL prompt pattern is a valid regex"),
32 ),
33 (
35 "root",
36 Regex::new(r"^root@[^#]*#\s*$").expect("Root prompt pattern is a valid regex"),
37 ),
38 (
40 "bash",
41 Regex::new(r"[$#]\s*$").expect("Bash prompt pattern is a valid regex"),
42 ),
43 (
44 "zsh",
45 Regex::new(r"%\s*$").expect("Zsh prompt pattern is a valid regex"),
46 ),
47 (
48 "fish",
49 Regex::new(r"[^>]>\s*$").expect("Fish prompt pattern is a valid regex"),
50 ),
51 (
52 "cmd",
53 Regex::new(r"[^>]>\s*$").expect("CMD prompt pattern is a valid regex"),
54 ),
55 (
56 "node",
57 Regex::new(r"[^>]>\s*$").expect("Node prompt pattern is a valid regex"),
58 ),
59 ]
60});
61
62#[derive(Debug, Clone)]
64pub struct PromptInfo {
65 pub prompt_type: String,
67 pub matched: String,
69 pub position: usize,
71 pub confidence: f32,
73}
74
75#[must_use]
77pub fn detect_prompt(text: &str) -> Option<PromptInfo> {
78 let lines: Vec<&str> = text.lines().collect();
80 let last_lines: String = lines
81 .iter()
82 .rev()
83 .take(3)
84 .collect::<Vec<_>>()
85 .into_iter()
86 .rev()
87 .copied()
88 .collect::<Vec<_>>()
89 .join("\n");
90
91 for (name, pattern) in PROMPT_PATTERNS.iter() {
92 if let Some(m) = pattern.find(&last_lines) {
93 return Some(PromptInfo {
94 prompt_type: (*name).to_string(),
95 matched: m.as_str().to_string(),
96 position: text.len() - (last_lines.len() - m.start()),
97 confidence: 0.8,
98 });
99 }
100 }
101
102 None
103}
104
105#[must_use]
107pub fn ends_with_prompt(text: &str) -> bool {
108 detect_prompt(text).is_some()
109}
110
111#[derive(Debug, Clone)]
113pub struct PromptConfig {
114 pub pattern: Option<String>,
116 regex: Option<Regex>,
118 pub wait_for_prompt: bool,
120 pub timeout_ms: u64,
122}
123
124impl Default for PromptConfig {
125 fn default() -> Self {
126 Self {
127 pattern: None,
128 regex: None,
129 wait_for_prompt: true,
130 timeout_ms: 5000,
131 }
132 }
133}
134
135impl PromptConfig {
136 #[must_use]
138 pub fn new() -> Self {
139 Self::default()
140 }
141
142 #[must_use]
144 pub fn with_pattern(mut self, pattern: &str) -> Self {
145 self.pattern = Some(pattern.to_string());
146 self.regex = Regex::new(pattern).ok();
147 self
148 }
149
150 #[must_use]
152 pub const fn with_wait(mut self, wait: bool) -> Self {
153 self.wait_for_prompt = wait;
154 self
155 }
156
157 #[must_use]
159 pub const fn with_timeout(mut self, timeout_ms: u64) -> Self {
160 self.timeout_ms = timeout_ms;
161 self
162 }
163
164 #[must_use]
166 pub fn matches(&self, text: &str) -> bool {
167 if let Some(ref regex) = self.regex {
168 regex.is_match(text)
169 } else {
170 ends_with_prompt(text)
171 }
172 }
173
174 #[must_use]
176 pub fn find(&self, text: &str) -> Option<PromptInfo> {
177 if let Some(ref regex) = self.regex {
178 regex.find(text).map(|m| PromptInfo {
179 prompt_type: "custom".to_string(),
180 matched: m.as_str().to_string(),
181 position: m.start(),
182 confidence: 1.0,
183 })
184 } else {
185 detect_prompt(text)
186 }
187 }
188}
189
190#[must_use]
192pub fn generate_prompt_marker() -> String {
193 use std::time::{SystemTime, UNIX_EPOCH};
194 let timestamp = SystemTime::now()
195 .duration_since(UNIX_EPOCH)
196 .map(|d| d.as_nanos())
197 .unwrap_or(0);
198 format!("__EXPECT_PROMPT_{timestamp}__")
199}
200
201#[must_use]
203pub fn set_prompt_command(marker: &str) -> String {
204 format!("PS1='{marker} '")
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn detect_bash_prompt() {
213 let text = "user@host:~$ ";
214 let info = detect_prompt(text);
215 assert!(info.is_some());
216 }
217
218 #[test]
219 fn detect_root_prompt() {
220 let text = "root@host:/# ";
221 let info = detect_prompt(text);
222 assert!(info.is_some());
223 assert_eq!(info.unwrap().prompt_type, "root");
224 }
225
226 #[test]
227 fn detect_python_prompt() {
228 let text = ">>> ";
229 let info = detect_prompt(text);
230 assert!(info.is_some());
231 assert_eq!(info.unwrap().prompt_type, "python");
232 }
233
234 #[test]
235 fn prompt_config_custom() {
236 let config = PromptConfig::new().with_pattern(r"myhost>\s*$");
237 assert!(config.matches("myhost> "));
238 assert!(!config.matches("other> "));
239 }
240
241 #[test]
242 fn prompt_marker() {
243 let marker = generate_prompt_marker();
244 assert!(marker.starts_with("__EXPECT_PROMPT_"));
245 assert!(marker.ends_with("__"));
246 }
247}