1use crate::{Command, Event};
4use anyhow::{Context, Result};
5use serde_json;
6use std::path::Path;
7use std::sync::Arc;
8use tokio::io::AsyncWriteExt;
9use tokio::net::{TcpStream, UnixStream};
10use tokio::sync::{mpsc, Mutex, RwLock};
11use tracing::{debug, error, info};
12
13pub enum Connection {
15 Unix(UnixStream),
16 Tcp(TcpStream),
17}
18
19impl Connection {
20 async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
21 match self {
22 Connection::Unix(stream) => stream.write_all(buf).await.context("Unix write failed"),
23 Connection::Tcp(stream) => stream.write_all(buf).await.context("TCP write failed"),
24 }
25 }
26
27 async fn flush(&mut self) -> Result<()> {
28 match self {
29 Connection::Unix(stream) => stream.flush().await.context("Unix flush failed"),
30 Connection::Tcp(stream) => stream.flush().await.context("TCP flush failed"),
31 }
32 }
33}
34
35pub struct IpcClient {
37 connection: Arc<Mutex<Connection>>,
38}
39
40pub struct SmithClient {
42 command_sender: mpsc::Sender<Command>,
43 event_receiver: Arc<RwLock<Option<mpsc::Receiver<Event>>>>,
44}
45
46impl IpcClient {
47 pub async fn connect<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
49 let stream = UnixStream::connect(&socket_path)
50 .await
51 .context("Failed to connect to Unix socket")?;
52
53 info!("Connected to IPC server at {:?}", socket_path.as_ref());
54
55 Ok(Self {
56 connection: Arc::new(Mutex::new(Connection::Unix(stream))),
57 })
58 }
59
60 pub async fn send_command(&mut self, command: &Command) -> Result<()> {
62 let json = serde_json::to_string(command).context("Failed to serialize command")?;
63
64 let mut connection = self.connection.lock().await;
65 connection
66 .write_all(json.as_bytes())
67 .await
68 .context("Failed to write command")?;
69 connection
70 .write_all(b"\n")
71 .await
72 .context("Failed to write newline")?;
73 connection
74 .flush()
75 .await
76 .context("Failed to flush connection")?;
77
78 debug!("Sent command: {}", json);
79 Ok(())
80 }
81
82 pub async fn process_events<F, Fut>(self, mut _handler: F) -> Result<()>
84 where
85 F: FnMut(Event) -> Fut,
86 Fut: std::future::Future<Output = Result<()>>,
87 {
88 info!("IPC connection processing started");
90 Ok(())
91 }
92}
93
94impl SmithClient {
95 pub async fn connect_tcp(address: &str) -> Result<Self> {
97 let stream = TcpStream::connect(address)
98 .await
99 .context("Failed to connect to TCP address")?;
100
101 info!("Connected to Smith service at {}", address);
102 Self::from_connection(Connection::Tcp(stream)).await
103 }
104
105 pub async fn connect_unix<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
107 let stream = UnixStream::connect(&socket_path)
108 .await
109 .context("Failed to connect to Unix socket")?;
110
111 info!("Connected to Smith service at {:?}", socket_path.as_ref());
112 Self::from_connection(Connection::Unix(stream)).await
113 }
114
115 async fn from_connection(connection: Connection) -> Result<Self> {
117 let (command_tx, mut command_rx) = mpsc::channel::<Command>(1000);
118 let (_event_tx, event_rx) = mpsc::channel::<Event>(10000);
119
120 let connection = Arc::new(Mutex::new(connection));
122 let connection_for_commands = Arc::clone(&connection);
123
124 tokio::spawn(async move {
125 while let Some(command) = command_rx.recv().await {
126 let json = match serde_json::to_string(&command) {
127 Ok(json) => json,
128 Err(err) => {
129 error!("Failed to serialize command: {}", err);
130 continue;
131 }
132 };
133
134 let mut conn = connection_for_commands.lock().await;
135 if let Err(err) = conn.write_all(json.as_bytes()).await {
136 error!("Failed to send command: {}", err);
137 break;
138 }
139 if let Err(err) = conn.write_all(b"\n").await {
140 error!("Failed to send newline: {}", err);
141 break;
142 }
143 if let Err(err) = conn.flush().await {
144 error!("Failed to flush connection: {}", err);
145 break;
146 }
147
148 debug!("Sent command: {}", json);
149 }
150 });
151
152 tokio::spawn(async move {
155 loop {
158 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
159 }
161 });
162
163 Ok(Self {
164 command_sender: command_tx,
165 event_receiver: Arc::new(RwLock::new(Some(event_rx))),
166 })
167 }
168
169 pub async fn send_command(&self, command: Command) -> Result<()> {
171 self.command_sender
172 .send(command)
173 .await
174 .context("Failed to send command - connection may be closed")?;
175 Ok(())
176 }
177
178 pub async fn receive_events(&self) -> Result<Vec<Event>> {
180 let mut events = Vec::new();
181
182 if let Some(ref mut receiver) = self.event_receiver.write().await.as_mut() {
184 while let Ok(event) = receiver.try_recv() {
185 events.push(event);
186 }
187 }
188
189 Ok(events)
190 }
191
192 pub async fn receive_event(&self) -> Result<Event> {
194 if let Some(ref mut receiver) = self.event_receiver.write().await.as_mut() {
195 receiver
196 .recv()
197 .await
198 .ok_or_else(|| anyhow::anyhow!("Event channel closed"))
199 } else {
200 Err(anyhow::anyhow!("Event receiver not available"))
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use serde_json::json;
209 use std::collections::HashMap;
210 use tokio::net::TcpListener;
211 use uuid::Uuid;
212
213 #[tokio::test]
214 async fn test_connection_write_and_flush() {
215 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
217 let mut connection = Connection::Unix(client_stream);
218
219 let test_data = b"test data";
220 assert!(connection.write_all(test_data).await.is_ok());
221 assert!(connection.flush().await.is_ok());
222 }
223
224 #[tokio::test]
225 async fn test_tcp_connection_write_and_flush() {
226 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
228 let addr = listener.local_addr().unwrap();
229
230 let tcp_stream = TcpStream::connect(addr).await.unwrap();
232 let mut connection = Connection::Tcp(tcp_stream);
233
234 let test_data = b"tcp test data";
235 assert!(connection.write_all(test_data).await.is_ok());
236 assert!(connection.flush().await.is_ok());
237 }
238
239 #[tokio::test]
240 async fn test_ipc_client_connection() {
241 let socket_path = format!("/tmp/smith_test_{}.sock", Uuid::new_v4());
243
244 let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
246
247 tokio::spawn(async move {
249 let _stream = listener.accept().await;
251 });
252
253 let client = IpcClient::connect(&socket_path).await;
254 assert!(client.is_ok());
255
256 std::fs::remove_file(&socket_path).ok();
258 }
259
260 #[tokio::test]
261 async fn test_ipc_client_send_command() {
262 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
264 let mut client = IpcClient {
265 connection: Arc::new(Mutex::new(Connection::Unix(client_stream))),
266 };
267
268 let command = Command::Handshake {
269 version: 1,
270 capabilities: vec!["shell_exec".to_string(), "nats".to_string()],
271 };
272
273 let result = client.send_command(&command).await;
274 assert!(result.is_ok());
275 }
276
277 #[tokio::test]
278 async fn test_ipc_client_process_events() {
279 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
280 let client = IpcClient {
281 connection: Arc::new(Mutex::new(Connection::Unix(client_stream))),
282 };
283
284 let handler = |_event: Event| async { Ok(()) };
286
287 let result = client.process_events(handler).await;
288 assert!(result.is_ok());
289 }
290
291 #[tokio::test]
292 async fn test_smith_client_tcp_connection() {
293 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
295 let addr = listener.local_addr().unwrap();
296
297 tokio::spawn(async move {
299 let _accepted = listener.accept().await;
300 });
301
302 let client = SmithClient::connect_tcp(&addr.to_string()).await;
303 assert!(client.is_ok());
304 }
305
306 #[tokio::test]
307 async fn test_smith_client_unix_connection() {
308 let socket_path = format!("/tmp/smith_test_unix_{}.sock", Uuid::new_v4());
309
310 let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
312
313 tokio::spawn(async move {
314 let _accepted = listener.accept().await;
315 });
316
317 let client = SmithClient::connect_unix(&socket_path).await;
318 assert!(client.is_ok());
319
320 std::fs::remove_file(&socket_path).ok();
322 }
323
324 #[tokio::test]
325 async fn test_smith_client_send_command() {
326 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
327 let client = SmithClient::from_connection(Connection::Unix(client_stream))
328 .await
329 .unwrap();
330
331 let command = Command::ToolCall {
332 request_id: Uuid::new_v4(),
333 tool: "test_tool".to_string(),
334 args: json!({"param": "value"}),
335 timeout_ms: Some(5000),
336 };
337
338 let result = client.send_command(command).await;
339 assert!(result.is_ok());
340 }
341
342 #[tokio::test]
343 async fn test_smith_client_receive_events() {
344 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
345 let client = SmithClient::from_connection(Connection::Unix(client_stream))
346 .await
347 .unwrap();
348
349 let events = client.receive_events().await.unwrap();
351 assert!(events.is_empty());
352 }
353
354 #[tokio::test]
355 async fn test_smith_client_receive_single_event_with_closed_receiver() {
356 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
357 let client = SmithClient::from_connection(Connection::Unix(client_stream))
358 .await
359 .unwrap();
360
361 {
363 let mut receiver_guard = client.event_receiver.write().await;
364 *receiver_guard = None;
365 }
366
367 let result = client.receive_event().await;
369 assert!(result.is_err());
370 assert!(result
371 .unwrap_err()
372 .to_string()
373 .contains("Event receiver not available"));
374 }
375
376 #[tokio::test]
377 async fn test_command_variants() {
378 let commands = vec![
380 Command::Handshake {
381 version: 1,
382 capabilities: vec!["test".to_string()],
383 },
384 Command::Plan {
385 request_id: Uuid::new_v4(),
386 goal: "test goal".to_string(),
387 context: HashMap::new(),
388 },
389 Command::ToolCall {
390 request_id: Uuid::new_v4(),
391 tool: "test_tool".to_string(),
392 args: json!({"key": "value"}),
393 timeout_ms: Some(1000),
394 },
395 Command::HookLoad {
396 request_id: Uuid::new_v4(),
397 hook_type: "js".to_string(),
398 script: "console.log('test');".to_string(),
399 },
400 Command::ShellExec {
401 request_id: Uuid::new_v4(),
402 command: "echo test".to_string(),
403 shell: Some("bash".to_string()),
404 cwd: Some("/tmp".to_string()),
405 env: HashMap::new(),
406 timeout_ms: Some(5000),
407 },
408 Command::Shutdown,
409 ];
410
411 for command in commands {
412 let json = serde_json::to_string(&command).unwrap();
413 let deserialized: Command = serde_json::from_str(&json).unwrap();
414
415 match (command, deserialized) {
417 (
418 Command::Handshake { version: v1, .. },
419 Command::Handshake { version: v2, .. },
420 ) => {
421 assert_eq!(v1, v2);
422 }
423 (Command::Plan { goal: g1, .. }, Command::Plan { goal: g2, .. }) => {
424 assert_eq!(g1, g2);
425 }
426 (Command::ToolCall { tool: t1, .. }, Command::ToolCall { tool: t2, .. }) => {
427 assert_eq!(t1, t2);
428 }
429 (
430 Command::HookLoad { hook_type: h1, .. },
431 Command::HookLoad { hook_type: h2, .. },
432 ) => {
433 assert_eq!(h1, h2);
434 }
435 (
436 Command::ShellExec { command: c1, .. },
437 Command::ShellExec { command: c2, .. },
438 ) => {
439 assert_eq!(c1, c2);
440 }
441 (Command::Shutdown, Command::Shutdown) => {}
442 _ => panic!("Mismatched command variants after serialization"),
443 }
444 }
445 }
446
447 #[tokio::test]
448 async fn test_connection_tcp_error_handling() {
449 let result = SmithClient::connect_tcp("invalid.host:999999").await;
451 assert!(result.is_err());
452 }
453
454 #[tokio::test]
455 async fn test_connection_unix_error_handling() {
456 let non_existent_path = "/tmp/non_existent_socket.sock";
458 let result = SmithClient::connect_unix(non_existent_path).await;
459 assert!(result.is_err());
460 }
461
462 #[tokio::test]
463 async fn test_connection_error_contexts() {
464 let result = SmithClient::connect_tcp("0.0.0.0:1").await; match result {
467 Err(err) => {
468 let err_msg = format!("{}", err);
469 assert!(err_msg.contains("Failed to connect to TCP address"));
470 }
471 Ok(_) => panic!("Expected connection to fail"),
472 }
473 }
474
475 #[tokio::test]
476 async fn test_ipc_client_connect_error() {
477 let non_existent_path = "/tmp/non_existent_ipc.sock";
479 let result = IpcClient::connect(non_existent_path).await;
480 assert!(result.is_err());
481
482 match result {
484 Err(err) => {
485 let err_msg = format!("{}", err);
486 assert!(err_msg.contains("Failed to connect to Unix socket"));
487 }
488 Ok(_) => panic!("Expected connection to fail"),
489 }
490 }
491
492 #[tokio::test]
493 async fn test_smith_client_from_connection_task_spawning() {
494 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
496 let client_result = SmithClient::from_connection(Connection::Unix(client_stream)).await;
497
498 assert!(client_result.is_ok());
499 let client = client_result.unwrap();
500
501 let command = Command::Shutdown;
503 let send_result = client.send_command(command).await;
504 assert!(send_result.is_ok());
505 }
506
507 #[tokio::test]
508 async fn test_command_serialization_edge_cases() {
509 let mut context = HashMap::new();
511 context.insert("key1".to_string(), "value1".to_string());
512 context.insert("key2".to_string(), "value2".to_string());
513
514 let mut env = HashMap::new();
515 env.insert("PATH".to_string(), "/usr/bin".to_string());
516
517 let shell_exec = Command::ShellExec {
518 request_id: Uuid::new_v4(),
519 command: "ls -la".to_string(),
520 shell: None, cwd: None, env,
523 timeout_ms: None, };
525
526 let json = serde_json::to_string(&shell_exec).unwrap();
527 let deserialized: Command = serde_json::from_str(&json).unwrap();
528
529 match deserialized {
530 Command::ShellExec {
531 shell,
532 cwd,
533 timeout_ms,
534 ..
535 } => {
536 assert_eq!(shell, None);
537 assert_eq!(cwd, None);
538 assert_eq!(timeout_ms, None);
539 }
540 _ => panic!("Expected ShellExec command"),
541 }
542 }
543
544 #[tokio::test]
545 async fn test_concurrent_command_sending() {
546 let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
548 let client = Arc::new(
549 SmithClient::from_connection(Connection::Unix(client_stream))
550 .await
551 .unwrap(),
552 );
553
554 let mut handles = vec![];
555
556 for i in 0..10 {
557 let client_clone = Arc::clone(&client);
558 let handle = tokio::spawn(async move {
559 let command = Command::Plan {
560 request_id: Uuid::new_v4(),
561 goal: format!("test goal {}", i),
562 context: HashMap::new(),
563 };
564 client_clone.send_command(command).await
565 });
566 handles.push(handle);
567 }
568
569 for handle in handles {
571 assert!(handle.await.unwrap().is_ok());
572 }
573 }
574}