Skip to main content

synaps_cli/events/
socket.rs

1// src/events/socket.rs
2// Unix socket listener for per-session event delivery
3
4use std::sync::{atomic::{AtomicBool, Ordering}, Arc};
5
6use tokio::io::AsyncReadExt;
7use tokio::net::UnixListener;
8
9use super::queue::EventQueue;
10use super::types::Event;
11
12const MAX_PAYLOAD: usize = 256 * 1024; // 256KB
13
14/// Remove socket file if it exists. Best-effort, never panics.
15/// Sockets now live in ~/.synaps-cli/run/ (mode 0700), so symlink
16/// attacks from other users are not possible. Still refuse symlinks
17/// as defense-in-depth.
18pub fn cleanup_socket(socket_path: &str) {
19    let path = std::path::Path::new(socket_path);
20    #[cfg(unix)]
21    {
22        if let Ok(meta) = std::fs::symlink_metadata(path) {
23            if meta.file_type().is_symlink() {
24                tracing::warn!("socket: refusing to remove symlink at {}", socket_path);
25                return;
26            }
27        }
28    }
29    match std::fs::remove_file(path) {
30        Ok(()) => {}
31        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
32        Err(e) => tracing::warn!("socket: failed to remove {}: {}", socket_path, e),
33    }
34}
35
36/// Bind a Unix socket at `socket_path`, accept connections, parse incoming
37/// events, push to `queue`. Runs until `shutdown` is set.
38///
39/// Protocol: client connects → sends full JSON event → closes connection.
40/// One event per connection. Max payload 256KB.
41pub fn listen_session_socket(
42    socket_path: String,
43    queue: Arc<EventQueue>,
44    shutdown: Arc<AtomicBool>,
45) -> tokio::task::JoinHandle<()> {
46    tokio::spawn(async move {
47        // Remove stale socket from a previous crash
48        cleanup_socket(&socket_path);
49
50        let listener = match UnixListener::bind(&socket_path) {
51            Ok(l) => l,
52            Err(e) => {
53                tracing::error!("socket: failed to bind {}: {}", socket_path, e);
54                return;
55            }
56        };
57
58        // Lock down the socket — session traffic only
59        #[cfg(unix)]
60        {
61            use std::os::unix::fs::PermissionsExt;
62            if let Ok(meta) = std::fs::metadata(&socket_path) {
63                let mut perms = meta.permissions();
64                perms.set_mode(0o600);
65                let _ = std::fs::set_permissions(&socket_path, perms);
66            }
67        }
68
69        tracing::info!("socket: listening on {}", socket_path);
70
71        loop {
72            if shutdown.load(Ordering::Acquire) {
73                break;
74            }
75
76            // Poll accept with a timeout so we can check shutdown periodically
77            let accept = tokio::time::timeout(
78                std::time::Duration::from_millis(500),
79                listener.accept(),
80            );
81
82            match accept.await {
83                Ok(Ok((mut stream, _addr))) => {
84                    let queue = queue.clone();
85                    tokio::spawn(async move {
86                        // 5s timeout prevents slow-send DoS from parking tasks indefinitely
87                        let _ = tokio::time::timeout(
88                            std::time::Duration::from_secs(5),
89                            handle_connection(&mut stream, &queue),
90                        ).await;
91                    });
92                }
93                Ok(Err(e)) => {
94                    tracing::warn!("socket: accept error: {}", e);
95                }
96                Err(_) => {
97                    // Timeout — loop around and check shutdown flag
98                }
99            }
100        }
101
102        cleanup_socket(&socket_path);
103        tracing::info!("socket: shut down, removed {}", socket_path);
104    })
105}
106
107async fn handle_connection(
108    stream: &mut tokio::net::UnixStream,
109    queue: &EventQueue,
110) {
111    // Read up to MAX_PAYLOAD + 1 so we can detect oversized payloads
112    let mut buf = Vec::with_capacity(4096);
113    let mut chunk = [0u8; 8192];
114
115    loop {
116        match stream.read(&mut chunk).await {
117            Ok(0) => break, // EOF — client closed connection
118            Ok(n) => {
119                if buf.len() + n > MAX_PAYLOAD {
120                    tracing::warn!(
121                        "socket: payload exceeds {}KB limit, dropping connection",
122                        MAX_PAYLOAD / 1024
123                    );
124                    return;
125                }
126                buf.extend_from_slice(&chunk[..n]);
127            }
128            Err(e) => {
129                tracing::warn!("socket: read error: {}", e);
130                return;
131            }
132        }
133    }
134
135    if buf.is_empty() {
136        return;
137    }
138
139    match serde_json::from_slice::<Event>(&buf) {
140        Ok(event) => {
141            tracing::info!(
142                "socket: event {} from {}",
143                event.id,
144                event.source.source_type
145            );
146            if let Err(e) = queue.push(event) {
147                tracing::warn!("socket: queue push failed: {}", e);
148            }
149        }
150        Err(e) => {
151            tracing::warn!("socket: invalid JSON payload: {}", e);
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::events::{Event, Severity};
160    use std::sync::atomic::AtomicBool;
161    use tokio::io::AsyncWriteExt;
162    use tokio::net::UnixStream;
163
164    fn tmp_socket_path() -> String {
165        format!(
166            "/tmp/test-session-socket-{}.sock",
167            uuid::Uuid::new_v4().simple()
168        )
169    }
170
171    async fn wait_for_socket(path: &str) {
172        for _ in 0..50 {
173            if std::path::Path::new(path).exists() {
174                return;
175            }
176            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
177        }
178        panic!("socket never appeared at {}", path);
179    }
180
181    #[tokio::test]
182    async fn delivers_event_to_queue() {
183        let path = tmp_socket_path();
184        let queue = Arc::new(EventQueue::new(10));
185        let shutdown = Arc::new(AtomicBool::new(false));
186
187        let handle = listen_session_socket(path.clone(), queue.clone(), shutdown.clone());
188        wait_for_socket(&path).await;
189
190        let event = Event::simple("test", "hello socket", Some(Severity::High));
191        let json = serde_json::to_vec(&event).unwrap();
192
193        let mut client = UnixStream::connect(&path).await.unwrap();
194        client.write_all(&json).await.unwrap();
195        client.shutdown().await.unwrap();
196
197        // Give the task a moment to push
198        for _ in 0..50 {
199            if queue.len() > 0 {
200                break;
201            }
202            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
203        }
204
205        shutdown.store(true, Ordering::Release);
206        handle.await.unwrap();
207
208        let popped = queue.pop().expect("event should be in queue");
209        assert_eq!(popped.content.text, "hello socket");
210        assert_eq!(popped.source.source_type, "test");
211    }
212
213    #[tokio::test]
214    async fn rejects_oversized_payload() {
215        let path = tmp_socket_path();
216        let queue = Arc::new(EventQueue::new(10));
217        let shutdown = Arc::new(AtomicBool::new(false));
218
219        let handle = listen_session_socket(path.clone(), queue.clone(), shutdown.clone());
220        wait_for_socket(&path).await;
221
222        // 257KB of junk — over the limit
223        let oversized = vec![b'x'; MAX_PAYLOAD + 1024];
224        let mut client = UnixStream::connect(&path).await.unwrap();
225        client.write_all(&oversized).await.unwrap();
226        client.shutdown().await.unwrap();
227
228        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
229
230        shutdown.store(true, Ordering::Release);
231        handle.await.unwrap();
232
233        assert_eq!(queue.len(), 0, "oversized payload should not reach queue");
234    }
235
236    #[tokio::test]
237    async fn invalid_json_does_not_crash() {
238        let path = tmp_socket_path();
239        let queue = Arc::new(EventQueue::new(10));
240        let shutdown = Arc::new(AtomicBool::new(false));
241
242        let handle = listen_session_socket(path.clone(), queue.clone(), shutdown.clone());
243        wait_for_socket(&path).await;
244
245        let mut client = UnixStream::connect(&path).await.unwrap();
246        client.write_all(b"this is not json at all").await.unwrap();
247        client.shutdown().await.unwrap();
248
249        tokio::time::sleep(std::time::Duration::from_millis(200)).await;
250
251        // Send a valid event after — proves listener is still running
252        let event = Event::simple("test", "still alive", None);
253        let json = serde_json::to_vec(&event).unwrap();
254        let mut client2 = UnixStream::connect(&path).await.unwrap();
255        client2.write_all(&json).await.unwrap();
256        client2.shutdown().await.unwrap();
257
258        for _ in 0..50 {
259            if queue.len() > 0 {
260                break;
261            }
262            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
263        }
264
265        shutdown.store(true, Ordering::Release);
266        handle.await.unwrap();
267
268        assert_eq!(queue.len(), 1);
269        assert_eq!(queue.pop().unwrap().content.text, "still alive");
270    }
271
272    #[tokio::test]
273    async fn stale_socket_removed_on_startup() {
274        let path = tmp_socket_path();
275        // Plant a stale file
276        std::fs::write(&path, b"stale").unwrap();
277        assert!(std::path::Path::new(&path).exists());
278
279        let queue = Arc::new(EventQueue::new(10));
280        let shutdown = Arc::new(AtomicBool::new(false));
281
282        // Should not panic — bind replaces the stale file
283        let handle = listen_session_socket(path.clone(), queue.clone(), shutdown.clone());
284        wait_for_socket(&path).await;
285
286        shutdown.store(true, Ordering::Release);
287        handle.await.unwrap();
288    }
289}