1use std::sync::Arc;
31use std::time::Duration;
32
33use tokio::io::{AsyncReadExt, AsyncWriteExt};
34use tokio::sync::Mutex;
35
36use super::hooks::{HookManager, InteractionEvent};
37use super::mode::InteractionMode;
38use super::terminal::TerminalSize;
39use crate::error::{ExpectError, Result};
40use crate::expect::Pattern;
41
42#[derive(Debug, Clone)]
44pub enum InteractAction {
45 Continue,
47 Send(Vec<u8>),
49 Stop,
51 Error(String),
53}
54
55impl InteractAction {
56 pub fn send(s: impl Into<String>) -> Self {
58 Self::Send(s.into().into_bytes())
59 }
60
61 pub fn send_bytes(data: impl Into<Vec<u8>>) -> Self {
63 Self::Send(data.into())
64 }
65}
66
67pub struct InteractContext<'a> {
69 pub matched: &'a str,
71 pub before: &'a str,
73 pub after: &'a str,
75 pub buffer: &'a str,
77 pub pattern_index: usize,
79}
80
81impl InteractContext<'_> {
82 pub fn send(&self, data: impl Into<String>) -> InteractAction {
84 InteractAction::send(data)
85 }
86
87 pub fn send_line(&self, data: impl Into<String>) -> InteractAction {
89 let mut s = data.into();
90 s.push('\n');
91 InteractAction::send(s)
92 }
93}
94
95pub type PatternHook = Box<dyn Fn(&InteractContext<'_>) -> InteractAction + Send + Sync>;
97
98#[derive(Debug, Clone, Copy)]
100pub struct ResizeContext {
101 pub size: TerminalSize,
103 pub previous: Option<TerminalSize>,
105}
106
107pub type ResizeHook = Box<dyn Fn(&ResizeContext) -> InteractAction + Send + Sync>;
109
110struct OutputPatternHook {
112 pattern: Pattern,
113 callback: PatternHook,
114}
115
116struct InputPatternHook {
118 pattern: Pattern,
119 callback: PatternHook,
120}
121
122pub struct InteractBuilder<'a, T>
124where
125 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
126{
127 transport: &'a Arc<Mutex<T>>,
129 output_hooks: Vec<OutputPatternHook>,
131 input_hooks: Vec<InputPatternHook>,
133 resize_hook: Option<ResizeHook>,
135 hook_manager: HookManager,
137 mode: InteractionMode,
139 buffer_size: usize,
141 escape_sequence: Option<Vec<u8>>,
143 timeout: Option<Duration>,
145}
146
147impl<'a, T> InteractBuilder<'a, T>
148where
149 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
150{
151 pub(crate) fn new(transport: &'a Arc<Mutex<T>>) -> Self {
153 Self {
154 transport,
155 output_hooks: Vec::new(),
156 input_hooks: Vec::new(),
157 resize_hook: None,
158 hook_manager: HookManager::new(),
159 mode: InteractionMode::default(),
160 buffer_size: 8192,
161 escape_sequence: Some(vec![0x1d]), timeout: None,
163 }
164 }
165
166 #[must_use]
181 pub fn on_output<F>(mut self, pattern: impl Into<Pattern>, callback: F) -> Self
182 where
183 F: Fn(&InteractContext<'_>) -> InteractAction + Send + Sync + 'static,
184 {
185 self.output_hooks.push(OutputPatternHook {
186 pattern: pattern.into(),
187 callback: Box::new(callback),
188 });
189 self
190 }
191
192 #[must_use]
196 pub fn on_input<F>(mut self, pattern: impl Into<Pattern>, callback: F) -> Self
197 where
198 F: Fn(&InteractContext<'_>) -> InteractAction + Send + Sync + 'static,
199 {
200 self.input_hooks.push(InputPatternHook {
201 pattern: pattern.into(),
202 callback: Box::new(callback),
203 });
204 self
205 }
206
207 #[must_use]
230 pub fn on_resize<F>(mut self, callback: F) -> Self
231 where
232 F: Fn(&ResizeContext) -> InteractAction + Send + Sync + 'static,
233 {
234 self.resize_hook = Some(Box::new(callback));
235 self
236 }
237
238 #[must_use]
240 pub const fn with_mode(mut self, mode: InteractionMode) -> Self {
241 self.mode = mode;
242 self
243 }
244
245 #[must_use]
249 pub fn with_escape(mut self, escape: impl Into<Vec<u8>>) -> Self {
250 self.escape_sequence = Some(escape.into());
251 self
252 }
253
254 #[must_use]
256 pub fn no_escape(mut self) -> Self {
257 self.escape_sequence = None;
258 self
259 }
260
261 #[must_use]
263 pub const fn with_timeout(mut self, timeout: Duration) -> Self {
264 self.timeout = Some(timeout);
265 self
266 }
267
268 #[must_use]
270 pub const fn with_buffer_size(mut self, size: usize) -> Self {
271 self.buffer_size = size;
272 self
273 }
274
275 #[must_use]
277 pub fn with_input_hook<F>(mut self, hook: F) -> Self
278 where
279 F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
280 {
281 self.hook_manager.add_input_hook(hook);
282 self
283 }
284
285 #[must_use]
287 pub fn with_output_hook<F>(mut self, hook: F) -> Self
288 where
289 F: Fn(&[u8]) -> Vec<u8> + Send + Sync + 'static,
290 {
291 self.hook_manager.add_output_hook(hook);
292 self
293 }
294
295 pub async fn start(self) -> Result<InteractResult> {
310 let mut runner = InteractRunner::new(
311 Arc::clone(self.transport),
312 self.output_hooks,
313 self.input_hooks,
314 self.resize_hook,
315 self.hook_manager,
316 self.mode,
317 self.buffer_size,
318 self.escape_sequence,
319 self.timeout,
320 );
321 runner.run().await
322 }
323}
324
325#[derive(Debug, Clone)]
327pub struct InteractResult {
328 pub reason: InteractEndReason,
330 pub buffer: String,
332}
333
334#[derive(Debug, Clone)]
336pub enum InteractEndReason {
337 PatternStop {
339 pattern_index: usize,
341 },
342 Escape,
344 Timeout,
346 Eof,
348 Error(String),
350}
351
352struct InteractRunner<T>
354where
355 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
356{
357 transport: Arc<Mutex<T>>,
358 output_hooks: Vec<OutputPatternHook>,
359 input_hooks: Vec<InputPatternHook>,
360 #[cfg_attr(windows, allow(dead_code))]
363 resize_hook: Option<ResizeHook>,
364 hook_manager: HookManager,
365 mode: InteractionMode,
366 buffer: String,
367 buffer_size: usize,
368 escape_sequence: Option<Vec<u8>>,
369 timeout: Option<Duration>,
370 #[cfg_attr(windows, allow(dead_code))]
373 current_size: Option<TerminalSize>,
374}
375
376impl<T> InteractRunner<T>
377where
378 T: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
379{
380 #[allow(clippy::too_many_arguments)]
381 fn new(
382 transport: Arc<Mutex<T>>,
383 output_hooks: Vec<OutputPatternHook>,
384 input_hooks: Vec<InputPatternHook>,
385 resize_hook: Option<ResizeHook>,
386 hook_manager: HookManager,
387 mode: InteractionMode,
388 buffer_size: usize,
389 escape_sequence: Option<Vec<u8>>,
390 timeout: Option<Duration>,
391 ) -> Self {
392 let current_size = super::terminal::Terminal::size().ok();
394
395 Self {
396 transport,
397 output_hooks,
398 input_hooks,
399 resize_hook,
400 hook_manager,
401 mode,
402 buffer: String::with_capacity(buffer_size),
403 buffer_size,
404 escape_sequence,
405 timeout,
406 current_size,
407 }
408 }
409
410 async fn run(&mut self) -> Result<InteractResult> {
411 #[cfg(unix)]
412 {
413 self.run_with_signals().await
414 }
415 #[cfg(not(unix))]
416 {
417 self.run_without_signals().await
418 }
419 }
420
421 #[cfg(unix)]
423 #[allow(clippy::significant_drop_tightening)]
424 async fn run_with_signals(&mut self) -> Result<InteractResult> {
425 use tokio::io::{BufReader, stdin, stdout};
426
427 self.hook_manager.notify(&InteractionEvent::Started);
428
429 let mut stdin = BufReader::new(stdin());
430 let mut input_buf = [0u8; 1024];
431 let mut output_buf = [0u8; 4096];
432 let mut escape_buf: Vec<u8> = Vec::new();
433
434 let deadline = self.timeout.map(|t| std::time::Instant::now() + t);
435
436 let mut sigwinch =
438 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::window_change())
439 .map_err(ExpectError::Io)?;
440
441 loop {
442 if let Some(deadline) = deadline
444 && std::time::Instant::now() >= deadline
445 {
446 self.hook_manager.notify(&InteractionEvent::Ended);
447 return Ok(InteractResult {
448 reason: InteractEndReason::Timeout,
449 buffer: self.buffer.clone(),
450 });
451 }
452
453 let read_timeout = self.mode.read_timeout;
454 let mut transport = self.transport.lock().await;
455
456 tokio::select! {
457 _ = sigwinch.recv() => {
459 drop(transport); if let Some(result) = self.handle_resize().await? {
462 return Ok(result);
463 }
464 }
465
466 result = transport.read(&mut output_buf) => {
468 drop(transport); match result {
470 Ok(0) => {
471 self.hook_manager.notify(&InteractionEvent::Ended);
472 return Ok(InteractResult {
473 reason: InteractEndReason::Eof,
474 buffer: self.buffer.clone(),
475 });
476 }
477 Ok(n) => {
478 let data = &output_buf[..n];
479 let processed = self.hook_manager.process_output(data.to_vec());
480
481 self.hook_manager.notify(&InteractionEvent::Output(processed.clone()));
482
483 let mut stdout = stdout();
485 let _ = stdout.write_all(&processed).await;
486 let _ = stdout.flush().await;
487
488 if let Ok(s) = std::str::from_utf8(&processed) {
490 self.buffer.push_str(s);
491 if self.buffer.len() > self.buffer_size {
493 let start = self.buffer.len() - self.buffer_size;
494 self.buffer = self.buffer[start..].to_string();
495 }
496 }
497
498 if let Some(result) = self.check_output_patterns().await? {
500 return Ok(result);
501 }
502 }
503 Err(e) => {
504 self.hook_manager.notify(&InteractionEvent::Ended);
505 return Err(ExpectError::Io(e));
506 }
507 }
508 }
509
510 result = tokio::time::timeout(read_timeout, stdin.read(&mut input_buf)) => {
512 drop(transport); if let Ok(Ok(n)) = result {
515 if n == 0 {
516 continue;
517 }
518
519 let data = &input_buf[..n];
520
521 if let Some(ref esc) = self.escape_sequence {
523 escape_buf.extend_from_slice(data);
524 if escape_buf.ends_with(esc) {
525 self.hook_manager.notify(&InteractionEvent::ExitRequested);
526 self.hook_manager.notify(&InteractionEvent::Ended);
527 return Ok(InteractResult {
528 reason: InteractEndReason::Escape,
529 buffer: self.buffer.clone(),
530 });
531 }
532 if escape_buf.len() > esc.len() {
534 escape_buf = escape_buf[escape_buf.len() - esc.len()..].to_vec();
535 }
536 }
537
538 let processed = self.hook_manager.process_input(data.to_vec());
540
541 self.hook_manager.notify(&InteractionEvent::Input(processed.clone()));
542
543 if let Some(result) = self.check_input_patterns(&processed).await? {
545 return Ok(result);
546 }
547
548 let mut transport = self.transport.lock().await;
550 transport.write_all(&processed).await.map_err(ExpectError::Io)?;
551 transport.flush().await.map_err(ExpectError::Io)?;
552 }
553 }
554 }
555 }
556 }
557
558 #[cfg(not(unix))]
560 #[allow(clippy::significant_drop_tightening)]
561 async fn run_without_signals(&mut self) -> Result<InteractResult> {
562 use tokio::io::{BufReader, stdin, stdout};
563
564 self.hook_manager.notify(&InteractionEvent::Started);
565
566 let mut stdin = BufReader::new(stdin());
567 let mut input_buf = [0u8; 1024];
568 let mut output_buf = [0u8; 4096];
569 let mut escape_buf: Vec<u8> = Vec::new();
570
571 let deadline = self.timeout.map(|t| std::time::Instant::now() + t);
572
573 loop {
574 if let Some(deadline) = deadline {
576 if std::time::Instant::now() >= deadline {
577 self.hook_manager.notify(&InteractionEvent::Ended);
578 return Ok(InteractResult {
579 reason: InteractEndReason::Timeout,
580 buffer: self.buffer.clone(),
581 });
582 }
583 }
584
585 let read_timeout = self.mode.read_timeout;
586 let mut transport = self.transport.lock().await;
587
588 tokio::select! {
589 result = transport.read(&mut output_buf) => {
591 drop(transport); match result {
593 Ok(0) => {
594 self.hook_manager.notify(&InteractionEvent::Ended);
595 return Ok(InteractResult {
596 reason: InteractEndReason::Eof,
597 buffer: self.buffer.clone(),
598 });
599 }
600 Ok(n) => {
601 let data = &output_buf[..n];
602 let processed = self.hook_manager.process_output(data.to_vec());
603
604 self.hook_manager.notify(&InteractionEvent::Output(processed.clone()));
605
606 let mut stdout = stdout();
608 let _ = stdout.write_all(&processed).await;
609 let _ = stdout.flush().await;
610
611 if let Ok(s) = std::str::from_utf8(&processed) {
613 self.buffer.push_str(s);
614 if self.buffer.len() > self.buffer_size {
616 let start = self.buffer.len() - self.buffer_size;
617 self.buffer = self.buffer[start..].to_string();
618 }
619 }
620
621 if let Some(result) = self.check_output_patterns().await? {
623 return Ok(result);
624 }
625 }
626 Err(e) => {
627 self.hook_manager.notify(&InteractionEvent::Ended);
628 return Err(ExpectError::Io(e));
629 }
630 }
631 }
632
633 result = tokio::time::timeout(read_timeout, stdin.read(&mut input_buf)) => {
635 drop(transport); if let Ok(Ok(n)) = result {
638 if n == 0 {
639 continue;
640 }
641
642 let data = &input_buf[..n];
643
644 if let Some(ref esc) = self.escape_sequence {
646 escape_buf.extend_from_slice(data);
647 if escape_buf.ends_with(esc) {
648 self.hook_manager.notify(&InteractionEvent::ExitRequested);
649 self.hook_manager.notify(&InteractionEvent::Ended);
650 return Ok(InteractResult {
651 reason: InteractEndReason::Escape,
652 buffer: self.buffer.clone(),
653 });
654 }
655 if escape_buf.len() > esc.len() {
657 escape_buf = escape_buf[escape_buf.len() - esc.len()..].to_vec();
658 }
659 }
660
661 let processed = self.hook_manager.process_input(data.to_vec());
663
664 self.hook_manager.notify(&InteractionEvent::Input(processed.clone()));
665
666 if let Some(result) = self.check_input_patterns(&processed).await? {
668 return Ok(result);
669 }
670
671 let mut transport = self.transport.lock().await;
673 transport.write_all(&processed).await.map_err(ExpectError::Io)?;
674 transport.flush().await.map_err(ExpectError::Io)?;
675 }
676 }
677 }
678 }
679 }
680
681 #[allow(clippy::significant_drop_tightening)]
682 async fn check_output_patterns(&mut self) -> Result<Option<InteractResult>> {
683 for (index, hook) in self.output_hooks.iter().enumerate() {
684 if let Some(m) = hook.pattern.matches(&self.buffer) {
685 let matched = &self.buffer[m.start..m.end];
686 let before = &self.buffer[..m.start];
687 let after = &self.buffer[m.end..];
688
689 let ctx = InteractContext {
690 matched,
691 before,
692 after,
693 buffer: &self.buffer,
694 pattern_index: index,
695 };
696
697 match (hook.callback)(&ctx) {
698 InteractAction::Continue => {
699 self.buffer = after.to_string();
701 }
702 InteractAction::Send(data) => {
703 let mut transport = self.transport.lock().await;
704 transport.write_all(&data).await.map_err(ExpectError::Io)?;
705 transport.flush().await.map_err(ExpectError::Io)?;
706 self.buffer = after.to_string();
708 }
709 InteractAction::Stop => {
710 self.hook_manager.notify(&InteractionEvent::Ended);
711 return Ok(Some(InteractResult {
712 reason: InteractEndReason::PatternStop {
713 pattern_index: index,
714 },
715 buffer: self.buffer.clone(),
716 }));
717 }
718 InteractAction::Error(msg) => {
719 self.hook_manager.notify(&InteractionEvent::Ended);
720 return Ok(Some(InteractResult {
721 reason: InteractEndReason::Error(msg),
722 buffer: self.buffer.clone(),
723 }));
724 }
725 }
726 }
727 }
728 Ok(None)
729 }
730
731 #[allow(clippy::significant_drop_tightening)]
732 async fn check_input_patterns(&self, input: &[u8]) -> Result<Option<InteractResult>> {
733 let input_str = String::from_utf8_lossy(input);
734
735 for (index, hook) in self.input_hooks.iter().enumerate() {
736 if let Some(m) = hook.pattern.matches(&input_str) {
737 let matched = &input_str[m.start..m.end];
738 let before = &input_str[..m.start];
739 let after = &input_str[m.end..];
740
741 let ctx = InteractContext {
742 matched,
743 before,
744 after,
745 buffer: &input_str,
746 pattern_index: index,
747 };
748
749 match (hook.callback)(&ctx) {
750 InteractAction::Continue => {}
751 InteractAction::Send(data) => {
752 let mut transport = self.transport.lock().await;
753 transport.write_all(&data).await.map_err(ExpectError::Io)?;
754 transport.flush().await.map_err(ExpectError::Io)?;
755 }
756 InteractAction::Stop => {
757 return Ok(Some(InteractResult {
758 reason: InteractEndReason::PatternStop {
759 pattern_index: index,
760 },
761 buffer: self.buffer.clone(),
762 }));
763 }
764 InteractAction::Error(msg) => {
765 return Ok(Some(InteractResult {
766 reason: InteractEndReason::Error(msg),
767 buffer: self.buffer.clone(),
768 }));
769 }
770 }
771 }
772 }
773 Ok(None)
774 }
775
776 #[cfg_attr(windows, allow(dead_code))]
781 #[allow(clippy::significant_drop_tightening)]
782 async fn handle_resize(&mut self) -> Result<Option<InteractResult>> {
783 let Ok(new_size) = super::terminal::Terminal::size() else {
785 return Ok(None); };
787
788 let ctx = ResizeContext {
790 size: new_size,
791 previous: self.current_size,
792 };
793
794 self.hook_manager.notify(&InteractionEvent::Resize {
796 cols: new_size.cols,
797 rows: new_size.rows,
798 });
799
800 self.current_size = Some(new_size);
802
803 if let Some(ref hook) = self.resize_hook {
805 match hook(&ctx) {
806 InteractAction::Continue => {}
807 InteractAction::Send(data) => {
808 let mut transport = self.transport.lock().await;
809 transport.write_all(&data).await.map_err(ExpectError::Io)?;
810 transport.flush().await.map_err(ExpectError::Io)?;
811 }
812 InteractAction::Stop => {
813 self.hook_manager.notify(&InteractionEvent::Ended);
814 return Ok(Some(InteractResult {
815 reason: InteractEndReason::PatternStop { pattern_index: 0 },
816 buffer: self.buffer.clone(),
817 }));
818 }
819 InteractAction::Error(msg) => {
820 self.hook_manager.notify(&InteractionEvent::Ended);
821 return Ok(Some(InteractResult {
822 reason: InteractEndReason::Error(msg),
823 buffer: self.buffer.clone(),
824 }));
825 }
826 }
827 }
828
829 Ok(None)
830 }
831}