1use std::collections::HashMap;
30use std::fmt;
31use std::future::Future;
32use std::pin::Pin;
33use std::sync::Arc;
34use std::time::Duration;
35
36use futures::stream::{FuturesUnordered, StreamExt};
37use tokio::io::{AsyncReadExt, AsyncWriteExt};
38use tokio::sync::Mutex;
39
40use crate::config::SessionConfig;
41use crate::error::{ExpectError, Result};
42use crate::expect::{Pattern, PatternSet};
43use crate::types::Match;
44
45pub type SessionId = usize;
47
48#[derive(Debug, Clone)]
50pub struct SelectResult {
51 pub session_id: SessionId,
53 pub matched: Match,
55 pub pattern_index: usize,
57}
58
59#[derive(Debug, Clone)]
61pub struct SendResult {
62 pub session_id: SessionId,
64 pub success: bool,
66 pub error: Option<String>,
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ReadyType {
73 Matched,
75 Readable,
77 Writable,
79 Closed,
81 Error,
83}
84
85struct ManagedSession<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> {
87 session: crate::session::Session<T>,
89 label: String,
91 active: bool,
93}
94
95impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send> fmt::Debug for ManagedSession<T> {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 f.debug_struct("ManagedSession")
98 .field("label", &self.label)
99 .field("active", &self.active)
100 .finish_non_exhaustive()
101 }
102}
103
104pub struct MultiSessionManager<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> {
113 sessions: HashMap<SessionId, Arc<Mutex<ManagedSession<T>>>>,
115 next_id: SessionId,
117 default_timeout: Duration,
119 default_config: SessionConfig,
121}
122
123impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> fmt::Debug
124 for MultiSessionManager<T>
125{
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("MultiSessionManager")
128 .field("session_count", &self.sessions.len())
129 .field("next_id", &self.next_id)
130 .field("default_timeout", &self.default_timeout)
131 .finish()
132 }
133}
134
135impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> Default for MultiSessionManager<T> {
136 fn default() -> Self {
137 Self::new()
138 }
139}
140
141impl<T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static> MultiSessionManager<T> {
142 #[must_use]
144 pub fn new() -> Self {
145 Self {
146 sessions: HashMap::new(),
147 next_id: 0,
148 default_timeout: Duration::from_secs(30),
149 default_config: SessionConfig::default(),
150 }
151 }
152
153 #[must_use]
155 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
156 self.default_timeout = timeout;
157 self
158 }
159
160 #[must_use]
162 pub fn with_config(mut self, config: SessionConfig) -> Self {
163 self.default_config = config;
164 self
165 }
166
167 pub fn add(
171 &mut self,
172 session: crate::session::Session<T>,
173 label: impl Into<String>,
174 ) -> SessionId {
175 let id = self.next_id;
176 self.next_id += 1;
177
178 let managed = ManagedSession {
179 session,
180 label: label.into(),
181 active: true,
182 };
183
184 self.sessions.insert(id, Arc::new(Mutex::new(managed)));
185 id
186 }
187
188 #[allow(clippy::unused_async)]
192 pub async fn remove(&mut self, id: SessionId) -> Option<crate::session::Session<T>> {
193 if let Some(arc) = self.sessions.remove(&id) {
194 match Arc::try_unwrap(arc) {
196 Ok(mutex) => Some(mutex.into_inner().session),
197 Err(arc) => {
198 self.sessions.insert(id, arc);
200 None
201 }
202 }
203 } else {
204 None
205 }
206 }
207
208 #[must_use]
210 pub fn len(&self) -> usize {
211 self.sessions.len()
212 }
213
214 #[must_use]
216 pub fn is_empty(&self) -> bool {
217 self.sessions.is_empty()
218 }
219
220 #[must_use]
222 pub fn session_ids(&self) -> Vec<SessionId> {
223 self.sessions.keys().copied().collect()
224 }
225
226 pub async fn label(&self, id: SessionId) -> Option<String> {
228 if let Some(arc) = self.sessions.get(&id) {
229 let guard = arc.lock().await;
230 Some(guard.label.clone())
231 } else {
232 None
233 }
234 }
235
236 pub async fn is_active(&self, id: SessionId) -> bool {
238 if let Some(arc) = self.sessions.get(&id) {
239 let guard = arc.lock().await;
240 guard.active
241 } else {
242 false
243 }
244 }
245
246 pub async fn set_active(&self, id: SessionId, active: bool) {
248 if let Some(arc) = self.sessions.get(&id) {
249 let mut guard = arc.lock().await;
250 guard.active = active;
251 }
252 }
253
254 pub async fn active_ids(&self) -> Vec<SessionId> {
256 let mut active = Vec::new();
257 for &id in self.sessions.keys() {
258 if self.is_active(id).await {
259 active.push(id);
260 }
261 }
262 active
263 }
264
265 pub async fn send(&self, id: SessionId, data: &[u8]) -> Result<()> {
271 let arc = self
272 .sessions
273 .get(&id)
274 .ok_or(ExpectError::SessionNotFound { id })?;
275
276 let mut guard = arc.lock().await;
277 guard.session.send(data).await
278 }
279
280 pub async fn send_line(&self, id: SessionId, line: &str) -> Result<()> {
286 let arc = self
287 .sessions
288 .get(&id)
289 .ok_or(ExpectError::SessionNotFound { id })?;
290
291 let mut guard = arc.lock().await;
292 guard.session.send_line(line).await
293 }
294
295 pub async fn send_all(&self, data: &[u8]) -> Vec<SendResult> {
299 let mut futures = FuturesUnordered::new();
300
301 for (&id, arc) in &self.sessions {
302 let arc = Arc::clone(arc);
303 let data = data.to_vec();
304
305 futures.push(async move {
306 let mut guard = arc.lock().await;
307 if !guard.active {
308 return SendResult {
309 session_id: id,
310 success: false,
311 error: Some("Session not active".to_string()),
312 };
313 }
314
315 match guard.session.send(&data).await {
316 Ok(()) => SendResult {
317 session_id: id,
318 success: true,
319 error: None,
320 },
321 Err(e) => SendResult {
322 session_id: id,
323 success: false,
324 error: Some(e.to_string()),
325 },
326 }
327 });
328 }
329
330 let mut results = Vec::new();
331 while let Some(result) = futures.next().await {
332 results.push(result);
333 }
334 results
335 }
336
337 pub async fn expect(&self, id: SessionId, pattern: impl Into<Pattern>) -> Result<Match> {
343 let arc = self
344 .sessions
345 .get(&id)
346 .ok_or(ExpectError::SessionNotFound { id })?;
347
348 let mut guard = arc.lock().await;
349 guard.session.expect(pattern).await
350 }
351
352 #[allow(clippy::type_complexity)]
361 pub async fn expect_any(&self, pattern: impl Into<Pattern>) -> Result<SelectResult> {
362 let pattern = pattern.into();
363 self.expect_any_of(&[pattern]).await
364 }
365
366 #[allow(clippy::type_complexity)]
372 pub async fn expect_any_of(&self, patterns: &[Pattern]) -> Result<SelectResult> {
373 if self.sessions.is_empty() {
374 return Err(ExpectError::NoSessions);
375 }
376
377 let pattern_set = PatternSet::from_patterns(patterns.to_vec());
378
379 let mut futures: FuturesUnordered<
381 Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
382 > = FuturesUnordered::new();
383
384 for (&id, arc) in &self.sessions {
385 let arc = Arc::clone(arc);
386 let patterns = pattern_set.clone();
387
388 let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
389 Box::pin(async move {
390 let mut guard = arc.lock().await;
391 if !guard.active {
392 return (id, Err(ExpectError::SessionClosed));
393 }
394
395 match guard.session.expect_any(&patterns).await {
396 Ok(m) => (id, Ok((m, 0))), Err(e) => (id, Err(e)),
398 }
399 });
400
401 futures.push(future);
402 }
403
404 let mut last_error: Option<ExpectError> = None;
406
407 while let Some((session_id, result)) = futures.next().await {
408 match result {
409 Ok((matched, pattern_index)) => {
410 return Ok(SelectResult {
411 session_id,
412 matched,
413 pattern_index,
414 });
415 }
416 Err(e) => {
417 if !matches!(e, ExpectError::Timeout { .. }) {
420 last_error = Some(e);
421 }
422 }
423 }
424 }
425
426 Err(last_error.unwrap_or_else(|| ExpectError::Timeout {
428 duration: self.default_timeout,
429 pattern: "multi-session expect".to_string(),
430 buffer: String::new(),
431 }))
432 }
433
434 pub async fn expect_all(&self, pattern: impl Into<Pattern>) -> Result<Vec<SelectResult>> {
442 let pattern = pattern.into();
443 self.expect_all_of(&[pattern]).await
444 }
445
446 #[allow(clippy::type_complexity)]
452 pub async fn expect_all_of(&self, patterns: &[Pattern]) -> Result<Vec<SelectResult>> {
453 if self.sessions.is_empty() {
454 return Err(ExpectError::NoSessions);
455 }
456
457 let pattern_set = PatternSet::from_patterns(patterns.to_vec());
458
459 let mut futures: FuturesUnordered<
461 Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
462 > = FuturesUnordered::new();
463
464 for (&id, arc) in &self.sessions {
465 let arc = Arc::clone(arc);
466 let patterns = pattern_set.clone();
467
468 let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
469 Box::pin(async move {
470 let mut guard = arc.lock().await;
471 if !guard.active {
472 return (id, Err(ExpectError::SessionClosed));
473 }
474
475 match guard.session.expect_any(&patterns).await {
476 Ok(m) => (id, Ok((m, 0))),
477 Err(e) => (id, Err(e)),
478 }
479 });
480
481 futures.push(future);
482 }
483
484 let mut results = Vec::new();
486 let mut errors = Vec::new();
487
488 while let Some((session_id, result)) = futures.next().await {
489 match result {
490 Ok((matched, pattern_index)) => {
491 results.push(SelectResult {
492 session_id,
493 matched,
494 pattern_index,
495 });
496 }
497 Err(e) => {
498 errors.push((session_id, e));
499 }
500 }
501 }
502
503 if let Some((id, error)) = errors.into_iter().next() {
505 return Err(ExpectError::MultiSessionError {
506 session_id: id,
507 error: Box::new(error),
508 });
509 }
510
511 Ok(results)
512 }
513
514 pub async fn with_session<F, R>(&self, id: SessionId, f: F) -> Result<R>
523 where
524 F: FnOnce(&mut crate::session::Session<T>) -> R,
525 {
526 let arc = self
527 .sessions
528 .get(&id)
529 .ok_or(ExpectError::SessionNotFound { id })?;
530
531 let mut guard = arc.lock().await;
532 Ok(f(&mut guard.session))
533 }
534
535 pub async fn with_session_async<F, Fut, R>(&self, id: SessionId, f: F) -> Result<R>
541 where
542 F: FnOnce(&mut crate::session::Session<T>) -> Fut,
543 Fut: Future<Output = R>,
544 {
545 let arc = self
546 .sessions
547 .get(&id)
548 .ok_or(ExpectError::SessionNotFound { id })?;
549
550 let mut guard = arc.lock().await;
551 Ok(f(&mut guard.session).await)
552 }
553}
554
555#[derive(Debug, Default)]
560pub struct PatternSelector {
561 patterns: HashMap<SessionId, Vec<Pattern>>,
563 default_patterns: Vec<Pattern>,
565 timeout: Option<Duration>,
567}
568
569impl PatternSelector {
570 #[must_use]
572 pub fn new() -> Self {
573 Self::default()
574 }
575
576 #[must_use]
578 pub fn session(mut self, id: SessionId, pattern: impl Into<Pattern>) -> Self {
579 self.patterns.entry(id).or_default().push(pattern.into());
580 self
581 }
582
583 #[must_use]
585 pub fn session_patterns(mut self, id: SessionId, patterns: Vec<Pattern>) -> Self {
586 self.patterns.entry(id).or_default().extend(patterns);
587 self
588 }
589
590 #[must_use]
592 pub fn default_pattern(mut self, pattern: impl Into<Pattern>) -> Self {
593 self.default_patterns.push(pattern.into());
594 self
595 }
596
597 #[must_use]
599 pub const fn timeout(mut self, timeout: Duration) -> Self {
600 self.timeout = Some(timeout);
601 self
602 }
603
604 #[must_use]
606 pub fn patterns_for(&self, id: SessionId) -> &[Pattern] {
607 self.patterns
608 .get(&id)
609 .map_or(&self.default_patterns, Vec::as_slice)
610 }
611
612 #[allow(clippy::type_complexity)]
618 pub async fn select<T>(&self, manager: &MultiSessionManager<T>) -> Result<SelectResult>
619 where
620 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
621 {
622 if manager.is_empty() {
623 return Err(ExpectError::NoSessions);
624 }
625
626 let timeout = self.timeout.unwrap_or(manager.default_timeout);
627
628 let mut futures: FuturesUnordered<
630 Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>>,
631 > = FuturesUnordered::new();
632
633 for &id in &manager.session_ids() {
634 let patterns = self.patterns_for(id);
635 if patterns.is_empty() {
636 continue;
637 }
638
639 let arc = match manager.sessions.get(&id) {
640 Some(arc) => Arc::clone(arc),
641 None => continue,
642 };
643
644 let pattern_set = PatternSet::from_patterns(patterns.to_vec());
645
646 let future: Pin<Box<dyn Future<Output = (SessionId, Result<(Match, usize)>)> + Send>> =
647 Box::pin(async move {
648 let mut guard = arc.lock().await;
649 if !guard.active {
650 return (id, Err(ExpectError::SessionClosed));
651 }
652
653 match guard.session.expect_any(&pattern_set).await {
654 Ok(m) => (id, Ok((m, 0))),
655 Err(e) => (id, Err(e)),
656 }
657 });
658
659 futures.push(future);
660 }
661
662 let select_future = async {
664 while let Some((session_id, result)) = futures.next().await {
665 if let Ok((matched, pattern_index)) = result {
666 return Ok(SelectResult {
667 session_id,
668 matched,
669 pattern_index,
670 });
671 }
672 }
673 Err(ExpectError::Timeout {
674 duration: timeout,
675 pattern: "pattern selector".to_string(),
676 buffer: String::new(),
677 })
678 };
679
680 tokio::time::timeout(timeout, select_future)
681 .await
682 .map_err(|_| ExpectError::Timeout {
683 duration: timeout,
684 pattern: "pattern selector".to_string(),
685 buffer: String::new(),
686 })?
687 }
688}
689
690#[cfg(test)]
691mod tests {
692 use tokio::io::DuplexStream;
693
694 use super::*;
695
696 fn create_mock_transport() -> (DuplexStream, DuplexStream) {
698 tokio::io::duplex(1024)
699 }
700
701 #[tokio::test]
702 async fn manager_add_remove() {
703 let mut manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
704
705 let (client, _server) = create_mock_transport();
706 let session = crate::session::Session::new(client, SessionConfig::default());
707
708 let id = manager.add(session, "test");
709 assert_eq!(manager.len(), 1);
710 assert_eq!(manager.label(id).await, Some("test".to_string()));
711
712 let removed = manager.remove(id).await;
713 assert!(removed.is_some());
714 assert!(manager.is_empty());
715 }
716
717 #[tokio::test]
718 async fn manager_active_state() {
719 let mut manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
720
721 let (client, _server) = create_mock_transport();
722 let session = crate::session::Session::new(client, SessionConfig::default());
723
724 let id = manager.add(session, "test");
725 assert!(manager.is_active(id).await);
726
727 manager.set_active(id, false).await;
728 assert!(!manager.is_active(id).await);
729
730 let active = manager.active_ids().await;
731 assert!(active.is_empty());
732 }
733
734 #[tokio::test]
735 async fn pattern_selector_build() {
736 let selector = PatternSelector::new()
737 .session(0, "login:")
738 .session(0, "password:")
739 .session(1, "prompt>")
740 .default_pattern("$");
741
742 assert_eq!(selector.patterns_for(0).len(), 2);
743 assert_eq!(selector.patterns_for(1).len(), 1);
744 assert_eq!(selector.patterns_for(99).len(), 1); }
746
747 #[tokio::test]
748 async fn expect_any_no_sessions() {
749 let manager: MultiSessionManager<DuplexStream> = MultiSessionManager::new();
750 let result = manager.expect_any("test").await;
751 assert!(matches!(result, Err(ExpectError::NoSessions)));
752 }
753}