rust_network_mgr/
socket.rs

1use crate::types::{AppError, ControlCommand, ControlCommandSender, Result};
2use directories::ProjectDirs; // Changed from BaseDirs to ProjectDirs for runtime path
3use std::path::{Path, PathBuf};
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::net::{UnixListener, UnixStream};
6
7const SOCKET_FILE: &str = "rust-network-manager.sock";
8
9/// Gets the recommended path for the Unix domain socket.
10/// Prefers /run/ RUST_PROJECT_NAME if possible, otherwise falls back to /tmp.
11fn get_socket_path(config_path: Option<&str>) -> Result<PathBuf> {
12    if let Some(path_str) = config_path {
13        return Ok(PathBuf::from(path_str));
14    }
15
16    // Try to use /run/rust-network-manager/
17    let run_dir = Path::new("/run");
18    if run_dir.exists() && run_dir.is_dir() {
19        // Basic check for write access might be needed in a real scenario
20        let app_run_dir = run_dir.join("rust-network-manager");
21        if std::fs::create_dir_all(&app_run_dir).is_ok() {
22             // Set appropriate permissions if needed (e.g., only accessible by root/group)
23             // For simplicity, we don't do this here.
24             return Ok(app_run_dir.join(SOCKET_FILE));
25        }
26    }
27    
28    // Fallback using ProjectDirs (might place it in user's runtime dir)
29    if let Some(proj_dirs) = ProjectDirs::from("", "", "RustNetworkManager") { // qualifier, organization, application
30        if let Some(runtime_dir) = proj_dirs.runtime_dir() {
31            if std::fs::create_dir_all(runtime_dir).is_ok() {
32                return Ok(runtime_dir.join(SOCKET_FILE));
33            }
34        }
35    }
36
37    // Absolute fallback to /tmp (less ideal for system services)
38    Ok(Path::new("/tmp").join(SOCKET_FILE))
39}
40
41pub struct SocketHandler {
42    listener: UnixListener,
43    command_sender: ControlCommandSender,
44}
45
46impl SocketHandler {
47    pub async fn new(config_socket_path: Option<&str>, command_sender: ControlCommandSender) -> Result<Self> {
48        let socket_path = get_socket_path(config_socket_path)?;
49        tracing::info!("Attempting to bind control socket at: {:?}", socket_path);
50
51        // Ensure the path exists and clean up old socket if present
52        if socket_path.exists() {
53            tracing::warn!("Existing socket file found at {:?}. Removing.", socket_path);
54            std::fs::remove_file(&socket_path)
55                .map_err(|e| AppError::Init(format!("Failed to remove old socket: {}", e)))?;
56        }
57        if let Some(parent) = socket_path.parent() {
58             if !parent.exists() {
59                 std::fs::create_dir_all(parent).map_err(|e| AppError::Init(format!("Failed to create socket directory: {}", e)))?;
60             }
61        }
62
63        let listener = UnixListener::bind(&socket_path)
64            .map_err(|e| AppError::Socket(e))?;
65
66        // TODO: Set permissions on the socket file (e.g., only allow specific user/group)
67
68        tracing::info!("Control socket listening at: {:?}", socket_path);
69        Ok(SocketHandler { listener, command_sender })
70    }
71
72    pub async fn start(self) {
73        tracing::info!("Starting socket command listener loop...");
74        loop {
75            match self.listener.accept().await {
76                Ok((stream, _addr)) => {
77                    tracing::debug!("Accepted new socket connection");
78                    let sender = self.command_sender.clone();
79                    tokio::spawn(async move {
80                        if let Err(e) = Self::handle_connection(stream, sender).await {
81                            tracing::error!("Error handling socket connection: {}", e);
82                        }
83                    });
84                }
85                Err(e) => {
86                    tracing::error!("Failed to accept socket connection: {}. Stopping listener.", e);
87                    break; // Stop listening on error
88                }
89            }
90        }
91    }
92
93    async fn handle_connection(stream: UnixStream, sender: ControlCommandSender) -> Result<()> {
94        let mut reader = BufReader::new(stream);
95        let mut line = String::new();
96
97        loop {
98            line.clear();
99            match reader.read_line(&mut line).await {
100                Ok(0) => { // Connection closed
101                    tracing::debug!("Socket connection closed by peer.");
102                    break;
103                }
104                Ok(_) => {
105                    let command_str = line.trim();
106                    tracing::info!("Received command via socket: {}", command_str);
107                    let command = match command_str {
108                        "reload" => Some(ControlCommand::Reload),
109                        "status" => Some(ControlCommand::Status),
110                        "ping" => Some(ControlCommand::Ping),
111                        "shutdown" => Some(ControlCommand::Shutdown),
112                        _ => {
113                            tracing::warn!("Received unknown command: {}", command_str);
114                             let stream_ref = reader.get_mut(); // Get ref to write response
115                             stream_ref.write_all(b"ERROR: Unknown command\n").await?;
116                            None
117                        }
118                    };
119
120                    if let Some(cmd) = command {
121                         let stream_ref = reader.get_mut(); // Get ref to write response
122                         match sender.send(cmd.clone()).await {
123                            Ok(_) => {
124                                tracing::debug!("Sent command {:?} to main loop", cmd);
125                                // Simple ACK for most commands
126                                let response_str: &'static str = match cmd {
127                                    ControlCommand::Ping => "PONG\n",
128                                    ControlCommand::Status => "STATUS command received (response handled by main loop)\n", // Status response is async
129                                    _ => "OK\n",
130                                };
131                                stream_ref.write_all(response_str.as_bytes()).await?;
132                                if matches!(cmd, ControlCommand::Shutdown) {
133                                     tracing::info!("Shutdown command received, closing connection.");
134                                     break; // Close connection after shutdown cmd
135                                }
136                            }
137                            Err(e) => {
138                                tracing::error!("Failed to send command {:?} to main loop: {}", cmd, e);
139                                stream_ref.write_all(b"ERROR: Failed to process command internally\n").await?;
140                            }
141                         }
142                    }
143                }
144                Err(e) => { // Read error
145                    tracing::error!("Error reading from socket: {}", e);
146                    break;
147                }
148            }
149        }
150        Ok(())
151    }
152}
153
154// Note: Testing socket interaction often requires integration tests or mocking frameworks.
155// Basic unit tests might focus on command parsing if extracted.