rust_expect/expect/
before_after.rs

1//! Before/after pattern handlers for expect operations.
2//!
3//! This module provides persistent pattern handlers that are automatically
4//! checked during every expect operation. These are useful for handling
5//! common patterns like error messages or prompts that can appear at any time.
6
7use std::collections::HashMap;
8
9use super::pattern::{Pattern, PatternSet};
10
11/// Handler function type for before/after patterns.
12pub type PatternHandler = Box<dyn Fn(&str) -> HandlerAction + Send + Sync>;
13
14/// Action to take after a pattern handler executes.
15#[derive(Debug, Clone, Default)]
16pub enum HandlerAction {
17    /// Continue with the expect operation.
18    #[default]
19    Continue,
20    /// Stop the expect operation and return success with this match.
21    Return(String),
22    /// Stop the expect operation and return an error.
23    Abort(String),
24    /// Send a response and continue.
25    Respond(String),
26}
27
28/// A persistent pattern with its handler.
29pub struct PersistentPattern {
30    /// The pattern to match.
31    pub pattern: Pattern,
32    /// The handler to execute on match.
33    pub handler: PatternHandler,
34    /// Whether this pattern is currently enabled.
35    pub enabled: bool,
36    /// Priority (lower = higher priority).
37    pub priority: i32,
38}
39
40impl PersistentPattern {
41    /// Create a new persistent pattern.
42    #[must_use]
43    pub fn new(pattern: Pattern, handler: PatternHandler) -> Self {
44        Self {
45            pattern,
46            handler,
47            enabled: true,
48            priority: 0,
49        }
50    }
51
52    /// Create a pattern with a simple response.
53    pub fn with_response(pattern: Pattern, response: impl Into<String>) -> Self {
54        let response = response.into();
55        Self::new(
56            pattern,
57            Box::new(move |_| HandlerAction::Respond(response.clone())),
58        )
59    }
60
61    /// Create a pattern that aborts on match.
62    pub fn with_abort(pattern: Pattern, message: impl Into<String>) -> Self {
63        let message = message.into();
64        Self::new(
65            pattern,
66            Box::new(move |_| HandlerAction::Abort(message.clone())),
67        )
68    }
69
70    /// Set the priority for this pattern.
71    #[must_use]
72    pub const fn with_priority(mut self, priority: i32) -> Self {
73        self.priority = priority;
74        self
75    }
76
77    /// Disable this pattern.
78    pub const fn disable(&mut self) {
79        self.enabled = false;
80    }
81
82    /// Enable this pattern.
83    pub const fn enable(&mut self) {
84        self.enabled = true;
85    }
86}
87
88impl std::fmt::Debug for PersistentPattern {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.debug_struct("PersistentPattern")
91            .field("pattern", &self.pattern)
92            .field("enabled", &self.enabled)
93            .field("priority", &self.priority)
94            .finish_non_exhaustive()
95    }
96}
97
98/// Manager for before/after patterns.
99///
100/// Before patterns are checked before every expect operation.
101/// After patterns are checked after each expect operation completes.
102#[derive(Default)]
103pub struct PatternManager {
104    /// Patterns checked before each expect.
105    before_patterns: HashMap<String, PersistentPattern>,
106    /// Patterns checked after each expect.
107    after_patterns: HashMap<String, PersistentPattern>,
108    /// Counter for generating unique IDs.
109    next_id: usize,
110}
111
112impl PatternManager {
113    /// Create a new pattern manager.
114    #[must_use]
115    pub fn new() -> Self {
116        Self::default()
117    }
118
119    /// Add a before pattern and return its ID.
120    pub fn add_before(&mut self, pattern: PersistentPattern) -> String {
121        let id = self.generate_id("before");
122        self.before_patterns.insert(id.clone(), pattern);
123        id
124    }
125
126    /// Add an after pattern and return its ID.
127    pub fn add_after(&mut self, pattern: PersistentPattern) -> String {
128        let id = self.generate_id("after");
129        self.after_patterns.insert(id.clone(), pattern);
130        id
131    }
132
133    /// Remove a before pattern by ID.
134    pub fn remove_before(&mut self, id: &str) -> Option<PersistentPattern> {
135        self.before_patterns.remove(id)
136    }
137
138    /// Remove an after pattern by ID.
139    pub fn remove_after(&mut self, id: &str) -> Option<PersistentPattern> {
140        self.after_patterns.remove(id)
141    }
142
143    /// Get a before pattern by ID.
144    #[must_use]
145    pub fn get_before(&self, id: &str) -> Option<&PersistentPattern> {
146        self.before_patterns.get(id)
147    }
148
149    /// Get a mutable before pattern by ID.
150    pub fn get_before_mut(&mut self, id: &str) -> Option<&mut PersistentPattern> {
151        self.before_patterns.get_mut(id)
152    }
153
154    /// Get an after pattern by ID.
155    #[must_use]
156    pub fn get_after(&self, id: &str) -> Option<&PersistentPattern> {
157        self.after_patterns.get(id)
158    }
159
160    /// Get a mutable after pattern by ID.
161    pub fn get_after_mut(&mut self, id: &str) -> Option<&mut PersistentPattern> {
162        self.after_patterns.get_mut(id)
163    }
164
165    /// Check before patterns against the buffer.
166    ///
167    /// Returns the first matching handler action, or None if no patterns match.
168    #[must_use]
169    pub fn check_before(&self, buffer: &str) -> Option<(String, HandlerAction)> {
170        self.check_patterns(&self.before_patterns, buffer)
171    }
172
173    /// Check after patterns against the buffer.
174    ///
175    /// Returns the first matching handler action, or None if no patterns match.
176    #[must_use]
177    pub fn check_after(&self, buffer: &str) -> Option<(String, HandlerAction)> {
178        self.check_patterns(&self.after_patterns, buffer)
179    }
180
181    /// Get all before patterns as a `PatternSet` for matching.
182    #[must_use]
183    pub fn before_pattern_set(&self) -> PatternSet {
184        self.patterns_to_set(&self.before_patterns)
185    }
186
187    /// Get all after patterns as a `PatternSet` for matching.
188    #[must_use]
189    pub fn after_pattern_set(&self) -> PatternSet {
190        self.patterns_to_set(&self.after_patterns)
191    }
192
193    /// Clear all before patterns.
194    pub fn clear_before(&mut self) {
195        self.before_patterns.clear();
196    }
197
198    /// Clear all after patterns.
199    pub fn clear_after(&mut self) {
200        self.after_patterns.clear();
201    }
202
203    /// Clear all patterns.
204    pub fn clear_all(&mut self) {
205        self.before_patterns.clear();
206        self.after_patterns.clear();
207    }
208
209    /// Get the number of before patterns.
210    #[must_use]
211    pub fn before_count(&self) -> usize {
212        self.before_patterns.len()
213    }
214
215    /// Get the number of after patterns.
216    #[must_use]
217    pub fn after_count(&self) -> usize {
218        self.after_patterns.len()
219    }
220
221    fn generate_id(&mut self, prefix: &str) -> String {
222        let id = format!("{prefix}_{}", self.next_id);
223        self.next_id += 1;
224        id
225    }
226
227    #[allow(clippy::unused_self)]
228    fn check_patterns(
229        &self,
230        patterns: &HashMap<String, PersistentPattern>,
231        buffer: &str,
232    ) -> Option<(String, HandlerAction)> {
233        // Collect enabled patterns sorted by priority
234        let mut sorted: Vec<_> = patterns.iter().filter(|(_, p)| p.enabled).collect();
235        sorted.sort_by_key(|(_, p)| p.priority);
236
237        for (id, persistent) in sorted {
238            if persistent.pattern.matches(buffer).is_some() {
239                let action = (persistent.handler)(buffer);
240                if !matches!(action, HandlerAction::Continue) {
241                    return Some((id.clone(), action));
242                }
243            }
244        }
245        None
246    }
247
248    #[allow(clippy::unused_self)]
249    fn patterns_to_set(&self, patterns: &HashMap<String, PersistentPattern>) -> PatternSet {
250        let mut sorted: Vec<_> = patterns.iter().filter(|(_, p)| p.enabled).collect();
251        sorted.sort_by_key(|(_, p)| p.priority);
252
253        PatternSet::from_patterns(sorted.into_iter().map(|(_, p)| p.pattern.clone()).collect())
254    }
255}
256
257impl std::fmt::Debug for PatternManager {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("PatternManager")
260            .field("before_count", &self.before_patterns.len())
261            .field("after_count", &self.after_patterns.len())
262            .finish()
263    }
264}
265
266/// Builder for common before/after pattern configurations.
267pub struct PatternBuilder {
268    manager: PatternManager,
269}
270
271impl PatternBuilder {
272    /// Create a new pattern builder.
273    #[must_use]
274    pub fn new() -> Self {
275        Self {
276            manager: PatternManager::new(),
277        }
278    }
279
280    /// Add a password prompt handler.
281    #[must_use]
282    pub fn with_password_handler(mut self, password: impl Into<String>) -> Self {
283        let password = password.into();
284        let pattern = PersistentPattern::with_response(
285            Pattern::regex(r"[Pp]assword:?\s*$").unwrap_or_else(|_| Pattern::literal("Password:")),
286            format!("{password}\n"),
287        );
288        self.manager.add_before(pattern);
289        self
290    }
291
292    /// Add a sudo password handler.
293    #[must_use]
294    pub fn with_sudo_handler(mut self, password: impl Into<String>) -> Self {
295        let password = password.into();
296        let pattern = PersistentPattern::with_response(
297            Pattern::regex(r"\[sudo\] password")
298                .unwrap_or_else(|_| Pattern::literal("[sudo] password")),
299            format!("{password}\n"),
300        );
301        self.manager.add_before(pattern);
302        self
303    }
304
305    /// Add an error pattern that aborts.
306    #[must_use]
307    pub fn with_error_pattern(mut self, pattern: Pattern, message: impl Into<String>) -> Self {
308        let persistent = PersistentPattern::with_abort(pattern, message);
309        self.manager.add_before(persistent);
310        self
311    }
312
313    /// Add a yes/no prompt handler that responds with yes.
314    #[must_use]
315    pub fn with_yes_handler(mut self) -> Self {
316        let pattern = PersistentPattern::with_response(
317            Pattern::regex(r"\(yes/no\)\??\s*$").unwrap_or_else(|_| Pattern::literal("(yes/no)")),
318            "yes\n",
319        );
320        self.manager.add_before(pattern);
321        self
322    }
323
324    /// Add a y/n prompt handler that responds with y.
325    #[must_use]
326    pub fn with_yn_handler(mut self) -> Self {
327        let pattern = PersistentPattern::with_response(
328            Pattern::regex(r"\[y/n\]\??\s*$").unwrap_or_else(|_| Pattern::literal("[y/n]")),
329            "y\n",
330        );
331        self.manager.add_before(pattern);
332        self
333    }
334
335    /// Add a continue prompt handler.
336    #[must_use]
337    pub fn with_continue_handler(mut self) -> Self {
338        let pattern = PersistentPattern::with_response(
339            Pattern::regex(r"Press (?:Enter|any key) to continue")
340                .unwrap_or_else(|_| Pattern::literal("Press Enter")),
341            "\n",
342        );
343        self.manager.add_before(pattern);
344        self
345    }
346
347    /// Build the pattern manager.
348    #[must_use]
349    pub fn build(self) -> PatternManager {
350        self.manager
351    }
352}
353
354impl Default for PatternBuilder {
355    fn default() -> Self {
356        Self::new()
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn pattern_manager_before() {
366        let mut manager = PatternManager::new();
367
368        let pattern = PersistentPattern::with_response(Pattern::literal("password:"), "secret\n");
369        let id = manager.add_before(pattern);
370
371        let result = manager.check_before("Enter password: ");
372        assert!(result.is_some());
373
374        let (matched_id, action) = result.unwrap();
375        assert_eq!(matched_id, id);
376        assert!(matches!(action, HandlerAction::Respond(_)));
377    }
378
379    #[test]
380    fn pattern_manager_priority() {
381        let mut manager = PatternManager::new();
382
383        let low = PersistentPattern::new(
384            Pattern::literal("test"),
385            Box::new(|_| HandlerAction::Respond("low".into())),
386        )
387        .with_priority(10);
388
389        let high = PersistentPattern::new(
390            Pattern::literal("test"),
391            Box::new(|_| HandlerAction::Respond("high".into())),
392        )
393        .with_priority(1);
394
395        manager.add_before(low);
396        manager.add_before(high);
397
398        let result = manager.check_before("test");
399        assert!(result.is_some());
400
401        if let Some((_, HandlerAction::Respond(s))) = result {
402            assert_eq!(s, "high");
403        } else {
404            panic!("Expected Respond action");
405        }
406    }
407
408    #[test]
409    fn pattern_manager_disable() {
410        let mut manager = PatternManager::new();
411
412        let pattern = PersistentPattern::with_response(Pattern::literal("test"), "response");
413        let id = manager.add_before(pattern);
414
415        // Should match when enabled
416        assert!(manager.check_before("test").is_some());
417
418        // Disable and check again
419        manager.get_before_mut(&id).unwrap().disable();
420        assert!(manager.check_before("test").is_none());
421    }
422
423    #[test]
424    fn pattern_builder() {
425        let manager = PatternBuilder::new()
426            .with_password_handler("secret")
427            .with_yes_handler()
428            .build();
429
430        assert_eq!(manager.before_count(), 2);
431    }
432}