Skip to main content

tmai_core/ipc/
server.rs

1//! IPC server for tmai parent process
2//!
3//! Listens on a Unix domain socket for wrapper connections and maintains
4//! a registry of connected wrapper states.
5
6use std::collections::HashMap;
7use std::os::unix::fs::PermissionsExt;
8use std::sync::Arc;
9
10use anyhow::{Context, Result};
11use parking_lot::RwLock;
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
13use tokio::net::UnixListener;
14use tokio::sync::mpsc;
15
16use crate::ipc::protocol::*;
17
18/// Registry of connected wrapper states, keyed by pane_id
19pub type IpcRegistry = Arc<RwLock<HashMap<String, WrapState>>>;
20
21/// Handle to a connected wrapper, allowing server-to-wrapper messaging
22struct ConnectionHandle {
23    pane_id: String,
24    tx: mpsc::Sender<ServerMessage>,
25}
26
27/// IPC server that manages wrapper connections
28pub struct IpcServer {
29    registry: IpcRegistry,
30    connections: Arc<RwLock<HashMap<String, ConnectionHandle>>>,
31}
32
33impl IpcServer {
34    /// Start the IPC server, binding to the Unix domain socket
35    pub async fn start() -> Result<Self> {
36        let registry: IpcRegistry = Arc::new(RwLock::new(HashMap::new()));
37        let connections: Arc<RwLock<HashMap<String, ConnectionHandle>>> =
38            Arc::new(RwLock::new(HashMap::new()));
39
40        // Ensure state directory exists with 0o700 permissions
41        ensure_state_dir()?;
42
43        // Clean up stale socket
44        let sock = socket_path();
45        if sock.exists() {
46            match tokio::net::UnixStream::connect(&sock).await {
47                Ok(_) => {
48                    anyhow::bail!(
49                        "Another tmai instance is already running (socket {} is active)",
50                        sock.display()
51                    );
52                }
53                Err(_) => {
54                    // Stale socket, safe to remove
55                    std::fs::remove_file(&sock).with_context(|| {
56                        format!("Failed to remove stale socket: {}", sock.display())
57                    })?;
58                }
59            }
60        }
61
62        let listener = UnixListener::bind(&sock).context("Failed to bind IPC Unix socket")?;
63
64        // Set socket permissions to owner-only
65        std::fs::set_permissions(&sock, std::fs::Permissions::from_mode(0o700))
66            .context("Failed to set socket permissions")?;
67
68        let server = Self {
69            registry: registry.clone(),
70            connections: connections.clone(),
71        };
72
73        // Spawn accept loop
74        tokio::spawn(async move {
75            Self::accept_loop(listener, registry, connections).await;
76        });
77
78        tracing::debug!("IPC server started on {}", sock.display());
79        Ok(server)
80    }
81
82    /// Accept loop for incoming wrapper connections
83    async fn accept_loop(
84        listener: UnixListener,
85        registry: IpcRegistry,
86        connections: Arc<RwLock<HashMap<String, ConnectionHandle>>>,
87    ) {
88        loop {
89            match listener.accept().await {
90                Ok((stream, _)) => {
91                    let registry = registry.clone();
92                    let connections = connections.clone();
93                    tokio::spawn(async move {
94                        if let Err(e) = Self::handle_connection(stream, registry, connections).await
95                        {
96                            tracing::debug!("IPC connection ended: {}", e);
97                        }
98                    });
99                }
100                Err(e) => {
101                    tracing::warn!("IPC accept error: {}", e);
102                }
103            }
104        }
105    }
106
107    /// Handle a single wrapper connection
108    async fn handle_connection(
109        stream: tokio::net::UnixStream,
110        registry: IpcRegistry,
111        connections: Arc<RwLock<HashMap<String, ConnectionHandle>>>,
112    ) -> Result<()> {
113        let (reader, mut writer) = stream.into_split();
114        let mut buf_reader = BufReader::new(reader);
115        let mut line_buf = String::new();
116
117        // First message must be Register
118        buf_reader.read_line(&mut line_buf).await?;
119        if line_buf.is_empty() {
120            anyhow::bail!("Connection closed before registration");
121        }
122        let first_msg: ClientMessage = decode(line_buf.trim_end().as_bytes())?;
123
124        let pane_id = match first_msg {
125            ClientMessage::Register {
126                pane_id,
127                pid,
128                team_name,
129                team_member_name,
130                is_team_lead,
131            } => {
132                let state = WrapState {
133                    pid,
134                    pane_id: Some(pane_id.clone()),
135                    team_name,
136                    team_member_name,
137                    is_team_lead,
138                    ..Default::default()
139                };
140                registry.write().insert(pane_id.clone(), state);
141                pane_id
142            }
143            _ => anyhow::bail!("First message must be Register"),
144        };
145
146        // Create channel for server → wrapper messages
147        let (tx, mut rx) = mpsc::channel::<ServerMessage>(32);
148        let connection_id = uuid::Uuid::new_v4().to_string();
149
150        // Send Registered response
151        let registered = ServerMessage::Registered {
152            connection_id: connection_id.clone(),
153        };
154        let msg_bytes = encode(&registered)?;
155        writer.write_all(&msg_bytes).await?;
156        writer.flush().await?;
157
158        // Remove any existing connection for this pane_id (reconnect scenario)
159        // then store the new connection handle
160        {
161            let mut conns = connections.write();
162            conns.retain(|_, handle| handle.pane_id != pane_id);
163            conns.insert(
164                connection_id.clone(),
165                ConnectionHandle {
166                    pane_id: pane_id.clone(),
167                    tx,
168                },
169            );
170        }
171
172        tracing::debug!("IPC client registered: pane_id={}", pane_id);
173
174        // Main loop: read from client OR send to client
175        line_buf.clear();
176        loop {
177            tokio::select! {
178                result = buf_reader.read_line(&mut line_buf) => {
179                    match result {
180                        Ok(0) => break, // EOF
181                        Ok(_) => {
182                            if let Ok(msg) = decode::<ClientMessage>(line_buf.trim_end().as_bytes()) {
183                                match msg {
184                                    ClientMessage::StateUpdate { state } => {
185                                        registry.write().insert(pane_id.clone(), state);
186                                    }
187                                    ClientMessage::Register { .. } => {
188                                        // Ignore duplicate register
189                                    }
190                                }
191                            }
192                            line_buf.clear();
193                        }
194                        Err(e) => {
195                            tracing::debug!("IPC read error for pane {}: {}", pane_id, e);
196                            break;
197                        }
198                    }
199                }
200                msg = rx.recv() => {
201                    match msg {
202                        Some(server_msg) => {
203                            match encode(&server_msg) {
204                                Ok(msg_bytes) => {
205                                    if writer.write_all(&msg_bytes).await.is_err() {
206                                        break;
207                                    }
208                                    let _ = writer.flush().await;
209                                }
210                                Err(_) => break,
211                            }
212                        }
213                        None => break, // Channel closed
214                    }
215                }
216            }
217        }
218
219        // Cleanup on disconnect
220        registry.write().remove(&pane_id);
221        connections.write().remove(&connection_id);
222        tracing::debug!("IPC client disconnected: pane_id={}", pane_id);
223
224        Ok(())
225    }
226
227    /// Get the registry for reading wrapper states
228    pub fn registry(&self) -> IpcRegistry {
229        self.registry.clone()
230    }
231
232    /// Check if a wrapper with the given pane_id is connected
233    pub fn has_connection(&self, pane_id: &str) -> bool {
234        self.connections
235            .read()
236            .values()
237            .any(|c| c.pane_id == pane_id)
238    }
239
240    /// Send keys to a wrapper via IPC. Returns true if sent successfully.
241    pub fn try_send_keys(&self, pane_id: &str, keys: &str, literal: bool) -> bool {
242        let connections = self.connections.read();
243        for handle in connections.values() {
244            if handle.pane_id == pane_id {
245                let msg = ServerMessage::SendKeys {
246                    keys: keys.to_string(),
247                    literal,
248                };
249                return handle.tx.try_send(msg).is_ok();
250            }
251        }
252        false
253    }
254
255    /// Send text + Enter to a wrapper via IPC. Returns true if sent successfully.
256    pub fn try_send_keys_and_enter(&self, pane_id: &str, text: &str) -> bool {
257        let connections = self.connections.read();
258        for handle in connections.values() {
259            if handle.pane_id == pane_id {
260                let msg = ServerMessage::SendKeysAndEnter {
261                    text: text.to_string(),
262                };
263                return handle.tx.try_send(msg).is_ok();
264            }
265        }
266        false
267    }
268}
269
270/// Ensure state directory exists with proper permissions
271fn ensure_state_dir() -> Result<()> {
272    let dir = state_dir();
273    // Check for symlink attack before creating
274    if dir.exists() {
275        let meta = std::fs::symlink_metadata(&dir)
276            .with_context(|| format!("Failed to read metadata for: {}", dir.display()))?;
277        if meta.is_symlink() {
278            anyhow::bail!(
279                "State directory is a symlink (possible attack): {}",
280                dir.display()
281            );
282        }
283    }
284    std::fs::create_dir_all(&dir)
285        .with_context(|| format!("Failed to create state directory: {}", dir.display()))?;
286    let metadata = std::fs::metadata(&dir)
287        .with_context(|| format!("Failed to read metadata for: {}", dir.display()))?;
288    if !metadata.is_dir() {
289        anyhow::bail!("State path is not a directory: {}", dir.display());
290    }
291    let mode = metadata.permissions().mode() & 0o777;
292    if mode != 0o700 {
293        std::fs::set_permissions(&dir, std::fs::Permissions::from_mode(0o700))
294            .with_context(|| format!("Failed to set permissions on: {}", dir.display()))?;
295    }
296    Ok(())
297}
298
299impl Drop for IpcServer {
300    fn drop(&mut self) {
301        let _ = std::fs::remove_file(socket_path());
302    }
303}