Skip to main content

tmai_core/ipc/
protocol.rs

1//! IPC protocol definitions for tmai wrapper ↔ parent communication
2//!
3//! Uses newline-delimited JSON (ndjson) for bidirectional messaging
4//! over Unix domain sockets.
5
6use std::path::PathBuf;
7
8use anyhow::Result;
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11
12/// Get the base state directory, preferring XDG_RUNTIME_DIR for security
13pub fn state_dir() -> PathBuf {
14    if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
15        PathBuf::from(xdg).join("tmai")
16    } else {
17        let uid = unsafe { libc::getuid() };
18        PathBuf::from(format!("/tmp/tmai-{}", uid))
19    }
20}
21
22/// Get the IPC socket path
23pub fn socket_path() -> PathBuf {
24    state_dir().join("control.sock")
25}
26
27/// Status of a wrapped agent
28#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum WrapStatus {
31    /// Agent is actively outputting (last output within 200ms)
32    Processing,
33    /// Agent is idle (output stopped, no approval detected)
34    #[default]
35    Idle,
36    /// Agent is awaiting approval (output stopped with approval pattern)
37    AwaitingApproval,
38}
39
40/// Type of approval being requested (for wrapped agents)
41#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum WrapApprovalType {
44    /// File edit/create/delete
45    FileEdit,
46    /// Shell command execution
47    ShellCommand,
48    /// MCP tool invocation
49    McpTool,
50    /// User question with selectable choices
51    UserQuestion,
52    /// Yes/No confirmation
53    YesNo,
54    /// Other/unknown
55    Other,
56}
57
58/// State data for a wrapped agent
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct WrapState {
61    /// Current status
62    pub status: WrapStatus,
63    /// Type of approval (if awaiting approval)
64    #[serde(skip_serializing_if = "Option::is_none")]
65    pub approval_type: Option<WrapApprovalType>,
66    /// Details about the current state
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub details: Option<String>,
69    /// Available choices (for UserQuestion)
70    #[serde(skip_serializing_if = "Vec::is_empty", default)]
71    pub choices: Vec<String>,
72    /// Whether multiple selections are allowed
73    #[serde(default)]
74    pub multi_select: bool,
75    /// Current cursor position (1-indexed, for UserQuestion)
76    #[serde(default)]
77    pub cursor_position: usize,
78    /// Timestamp of last output (Unix millis)
79    pub last_output: u64,
80    /// Timestamp of last input (Unix millis)
81    pub last_input: u64,
82    /// Process ID of the wrapped command
83    pub pid: u32,
84    /// Tmux pane ID (if known)
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub pane_id: Option<String>,
87    /// Team name (if this agent is part of a team)
88    #[serde(skip_serializing_if = "Option::is_none", default)]
89    pub team_name: Option<String>,
90    /// Team member name
91    #[serde(skip_serializing_if = "Option::is_none", default)]
92    pub team_member_name: Option<String>,
93    /// Whether this agent is the team lead
94    #[serde(default)]
95    pub is_team_lead: bool,
96}
97
98impl Default for WrapState {
99    fn default() -> Self {
100        let now = current_time_millis();
101        Self {
102            status: WrapStatus::Idle,
103            approval_type: None,
104            details: None,
105            choices: Vec::new(),
106            multi_select: false,
107            cursor_position: 0,
108            last_output: now,
109            last_input: now,
110            pid: 0,
111            pane_id: None,
112            team_name: None,
113            team_member_name: None,
114            is_team_lead: false,
115        }
116    }
117}
118
119impl WrapState {
120    /// Create a new state for processing
121    pub fn processing(pid: u32) -> Self {
122        Self {
123            status: WrapStatus::Processing,
124            pid,
125            ..Default::default()
126        }
127    }
128
129    /// Create a new state for idle
130    pub fn idle(pid: u32) -> Self {
131        Self {
132            status: WrapStatus::Idle,
133            pid,
134            ..Default::default()
135        }
136    }
137
138    /// Create a new state for awaiting approval
139    pub fn awaiting_approval(
140        pid: u32,
141        approval_type: WrapApprovalType,
142        details: Option<String>,
143    ) -> Self {
144        Self {
145            status: WrapStatus::AwaitingApproval,
146            approval_type: Some(approval_type),
147            details,
148            pid,
149            ..Default::default()
150        }
151    }
152
153    /// Create a state for user question
154    pub fn user_question(
155        pid: u32,
156        choices: Vec<String>,
157        multi_select: bool,
158        cursor_position: usize,
159    ) -> Self {
160        Self {
161            status: WrapStatus::AwaitingApproval,
162            approval_type: Some(WrapApprovalType::UserQuestion),
163            choices,
164            multi_select,
165            cursor_position,
166            pid,
167            ..Default::default()
168        }
169    }
170
171    /// Update last output timestamp
172    pub fn touch_output(&mut self) {
173        self.last_output = current_time_millis();
174    }
175
176    /// Update last input timestamp
177    pub fn touch_input(&mut self) {
178        self.last_input = current_time_millis();
179    }
180
181    /// Set pane ID
182    pub fn with_pane_id(mut self, pane_id: String) -> Self {
183        self.pane_id = Some(pane_id);
184        self
185    }
186}
187
188/// Message from wrapper to tmai parent (upstream)
189#[derive(Debug, Clone, Serialize, Deserialize)]
190#[serde(tag = "type")]
191pub enum ClientMessage {
192    /// Initial registration message
193    Register {
194        pane_id: String,
195        pid: u32,
196        #[serde(skip_serializing_if = "Option::is_none")]
197        team_name: Option<String>,
198        #[serde(skip_serializing_if = "Option::is_none")]
199        team_member_name: Option<String>,
200        #[serde(default)]
201        is_team_lead: bool,
202    },
203    /// State update from wrapper
204    StateUpdate { state: WrapState },
205}
206
207/// Message from tmai parent to wrapper (downstream)
208#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(tag = "type")]
210pub enum ServerMessage {
211    /// Registration acknowledgement
212    Registered { connection_id: String },
213    /// Send keys to the wrapped process
214    SendKeys { keys: String, literal: bool },
215    /// Send text followed by Enter
216    SendKeysAndEnter { text: String },
217}
218
219/// Encode a message as ndjson (JSON + newline)
220pub fn encode<T: Serialize>(msg: &T) -> Result<Vec<u8>> {
221    let mut json = serde_json::to_vec(msg)?;
222    json.push(b'\n');
223    Ok(json)
224}
225
226/// Decode a message from a JSON line
227pub fn decode<T: DeserializeOwned>(line: &[u8]) -> Result<T> {
228    Ok(serde_json::from_slice(line)?)
229}
230
231/// Get current time in milliseconds
232pub fn current_time_millis() -> u64 {
233    use std::time::{SystemTime, UNIX_EPOCH};
234    SystemTime::now()
235        .duration_since(UNIX_EPOCH)
236        .unwrap_or_default()
237        .as_millis() as u64
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_wrap_state_serialization() {
246        let state = WrapState::processing(1234);
247        let json = serde_json::to_string(&state).unwrap();
248        assert!(json.contains("\"status\":\"processing\""));
249        assert!(json.contains("\"pid\":1234"));
250    }
251
252    #[test]
253    fn test_wrap_state_deserialization() {
254        let json = r#"{
255            "status": "awaiting_approval",
256            "approval_type": "user_question",
257            "choices": ["Yes", "No"],
258            "multi_select": false,
259            "cursor_position": 1,
260            "last_output": 1234567890,
261            "last_input": 1234567890,
262            "pid": 5678
263        }"#;
264
265        let state: WrapState = serde_json::from_str(json).unwrap();
266        assert_eq!(state.status, WrapStatus::AwaitingApproval);
267        assert_eq!(state.approval_type, Some(WrapApprovalType::UserQuestion));
268        assert_eq!(state.choices, vec!["Yes", "No"]);
269        assert_eq!(state.cursor_position, 1);
270        assert_eq!(state.pid, 5678);
271    }
272
273    #[test]
274    fn test_current_time_millis() {
275        let t1 = current_time_millis();
276        std::thread::sleep(std::time::Duration::from_millis(10));
277        let t2 = current_time_millis();
278        assert!(t2 > t1);
279    }
280
281    #[test]
282    fn test_client_message_register_serialization() {
283        let msg = ClientMessage::Register {
284            pane_id: "5".to_string(),
285            pid: 1234,
286            team_name: Some("my-team".to_string()),
287            team_member_name: Some("dev".to_string()),
288            is_team_lead: false,
289        };
290        let encoded = encode(&msg).unwrap();
291        let decoded: ClientMessage = decode(encoded.trim_ascii_end()).unwrap();
292        match decoded {
293            ClientMessage::Register { pane_id, pid, .. } => {
294                assert_eq!(pane_id, "5");
295                assert_eq!(pid, 1234);
296            }
297            _ => panic!("Expected Register"),
298        }
299    }
300
301    #[test]
302    fn test_server_message_send_keys_serialization() {
303        let msg = ServerMessage::SendKeys {
304            keys: "y".to_string(),
305            literal: true,
306        };
307        let encoded = encode(&msg).unwrap();
308        let decoded: ServerMessage = decode(encoded.trim_ascii_end()).unwrap();
309        match decoded {
310            ServerMessage::SendKeys { keys, literal } => {
311                assert_eq!(keys, "y");
312                assert!(literal);
313            }
314            _ => panic!("Expected SendKeys"),
315        }
316    }
317
318    #[test]
319    fn test_state_dir_default() {
320        // Without XDG_RUNTIME_DIR, should use /tmp/tmai-UID
321        temp_env::with_var_unset("XDG_RUNTIME_DIR", || {
322            let dir = state_dir();
323            let uid = unsafe { libc::getuid() };
324            assert_eq!(dir, PathBuf::from(format!("/tmp/tmai-{}", uid)));
325        });
326    }
327
328    #[test]
329    fn test_state_dir_with_xdg() {
330        temp_env::with_var("XDG_RUNTIME_DIR", Some("/run/user/1000"), || {
331            let dir = state_dir();
332            assert_eq!(dir, PathBuf::from("/run/user/1000/tmai"));
333        });
334    }
335
336    #[test]
337    fn test_socket_path_contains_control_sock() {
338        let path = socket_path();
339        assert!(path.ends_with("control.sock"));
340    }
341
342    #[test]
343    fn test_encode_decode_roundtrip() {
344        let state = WrapState::user_question(42, vec!["A".into(), "B".into()], true, 1);
345        let msg = ClientMessage::StateUpdate { state };
346        let encoded = encode(&msg).unwrap();
347        assert!(encoded.ends_with(b"\n"));
348        let decoded: ClientMessage = decode(encoded.trim_ascii_end()).unwrap();
349        match decoded {
350            ClientMessage::StateUpdate { state } => {
351                assert_eq!(state.pid, 42);
352                assert_eq!(state.choices, vec!["A", "B"]);
353                assert!(state.multi_select);
354            }
355            _ => panic!("Expected StateUpdate"),
356        }
357    }
358}