Skip to main content

sidedns_core/ipc/
server.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use interprocess::local_socket::{GenericFilePath, ListenerOptions, ToFsName, tokio::prelude::*};
5use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
6use tokio::sync::broadcast;
7use tokio_util::sync::CancellationToken;
8
9use super::message::{DnsEvent, IpcRequest, IpcResponse};
10use anyhow::Result;
11
12/// Implemented by the daemon's shared state to handle incoming IPC requests.
13///
14/// The implementation receives decoded requests and returns encoded responses.
15/// It also provides a broadcast receiver for the event stream used by
16/// [`IpcServer`] to forward events to subscribed clients.
17#[async_trait]
18pub trait IpcHandler: Send + Sync + 'static {
19    /// Process a single request and return a response.
20    async fn handle(&self, request: IpcRequest) -> IpcResponse;
21
22    /// Subscribe to the internal event broadcast channel.
23    ///
24    /// Each call returns an independent receiver. The server creates one per
25    /// subscriber connection.
26    fn subscribe_events(&self) -> broadcast::Receiver<DnsEvent>;
27}
28
29/// Listens on a local socket and dispatches connections to an [`IpcHandler`].
30///
31/// Uses Unix domain sockets on Linux/macOS and named pipes on Windows.
32/// The server shuts down cleanly when the provided [`CancellationToken`] is cancelled.
33pub struct IpcServer {
34    socket_path: String,
35}
36
37impl Default for IpcServer {
38    fn default() -> Self {
39        Self {
40            socket_path: crate::IPC_SOCKET_PATH.to_string(),
41        }
42    }
43}
44
45impl Drop for IpcServer {
46    fn drop(&mut self) {
47        #[cfg(not(windows))]
48        if std::path::Path::new(&self.socket_path).exists() {
49            if let Err(e) = std::fs::remove_file(&self.socket_path) {
50                tracing::error!("Failed to remove IpcServer socket file: {e}");
51            } else {
52                tracing::info!("IpcServer socket file removed");
53            }
54        }
55    }
56}
57
58impl IpcServer {
59    /// Create a server bound to a custom socket path.
60    ///
61    /// Prefer [`IpcServer::default`] in production code.
62    /// This method exists primarily for test isolation.
63    pub fn with_path(path: impl Into<String>) -> Self {
64        Self {
65            socket_path: path.into(),
66        }
67    }
68
69    /// Start accepting connections until `token` is cancelled.
70    ///
71    /// Removes a stale socket file at startup on Unix. On shutdown,
72    /// the socket file is removed again so subsequent starts are clean.
73    ///
74    /// # Errors
75    ///
76    /// Returns [`IpcError`] if the socket cannot be bound or if an
77    /// unrecoverable I/O error occurs during the accept loop.
78    #[tracing::instrument(skip(self, handler, token), name = "IPC Server")]
79    pub async fn serve<H: IpcHandler>(
80        &self,
81        handler: Arc<H>,
82        token: CancellationToken,
83    ) -> Result<()> {
84        #[cfg(not(windows))]
85        {
86            if let Err(e) = std::fs::remove_file(&self.socket_path)
87                && e.kind() != std::io::ErrorKind::NotFound
88            {
89                return Err(e.into());
90            }
91        }
92
93        let name = self.socket_path.as_str().to_fs_name::<GenericFilePath>()?;
94        let options = ListenerOptions::new().name(name);
95
96        #[cfg(target_os = "linux")]
97        let options = {
98            use interprocess::os::unix::local_socket::ListenerOptionsExt;
99            options.mode(0o666)
100        };
101
102        #[cfg(target_os = "macos")]
103        {
104            use std::os::unix::fs::PermissionsExt;
105            if let Ok(metadata) = std::fs::metadata(&self.socket_path) {
106                let mut perms = metadata.permissions();
107                perms.set_mode(0o666);
108                let _ = std::fs::set_permissions(&self.socket_path, perms);
109            }
110        }
111
112        #[cfg(windows)]
113        let options = {
114            // On Windows, set a security descriptor that allows all users to connect to the named pipe.
115            // This is necessary for the CLI to work when run from an unelevated prompt.
116            // D:(A;;GA;;;BA)(A;;GA;;;SY)(A;;GA;;;AU)
117
118            use interprocess::os::windows::{
119                local_socket::ListenerOptionsExt, security_descriptor::SecurityDescriptor,
120            };
121            use widestring::u16cstr;
122            let sd =
123                SecurityDescriptor::deserialize(u16cstr!("D:(A;;GA;;;BA)(A;;GA;;;SY)(A;;GA;;;AU)"))
124                    .map_err(std::io::Error::other)?;
125            options.security_descriptor(sd)
126        };
127
128        let listener = options.create_tokio()?;
129
130        tracing::info!("Started");
131
132        let result = self.accept_loop(&listener, handler, token).await;
133
134        tracing::info!("Stopped");
135
136        result
137    }
138
139    async fn accept_loop<H: IpcHandler>(
140        &self,
141        listener: &LocalSocketListener,
142        handler: Arc<H>,
143        token: CancellationToken,
144    ) -> Result<()> {
145        loop {
146            tokio::select! {
147                biased;
148                _ = token.cancelled() => {
149                    tracing::info!("Shutdown requested, stopping...");
150                    break;
151                },
152                result = listener.accept() => {
153                    let conn = result?;
154                    let handler = handler.clone();
155                    let token = token.clone();
156
157                    tokio::spawn(async move {
158                        if let Err(e) = handle_connection(conn, handler, token).await {
159                            tracing::error!("Connection error: {e}");
160                        }
161                    });
162                }
163            }
164        }
165
166        Ok(())
167    }
168}
169
170#[tracing::instrument(skip(conn, handler, token), name = "IPC Server")]
171async fn handle_connection<H: IpcHandler>(
172    conn: LocalSocketStream,
173    handler: Arc<H>,
174    token: CancellationToken,
175) -> Result<()> {
176    let (reader, mut writer) = tokio::io::split(conn);
177    let mut lines = BufReader::new(reader).lines();
178
179    let Some(line) = lines.next_line().await? else {
180        return Ok(());
181    };
182
183    let request: IpcRequest = serde_json::from_str(&line)?;
184
185    if matches!(request, IpcRequest::Subscribe) {
186        handle_subscribe(&mut writer, handler, token).await
187    } else {
188        let response = handler.handle(request).await;
189        write_response(&mut writer, &response).await
190    }
191}
192
193async fn handle_subscribe<W>(
194    writer: &mut W,
195    handler: Arc<impl IpcHandler>,
196    token: CancellationToken,
197) -> Result<()>
198where
199    W: AsyncWriteExt + Unpin,
200{
201    let mut rx = handler.subscribe_events();
202
203    loop {
204        tokio::select! {
205            biased;
206            _ = token.cancelled() => break,
207            result = rx.recv() => {
208                match result {
209                    Ok(event) => {
210                        write_response(writer, &IpcResponse::Event(event)).await?;
211                    }
212                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
213                    Err(broadcast::error::RecvError::Closed) => break,
214                }
215            }
216        }
217    }
218
219    Ok(())
220}
221
222async fn write_response<W>(writer: &mut W, response: &IpcResponse) -> Result<()>
223where
224    W: AsyncWriteExt + Unpin,
225{
226    let mut payload = serde_json::to_string(response)?;
227    payload.push('\n');
228    writer.write_all(payload.as_bytes()).await?;
229    writer.flush().await?;
230    Ok(())
231}