shell_server/
server.rs

1use std::{
2    io::Read,
3    os::{
4        fd::AsRawFd,
5        unix::net::{UnixListener, UnixStream},
6    },
7    thread::{spawn, JoinHandle},
8};
9
10use libc::{c_int, close, dup, dup2, STDOUT_FILENO};
11use shell_core::{read_line, write_line};
12
13use crate::shell::Shell;
14
15/// 一个服务器,侦听传入的 Unix 域套接字 (UDS) 连接并处理命令。
16pub struct Server {
17    /// 要在服务器上执行的 shell 实例。
18    shell: Shell,
19
20    /// Unix 域套接字 (UDS) 路径,用于侦听命令。
21    uds_cmd_path: String,
22
23    /// Unix 域套接字 (UDS) 路径,用于侦听输出。
24    uds_output_path: String,
25}
26
27/// 实现 Drop trait,以便在 Server 实例被丢弃时删除 Unix 域套接字 (UDS) 文件。
28impl Drop for Server {
29    /// 当 Server 实例被丢弃时,此函数将被调用。
30    /// 它将删除 `uds_cmd_path` 和 `uds_output_path` 所指向的 Unix 域套接字 (UDS) 文件。
31    fn drop(&mut self) {
32        let _ = std::fs::remove_file(&self.uds_cmd_path);
33        let _ = std::fs::remove_file(&self.uds_output_path);
34    }
35}
36
37impl Server {
38    /// 创建一个新的 Server 实例。
39    ///
40    /// # Arguments
41    ///
42    /// * `shell_` - 要在服务器上执行的 shell 实例。
43    /// * `uds_cmd_path_` - Unix 域套接字 (UDS) 路径,用于侦听命令。
44    /// * `uds_output_path_` - Unix 域套接字 (UDS) 路径,用于侦听输出。
45    ///
46    /// # Returns
47    ///
48    /// 一个新的 Server 实例。
49    pub fn new(shell_: Shell, uds_cmd_path_: String, uds_output_path_: String) -> Server {
50        Server {
51            shell: shell_,
52            uds_cmd_path: uds_cmd_path_,
53            uds_output_path: uds_output_path_,
54        }
55    }
56
57    fn handle_cmd_connect(mut conn: UnixStream, shell: Shell) -> Result<(), String> {
58        write_line(&mut conn, &shell.get_reg_commands().join(" "))?;
59        loop {
60            let s = read_line(&mut conn)?;
61            if let Err(err) = shell.run_command(&s) {
62                println!("Error: {}", err);
63            }
64        }
65    }
66
67    fn cmd_thread(path: &String, shell: &Shell) -> Result<(), String> {
68        let server = UnixListener::bind(path).map_err(|err| format!("bind err: {:?}", err))?;
69        while let Ok(conn) = server.incoming().next().ok_or("listen err")? {
70            spawn({
71                let conn_copy = conn
72                    .try_clone()
73                    .map_err(|err| format!("clone err: {:?}", err))?;
74                let shell_copy = shell.clone();
75                move || {
76                    if let Err(err) = Server::handle_cmd_connect(conn_copy, shell_copy)
77                        .map_err(|err| format!("handle cmd connect err: {:?}", err))
78                    {
79                        println!("handle cmd connect err: {}", err);
80                    }
81                }
82            });
83        }
84
85        Ok(())
86    }
87
88    fn redirect_stdout_to_unix_stream(stream: &UnixStream) -> c_int {
89        let original_fd = unsafe { dup(STDOUT_FILENO) }; // 保存原始stdout的文件描述符
90        let ret = original_fd;
91
92        let stream_fd = stream.as_raw_fd(); // 获取UnixStream的文件描述符
93        unsafe { dup2(stream_fd, STDOUT_FILENO) }; // 将stdout的文件描述符重定向到UnixStream
94
95        ret
96    }
97
98    fn restore_stdout(old: c_int) {
99        unsafe { dup2(old, STDOUT_FILENO) }; // 恢复原始stdout的文件描述符
100        unsafe { close(old) }; // 关闭原始文件描述符
101    }
102
103    fn output_thread(path: &String) -> Result<(), String> {
104        let mut future: Option<JoinHandle<()>> = None;
105        let mut old_conn: Option<UnixStream> = None;
106        let server = UnixListener::bind(path).map_err(|err| format!("bind err: {:?}", err))?;
107        while let Ok(conn) = server.incoming().next().ok_or("listen err")? {
108            if let Some(o) = old_conn.take() {
109                drop(o);
110                future.take().unwrap().join().unwrap();
111            }
112
113            old_conn = Some(conn.try_clone().map_err(|err| err.to_string())?);
114
115            let mut conn_copy = conn.try_clone().map_err(|err| err.to_string())?;
116
117            let old_stdout = Server::redirect_stdout_to_unix_stream(&conn);
118
119            future = Some(spawn(move || {
120                let mut buf = String::new();
121                let _ = conn_copy
122                    .read_to_string(&mut buf)
123                    .map_err(|err| err.to_string());
124                Server::restore_stdout(old_stdout);
125            }));
126        }
127
128        Ok(())
129    }
130
131    /// 在 Server 实例上运行命令并处理输出。
132    ///
133    /// # Returns
134    ///
135    /// 运行结果。
136    ///
137    /// # Errors
138    ///
139    /// 如果命令线程或输出线程返回错误,则返回包含该错误的 Result。
140    pub fn run(&mut self) -> Result<(), String> {
141        let uds_cmd_path = self.uds_cmd_path.clone();
142        let uds_output_path = self.uds_output_path.clone();
143
144        let shell_copy = self.shell.clone();
145
146        let command_thread = spawn(move || Server::cmd_thread(&uds_cmd_path, &shell_copy));
147        let output_thread = spawn(move || Server::output_thread(&uds_output_path));
148
149        let _ = command_thread
150            .join()
151            .map_err(|err| format!("run command err: {:?}", err))?;
152        let _ = output_thread
153            .join()
154            .map_err(|err| format!("run output err: {:?}", err))?;
155        Ok(())
156    }
157}