1use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11
12use crate::config::SessionConfig;
13use crate::error::{ExpectError, Result, SpawnError};
14
15pub struct PtyTransport {
17 reader: Box<dyn AsyncRead + Unpin + Send>,
19 writer: Box<dyn AsyncWrite + Unpin + Send>,
21 pid: Option<u32>,
23}
24
25impl PtyTransport {
26 pub fn new<R, W>(reader: R, writer: W) -> Self
28 where
29 R: AsyncRead + Unpin + Send + 'static,
30 W: AsyncWrite + Unpin + Send + 'static,
31 {
32 Self {
33 reader: Box::new(reader),
34 writer: Box::new(writer),
35 pid: None,
36 }
37 }
38
39 pub const fn set_pid(&mut self, pid: u32) {
41 self.pid = Some(pid);
42 }
43
44 #[must_use]
46 pub const fn pid(&self) -> Option<u32> {
47 self.pid
48 }
49}
50
51impl AsyncRead for PtyTransport {
52 fn poll_read(
53 mut self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 buf: &mut ReadBuf<'_>,
56 ) -> Poll<io::Result<()>> {
57 Pin::new(&mut self.reader).poll_read(cx, buf)
58 }
59}
60
61impl AsyncWrite for PtyTransport {
62 fn poll_write(
63 mut self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &[u8],
66 ) -> Poll<io::Result<usize>> {
67 Pin::new(&mut self.writer).poll_write(cx, buf)
68 }
69
70 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
71 Pin::new(&mut self.writer).poll_flush(cx)
72 }
73
74 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
75 Pin::new(&mut self.writer).poll_shutdown(cx)
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct PtyConfig {
82 pub dimensions: (u16, u16),
84 pub login_shell: bool,
86 pub env_mode: EnvMode,
88}
89
90impl Default for PtyConfig {
91 fn default() -> Self {
92 Self {
93 dimensions: (80, 24),
94 login_shell: false,
95 env_mode: EnvMode::Inherit,
96 }
97 }
98}
99
100impl From<&SessionConfig> for PtyConfig {
101 fn from(config: &SessionConfig) -> Self {
102 Self {
103 dimensions: config.dimensions,
104 login_shell: false,
105 env_mode: if config.env.is_empty() {
106 EnvMode::Inherit
107 } else {
108 EnvMode::Extend
109 },
110 }
111 }
112}
113
114#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum EnvMode {
117 Inherit,
119 Clear,
121 Extend,
123}
124
125pub struct PtySpawner {
127 config: PtyConfig,
128}
129
130impl PtySpawner {
131 #[must_use]
133 pub fn new() -> Self {
134 Self {
135 config: PtyConfig::default(),
136 }
137 }
138
139 #[must_use]
141 pub const fn with_config(config: PtyConfig) -> Self {
142 Self { config }
143 }
144
145 pub const fn set_dimensions(&mut self, cols: u16, rows: u16) {
147 self.config.dimensions = (cols, rows);
148 }
149
150 #[cfg(unix)]
160 #[allow(unsafe_code)]
161 #[allow(clippy::unused_async)]
162 pub async fn spawn(&self, command: &str, args: &[String]) -> Result<PtyHandle> {
163 use std::ffi::CString;
164
165 let cmd_cstring = CString::new(command).map_err(|_| {
167 ExpectError::Spawn(SpawnError::InvalidArgument {
168 kind: "command".to_string(),
169 value: command.to_string(),
170 reason: "command contains null byte".to_string(),
171 })
172 })?;
173
174 let mut argv_cstrings: Vec<CString> = Vec::with_capacity(args.len() + 1);
175 argv_cstrings.push(cmd_cstring.clone());
176
177 for (idx, arg) in args.iter().enumerate() {
178 let arg_cstring = CString::new(arg.as_str()).map_err(|_| {
179 ExpectError::Spawn(SpawnError::InvalidArgument {
180 kind: format!("argument[{idx}]"),
181 value: arg.clone(),
182 reason: "argument contains null byte".to_string(),
183 })
184 })?;
185 argv_cstrings.push(arg_cstring);
186 }
187
188 let pty_result = unsafe {
193 let mut master: libc::c_int = 0;
194 let mut slave: libc::c_int = 0;
195
196 if libc::openpty(
198 &raw mut master,
199 &raw mut slave,
200 std::ptr::null_mut(),
201 std::ptr::null_mut(),
202 std::ptr::null_mut(),
203 ) != 0
204 {
205 return Err(ExpectError::Spawn(SpawnError::PtyAllocation {
206 reason: "Failed to open PTY".to_string(),
207 }));
208 }
209
210 (master, slave)
211 };
212
213 let (master_fd, slave_fd) = pty_result;
214
215 let pid = unsafe { libc::fork() };
220
221 match pid {
222 -1 => Err(ExpectError::Spawn(SpawnError::Io(
223 io::Error::last_os_error(),
224 ))),
225 0 => {
226 unsafe {
237 libc::close(master_fd);
238 libc::setsid();
239 libc::ioctl(slave_fd, libc::TIOCSCTTY as libc::c_ulong, 0);
241
242 libc::dup2(slave_fd, 0);
243 libc::dup2(slave_fd, 1);
244 libc::dup2(slave_fd, 2);
245
246 if slave_fd > 2 {
247 libc::close(slave_fd);
248 }
249
250 let argv_ptrs: Vec<*const libc::c_char> = argv_cstrings
252 .iter()
253 .map(|s| s.as_ptr())
254 .chain(std::iter::once(std::ptr::null()))
255 .collect();
256
257 libc::execvp(cmd_cstring.as_ptr(), argv_ptrs.as_ptr());
258 libc::_exit(1);
259 }
260 }
261 child_pid => {
262 unsafe {
266 libc::close(slave_fd);
267 }
268
269 unsafe {
274 let flags = libc::fcntl(master_fd, libc::F_GETFL);
275 libc::fcntl(master_fd, libc::F_SETFL, flags | libc::O_NONBLOCK);
276 }
277
278 Ok(PtyHandle {
279 master_fd,
280 pid: child_pid as u32,
281 dimensions: self.config.dimensions,
282 })
283 }
284 }
285 }
286
287 #[cfg(windows)]
296 pub async fn spawn(&self, command: &str, args: &[String]) -> Result<WindowsPtyHandle> {
297 use rust_pty::{PtySystem, WindowsPtySystem};
298
299 let pty_config = rust_pty::PtyConfig {
301 window_size: self.config.dimensions,
302 env: match self.config.env_mode {
304 EnvMode::Clear => Some(std::collections::HashMap::new()),
305 _ => None,
306 },
307 ..Default::default()
308 };
309
310 let (master, child) =
312 WindowsPtySystem::spawn(command, args.iter().map(|s| s.as_str()), &pty_config)
313 .await
314 .map_err(|e| {
315 ExpectError::Spawn(SpawnError::PtyAllocation {
316 reason: format!("Windows ConPTY spawn failed: {e}"),
317 })
318 })?;
319
320 Ok(WindowsPtyHandle {
321 master,
322 child,
323 dimensions: self.config.dimensions,
324 })
325 }
326}
327
328impl Default for PtySpawner {
329 fn default() -> Self {
330 Self::new()
331 }
332}
333
334#[cfg(unix)]
336#[derive(Debug)]
337pub struct PtyHandle {
338 master_fd: i32,
340 pid: u32,
342 dimensions: (u16, u16),
344}
345
346#[cfg(windows)]
348pub struct WindowsPtyHandle {
349 pub(crate) master: rust_pty::WindowsPtyMaster,
351 pub(crate) child: rust_pty::WindowsPtyChild,
353 dimensions: (u16, u16),
355}
356
357#[cfg(windows)]
358impl std::fmt::Debug for WindowsPtyHandle {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.debug_struct("WindowsPtyHandle")
361 .field("dimensions", &self.dimensions)
362 .finish_non_exhaustive()
363 }
364}
365
366#[cfg(unix)]
367impl PtyHandle {
368 #[must_use]
370 pub const fn pid(&self) -> u32 {
371 self.pid
372 }
373
374 #[must_use]
376 pub const fn dimensions(&self) -> (u16, u16) {
377 self.dimensions
378 }
379
380 #[allow(unsafe_code)]
382 pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
383 let winsize = libc::winsize {
384 ws_row: rows,
385 ws_col: cols,
386 ws_xpixel: 0,
387 ws_ypixel: 0,
388 };
389
390 let result =
395 unsafe { libc::ioctl(self.master_fd, libc::TIOCSWINSZ as libc::c_ulong, &winsize) };
396
397 if result != 0 {
398 Err(ExpectError::Io(io::Error::last_os_error()))
399 } else {
400 self.dimensions = (cols, rows);
401 Ok(())
402 }
403 }
404
405 #[allow(unsafe_code)]
407 pub fn wait(&self) -> Result<i32> {
408 let mut status: libc::c_int = 0;
409 let result = unsafe { libc::waitpid(self.pid as i32, &raw mut status, 0) };
413
414 if result == -1 {
415 Err(ExpectError::Io(io::Error::last_os_error()))
416 } else if libc::WIFEXITED(status) {
417 Ok(libc::WEXITSTATUS(status))
418 } else if libc::WIFSIGNALED(status) {
419 Ok(128 + libc::WTERMSIG(status))
420 } else {
421 Ok(-1)
422 }
423 }
424
425 #[allow(unsafe_code)]
427 pub fn signal(&self, signal: i32) -> Result<()> {
428 let result = unsafe { libc::kill(self.pid as i32, signal) };
432 if result != 0 {
433 Err(ExpectError::Io(io::Error::last_os_error()))
434 } else {
435 Ok(())
436 }
437 }
438
439 pub fn kill(&self) -> Result<()> {
441 self.signal(libc::SIGKILL)
442 }
443}
444
445#[cfg(windows)]
446impl WindowsPtyHandle {
447 #[must_use]
449 pub fn pid(&self) -> u32 {
450 self.child.pid()
451 }
452
453 #[must_use]
455 pub const fn dimensions(&self) -> (u16, u16) {
456 self.dimensions
457 }
458
459 pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
461 use rust_pty::{PtyMaster, WindowSize};
462 let size = WindowSize::new(cols, rows);
463 self.master
464 .resize(size)
465 .map_err(|e| ExpectError::Io(io::Error::other(format!("resize failed: {e}"))))?;
466 self.dimensions = (cols, rows);
467 Ok(())
468 }
469
470 #[must_use]
472 pub fn is_running(&self) -> bool {
473 self.child.is_running()
474 }
475
476 pub fn kill(&mut self) -> Result<()> {
478 self.child
479 .kill()
480 .map_err(|e| ExpectError::Io(io::Error::other(format!("kill failed: {e}"))))
481 }
482}
483
484#[cfg(unix)]
485impl Drop for PtyHandle {
486 #[allow(unsafe_code)]
487 fn drop(&mut self) {
488 unsafe {
493 libc::close(self.master_fd);
494 }
495 }
496}
497
498#[cfg(unix)]
503pub struct AsyncPty {
504 inner: tokio::io::unix::AsyncFd<std::os::unix::io::RawFd>,
506 pid: u32,
508 dimensions: (u16, u16),
510}
511
512#[cfg(unix)]
513impl AsyncPty {
514 pub fn from_handle(handle: PtyHandle) -> io::Result<Self> {
522 let fd = handle.master_fd;
523 let pid = handle.pid;
524 let dimensions = handle.dimensions;
525
526 std::mem::forget(handle);
528
529 let inner = tokio::io::unix::AsyncFd::new(fd)?;
530 Ok(Self {
531 inner,
532 pid,
533 dimensions,
534 })
535 }
536
537 #[must_use]
539 pub const fn pid(&self) -> u32 {
540 self.pid
541 }
542
543 #[must_use]
545 pub const fn dimensions(&self) -> (u16, u16) {
546 self.dimensions
547 }
548
549 #[allow(unsafe_code)]
551 pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
552 let winsize = libc::winsize {
553 ws_row: rows,
554 ws_col: cols,
555 ws_xpixel: 0,
556 ws_ypixel: 0,
557 };
558
559 let result = unsafe {
562 libc::ioctl(
563 *self.inner.get_ref(),
564 libc::TIOCSWINSZ as libc::c_ulong,
565 &winsize,
566 )
567 };
568
569 if result != 0 {
570 Err(ExpectError::Io(io::Error::last_os_error()))
571 } else {
572 self.dimensions = (cols, rows);
573 Ok(())
574 }
575 }
576
577 #[allow(unsafe_code)]
579 pub fn signal(&self, signal: i32) -> Result<()> {
580 let result = unsafe { libc::kill(self.pid as i32, signal) };
582 if result != 0 {
583 Err(ExpectError::Io(io::Error::last_os_error()))
584 } else {
585 Ok(())
586 }
587 }
588
589 pub fn kill(&self) -> Result<()> {
591 self.signal(libc::SIGKILL)
592 }
593}
594
595#[cfg(unix)]
596impl AsyncRead for AsyncPty {
597 #[allow(unsafe_code)]
598 fn poll_read(
599 self: Pin<&mut Self>,
600 cx: &mut Context<'_>,
601 buf: &mut ReadBuf<'_>,
602 ) -> Poll<io::Result<()>> {
603 loop {
604 let mut guard = match self.inner.poll_read_ready(cx) {
605 Poll::Ready(Ok(guard)) => guard,
606 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
607 Poll::Pending => return Poll::Pending,
608 };
609
610 let fd = *self.inner.get_ref();
611 let unfilled = buf.initialize_unfilled();
612
613 let result = unsafe {
615 libc::read(
616 fd,
617 unfilled.as_mut_ptr().cast::<libc::c_void>(),
618 unfilled.len(),
619 )
620 };
621
622 if result >= 0 {
623 buf.advance(result as usize);
624 return Poll::Ready(Ok(()));
625 }
626
627 let err = io::Error::last_os_error();
628 if err.kind() == io::ErrorKind::WouldBlock {
629 guard.clear_ready();
630 continue;
631 }
632 return Poll::Ready(Err(err));
633 }
634 }
635}
636
637#[cfg(unix)]
638impl AsyncWrite for AsyncPty {
639 #[allow(unsafe_code)]
640 fn poll_write(
641 self: Pin<&mut Self>,
642 cx: &mut Context<'_>,
643 buf: &[u8],
644 ) -> Poll<io::Result<usize>> {
645 loop {
646 let mut guard = match self.inner.poll_write_ready(cx) {
647 Poll::Ready(Ok(guard)) => guard,
648 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
649 Poll::Pending => return Poll::Pending,
650 };
651
652 let fd = *self.inner.get_ref();
653
654 let result = unsafe { libc::write(fd, buf.as_ptr().cast::<libc::c_void>(), buf.len()) };
656
657 if result >= 0 {
658 return Poll::Ready(Ok(result as usize));
659 }
660
661 let err = io::Error::last_os_error();
662 if err.kind() == io::ErrorKind::WouldBlock {
663 guard.clear_ready();
664 continue;
665 }
666 return Poll::Ready(Err(err));
667 }
668 }
669
670 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
671 Poll::Ready(Ok(()))
673 }
674
675 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
676 Poll::Ready(Ok(()))
678 }
679}
680
681#[cfg(unix)]
682impl Drop for AsyncPty {
683 #[allow(unsafe_code)]
684 fn drop(&mut self) {
685 unsafe {
687 libc::close(*self.inner.get_ref());
688 }
689 }
690}
691
692#[cfg(unix)]
693impl std::fmt::Debug for AsyncPty {
694 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
695 f.debug_struct("AsyncPty")
696 .field("fd", self.inner.get_ref())
697 .field("pid", &self.pid)
698 .field("dimensions", &self.dimensions)
699 .finish()
700 }
701}
702
703#[cfg(windows)]
708pub struct WindowsAsyncPty {
709 master: rust_pty::WindowsPtyMaster,
711 child: rust_pty::WindowsPtyChild,
713 pid: u32,
715 dimensions: (u16, u16),
717}
718
719#[cfg(windows)]
720impl WindowsAsyncPty {
721 pub fn from_handle(handle: WindowsPtyHandle) -> Self {
725 let pid = handle.child.pid();
726 let dimensions = handle.dimensions;
727 Self {
728 master: handle.master,
729 child: handle.child,
730 pid,
731 dimensions,
732 }
733 }
734
735 #[must_use]
737 pub const fn pid(&self) -> u32 {
738 self.pid
739 }
740
741 #[must_use]
743 pub const fn dimensions(&self) -> (u16, u16) {
744 self.dimensions
745 }
746
747 pub fn resize(&mut self, cols: u16, rows: u16) -> Result<()> {
749 use rust_pty::{PtyMaster, WindowSize};
750 let size = WindowSize::new(cols, rows);
751 self.master
752 .resize(size)
753 .map_err(|e| ExpectError::Io(io::Error::other(format!("resize failed: {e}"))))?;
754 self.dimensions = (cols, rows);
755 Ok(())
756 }
757
758 #[must_use]
760 pub fn is_running(&self) -> bool {
761 self.child.is_running()
762 }
763
764 pub fn kill(&mut self) -> Result<()> {
766 self.child
767 .kill()
768 .map_err(|e| ExpectError::Io(io::Error::other(format!("kill failed: {e}"))))
769 }
770}
771
772#[cfg(windows)]
773impl AsyncRead for WindowsAsyncPty {
774 fn poll_read(
775 mut self: Pin<&mut Self>,
776 cx: &mut Context<'_>,
777 buf: &mut ReadBuf<'_>,
778 ) -> Poll<io::Result<()>> {
779 Pin::new(&mut self.master).poll_read(cx, buf)
781 }
782}
783
784#[cfg(windows)]
785impl AsyncWrite for WindowsAsyncPty {
786 fn poll_write(
787 mut self: Pin<&mut Self>,
788 cx: &mut Context<'_>,
789 buf: &[u8],
790 ) -> Poll<io::Result<usize>> {
791 Pin::new(&mut self.master).poll_write(cx, buf)
792 }
793
794 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
795 Pin::new(&mut self.master).poll_flush(cx)
796 }
797
798 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
799 Pin::new(&mut self.master).poll_shutdown(cx)
800 }
801}
802
803#[cfg(windows)]
804impl std::fmt::Debug for WindowsAsyncPty {
805 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
806 f.debug_struct("WindowsAsyncPty")
807 .field("pid", &self.pid)
808 .field("dimensions", &self.dimensions)
809 .finish_non_exhaustive()
810 }
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816
817 #[test]
818 fn pty_config_default() {
819 let config = PtyConfig::default();
820 assert_eq!(config.dimensions.0, 80);
821 assert_eq!(config.dimensions.1, 24);
822 assert_eq!(config.env_mode, EnvMode::Inherit);
823 }
824
825 #[test]
826 fn pty_config_from_session() {
827 let session_config = SessionConfig {
828 dimensions: (120, 40),
829 ..Default::default()
830 };
831
832 let pty_config = PtyConfig::from(&session_config);
833 assert_eq!(pty_config.dimensions.0, 120);
834 assert_eq!(pty_config.dimensions.1, 40);
835 }
836
837 #[cfg(unix)]
838 #[tokio::test]
839 async fn spawn_rejects_null_byte_in_command() {
840 let spawner = PtySpawner::new();
841 let result = spawner.spawn("test\0command", &[]).await;
842
843 assert!(result.is_err());
844 let err = result.unwrap_err();
845 let err_str = err.to_string();
846 assert!(
847 err_str.contains("null byte"),
848 "Expected error about null byte, got: {err_str}"
849 );
850 }
851
852 #[cfg(unix)]
853 #[tokio::test]
854 async fn spawn_rejects_null_byte_in_args() {
855 let spawner = PtySpawner::new();
856 let result = spawner
857 .spawn("/bin/echo", &["hello\0world".to_string()])
858 .await;
859
860 assert!(result.is_err());
861 let err = result.unwrap_err();
862 let err_str = err.to_string();
863 assert!(
864 err_str.contains("null byte"),
865 "Expected error about null byte, got: {err_str}"
866 );
867 }
868}