tsk/context/
tsk_client.rs

1use crate::server::protocol::{Request, Response};
2use crate::storage::XdgDirectories;
3use crate::task::{Task, TaskStatus};
4use async_trait::async_trait;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::net::UnixStream;
10use tokio::time::timeout;
11
12/// Trait for communicating with the TSK server
13#[async_trait]
14pub trait TskClient: Send + Sync {
15    /// Check if the server is available
16    async fn is_server_available(&self) -> bool;
17
18    /// Add a task to the server
19    async fn add_task(
20        &self,
21        repo_path: PathBuf,
22        task: Task,
23    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
24
25    /// List all tasks from the server
26    async fn list_tasks(&self) -> Result<Vec<Task>, Box<dyn std::error::Error + Send + Sync>>;
27
28    /// Get the status of a specific task
29    #[allow(dead_code)]
30    async fn get_task_status(
31        &self,
32        task_id: String,
33    ) -> Result<TaskStatus, Box<dyn std::error::Error + Send + Sync>>;
34
35    /// Shutdown the server
36    async fn shutdown_server(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
37}
38
39/// Default implementation of TskClient that communicates with the TSK server via Unix sockets
40#[derive(Clone)]
41pub struct DefaultTskClient {
42    socket_path: PathBuf,
43}
44
45impl DefaultTskClient {
46    /// Create a new TSK client
47    pub fn new(xdg_directories: Arc<XdgDirectories>) -> Self {
48        Self {
49            socket_path: xdg_directories.socket_path(),
50        }
51    }
52
53    /// Send a request to the server and get a response
54    async fn send_request(
55        &self,
56        request: Request,
57    ) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
58        // Connect to server with timeout
59        let stream = timeout(
60            Duration::from_secs(5),
61            UnixStream::connect(&self.socket_path),
62        )
63        .await
64        .map_err(|_| "Connection timeout")?
65        .map_err(|e| format!("Failed to connect to server: {e}"))?;
66
67        let (reader, mut writer) = stream.into_split();
68        let mut reader = BufReader::new(reader);
69
70        // Send request
71        let request_json = serde_json::to_string(&request)?;
72        writer.write_all(request_json.as_bytes()).await?;
73        writer.write_all(b"\n").await?;
74        writer.flush().await?;
75
76        // Read response with timeout
77        let mut response_line = String::new();
78        let bytes_read = timeout(
79            Duration::from_secs(10),
80            reader.read_line(&mut response_line),
81        )
82        .await
83        .map_err(|_| "Response timeout")?
84        .map_err(|e| format!("Failed to read response: {e}"))?;
85
86        // Check if we received any data
87        if bytes_read == 0 || response_line.trim().is_empty() {
88            return Err("Server closed connection without sending a response".into());
89        }
90
91        // Parse response
92        let response: Response = serde_json::from_str(&response_line)?;
93        Ok(response)
94    }
95}
96
97#[async_trait]
98impl TskClient for DefaultTskClient {
99    async fn is_server_available(&self) -> bool {
100        matches!(
101            timeout(
102                Duration::from_secs(1),
103                UnixStream::connect(&self.socket_path),
104            )
105            .await,
106            Ok(Ok(_))
107        )
108    }
109
110    async fn add_task(
111        &self,
112        repo_path: PathBuf,
113        task: Task,
114    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
115        let request = Request::AddTask {
116            repo_path,
117            task: Box::new(task),
118        };
119        let response = self.send_request(request).await?;
120
121        match response {
122            Response::Success { message } => {
123                println!("{message}");
124                Ok(())
125            }
126            Response::Error { message } => Err(message.into()),
127            _ => Err("Unexpected response from server".into()),
128        }
129    }
130
131    async fn list_tasks(&self) -> Result<Vec<Task>, Box<dyn std::error::Error + Send + Sync>> {
132        let request = Request::ListTasks;
133        let response = self.send_request(request).await?;
134
135        match response {
136            Response::TaskList { tasks } => Ok(tasks),
137            Response::Error { message } => Err(message.into()),
138            _ => Err("Unexpected response from server".into()),
139        }
140    }
141
142    async fn get_task_status(
143        &self,
144        task_id: String,
145    ) -> Result<TaskStatus, Box<dyn std::error::Error + Send + Sync>> {
146        let request = Request::GetStatus { task_id };
147        let response = self.send_request(request).await?;
148
149        match response {
150            Response::TaskStatus { status } => Ok(status),
151            Response::Error { message } => Err(message.into()),
152            _ => Err("Unexpected response from server".into()),
153        }
154    }
155
156    async fn shutdown_server(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
157        let request = Request::Shutdown;
158        let response = self.send_request(request).await?;
159
160        match response {
161            Response::Success { message } => {
162                println!("{message}");
163                Ok(())
164            }
165            Response::Error { message } => Err(message.into()),
166            _ => Err("Unexpected response from server".into()),
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use tempfile::TempDir;
175
176    #[tokio::test]
177    async fn test_client_creation() {
178        let temp_dir = TempDir::new().unwrap();
179        unsafe {
180            std::env::set_var("XDG_DATA_HOME", temp_dir.path().join("data"));
181        }
182        unsafe {
183            std::env::set_var("XDG_RUNTIME_DIR", temp_dir.path().join("runtime"));
184        }
185
186        let xdg = Arc::new(XdgDirectories::new().unwrap());
187        xdg.ensure_directories().unwrap();
188
189        let client = DefaultTskClient::new(xdg.clone());
190
191        // Server should not be available without starting it
192        assert!(!client.is_server_available().await);
193    }
194
195    #[tokio::test]
196    async fn test_response_parsing_validates_empty_responses() {
197        // This test documents that the send_request method now properly handles
198        // empty responses by returning an error instead of causing a JSON parse error.
199        // The actual EOF scenario is tested implicitly when the server closes
200        // connections without sending data, which was the original bug.
201        let temp_dir = TempDir::new().unwrap();
202        unsafe {
203            std::env::set_var("XDG_DATA_HOME", temp_dir.path().join("data"));
204        }
205        unsafe {
206            std::env::set_var("XDG_RUNTIME_DIR", temp_dir.path().join("runtime"));
207        }
208
209        let xdg = Arc::new(XdgDirectories::new().unwrap());
210        xdg.ensure_directories().unwrap();
211
212        let client = DefaultTskClient::new(xdg.clone());
213
214        // Attempting to list tasks when server is not running should fail gracefully
215        let result = client.list_tasks().await;
216        assert!(result.is_err());
217
218        // The error should be about connection, not JSON parsing
219        let error_msg = result.unwrap_err().to_string();
220        assert!(
221            error_msg.contains("Failed to connect to server")
222                || error_msg.contains("Connection refused"),
223            "Expected connection error, got: {}",
224            error_msg
225        );
226    }
227}