rust_expect/multi/
select.rs

1//! Async multi-session selection.
2//!
3//! This module provides true async selection across multiple sessions,
4//! allowing you to wait for patterns to match on any of several sessions
5//! simultaneously using Tokio's async primitives.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use rust_expect::multi::{MultiSessionManager, SelectResult};
11//! use rust_expect::Session;
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), rust_expect::ExpectError> {
15//!     let mut manager = MultiSessionManager::new();
16//!
17//!     // Add sessions
18//!     let id1 = manager.spawn("bash", &["-c", "echo server1"]).await?;
19//!     let id2 = manager.spawn("bash", &["-c", "echo server2"]).await?;
20//!
21//!     // Wait for any session to produce output
22//!     let result = manager.expect_any("server").await?;
23//!     println!("Session {} matched: {}", result.session_id, result.matched.matched);
24//!
25//!     Ok(())
26//! }
27//! ```
28
29use std::collections::HashMap;
30use std::fmt;
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::time::Duration;
35
36use futures::stream::{FuturesUnordered, StreamExt};
37use tokio::io::{AsyncReadExt, AsyncWriteExt};
38use tokio::sync::Mutex;
39
40use crate::config::SessionConfig;
41use crate::error::{ExpectError, Result};
42use crate::expect::{Pattern, PatternSet};
43use crate::types::Match;
44
45/// Unique identifier for a session within a multi-session manager.
46pub type SessionId = usize;
47
48/// Result of a multi-session select operation.
49#[derive(Debug, Clone)]
50pub struct SelectResult {
51    /// The session that matched.
52    pub session_id: SessionId,
53    /// The match result.
54    pub matched: Match,
55    /// Index of the pattern that matched (if multiple patterns provided).
56    pub pattern_index: usize,
57}
58
59/// Result of a multi-session send operation.
60#[derive(Debug, Clone)]
61pub struct SendResult {
62    /// The session the data was sent to.
63    pub session_id: SessionId,
64    /// Whether the send succeeded.
65    pub success: bool,
66    /// Error message if failed.
67    pub error: Option<String>,
68}
69
70/// Type of readiness event.
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ReadyType {
73    /// Session has data matching a pattern.
74    Matched,
75    /// Session has data available to read.
76    Readable,
77    /// Session is ready for writing.
78    Writable,
79    /// Session has closed (EOF).
80    Closed,
81    /// Session encountered an error.
82    Error,
83}
84
85/// A managed session with its metadata.
86struct ManagedSession<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> {
87    /// The underlying session.
88    session: crate::session::Session<T>,
89    /// Session label for identification.
90    label: String,
91    /// Whether the session is active.
92    active: bool,
93}
94
95impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> fmt::Debug for ManagedSession<T> {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        f.debug_struct("ManagedSession")
98            .field("label", &self.label)
99            .field("active", &self.active)
100            .finish_non_exhaustive()
101    }
102}
103
104/// Manager for multiple async sessions with select capabilities.
105///
106/// This provides the core multi-session functionality, allowing you to:
107/// - Manage multiple sessions simultaneously
108/// - Wait for any session to match a pattern (`expect_any`)
109/// - Wait for all sessions to match patterns (`expect_all`)
110/// - Send to multiple sessions in parallel
111/// - Select on multiple sessions with different patterns per session
112pub struct MultiSessionManager<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> {
113    /// Sessions indexed by ID.
114    sessions: HashMap<SessionId, Arc<Mutex<ManagedSession<T>>>>,
115    /// Next session ID to assign.
116    next_id: SessionId,
117    /// Default timeout for operations.
118    default_timeout: Duration,
119    /// Default configuration for spawned sessions.
120    default_config: SessionConfig,
121}
122
123impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> fmt::Debug
124    for MultiSessionManager<T>
125{
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("MultiSessionManager")
128            .field("session_count", &self.sessions.len())
129            .field("next_id", &self.next_id)
130            .field("default_timeout", &self.default_timeout)
131            .finish()
132    }
133}
134
135impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> Default for MultiSessionManager<T> {
136    fn default() -> Self {
137        Self::new()
138    }
139}
140
141impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> MultiSessionManager<T> {
142    /// Create a new multi-session manager.
143    #[must_use]
144    pub fn new() -> Self {
145        Self {
146            sessions: HashMap::new(),
147            next_id: 0,
148            default_timeout: Duration::from_secs(30),
149            default_config: SessionConfig::default(),
150        }
151    }
152
153    /// Set the default timeout for operations.
154    #[must_use]
155    pub const fn with_timeout(mut self, timeout: Duration) -> Self {
156        self.default_timeout = timeout;
157        self
158    }
159
160    /// Set the default session configuration.
161    #[must_use]
162    pub fn with_config(mut self, config: SessionConfig) -> Self {
163        self.default_config = config;
164        self
165    }
166
167    /// Add an existing session to the manager.
168    ///
169    /// Returns the assigned session ID.
170    pub fn add(
171        &mut self,
172        session: crate::session::Session<T>,
173        label: impl Into<String>,
174    ) -> SessionId {
175        let id = self.next_id;
176        self.next_id += 1;
177
178        let managed = ManagedSession {
179            session,
180            label: label.into(),
181            active: true,
182        };
183
184        self.sessions.insert(id, Arc::new(Mutex::new(managed)));
185        id
186    }
187
188    /// Remove a session from the manager.
189    ///
190    /// Returns the session if it existed.
191    #[allow(clippy::unused_async)]
192    pub async fn remove(&mut self, id: SessionId) -> Option<crate::session::Session<T>> {
193        if let Some(arc) = self.sessions.remove(&id) {
194            // Try to unwrap the Arc - this will only succeed if we have the only reference
195            match Arc::try_unwrap(arc) {
196                Ok(mutex) => Some(mutex.into_inner().session),
197                Err(arc) => {
198                    // Put it back and return None - someone else has a reference
199                    self.sessions.insert(id, arc);
200                    None
201                }
202            }
203        } else {
204            None
205        }
206    }
207
208    /// Get the number of sessions.
209    #[must_use]
210    pub fn len(&self) -> usize {
211        self.sessions.len()
212    }
213
214    /// Check if there are no sessions.
215    #[must_use]
216    pub fn is_empty(&self) -> bool {
217        self.sessions.is_empty()
218    }
219
220    /// Get all session IDs.
221    #[must_use]
222    pub fn session_ids(&self) -> Vec<SessionId> {
223        self.sessions.keys().copied().collect()
224    }
225
226    /// Get the label for a session.
227    pub async fn label(&self, id: SessionId) -> Option<String> {
228        if let Some(arc) = self.sessions.get(&id) {
229            let guard = arc.lock().await;
230            Some(guard.label.clone())
231        } else {
232            None
233        }
234    }
235
236    /// Check if a session is active.
237    pub async fn is_active(&self, id: SessionId) -> bool {
238        if let Some(arc) = self.sessions.get(&id) {
239            let guard = arc.lock().await;
240            guard.active
241        } else {
242            false
243        }
244    }
245
246    /// Set a session's active state.
247    pub async fn set_active(&self, id: SessionId, active: bool) {
248        if let Some(arc) = self.sessions.get(&id) {
249            let mut guard = arc.lock().await;
250            guard.active = active;
251        }
252    }
253
254    /// Get active session IDs.
255    pub async fn active_ids(&self) -> Vec<SessionId> {
256        let mut active = Vec::new();
257        for &id in self.sessions.keys() {
258            if self.is_active(id).await {
259                active.push(id);
260            }
261        }
262        active
263    }
264
265    /// Send data to a specific session.
266    ///
267    /// # Errors
268    ///
269    /// Returns an error if the session doesn't exist or the send fails.
270    pub async fn send(&self, id: SessionId, data: &[u8]) -> Result<()> {
271        let arc = self
272            .sessions
273            .get(&id)
274            .ok_or(ExpectError::SessionNotFound { id })?;
275
276        let mut guard = arc.lock().await;
277        guard.session.send(data).await
278    }
279
280    /// Send a line to a specific session.
281    ///
282    /// # Errors
283    ///
284    /// Returns an error if the session doesn't exist or the send fails.
285    pub async fn send_line(&self, id: SessionId, line: &str) -> Result<()> {
286        let arc = self
287            .sessions
288            .get(&id)
289            .ok_or(ExpectError::SessionNotFound { id })?;
290
291        let mut guard = arc.lock().await;
292        guard.session.send_line(line).await
293    }
294
295    /// Send data to all active sessions in parallel.
296    ///
297    /// Returns results for each session.
298    pub async fn send_all(&self, data: &[u8]) -> Vec<SendResult> {
299        let mut futures = FuturesUnordered::new();
300
301        for (&id, arc) in &self.sessions {
302            let arc = Arc::clone(arc);
303            let data = data.to_vec();
304
305            futures.push(async move {
306                let mut guard = arc.lock().await;
307                if !guard.active {
308                    return SendResult {
309                        session_id: id,
310                        success: false,
311                        error: Some("Session not active".to_string()),
312                    };
313                }
314
315                match guard.session.send(&data).await {
316                    Ok(()) => SendResult {
317                        session_id: id,
318                        success: true,
319                        error: None,
320                    },
321                    Err(e) => SendResult {
322                        session_id: id,
323                        success: false,
324                        error: Some(e.to_string()),
325                    },
326                }
327            });
328        }
329
330        let mut results = Vec::new();
331        while let Some(result) = futures.next().await {
332            results.push(result);
333        }
334        results
335    }
336
337    /// Expect a pattern on a specific session.
338    ///
339    /// # Errors
340    ///
341    /// Returns an error if the session doesn't exist or expect fails.
342    pub async fn expect(&self, id: SessionId, pattern: impl Into<Pattern>) -> Result<Match> {
343        let arc = self
344            .sessions
345            .get(&id)
346            .ok_or(ExpectError::SessionNotFound { id })?;
347
348        let mut guard = arc.lock().await;
349        guard.session.expect(pattern).await
350    }
351
352    /// Wait for any session to match the given pattern.
353    ///
354    /// Returns as soon as any session matches. This is the primary multi-session
355    /// select operation, equivalent to TCL Expect's multi-spawn expect.
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if all sessions timeout or encounter errors.
360    #[allow(clippy::type_complexity)]
361    pub async fn expect_any(&self, pattern: impl Into<Pattern>) -> Result<SelectResult> {
362        let pattern = pattern.into();
363        self.expect_any_of(&[pattern]).await
364    }
365
366    /// Wait for any session to match any of the given patterns.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if all sessions timeout or encounter errors.
371    #[allow(clippy::type_complexity)]
372    pub async fn expect_any_of(&self, patterns: &[Pattern]) -> Result<SelectResult> {
373        if self.sessions.is_empty() {
374            return Err(ExpectError::NoSessions);
375        }
376
377        let pattern_set = PatternSet::from_patterns(patterns.to_vec());
378
379        // Create futures for all active sessions
380        let mut futures: FuturesUnordered<
381            Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
382        > = FuturesUnordered::new();
383
384        for (&id, arc) in &self.sessions {
385            let arc = Arc::clone(arc);
386            let patterns = pattern_set.clone();
387
388            let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
389                Box::pin(async move {
390                    let mut guard = arc.lock().await;
391                    if !guard.active {
392                        return (id, Err(ExpectError::SessionClosed));
393                    }
394
395                    match guard.session.expect_any(&patterns).await {
396                        Ok(m) => (id, Ok((m, 0))), // pattern_index 0 for now
397                        Err(e) => (id, Err(e)),
398                    }
399                });
400
401            futures.push(future);
402        }
403
404        // Wait for the first successful match
405        let mut last_error: Option<ExpectError> = None;
406
407        while let Some((session_id, result)) = futures.next().await {
408            match result {
409                Ok((matched, pattern_index)) => {
410                    return Ok(SelectResult {
411                        session_id,
412                        matched,
413                        pattern_index,
414                    });
415                }
416                Err(e) => {
417                    // Store the error but continue waiting for others
418                    // Only timeouts should be ignored; other errors are real failures
419                    if !matches!(e, ExpectError::Timeout { .. }) {
420                        last_error = Some(e);
421                    }
422                }
423            }
424        }
425
426        // All futures completed without a match
427        Err(last_error.unwrap_or_else(|| ExpectError::Timeout {
428            duration: self.default_timeout,
429            pattern: "multi-session expect".to_string(),
430            buffer: String::new(),
431        }))
432    }
433
434    /// Wait for all sessions to match patterns.
435    ///
436    /// Each session must match at least one pattern. Returns results for all sessions.
437    ///
438    /// # Errors
439    ///
440    /// Returns an error if any session fails to match.
441    pub async fn expect_all(&self, pattern: impl Into<Pattern>) -> Result<Vec<SelectResult>> {
442        let pattern = pattern.into();
443        self.expect_all_of(&[pattern]).await
444    }
445
446    /// Wait for all sessions to match any of the given patterns.
447    ///
448    /// # Errors
449    ///
450    /// Returns an error if any session fails.
451    #[allow(clippy::type_complexity)]
452    pub async fn expect_all_of(&self, patterns: &[Pattern]) -> Result<Vec<SelectResult>> {
453        if self.sessions.is_empty() {
454            return Err(ExpectError::NoSessions);
455        }
456
457        let pattern_set = PatternSet::from_patterns(patterns.to_vec());
458
459        // Create futures for all active sessions
460        let mut futures: FuturesUnordered<
461            Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
462        > = FuturesUnordered::new();
463
464        for (&id, arc) in &self.sessions {
465            let arc = Arc::clone(arc);
466            let patterns = pattern_set.clone();
467
468            let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
469                Box::pin(async move {
470                    let mut guard = arc.lock().await;
471                    if !guard.active {
472                        return (id, Err(ExpectError::SessionClosed));
473                    }
474
475                    match guard.session.expect_any(&patterns).await {
476                        Ok(m) => (id, Ok((m, 0))),
477                        Err(e) => (id, Err(e)),
478                    }
479                });
480
481            futures.push(future);
482        }
483
484        // Collect all results
485        let mut results = Vec::new();
486        let mut errors = Vec::new();
487
488        while let Some((session_id, result)) = futures.next().await {
489            match result {
490                Ok((matched, pattern_index)) => {
491                    results.push(SelectResult {
492                        session_id,
493                        matched,
494                        pattern_index,
495                    });
496                }
497                Err(e) => {
498                    errors.push((session_id, e));
499                }
500            }
501        }
502
503        // If any session failed, return the first error
504        if let Some((id, error)) = errors.into_iter().next() {
505            return Err(ExpectError::MultiSessionError {
506                session_id: id,
507                error: Box::new(error),
508            });
509        }
510
511        Ok(results)
512    }
513
514    /// Execute a closure on a specific session.
515    ///
516    /// This provides direct access to the session for operations not covered
517    /// by the manager's API.
518    ///
519    /// # Errors
520    ///
521    /// Returns an error if the session doesn't exist.
522    pub async fn with_session<F, R>(&self, id: SessionId, f: F) -> Result<R>
523    where
524        F: FnOnce(&mut crate::session::Session<T>) -> R,
525    {
526        let arc = self
527            .sessions
528            .get(&id)
529            .ok_or(ExpectError::SessionNotFound { id })?;
530
531        let mut guard = arc.lock().await;
532        Ok(f(&mut guard.session))
533    }
534
535    /// Execute an async closure on a specific session.
536    ///
537    /// # Errors
538    ///
539    /// Returns an error if the session doesn't exist.
540    pub async fn with_session_async<F, Fut, R>(&self, id: SessionId, f: F) -> Result<R>
541    where
542        F: FnOnce(&mut crate::session::Session<T>) -> Fut,
543        Fut: Future<Output = R>,
544    {
545        let arc = self
546            .sessions
547            .get(&id)
548            .ok_or(ExpectError::SessionNotFound { id })?;
549
550        let mut guard = arc.lock().await;
551        Ok(f(&mut guard.session).await)
552    }
553}
554
555/// Builder for creating pattern selectors with per-session patterns.
556///
557/// This allows different patterns for different sessions, enabling
558/// complex multi-session automation scenarios.
559#[derive(Debug, Default)]
560pub struct PatternSelector {
561    /// Patterns per session.
562    patterns: HashMap<SessionId, Vec<Pattern>>,
563    /// Default patterns for sessions not explicitly configured.
564    default_patterns: Vec<Pattern>,
565    /// Timeout for the select operation.
566    timeout: Option<Duration>,
567}
568
569impl PatternSelector {
570    /// Create a new pattern selector.
571    #[must_use]
572    pub fn new() -> Self {
573        Self::default()
574    }
575
576    /// Add a pattern for a specific session.
577    #[must_use]
578    pub fn session(mut self, id: SessionId, pattern: impl Into<Pattern>) -> Self {
579        self.patterns.entry(id).or_default().push(pattern.into());
580        self
581    }
582
583    /// Add multiple patterns for a specific session.
584    #[must_use]
585    pub fn session_patterns(mut self, id: SessionId, patterns: Vec<Pattern>) -> Self {
586        self.patterns.entry(id).or_default().extend(patterns);
587        self
588    }
589
590    /// Set default patterns for sessions not explicitly configured.
591    #[must_use]
592    pub fn default_pattern(mut self, pattern: impl Into<Pattern>) -> Self {
593        self.default_patterns.push(pattern.into());
594        self
595    }
596
597    /// Set timeout for the select operation.
598    #[must_use]
599    pub const fn timeout(mut self, timeout: Duration) -> Self {
600        self.timeout = Some(timeout);
601        self
602    }
603
604    /// Get patterns for a session, falling back to defaults.
605    #[must_use]
606    pub fn patterns_for(&self, id: SessionId) -> &[Pattern] {
607        self.patterns
608            .get(&id)
609            .map_or(&self.default_patterns, Vec::as_slice)
610    }
611
612    /// Execute the select operation on a multi-session manager.
613    ///
614    /// # Errors
615    ///
616    /// Returns an error if no sessions match or all timeout.
617    #[allow(clippy::type_complexity)]
618    pub async fn select<T>(&self, manager: &MultiSessionManager<T>) -> Result<SelectResult>
619    where
620        T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
621    {
622        if manager.is_empty() {
623            return Err(ExpectError::NoSessions);
624        }
625
626        let timeout = self.timeout.unwrap_or(manager.default_timeout);
627
628        // Create futures for all configured sessions
629        let mut futures: FuturesUnordered<
630            Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
631        > = FuturesUnordered::new();
632
633        for &id in &manager.session_ids() {
634            let patterns = self.patterns_for(id);
635            if patterns.is_empty() {
636                continue;
637            }
638
639            let arc = match manager.sessions.get(&id) {
640                Some(arc) => Arc::clone(arc),
641                None => continue,
642            };
643
644            let pattern_set = PatternSet::from_patterns(patterns.to_vec());
645
646            let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
647                Box::pin(async move {
648                    let mut guard = arc.lock().await;
649                    if !guard.active {
650                        return (id, Err(ExpectError::SessionClosed));
651                    }
652
653                    match guard.session.expect_any(&pattern_set).await {
654                        Ok(m) => (id, Ok((m, 0))),
655                        Err(e) => (id, Err(e)),
656                    }
657                });
658
659            futures.push(future);
660        }
661
662        // Apply overall timeout
663        let select_future = async {
664            while let Some((session_id, result)) = futures.next().await {
665                if let Ok((matched, pattern_index)) = result {
666                    return Ok(SelectResult {
667                        session_id,
668                        matched,
669                        pattern_index,
670                    });
671                }
672            }
673            Err(ExpectError::Timeout {
674                duration: timeout,
675                pattern: "pattern selector".to_string(),
676                buffer: String::new(),
677            })
678        };
679
680        tokio::time::timeout(timeout, select_future)
681            .await
682            .map_err(|_| ExpectError::Timeout {
683                duration: timeout,
684                pattern: "pattern selector".to_string(),
685                buffer: String::new(),
686            })?
687    }
688}
689
690#[cfg(test)]
691mod tests {
692    use tokio::io::DuplexStream;
693
694    use super::*;
695
696    // Helper to create a mock session transport
697    fn create_mock_transport() -> (DuplexStream, DuplexStream) {
698        tokio::io::duplex(1024)
699    }
700
701    #[tokio::test]
702    async fn manager_add_remove() {
703        let mut manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
704
705        let (client, _server) = create_mock_transport();
706        let session = crate::session::Session::new(client, SessionConfig::default());
707
708        let id = manager.add(session, "test");
709        assert_eq!(manager.len(), 1);
710        assert_eq!(manager.label(id).await, Some("test".to_string()));
711
712        let removed = manager.remove(id).await;
713        assert!(removed.is_some());
714        assert!(manager.is_empty());
715    }
716
717    #[tokio::test]
718    async fn manager_active_state() {
719        let mut manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
720
721        let (client, _server) = create_mock_transport();
722        let session = crate::session::Session::new(client, SessionConfig::default());
723
724        let id = manager.add(session, "test");
725        assert!(manager.is_active(id).await);
726
727        manager.set_active(id, false).await;
728        assert!(!manager.is_active(id).await);
729
730        let active = manager.active_ids().await;
731        assert!(active.is_empty());
732    }
733
734    #[tokio::test]
735    async fn pattern_selector_build() {
736        let selector = PatternSelector::new()
737            .session(0, "login:")
738            .session(0, "password:")
739            .session(1, "prompt>")
740            .default_pattern("$");
741
742        assert_eq!(selector.patterns_for(0).len(), 2);
743        assert_eq!(selector.patterns_for(1).len(), 1);
744        assert_eq!(selector.patterns_for(99).len(), 1); // Falls back to default
745    }
746
747    #[tokio::test]
748    async fn expect_any_no_sessions() {
749        let manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
750        let result = manager.expect_any("test").await;
751        assert!(matches!(result, Err(ExpectError::NoSessions)));
752    }
753}