rust_expect/expect/
pattern.rs

1//! Pattern types for expect operations.
2//!
3//! This module defines the pattern types that can be used with expect operations,
4//! including literal strings, regular expressions, globs, EOF, and timeout.
5//!
6//! # Examples
7//!
8//! ```
9//! use rust_expect::Pattern;
10//! use std::time::Duration;
11//!
12//! // Literal pattern - matches exact text
13//! let prompt = Pattern::literal("$ ");
14//! assert!(prompt.matches("user@host:~ $ ").is_some());
15//!
16//! // Regex pattern - matches regular expressions
17//! let version = Pattern::regex(r"\d+\.\d+\.\d+").unwrap();
18//! assert!(version.matches("Version: 1.2.3").is_some());
19//!
20//! // Glob pattern - matches shell-style wildcards
21//! let log = Pattern::glob("Error:*");
22//! assert!(log.matches("Error: connection failed").is_some());
23//!
24//! // Timeout pattern - used with expect_any for timeouts
25//! let timeout = Pattern::timeout(Duration::from_secs(5));
26//! assert!(timeout.is_timeout());
27//!
28//! // EOF pattern - matches process termination
29//! let eof = Pattern::eof();
30//! assert!(eof.is_eof());
31//! ```
32
33use std::fmt;
34use std::time::Duration;
35
36use regex::Regex;
37
38/// A pattern that can be matched against terminal output.
39#[derive(Clone)]
40pub enum Pattern {
41    /// Match an exact string.
42    Literal(String),
43
44    /// Match a regular expression.
45    Regex(CompiledRegex),
46
47    /// Match a glob pattern.
48    Glob(String),
49
50    /// Match end of file (process terminated).
51    Eof,
52
53    /// Match after a timeout.
54    Timeout(Duration),
55
56    /// Match when N bytes have been received.
57    Bytes(usize),
58}
59
60impl Pattern {
61    /// Create a literal pattern.
62    #[must_use]
63    pub fn literal(s: impl Into<String>) -> Self {
64        Self::Literal(s.into())
65    }
66
67    /// Create a regex pattern.
68    ///
69    /// # Errors
70    ///
71    /// Returns an error if the regex pattern is invalid.
72    pub fn regex(pattern: &str) -> Result<Self, regex::Error> {
73        let regex = Regex::new(pattern)?;
74        Ok(Self::Regex(CompiledRegex::new(pattern.to_string(), regex)))
75    }
76
77    /// Create a glob pattern.
78    #[must_use]
79    pub fn glob(pattern: impl Into<String>) -> Self {
80        Self::Glob(pattern.into())
81    }
82
83    /// Create an EOF pattern.
84    #[must_use]
85    pub const fn eof() -> Self {
86        Self::Eof
87    }
88
89    /// Create a timeout pattern.
90    #[must_use]
91    pub const fn timeout(duration: Duration) -> Self {
92        Self::Timeout(duration)
93    }
94
95    /// Create a bytes pattern.
96    #[must_use]
97    pub const fn bytes(n: usize) -> Self {
98        Self::Bytes(n)
99    }
100
101    /// Get the pattern as a string for display purposes.
102    #[must_use]
103    pub fn as_str(&self) -> &str {
104        match self {
105            Self::Literal(s) => s,
106            Self::Regex(r) => r.pattern(),
107            Self::Glob(s) => s,
108            Self::Eof => "<EOF>",
109            Self::Timeout(_) => "<TIMEOUT>",
110            Self::Bytes(_) => "<BYTES>",
111        }
112    }
113
114    /// Check if this pattern matches the given text.
115    ///
116    /// Returns the match position and captures if successful.
117    #[must_use]
118    pub fn matches(&self, text: &str) -> Option<PatternMatch> {
119        match self {
120            Self::Literal(s) => text.find(s).map(|pos| PatternMatch {
121                start: pos,
122                end: pos + s.len(),
123                captures: Vec::new(),
124            }),
125            Self::Regex(r) => r.find(text).map(|m| PatternMatch {
126                start: m.start(),
127                end: m.end(),
128                captures: r.captures(text),
129            }),
130            Self::Glob(pattern) => glob_match(pattern, text).map(|pos| PatternMatch {
131                start: pos,
132                end: text.len(),
133                captures: Vec::new(),
134            }),
135            Self::Eof | Self::Timeout(_) | Self::Bytes(_) => None,
136        }
137    }
138
139    /// Check if this is a timeout pattern.
140    #[must_use]
141    pub const fn is_timeout(&self) -> bool {
142        matches!(self, Self::Timeout(_))
143    }
144
145    /// Check if this is an EOF pattern.
146    #[must_use]
147    pub const fn is_eof(&self) -> bool {
148        matches!(self, Self::Eof)
149    }
150
151    /// Get the timeout duration if this is a timeout pattern.
152    #[must_use]
153    pub const fn timeout_duration(&self) -> Option<Duration> {
154        match self {
155            Self::Timeout(d) => Some(*d),
156            _ => None,
157        }
158    }
159
160    // =========================================================================
161    // Convenience pattern constructors
162    // =========================================================================
163
164    /// Create a pattern that matches common shell prompts.
165    ///
166    /// Matches prompts ending with `$`, `#`, `>`, or `%` followed by optional whitespace.
167    /// This handles most Unix shells (bash, zsh, sh) and root prompts.
168    ///
169    /// # Examples
170    ///
171    /// ```
172    /// use rust_expect::Pattern;
173    ///
174    /// let prompt = Pattern::shell_prompt();
175    /// assert!(prompt.matches("user@host:~$ ").is_some());
176    /// assert!(prompt.matches("root@host:~# ").is_some());
177    /// assert!(prompt.matches("> ").is_some());
178    /// ```
179    #[must_use]
180    pub fn shell_prompt() -> Self {
181        // Use a fallback to literal if regex somehow fails (it won't for this pattern)
182        Self::regex(r"[$#>%]\s*$").unwrap_or_else(|_| Self::Literal("$ ".to_string()))
183    }
184
185    /// Create a pattern that matches any common prompt character.
186    ///
187    /// A simpler alternative to `shell_prompt()` that uses glob matching.
188    /// Less precise but faster for simple cases.
189    #[must_use]
190    pub fn any_prompt() -> Self {
191        Self::Glob("*$*".to_string())
192    }
193
194    /// Create a pattern that matches IPv4 addresses.
195    ///
196    /// # Errors
197    ///
198    /// Returns an error if the regex compilation fails (should not happen).
199    ///
200    /// # Examples
201    ///
202    /// ```
203    /// use rust_expect::Pattern;
204    ///
205    /// let ipv4 = Pattern::ipv4().unwrap();
206    /// assert!(ipv4.matches("Server IP: 192.168.1.1").is_some());
207    /// assert!(ipv4.matches("10.0.0.255 is local").is_some());
208    /// ```
209    pub fn ipv4() -> Result<Self, regex::Error> {
210        Self::regex(
211            r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
212        )
213    }
214
215    /// Create a pattern that matches email addresses.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if the regex compilation fails (should not happen).
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// use rust_expect::Pattern;
225    ///
226    /// let email = Pattern::email().unwrap();
227    /// assert!(email.matches("Contact: user@example.com").is_some());
228    /// ```
229    pub fn email() -> Result<Self, regex::Error> {
230        Self::regex(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b")
231    }
232
233    /// Create a pattern that matches ISO 8601 timestamps.
234    ///
235    /// Matches formats like `2024-01-15T10:30:00` or `2024-01-15 10:30:00`.
236    ///
237    /// # Errors
238    ///
239    /// Returns an error if the regex compilation fails (should not happen).
240    pub fn timestamp_iso8601() -> Result<Self, regex::Error> {
241        Self::regex(r"\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}")
242    }
243
244    /// Create a pattern that matches common error indicators.
245    ///
246    /// Matches words like "error", "failed", "fatal" (case-insensitive).
247    ///
248    /// # Examples
249    ///
250    /// ```
251    /// use rust_expect::Pattern;
252    ///
253    /// let error = Pattern::error_indicator();
254    /// assert!(error.matches("Error: connection refused").is_some());
255    /// assert!(error.matches("Command FAILED").is_some());
256    /// ```
257    #[must_use]
258    pub fn error_indicator() -> Self {
259        Self::regex(r"(?i)\b(?:error|failed|fatal)\b")
260            .unwrap_or_else(|_| Self::Glob("*[Ee]rror*".to_string()))
261    }
262
263    /// Create a pattern that matches common success indicators.
264    ///
265    /// Matches words like "success", "passed", "complete", "ok" (case-insensitive).
266    #[must_use]
267    pub fn success_indicator() -> Self {
268        Self::regex(r"(?i)\b(?:success|successful|passed|complete|ok)\b")
269            .unwrap_or_else(|_| Self::Glob("*[Ss]uccess*".to_string()))
270    }
271
272    /// Create a pattern that matches common password prompts.
273    ///
274    /// Matches prompts like "Password:", "password: ", "Passphrase:".
275    ///
276    /// # Examples
277    ///
278    /// ```
279    /// use rust_expect::Pattern;
280    ///
281    /// let pwd = Pattern::password_prompt();
282    /// assert!(pwd.matches("Password: ").is_some());
283    /// assert!(pwd.matches("Enter passphrase: ").is_some());
284    /// ```
285    #[must_use]
286    pub fn password_prompt() -> Self {
287        Self::regex(r"(?i)(?:password|passphrase)\s*:\s*$")
288            .unwrap_or_else(|_| Self::Literal("password:".to_string()))
289    }
290
291    /// Create a pattern that matches common login/username prompts.
292    ///
293    /// Matches prompts like "login:", "Username:", "user: ".
294    #[must_use]
295    pub fn login_prompt() -> Self {
296        Self::regex(r"(?i)(?:login|username|user)\s*:\s*$")
297            .unwrap_or_else(|_| Self::Literal("login:".to_string()))
298    }
299
300    /// Create a pattern that matches common yes/no confirmation prompts.
301    ///
302    /// Matches prompts like "[y/n]", "(yes/no)", "[Y/n]".
303    #[must_use]
304    pub fn confirmation_prompt() -> Self {
305        Self::regex(r"\[([yYnN])/([yYnN])\]|\(([yY]es)/([nN]o)\)")
306            .unwrap_or_else(|_| Self::Glob("*[y/n]*".to_string()))
307    }
308
309    /// Create a pattern that matches common "continue?" prompts.
310    ///
311    /// Matches prompts like "Continue?", "Do you want to continue?", "Press any key".
312    #[must_use]
313    pub fn continue_prompt() -> Self {
314        Self::regex(r"(?i)(?:continue\s*\?|press any key|hit enter)")
315            .unwrap_or_else(|_| Self::Glob("*continue*".to_string()))
316    }
317}
318
319impl fmt::Debug for Pattern {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        match self {
322            Self::Literal(s) => write!(f, "Literal({s:?})"),
323            Self::Regex(r) => write!(f, "Regex({:?})", r.pattern()),
324            Self::Glob(s) => write!(f, "Glob({s:?})"),
325            Self::Eof => write!(f, "Eof"),
326            Self::Timeout(d) => write!(f, "Timeout({d:?})"),
327            Self::Bytes(n) => write!(f, "Bytes({n})"),
328        }
329    }
330}
331
332impl From<&str> for Pattern {
333    fn from(s: &str) -> Self {
334        Self::Literal(s.to_string())
335    }
336}
337
338impl From<String> for Pattern {
339    fn from(s: String) -> Self {
340        Self::Literal(s)
341    }
342}
343
344/// A compiled regular expression with its source pattern.
345#[derive(Clone)]
346pub struct CompiledRegex {
347    pattern: String,
348    regex: Regex,
349}
350
351impl CompiledRegex {
352    /// Create a new compiled regex.
353    #[must_use]
354    pub const fn new(pattern: String, regex: Regex) -> Self {
355        Self { pattern, regex }
356    }
357
358    /// Get the source pattern.
359    #[must_use]
360    pub fn pattern(&self) -> &str {
361        &self.pattern
362    }
363
364    /// Find the first match in the text.
365    #[must_use]
366    pub fn find<'a>(&self, text: &'a str) -> Option<regex::Match<'a>> {
367        self.regex.find(text)
368    }
369
370    /// Get capture groups from a match.
371    #[must_use]
372    pub fn captures(&self, text: &str) -> Vec<String> {
373        self.regex
374            .captures(text)
375            .map(|caps| {
376                caps.iter()
377                    .skip(1) // Skip the full match
378                    .filter_map(|m| m.map(|m| m.as_str().to_string()))
379                    .collect()
380            })
381            .unwrap_or_default()
382    }
383}
384
385/// Result of a successful pattern match.
386#[derive(Debug, Clone)]
387pub struct PatternMatch {
388    /// Start position of the match in the text.
389    pub start: usize,
390    /// End position of the match in the text.
391    pub end: usize,
392    /// Capture groups (for regex patterns).
393    pub captures: Vec<String>,
394}
395
396impl PatternMatch {
397    /// Get the matched text from the original input.
398    #[must_use]
399    pub fn as_str<'a>(&self, text: &'a str) -> &'a str {
400        &text[self.start..self.end]
401    }
402
403    /// Get the length of the match.
404    #[must_use]
405    pub const fn len(&self) -> usize {
406        self.end - self.start
407    }
408
409    /// Check if the match is empty.
410    #[must_use]
411    pub const fn is_empty(&self) -> bool {
412        self.start == self.end
413    }
414}
415
416/// A set of patterns for multi-pattern matching.
417#[derive(Debug, Clone, Default)]
418pub struct PatternSet {
419    patterns: Vec<NamedPattern>,
420}
421
422/// A pattern with an optional name.
423#[derive(Clone)]
424pub struct NamedPattern {
425    /// The pattern.
426    pub pattern: Pattern,
427    /// Optional name for the pattern.
428    pub name: Option<String>,
429    /// Index in the pattern set.
430    pub index: usize,
431}
432
433impl fmt::Debug for NamedPattern {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        f.debug_struct("NamedPattern")
436            .field("pattern", &self.pattern)
437            .field("name", &self.name)
438            .field("index", &self.index)
439            .finish()
440    }
441}
442
443impl PatternSet {
444    /// Create a new empty pattern set.
445    #[must_use]
446    pub fn new() -> Self {
447        Self::default()
448    }
449
450    /// Create a pattern set from a vector of patterns.
451    #[must_use]
452    pub fn from_patterns(patterns: Vec<Pattern>) -> Self {
453        let patterns = patterns
454            .into_iter()
455            .enumerate()
456            .map(|(index, pattern)| NamedPattern {
457                pattern,
458                name: None,
459                index,
460            })
461            .collect();
462        Self { patterns }
463    }
464
465    /// Add a pattern to the set.
466    pub fn add(&mut self, pattern: Pattern) -> &mut Self {
467        let index = self.patterns.len();
468        self.patterns.push(NamedPattern {
469            pattern,
470            name: None,
471            index,
472        });
473        self
474    }
475
476    /// Add a named pattern to the set.
477    pub fn add_named(&mut self, name: impl Into<String>, pattern: Pattern) -> &mut Self {
478        let index = self.patterns.len();
479        self.patterns.push(NamedPattern {
480            pattern,
481            name: Some(name.into()),
482            index,
483        });
484        self
485    }
486
487    /// Get the number of patterns in the set.
488    #[must_use]
489    pub const fn len(&self) -> usize {
490        self.patterns.len()
491    }
492
493    /// Check if the set is empty.
494    #[must_use]
495    pub const fn is_empty(&self) -> bool {
496        self.patterns.is_empty()
497    }
498
499    /// Find the first matching pattern in the text.
500    ///
501    /// Returns the pattern index and match details.
502    #[must_use]
503    pub fn find_match(&self, text: &str) -> Option<(usize, PatternMatch)> {
504        let mut best_match: Option<(usize, PatternMatch)> = None;
505
506        for (idx, named) in self.patterns.iter().enumerate() {
507            if let Some(m) = named.pattern.matches(text) {
508                match &best_match {
509                    None => best_match = Some((idx, m)),
510                    Some((_, current)) if m.start < current.start => {
511                        best_match = Some((idx, m));
512                    }
513                    _ => {}
514                }
515            }
516        }
517
518        best_match
519    }
520
521    /// Get a pattern by index.
522    #[must_use]
523    pub fn get(&self, index: usize) -> Option<&NamedPattern> {
524        self.patterns.get(index)
525    }
526
527    /// Get the minimum timeout from timeout patterns.
528    #[must_use]
529    pub fn min_timeout(&self) -> Option<Duration> {
530        self.patterns
531            .iter()
532            .filter_map(|p| p.pattern.timeout_duration())
533            .min()
534    }
535
536    /// Check if any pattern is an EOF pattern.
537    #[must_use]
538    pub fn has_eof(&self) -> bool {
539        self.patterns.iter().any(|p| p.pattern.is_eof())
540    }
541
542    /// Get iterator over patterns.
543    pub fn iter(&self) -> impl Iterator<Item = &NamedPattern> {
544        self.patterns.iter()
545    }
546}
547
548/// Simple glob pattern matching.
549///
550/// Supports `*` (any characters) and `?` (single character).
551fn glob_match(pattern: &str, text: &str) -> Option<usize> {
552    let pattern_chars: Vec<char> = pattern.chars().collect();
553    let text_chars: Vec<char> = text.chars().collect();
554
555    (0..=text_chars.len()).find(|&start| glob_match_from(&pattern_chars, &text_chars[start..]))
556}
557
558const fn glob_match_from(pattern: &[char], text: &[char]) -> bool {
559    let mut p = 0;
560    let mut t = 0;
561    let mut star_p = None;
562    let mut star_t = 0;
563
564    while p < pattern.len() {
565        if pattern[p] == '*' {
566            star_p = Some(p);
567            star_t = t;
568            p += 1;
569        } else if t < text.len() && (pattern[p] == '?' || pattern[p] == text[t]) {
570            p += 1;
571            t += 1;
572        } else if let Some(sp) = star_p {
573            p = sp + 1;
574            star_t += 1;
575            if star_t > text.len() {
576                return false;
577            }
578            t = star_t;
579        } else {
580            return false;
581        }
582    }
583
584    // Pattern matched - we don't require text to be fully consumed
585    // (we're looking for the pattern within the text)
586    true
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[test]
594    fn literal_pattern_matches() {
595        let pattern = Pattern::literal("hello");
596        let result = pattern.matches("say hello world");
597        assert!(result.is_some());
598        let m = result.unwrap();
599        assert_eq!(m.start, 4);
600        assert_eq!(m.end, 9);
601    }
602
603    #[test]
604    fn regex_pattern_matches() {
605        let pattern = Pattern::regex(r"\d+").unwrap();
606        let result = pattern.matches("test 123 value");
607        assert!(result.is_some());
608        let m = result.unwrap();
609        assert_eq!(m.as_str("test 123 value"), "123");
610    }
611
612    #[test]
613    fn regex_pattern_captures() {
614        let pattern = Pattern::regex(r"(\w+)@(\w+)").unwrap();
615        let result = pattern.matches("email: user@domain here");
616        assert!(result.is_some());
617        let m = result.unwrap();
618        assert_eq!(m.captures, vec!["user", "domain"]);
619    }
620
621    #[test]
622    fn glob_pattern_matches() {
623        let pattern = Pattern::glob("hello*world");
624        let result = pattern.matches("say hello beautiful world!");
625        assert!(result.is_some());
626    }
627
628    #[test]
629    fn pattern_set_finds_first() {
630        let mut set = PatternSet::new();
631        set.add(Pattern::literal("world"))
632            .add(Pattern::literal("hello"));
633
634        let result = set.find_match("hello world");
635        assert!(result.is_some());
636        let (idx, _) = result.unwrap();
637        // "hello" comes first in the text
638        assert_eq!(idx, 1);
639    }
640
641    #[test]
642    fn pattern_set_min_timeout() {
643        let mut set = PatternSet::new();
644        set.add(Pattern::timeout(Duration::from_secs(10)))
645            .add(Pattern::timeout(Duration::from_secs(5)));
646
647        assert_eq!(set.min_timeout(), Some(Duration::from_secs(5)));
648    }
649}