tiny_std/unix/misc/
openpty.rs

1use rusl::ioctl::ioctl;
2use rusl::platform::{Fd, OpenFlags, SetAction, TermioFlags, Termios, WindowSize};
3use rusl::string::unix_str::UnixStr;
4use rusl::termios::tcsetattr;
5use rusl::unistd::{open, open_raw};
6
7#[derive(Debug, Copy, Clone)]
8pub struct TerminalHandle {
9    pub master: Fd,
10    pub slave: Fd,
11}
12
13/// Attempts to open a pty returning the handles to it
14/// # Errors
15/// Not many errors can occur assuming that (`None`, `None`, `None`) is passed and you have
16/// appropriate permissions.
17/// See the [linux docs for the exceptions](https://man7.org/linux/man-pages/man2/ioctl_tty.2.html)
18pub fn openpty(
19    name: Option<&UnixStr>,
20    termios: Option<&Termios>,
21    winsize: Option<&WindowSize>,
22) -> crate::error::Result<TerminalHandle> {
23    const PTMX: &UnixStr = UnixStr::from_str_checked("/dev/ptmx\0");
24    let use_flags: OpenFlags = OpenFlags::O_RDWR | OpenFlags::O_NOCTTY;
25    unsafe {
26        let master = open(PTMX, use_flags)?;
27        let mut pty_num = 0;
28        let pty_num_addr = core::ptr::addr_of_mut!(pty_num);
29        // Todo: Maybe check if not zero and bail like musl does
30        ioctl(
31            master,
32            TermioFlags::TIOCSPTLCK.bits(),
33            pty_num_addr as usize,
34        )?;
35        ioctl(master, TermioFlags::TIOCGPTN.bits(), pty_num_addr as usize)?;
36        let slave = if let Some(name) = name {
37            open(name, use_flags)?
38        } else {
39            let bytename: u8 = pty_num.try_into().map_err(|_| {
40                crate::error::Error::no_code("Terminal number exceeded u8::MAX or was negative")
41            })?;
42            // To do this without an allocator have to format this string manually
43            // on the stack.
44            let name = create_pty_name(bytename);
45            open_raw(core::ptr::addr_of!(name) as usize, use_flags)?
46        };
47        if let Some(tio) = termios {
48            tcsetattr(slave, SetAction::NOW, tio)?;
49        }
50        if let Some(winsize) = winsize {
51            ioctl(
52                slave,
53                TermioFlags::TIOCSWINSZ.bits(),
54                core::ptr::addr_of!(winsize) as usize,
55            )?;
56        }
57        Ok(TerminalHandle { master, slave })
58    }
59}
60
61#[derive(Debug, Copy, Clone)]
62enum ByteChars {
63    One(u8),
64    Two([u8; 2]),
65    Three([u8; 3]),
66}
67
68#[inline]
69fn create_pty_name(pty_num: u8) -> [u8; 13] {
70    let mut name = *b"/dev/pts/0\0\0\0";
71    match get_chars(pty_num) {
72        ByteChars::One(byte) => {
73            name[9] = byte;
74            name
75        }
76        ByteChars::Two([b1, b2]) => {
77            name[9] = b1;
78            name[10] = b2;
79            name
80        }
81        ByteChars::Three([b1, b2, b3]) => {
82            name[9] = b1;
83            name[10] = b2;
84            name[11] = b3;
85            name
86        }
87    }
88}
89
90fn get_chars(num: u8) -> ByteChars {
91    if num < 10 {
92        ByteChars::One(num + 48)
93    } else if num < 100 {
94        let rem = num % 10;
95        let base = num / 10;
96        ByteChars::Two([base + 48, rem + 48])
97    } else {
98        let base = num / 100;
99        let next_base = num - base * 100;
100        let nb = next_base / 10;
101        let rem = next_base % 10;
102        ByteChars::Three([base + 48, nb + 48, rem + 48])
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn rewrite_name() {
112        for i in 0..u8::MAX {
113            check_name(i);
114        }
115    }
116
117    fn check_name(num: u8) {
118        let n = create_pty_name(num);
119        if num < 10 {
120            assert_eq!(std::format!("/dev/pts/{num}\0\0\0").as_bytes(), &n);
121        } else if num < 100 {
122            assert_eq!(std::format!("/dev/pts/{num}\0\0").as_bytes(), &n);
123        } else {
124            assert_eq!(std::format!("/dev/pts/{num}\0").as_bytes(), &n);
125        }
126    }
127
128    #[test]
129    fn rewrite_single() {
130        let bc = get_chars(8);
131        if let ByteChars::One(b) = bc {
132            assert_eq!('8', b as char);
133        } else {
134            panic!("Bad match");
135        }
136    }
137
138    #[test]
139    fn rewrite_double() {
140        let bc = get_chars(59);
141        if let ByteChars::Two([c1, c2]) = bc {
142            assert_eq!('5', c1 as char);
143            assert_eq!('9', c2 as char);
144        } else {
145            panic!("Bad match");
146        }
147    }
148
149    #[test]
150    fn rewrite_triple() {
151        let bc = get_chars(231);
152        if let ByteChars::Three([c1, c2, c3]) = bc {
153            assert_eq!('2', c1 as char);
154            assert_eq!('3', c2 as char);
155            assert_eq!('1', c3 as char);
156        } else {
157            panic!("Bad match");
158        }
159    }
160}