1use std::path::PathBuf;
7
8use anyhow::Result;
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11
12pub fn state_dir() -> PathBuf {
14 if let Ok(xdg) = std::env::var("XDG_RUNTIME_DIR") {
15 PathBuf::from(xdg).join("tmai")
16 } else {
17 let uid = unsafe { libc::getuid() };
18 PathBuf::from(format!("/tmp/tmai-{}", uid))
19 }
20}
21
22pub fn socket_path() -> PathBuf {
24 state_dir().join("control.sock")
25}
26
27#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum WrapStatus {
31 Processing,
33 #[default]
35 Idle,
36 AwaitingApproval,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "snake_case")]
43pub enum WrapApprovalType {
44 FileEdit,
46 ShellCommand,
48 McpTool,
50 UserQuestion,
52 YesNo,
54 Other,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct WrapState {
61 pub status: WrapStatus,
63 #[serde(skip_serializing_if = "Option::is_none")]
65 pub approval_type: Option<WrapApprovalType>,
66 #[serde(skip_serializing_if = "Option::is_none")]
68 pub details: Option<String>,
69 #[serde(skip_serializing_if = "Vec::is_empty", default)]
71 pub choices: Vec<String>,
72 #[serde(default)]
74 pub multi_select: bool,
75 #[serde(default)]
77 pub cursor_position: usize,
78 pub last_output: u64,
80 pub last_input: u64,
82 pub pid: u32,
84 #[serde(skip_serializing_if = "Option::is_none")]
86 pub pane_id: Option<String>,
87 #[serde(skip_serializing_if = "Option::is_none", default)]
89 pub team_name: Option<String>,
90 #[serde(skip_serializing_if = "Option::is_none", default)]
92 pub team_member_name: Option<String>,
93 #[serde(default)]
95 pub is_team_lead: bool,
96}
97
98impl Default for WrapState {
99 fn default() -> Self {
100 let now = current_time_millis();
101 Self {
102 status: WrapStatus::Idle,
103 approval_type: None,
104 details: None,
105 choices: Vec::new(),
106 multi_select: false,
107 cursor_position: 0,
108 last_output: now,
109 last_input: now,
110 pid: 0,
111 pane_id: None,
112 team_name: None,
113 team_member_name: None,
114 is_team_lead: false,
115 }
116 }
117}
118
119impl WrapState {
120 pub fn processing(pid: u32) -> Self {
122 Self {
123 status: WrapStatus::Processing,
124 pid,
125 ..Default::default()
126 }
127 }
128
129 pub fn idle(pid: u32) -> Self {
131 Self {
132 status: WrapStatus::Idle,
133 pid,
134 ..Default::default()
135 }
136 }
137
138 pub fn awaiting_approval(
140 pid: u32,
141 approval_type: WrapApprovalType,
142 details: Option<String>,
143 ) -> Self {
144 Self {
145 status: WrapStatus::AwaitingApproval,
146 approval_type: Some(approval_type),
147 details,
148 pid,
149 ..Default::default()
150 }
151 }
152
153 pub fn user_question(
155 pid: u32,
156 choices: Vec<String>,
157 multi_select: bool,
158 cursor_position: usize,
159 ) -> Self {
160 Self {
161 status: WrapStatus::AwaitingApproval,
162 approval_type: Some(WrapApprovalType::UserQuestion),
163 choices,
164 multi_select,
165 cursor_position,
166 pid,
167 ..Default::default()
168 }
169 }
170
171 pub fn touch_output(&mut self) {
173 self.last_output = current_time_millis();
174 }
175
176 pub fn touch_input(&mut self) {
178 self.last_input = current_time_millis();
179 }
180
181 pub fn with_pane_id(mut self, pane_id: String) -> Self {
183 self.pane_id = Some(pane_id);
184 self
185 }
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
190#[serde(tag = "type")]
191pub enum ClientMessage {
192 Register {
194 pane_id: String,
195 pid: u32,
196 #[serde(skip_serializing_if = "Option::is_none")]
197 team_name: Option<String>,
198 #[serde(skip_serializing_if = "Option::is_none")]
199 team_member_name: Option<String>,
200 #[serde(default)]
201 is_team_lead: bool,
202 },
203 StateUpdate { state: WrapState },
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
209#[serde(tag = "type")]
210pub enum ServerMessage {
211 Registered { connection_id: String },
213 SendKeys { keys: String, literal: bool },
215 SendKeysAndEnter { text: String },
217}
218
219pub fn encode<T: Serialize>(msg: &T) -> Result<Vec<u8>> {
221 let mut json = serde_json::to_vec(msg)?;
222 json.push(b'\n');
223 Ok(json)
224}
225
226pub fn decode<T: DeserializeOwned>(line: &[u8]) -> Result<T> {
228 Ok(serde_json::from_slice(line)?)
229}
230
231pub fn current_time_millis() -> u64 {
233 use std::time::{SystemTime, UNIX_EPOCH};
234 SystemTime::now()
235 .duration_since(UNIX_EPOCH)
236 .unwrap_or_default()
237 .as_millis() as u64
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_wrap_state_serialization() {
246 let state = WrapState::processing(1234);
247 let json = serde_json::to_string(&state).unwrap();
248 assert!(json.contains("\"status\":\"processing\""));
249 assert!(json.contains("\"pid\":1234"));
250 }
251
252 #[test]
253 fn test_wrap_state_deserialization() {
254 let json = r#"{
255 "status": "awaiting_approval",
256 "approval_type": "user_question",
257 "choices": ["Yes", "No"],
258 "multi_select": false,
259 "cursor_position": 1,
260 "last_output": 1234567890,
261 "last_input": 1234567890,
262 "pid": 5678
263 }"#;
264
265 let state: WrapState = serde_json::from_str(json).unwrap();
266 assert_eq!(state.status, WrapStatus::AwaitingApproval);
267 assert_eq!(state.approval_type, Some(WrapApprovalType::UserQuestion));
268 assert_eq!(state.choices, vec!["Yes", "No"]);
269 assert_eq!(state.cursor_position, 1);
270 assert_eq!(state.pid, 5678);
271 }
272
273 #[test]
274 fn test_current_time_millis() {
275 let t1 = current_time_millis();
276 std::thread::sleep(std::time::Duration::from_millis(10));
277 let t2 = current_time_millis();
278 assert!(t2 > t1);
279 }
280
281 #[test]
282 fn test_client_message_register_serialization() {
283 let msg = ClientMessage::Register {
284 pane_id: "5".to_string(),
285 pid: 1234,
286 team_name: Some("my-team".to_string()),
287 team_member_name: Some("dev".to_string()),
288 is_team_lead: false,
289 };
290 let encoded = encode(&msg).unwrap();
291 let decoded: ClientMessage = decode(encoded.trim_ascii_end()).unwrap();
292 match decoded {
293 ClientMessage::Register { pane_id, pid, .. } => {
294 assert_eq!(pane_id, "5");
295 assert_eq!(pid, 1234);
296 }
297 _ => panic!("Expected Register"),
298 }
299 }
300
301 #[test]
302 fn test_server_message_send_keys_serialization() {
303 let msg = ServerMessage::SendKeys {
304 keys: "y".to_string(),
305 literal: true,
306 };
307 let encoded = encode(&msg).unwrap();
308 let decoded: ServerMessage = decode(encoded.trim_ascii_end()).unwrap();
309 match decoded {
310 ServerMessage::SendKeys { keys, literal } => {
311 assert_eq!(keys, "y");
312 assert!(literal);
313 }
314 _ => panic!("Expected SendKeys"),
315 }
316 }
317
318 #[test]
319 fn test_state_dir_default() {
320 temp_env::with_var_unset("XDG_RUNTIME_DIR", || {
322 let dir = state_dir();
323 let uid = unsafe { libc::getuid() };
324 assert_eq!(dir, PathBuf::from(format!("/tmp/tmai-{}", uid)));
325 });
326 }
327
328 #[test]
329 fn test_state_dir_with_xdg() {
330 temp_env::with_var("XDG_RUNTIME_DIR", Some("/run/user/1000"), || {
331 let dir = state_dir();
332 assert_eq!(dir, PathBuf::from("/run/user/1000/tmai"));
333 });
334 }
335
336 #[test]
337 fn test_socket_path_contains_control_sock() {
338 let path = socket_path();
339 assert!(path.ends_with("control.sock"));
340 }
341
342 #[test]
343 fn test_encode_decode_roundtrip() {
344 let state = WrapState::user_question(42, vec!["A".into(), "B".into()], true, 1);
345 let msg = ClientMessage::StateUpdate { state };
346 let encoded = encode(&msg).unwrap();
347 assert!(encoded.ends_with(b"\n"));
348 let decoded: ClientMessage = decode(encoded.trim_ascii_end()).unwrap();
349 match decoded {
350 ClientMessage::StateUpdate { state } => {
351 assert_eq!(state.pid, 42);
352 assert_eq!(state.choices, vec!["A", "B"]);
353 assert!(state.multi_select);
354 }
355 _ => panic!("Expected StateUpdate"),
356 }
357 }
358}