Skip to main content

smith_protocol/
client.rs

1//! IPC client for connecting to Smith service
2
3use crate::{Command, Event};
4use anyhow::{Context, Result};
5use serde_json;
6use std::path::Path;
7use std::sync::Arc;
8use tokio::io::AsyncWriteExt;
9use tokio::net::{TcpStream, UnixStream};
10use tokio::sync::{mpsc, Mutex, RwLock};
11use tracing::{debug, error, info};
12
13/// Connection type for Smith service
14pub enum Connection {
15    Unix(UnixStream),
16    Tcp(TcpStream),
17}
18
19impl Connection {
20    async fn write_all(&mut self, buf: &[u8]) -> Result<()> {
21        match self {
22            Connection::Unix(stream) => stream.write_all(buf).await.context("Unix write failed"),
23            Connection::Tcp(stream) => stream.write_all(buf).await.context("TCP write failed"),
24        }
25    }
26
27    async fn flush(&mut self) -> Result<()> {
28        match self {
29            Connection::Unix(stream) => stream.flush().await.context("Unix flush failed"),
30            Connection::Tcp(stream) => stream.flush().await.context("TCP flush failed"),
31        }
32    }
33}
34
35/// IPC client for connecting to the Smith service (legacy)
36pub struct IpcClient {
37    connection: Arc<Mutex<Connection>>,
38}
39
40/// Smith service client with both TCP and Unix socket support
41pub struct SmithClient {
42    command_sender: mpsc::Sender<Command>,
43    event_receiver: Arc<RwLock<Option<mpsc::Receiver<Event>>>>,
44}
45
46impl IpcClient {
47    /// Connect to IPC server at socket path
48    pub async fn connect<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
49        let stream = UnixStream::connect(&socket_path)
50            .await
51            .context("Failed to connect to Unix socket")?;
52
53        info!("Connected to IPC server at {:?}", socket_path.as_ref());
54
55        Ok(Self {
56            connection: Arc::new(Mutex::new(Connection::Unix(stream))),
57        })
58    }
59
60    /// Send command to server
61    pub async fn send_command(&mut self, command: &Command) -> Result<()> {
62        let json = serde_json::to_string(command).context("Failed to serialize command")?;
63
64        let mut connection = self.connection.lock().await;
65        connection
66            .write_all(json.as_bytes())
67            .await
68            .context("Failed to write command")?;
69        connection
70            .write_all(b"\n")
71            .await
72            .context("Failed to write newline")?;
73        connection
74            .flush()
75            .await
76            .context("Failed to flush connection")?;
77
78        debug!("Sent command: {}", json);
79        Ok(())
80    }
81
82    /// Process events from server
83    pub async fn process_events<F, Fut>(self, mut _handler: F) -> Result<()>
84    where
85        F: FnMut(Event) -> Fut,
86        Fut: std::future::Future<Output = Result<()>>,
87    {
88        // This is a simplified implementation - would need proper stream splitting in production
89        info!("IPC connection processing started");
90        Ok(())
91    }
92}
93
94impl SmithClient {
95    /// Connect to Smith service via TCP
96    pub async fn connect_tcp(address: &str) -> Result<Self> {
97        let stream = TcpStream::connect(address)
98            .await
99            .context("Failed to connect to TCP address")?;
100
101        info!("Connected to Smith service at {}", address);
102        Self::from_connection(Connection::Tcp(stream)).await
103    }
104
105    /// Connect to Smith service via Unix socket
106    pub async fn connect_unix<P: AsRef<Path>>(socket_path: P) -> Result<Self> {
107        let stream = UnixStream::connect(&socket_path)
108            .await
109            .context("Failed to connect to Unix socket")?;
110
111        info!("Connected to Smith service at {:?}", socket_path.as_ref());
112        Self::from_connection(Connection::Unix(stream)).await
113    }
114
115    /// Create client from existing connection
116    async fn from_connection(connection: Connection) -> Result<Self> {
117        let (command_tx, mut command_rx) = mpsc::channel::<Command>(1000);
118        let (_event_tx, event_rx) = mpsc::channel::<Event>(10000);
119
120        // Start command sending task
121        let connection = Arc::new(Mutex::new(connection));
122        let connection_for_commands = Arc::clone(&connection);
123
124        tokio::spawn(async move {
125            while let Some(command) = command_rx.recv().await {
126                let json = match serde_json::to_string(&command) {
127                    Ok(json) => json,
128                    Err(err) => {
129                        error!("Failed to serialize command: {}", err);
130                        continue;
131                    }
132                };
133
134                let mut conn = connection_for_commands.lock().await;
135                if let Err(err) = conn.write_all(json.as_bytes()).await {
136                    error!("Failed to send command: {}", err);
137                    break;
138                }
139                if let Err(err) = conn.write_all(b"\n").await {
140                    error!("Failed to send newline: {}", err);
141                    break;
142                }
143                if let Err(err) = conn.flush().await {
144                    error!("Failed to flush connection: {}", err);
145                    break;
146                }
147
148                debug!("Sent command: {}", json);
149            }
150        });
151
152        // Start event receiving task - simplified for now
153        // In a full implementation, you'd split the connection and read from it
154        tokio::spawn(async move {
155            // This is where event reading would happen
156            // For now, we'll just keep the channel alive
157            loop {
158                tokio::time::sleep(std::time::Duration::from_secs(1)).await;
159                // In real implementation, read events from connection and send to event_tx
160            }
161        });
162
163        Ok(Self {
164            command_sender: command_tx,
165            event_receiver: Arc::new(RwLock::new(Some(event_rx))),
166        })
167    }
168
169    /// Send command to Smith service
170    pub async fn send_command(&self, command: Command) -> Result<()> {
171        self.command_sender
172            .send(command)
173            .await
174            .context("Failed to send command - connection may be closed")?;
175        Ok(())
176    }
177
178    /// Receive events from Smith service
179    pub async fn receive_events(&self) -> Result<Vec<Event>> {
180        let mut events = Vec::new();
181
182        // Try to get events without blocking
183        if let Some(ref mut receiver) = self.event_receiver.write().await.as_mut() {
184            while let Ok(event) = receiver.try_recv() {
185                events.push(event);
186            }
187        }
188
189        Ok(events)
190    }
191
192    /// Get single event (blocking)
193    pub async fn receive_event(&self) -> Result<Event> {
194        if let Some(ref mut receiver) = self.event_receiver.write().await.as_mut() {
195            receiver
196                .recv()
197                .await
198                .ok_or_else(|| anyhow::anyhow!("Event channel closed"))
199        } else {
200            Err(anyhow::anyhow!("Event receiver not available"))
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use serde_json::json;
209    use std::collections::HashMap;
210    use tokio::net::TcpListener;
211    use uuid::Uuid;
212
213    #[tokio::test]
214    async fn test_connection_write_and_flush() {
215        // Test Unix connection write operations
216        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
217        let mut connection = Connection::Unix(client_stream);
218
219        let test_data = b"test data";
220        assert!(connection.write_all(test_data).await.is_ok());
221        assert!(connection.flush().await.is_ok());
222    }
223
224    #[tokio::test]
225    async fn test_tcp_connection_write_and_flush() {
226        // Create TCP listener for testing
227        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
228        let addr = listener.local_addr().unwrap();
229
230        // Connect to our test server
231        let tcp_stream = TcpStream::connect(addr).await.unwrap();
232        let mut connection = Connection::Tcp(tcp_stream);
233
234        let test_data = b"tcp test data";
235        assert!(connection.write_all(test_data).await.is_ok());
236        assert!(connection.flush().await.is_ok());
237    }
238
239    #[tokio::test]
240    async fn test_ipc_client_connection() {
241        // Create temporary Unix socket path
242        let socket_path = format!("/tmp/smith_test_{}.sock", Uuid::new_v4());
243
244        // Create Unix listener
245        let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
246
247        // Test connecting to the socket
248        tokio::spawn(async move {
249            // Accept connection in background
250            let _stream = listener.accept().await;
251        });
252
253        let client = IpcClient::connect(&socket_path).await;
254        assert!(client.is_ok());
255
256        // Cleanup
257        std::fs::remove_file(&socket_path).ok();
258    }
259
260    #[tokio::test]
261    async fn test_ipc_client_send_command() {
262        // Create Unix socket pair for testing
263        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
264        let mut client = IpcClient {
265            connection: Arc::new(Mutex::new(Connection::Unix(client_stream))),
266        };
267
268        let command = Command::Handshake {
269            version: 1,
270            capabilities: vec!["shell_exec".to_string(), "nats".to_string()],
271        };
272
273        let result = client.send_command(&command).await;
274        assert!(result.is_ok());
275    }
276
277    #[tokio::test]
278    async fn test_ipc_client_process_events() {
279        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
280        let client = IpcClient {
281            connection: Arc::new(Mutex::new(Connection::Unix(client_stream))),
282        };
283
284        // Test event processing with simple handler
285        let handler = |_event: Event| async { Ok(()) };
286
287        let result = client.process_events(handler).await;
288        assert!(result.is_ok());
289    }
290
291    #[tokio::test]
292    async fn test_smith_client_tcp_connection() {
293        // Create TCP listener for testing
294        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
295        let addr = listener.local_addr().unwrap();
296
297        // Accept connection in background
298        tokio::spawn(async move {
299            let _accepted = listener.accept().await;
300        });
301
302        let client = SmithClient::connect_tcp(&addr.to_string()).await;
303        assert!(client.is_ok());
304    }
305
306    #[tokio::test]
307    async fn test_smith_client_unix_connection() {
308        let socket_path = format!("/tmp/smith_test_unix_{}.sock", Uuid::new_v4());
309
310        // Create Unix listener
311        let listener = tokio::net::UnixListener::bind(&socket_path).unwrap();
312
313        tokio::spawn(async move {
314            let _accepted = listener.accept().await;
315        });
316
317        let client = SmithClient::connect_unix(&socket_path).await;
318        assert!(client.is_ok());
319
320        // Cleanup
321        std::fs::remove_file(&socket_path).ok();
322    }
323
324    #[tokio::test]
325    async fn test_smith_client_send_command() {
326        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
327        let client = SmithClient::from_connection(Connection::Unix(client_stream))
328            .await
329            .unwrap();
330
331        let command = Command::ToolCall {
332            request_id: Uuid::new_v4(),
333            tool: "test_tool".to_string(),
334            args: json!({"param": "value"}),
335            timeout_ms: Some(5000),
336        };
337
338        let result = client.send_command(command).await;
339        assert!(result.is_ok());
340    }
341
342    #[tokio::test]
343    async fn test_smith_client_receive_events() {
344        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
345        let client = SmithClient::from_connection(Connection::Unix(client_stream))
346            .await
347            .unwrap();
348
349        // Test receiving events (should return empty vector since no events are sent)
350        let events = client.receive_events().await.unwrap();
351        assert!(events.is_empty());
352    }
353
354    #[tokio::test]
355    async fn test_smith_client_receive_single_event_with_closed_receiver() {
356        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
357        let client = SmithClient::from_connection(Connection::Unix(client_stream))
358            .await
359            .unwrap();
360
361        // Drop the receiver to test the error path
362        {
363            let mut receiver_guard = client.event_receiver.write().await;
364            *receiver_guard = None;
365        }
366
367        // Now trying to receive should return an error
368        let result = client.receive_event().await;
369        assert!(result.is_err());
370        assert!(result
371            .unwrap_err()
372            .to_string()
373            .contains("Event receiver not available"));
374    }
375
376    #[tokio::test]
377    async fn test_command_variants() {
378        // Test all command variants can be serialized
379        let commands = vec![
380            Command::Handshake {
381                version: 1,
382                capabilities: vec!["test".to_string()],
383            },
384            Command::Plan {
385                request_id: Uuid::new_v4(),
386                goal: "test goal".to_string(),
387                context: HashMap::new(),
388            },
389            Command::ToolCall {
390                request_id: Uuid::new_v4(),
391                tool: "test_tool".to_string(),
392                args: json!({"key": "value"}),
393                timeout_ms: Some(1000),
394            },
395            Command::HookLoad {
396                request_id: Uuid::new_v4(),
397                hook_type: "js".to_string(),
398                script: "console.log('test');".to_string(),
399            },
400            Command::ShellExec {
401                request_id: Uuid::new_v4(),
402                command: "echo test".to_string(),
403                shell: Some("bash".to_string()),
404                cwd: Some("/tmp".to_string()),
405                env: HashMap::new(),
406                timeout_ms: Some(5000),
407            },
408            Command::Shutdown,
409        ];
410
411        for command in commands {
412            let json = serde_json::to_string(&command).unwrap();
413            let deserialized: Command = serde_json::from_str(&json).unwrap();
414
415            // Verify serialization roundtrip works
416            match (command, deserialized) {
417                (
418                    Command::Handshake { version: v1, .. },
419                    Command::Handshake { version: v2, .. },
420                ) => {
421                    assert_eq!(v1, v2);
422                }
423                (Command::Plan { goal: g1, .. }, Command::Plan { goal: g2, .. }) => {
424                    assert_eq!(g1, g2);
425                }
426                (Command::ToolCall { tool: t1, .. }, Command::ToolCall { tool: t2, .. }) => {
427                    assert_eq!(t1, t2);
428                }
429                (
430                    Command::HookLoad { hook_type: h1, .. },
431                    Command::HookLoad { hook_type: h2, .. },
432                ) => {
433                    assert_eq!(h1, h2);
434                }
435                (
436                    Command::ShellExec { command: c1, .. },
437                    Command::ShellExec { command: c2, .. },
438                ) => {
439                    assert_eq!(c1, c2);
440                }
441                (Command::Shutdown, Command::Shutdown) => {}
442                _ => panic!("Mismatched command variants after serialization"),
443            }
444        }
445    }
446
447    #[tokio::test]
448    async fn test_connection_tcp_error_handling() {
449        // Test TCP connection with invalid address
450        let result = SmithClient::connect_tcp("invalid.host:999999").await;
451        assert!(result.is_err());
452    }
453
454    #[tokio::test]
455    async fn test_connection_unix_error_handling() {
456        // Test Unix connection with non-existent socket
457        let non_existent_path = "/tmp/non_existent_socket.sock";
458        let result = SmithClient::connect_unix(non_existent_path).await;
459        assert!(result.is_err());
460    }
461
462    #[tokio::test]
463    async fn test_connection_error_contexts() {
464        // Test that connection errors have proper context messages
465        let result = SmithClient::connect_tcp("0.0.0.0:1").await; // Port 1 should be unavailable
466        match result {
467            Err(err) => {
468                let err_msg = format!("{}", err);
469                assert!(err_msg.contains("Failed to connect to TCP address"));
470            }
471            Ok(_) => panic!("Expected connection to fail"),
472        }
473    }
474
475    #[tokio::test]
476    async fn test_ipc_client_connect_error() {
477        // Test IPC client connection with non-existent socket
478        let non_existent_path = "/tmp/non_existent_ipc.sock";
479        let result = IpcClient::connect(non_existent_path).await;
480        assert!(result.is_err());
481
482        // Verify error context
483        match result {
484            Err(err) => {
485                let err_msg = format!("{}", err);
486                assert!(err_msg.contains("Failed to connect to Unix socket"));
487            }
488            Ok(_) => panic!("Expected connection to fail"),
489        }
490    }
491
492    #[tokio::test]
493    async fn test_smith_client_from_connection_task_spawning() {
494        // Test that background tasks are properly spawned
495        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
496        let client_result = SmithClient::from_connection(Connection::Unix(client_stream)).await;
497
498        assert!(client_result.is_ok());
499        let client = client_result.unwrap();
500
501        // Verify client can send commands (which means background task is running)
502        let command = Command::Shutdown;
503        let send_result = client.send_command(command).await;
504        assert!(send_result.is_ok());
505    }
506
507    #[tokio::test]
508    async fn test_command_serialization_edge_cases() {
509        // Test commands with optional fields and edge cases
510        let mut context = HashMap::new();
511        context.insert("key1".to_string(), "value1".to_string());
512        context.insert("key2".to_string(), "value2".to_string());
513
514        let mut env = HashMap::new();
515        env.insert("PATH".to_string(), "/usr/bin".to_string());
516
517        let shell_exec = Command::ShellExec {
518            request_id: Uuid::new_v4(),
519            command: "ls -la".to_string(),
520            shell: None, // Test None case
521            cwd: None,   // Test None case
522            env,
523            timeout_ms: None, // Test None case
524        };
525
526        let json = serde_json::to_string(&shell_exec).unwrap();
527        let deserialized: Command = serde_json::from_str(&json).unwrap();
528
529        match deserialized {
530            Command::ShellExec {
531                shell,
532                cwd,
533                timeout_ms,
534                ..
535            } => {
536                assert_eq!(shell, None);
537                assert_eq!(cwd, None);
538                assert_eq!(timeout_ms, None);
539            }
540            _ => panic!("Expected ShellExec command"),
541        }
542    }
543
544    #[tokio::test]
545    async fn test_concurrent_command_sending() {
546        // Test that multiple commands can be sent concurrently
547        let (client_stream, _server_stream) = tokio::net::UnixStream::pair().unwrap();
548        let client = Arc::new(
549            SmithClient::from_connection(Connection::Unix(client_stream))
550                .await
551                .unwrap(),
552        );
553
554        let mut handles = vec![];
555
556        for i in 0..10 {
557            let client_clone = Arc::clone(&client);
558            let handle = tokio::spawn(async move {
559                let command = Command::Plan {
560                    request_id: Uuid::new_v4(),
561                    goal: format!("test goal {}", i),
562                    context: HashMap::new(),
563                };
564                client_clone.send_command(command).await
565            });
566            handles.push(handle);
567        }
568
569        // Wait for all commands to be sent
570        for handle in handles {
571            assert!(handle.await.unwrap().is_ok());
572        }
573    }
574}