Skip to main content

tmai_core/ipc/
client.rs

1//! IPC client for PTY wrapper process
2//!
3//! Connects to the tmai parent process via Unix domain socket
4//! to send state updates and receive keystroke commands.
5
6use std::io::{BufRead, BufReader, Write};
7use std::os::unix::net::UnixStream;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::thread;
11use std::time::Duration;
12
13use crate::ipc::protocol::*;
14
15/// Registration info for the IPC client connection
16struct RegistrationInfo {
17    pane_id: String,
18    pid: u32,
19    team_name: Option<String>,
20    team_member_name: Option<String>,
21    is_team_lead: bool,
22}
23
24/// IPC client that connects to the tmai parent process
25pub struct IpcClient {
26    state_tx: std::sync::mpsc::SyncSender<WrapState>,
27}
28
29impl IpcClient {
30    /// Start the IPC client
31    ///
32    /// Creates a background thread that connects to the IPC server
33    /// and handles bidirectional communication. The `pty_writer` is
34    /// used to forward keystroke commands from the server to the PTY.
35    /// The `analyzer` is notified of IPC-originated input so the echo
36    /// grace period applies equally to remote keystrokes.
37    #[allow(clippy::too_many_arguments)]
38    pub fn start(
39        pane_id: String,
40        pid: u32,
41        team_name: Option<String>,
42        team_member_name: Option<String>,
43        is_team_lead: bool,
44        running: Arc<AtomicBool>,
45        pty_writer: Arc<parking_lot::Mutex<Box<dyn Write + Send>>>,
46        analyzer: Arc<parking_lot::Mutex<crate::wrap::analyzer::Analyzer>>,
47    ) -> Self {
48        // Bounded channel with capacity 2 - only recent states matter
49        let (state_tx, state_rx) = std::sync::mpsc::sync_channel::<WrapState>(2);
50
51        let reg = RegistrationInfo {
52            pane_id,
53            pid,
54            team_name,
55            team_member_name,
56            is_team_lead,
57        };
58        let client_running = running;
59        thread::spawn(move || {
60            Self::connection_loop(reg, state_rx, client_running, pty_writer, analyzer);
61        });
62
63        Self { state_tx }
64    }
65
66    /// Send a state update (non-blocking)
67    ///
68    /// If the channel is full, the update is dropped (the next tick
69    /// will send the latest state anyway).
70    pub fn send_state(&self, state: WrapState) {
71        // try_send: if Full, that's ok - writer thread will catch up
72        let _ = self.state_tx.try_send(state);
73    }
74
75    /// Connection loop with exponential backoff retry
76    fn connection_loop(
77        reg: RegistrationInfo,
78        state_rx: std::sync::mpsc::Receiver<WrapState>,
79        running: Arc<AtomicBool>,
80        pty_writer: Arc<parking_lot::Mutex<Box<dyn Write + Send>>>,
81        analyzer: Arc<parking_lot::Mutex<crate::wrap::analyzer::Analyzer>>,
82    ) {
83        let mut backoff_ms = 100u64;
84
85        while running.load(Ordering::Relaxed) {
86            match UnixStream::connect(socket_path()) {
87                Ok(stream) => {
88                    backoff_ms = 100; // Reset on successful connect
89                    tracing::debug!("IPC connected to server");
90
91                    if let Err(e) = Self::handle_connection(
92                        stream,
93                        &reg,
94                        &state_rx,
95                        &running,
96                        &pty_writer,
97                        &analyzer,
98                    ) {
99                        tracing::debug!("IPC connection lost: {}", e);
100                    }
101                }
102                Err(e) => {
103                    tracing::debug!("IPC connect failed (will retry): {}", e);
104                }
105            }
106
107            if !running.load(Ordering::Relaxed) {
108                break;
109            }
110
111            thread::sleep(Duration::from_millis(backoff_ms));
112            backoff_ms = (backoff_ms * 2).min(2000);
113        }
114    }
115
116    /// Handle a single connection session
117    fn handle_connection(
118        stream: UnixStream,
119        reg: &RegistrationInfo,
120        state_rx: &std::sync::mpsc::Receiver<WrapState>,
121        running: &Arc<AtomicBool>,
122        pty_writer: &Arc<parking_lot::Mutex<Box<dyn Write + Send>>>,
123        analyzer: &Arc<parking_lot::Mutex<crate::wrap::analyzer::Analyzer>>,
124    ) -> anyhow::Result<()> {
125        stream.set_write_timeout(Some(Duration::from_secs(5)))?;
126
127        let mut write_stream = stream.try_clone()?;
128        let read_stream = stream;
129
130        // Send Register message
131        let register = ClientMessage::Register {
132            pane_id: reg.pane_id.clone(),
133            pid: reg.pid,
134            team_name: reg.team_name.clone(),
135            team_member_name: reg.team_member_name.clone(),
136            is_team_lead: reg.is_team_lead,
137        };
138        let msg = encode(&register)?;
139        write_stream.write_all(&msg)?;
140        write_stream.flush()?;
141
142        // Wait for Registered response (with timeout)
143        read_stream.set_read_timeout(Some(Duration::from_secs(5)))?;
144        let mut reader = BufReader::new(read_stream);
145        let mut line = String::new();
146        reader.read_line(&mut line)?;
147        let _response: ServerMessage = decode(line.trim_end().as_bytes())?;
148
149        // Switch to short read timeout for non-blocking reads
150        reader
151            .get_ref()
152            .set_read_timeout(Some(Duration::from_millis(100)))?;
153
154        tracing::debug!("IPC registered as pane_id={}", reg.pane_id);
155
156        // Connection is live flag
157        let connected = Arc::new(AtomicBool::new(true));
158
159        // Reader thread: receive SendKeys from server
160        let reader_connected = connected.clone();
161        let reader_running = running.clone();
162        let pty_writer_clone = pty_writer.clone();
163        let analyzer_clone = analyzer.clone();
164        let reader_thread = thread::spawn(move || {
165            let mut read_line = String::new();
166            while reader_connected.load(Ordering::Relaxed) && reader_running.load(Ordering::Relaxed)
167            {
168                read_line.clear();
169                match reader.read_line(&mut read_line) {
170                    Ok(0) => break, // EOF
171                    Ok(_) => {
172                        if let Ok(msg) = decode::<ServerMessage>(read_line.trim_end().as_bytes()) {
173                            match msg {
174                                ServerMessage::SendKeys { keys, literal } => {
175                                    let data = if literal {
176                                        keys.as_bytes().to_vec()
177                                    } else {
178                                        tmux_key_to_bytes(&keys)
179                                    };
180                                    let mut writer = pty_writer_clone.lock();
181                                    let _ = writer.write_all(&data);
182                                    let _ = writer.flush();
183                                    // Notify analyzer of IPC-originated input for echo grace
184                                    analyzer_clone.lock().process_input(&keys);
185                                }
186                                ServerMessage::SendKeysAndEnter { text } => {
187                                    let mut writer = pty_writer_clone.lock();
188                                    let _ = writer.write_all(text.as_bytes());
189                                    let _ = writer.write_all(b"\r");
190                                    let _ = writer.flush();
191                                    // Notify analyzer of IPC-originated input for echo grace
192                                    analyzer_clone.lock().process_input(&text);
193                                }
194                                ServerMessage::Registered { .. } => {
195                                    // Ignore duplicate
196                                }
197                            }
198                        }
199                    }
200                    Err(ref e)
201                        if e.kind() == std::io::ErrorKind::WouldBlock
202                            || e.kind() == std::io::ErrorKind::TimedOut =>
203                    {
204                        continue;
205                    }
206                    Err(_) => break,
207                }
208            }
209            reader_connected.store(false, Ordering::Relaxed);
210        });
211
212        // Writer loop: send state updates (runs on current thread)
213        while connected.load(Ordering::Relaxed) && running.load(Ordering::Relaxed) {
214            match state_rx.recv_timeout(Duration::from_millis(100)) {
215                Ok(state) => {
216                    let msg = ClientMessage::StateUpdate { state };
217                    match encode(&msg) {
218                        Ok(bytes) => {
219                            if write_stream.write_all(&bytes).is_err() {
220                                break;
221                            }
222                            let _ = write_stream.flush();
223                        }
224                        Err(_) => break,
225                    }
226                }
227                Err(std::sync::mpsc::RecvTimeoutError::Timeout) => continue,
228                Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => break,
229            }
230        }
231
232        connected.store(false, Ordering::Relaxed);
233        let _ = reader_thread.join();
234
235        Ok(())
236    }
237}
238
239/// Convert tmux key name to bytes for PTY input
240fn tmux_key_to_bytes(key: &str) -> Vec<u8> {
241    match key {
242        "Enter" => vec![b'\r'],
243        "Space" => vec![b' '],
244        "BSpace" => vec![0x7f],
245        "Tab" => vec![b'\t'],
246        "Escape" | "Esc" => vec![0x1b],
247        "Up" => vec![0x1b, b'[', b'A'],
248        "Down" => vec![0x1b, b'[', b'B'],
249        "Right" => vec![0x1b, b'[', b'C'],
250        "Left" => vec![0x1b, b'[', b'D'],
251        "Home" => vec![0x1b, b'[', b'H'],
252        "End" => vec![0x1b, b'[', b'F'],
253        "PPage" => vec![0x1b, b'[', b'5', b'~'],
254        "NPage" => vec![0x1b, b'[', b'6', b'~'],
255        "DC" => vec![0x1b, b'[', b'3', b'~'],
256        s if s.starts_with("C-") && s.len() == 3 => {
257            // Control character via bitmask: C-a/C-A = 0x01, C-@ = 0x00, C-[ = 0x1b
258            let c = s.as_bytes()[2];
259            vec![c & 0x1f]
260        }
261        // For literal text like "y"
262        other => other.as_bytes().to_vec(),
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[test]
271    fn test_tmux_key_to_bytes() {
272        assert_eq!(tmux_key_to_bytes("Enter"), vec![b'\r']);
273        assert_eq!(tmux_key_to_bytes("Space"), vec![b' ']);
274        assert_eq!(tmux_key_to_bytes("Up"), vec![0x1b, b'[', b'A']);
275        assert_eq!(tmux_key_to_bytes("C-c"), vec![3]); // 0x03
276        assert_eq!(tmux_key_to_bytes("C-A"), vec![1]); // uppercase: same as C-a
277        assert_eq!(tmux_key_to_bytes("C-@"), vec![0]); // NUL
278        assert_eq!(tmux_key_to_bytes("C-["), vec![0x1b]); // ESC
279        assert_eq!(tmux_key_to_bytes("y"), vec![b'y']);
280    }
281}