tsk/context/
tsk_client.rs1use crate::server::protocol::{Request, Response};
2use crate::storage::XdgDirectories;
3use crate::task::{Task, TaskStatus};
4use async_trait::async_trait;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::net::UnixStream;
10use tokio::time::timeout;
11
12#[async_trait]
14pub trait TskClient: Send + Sync {
15 async fn is_server_available(&self) -> bool;
17
18 async fn add_task(
20 &self,
21 repo_path: PathBuf,
22 task: Task,
23 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
24
25 async fn list_tasks(&self) -> Result<Vec<Task>, Box<dyn std::error::Error + Send + Sync>>;
27
28 #[allow(dead_code)]
30 async fn get_task_status(
31 &self,
32 task_id: String,
33 ) -> Result<TaskStatus, Box<dyn std::error::Error + Send + Sync>>;
34
35 async fn shutdown_server(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>>;
37}
38
39#[derive(Clone)]
41pub struct DefaultTskClient {
42 socket_path: PathBuf,
43}
44
45impl DefaultTskClient {
46 pub fn new(xdg_directories: Arc<XdgDirectories>) -> Self {
48 Self {
49 socket_path: xdg_directories.socket_path(),
50 }
51 }
52
53 async fn send_request(
55 &self,
56 request: Request,
57 ) -> Result<Response, Box<dyn std::error::Error + Send + Sync>> {
58 let stream = timeout(
60 Duration::from_secs(5),
61 UnixStream::connect(&self.socket_path),
62 )
63 .await
64 .map_err(|_| "Connection timeout")?
65 .map_err(|e| format!("Failed to connect to server: {e}"))?;
66
67 let (reader, mut writer) = stream.into_split();
68 let mut reader = BufReader::new(reader);
69
70 let request_json = serde_json::to_string(&request)?;
72 writer.write_all(request_json.as_bytes()).await?;
73 writer.write_all(b"\n").await?;
74 writer.flush().await?;
75
76 let mut response_line = String::new();
78 let bytes_read = timeout(
79 Duration::from_secs(10),
80 reader.read_line(&mut response_line),
81 )
82 .await
83 .map_err(|_| "Response timeout")?
84 .map_err(|e| format!("Failed to read response: {e}"))?;
85
86 if bytes_read == 0 || response_line.trim().is_empty() {
88 return Err("Server closed connection without sending a response".into());
89 }
90
91 let response: Response = serde_json::from_str(&response_line)?;
93 Ok(response)
94 }
95}
96
97#[async_trait]
98impl TskClient for DefaultTskClient {
99 async fn is_server_available(&self) -> bool {
100 matches!(
101 timeout(
102 Duration::from_secs(1),
103 UnixStream::connect(&self.socket_path),
104 )
105 .await,
106 Ok(Ok(_))
107 )
108 }
109
110 async fn add_task(
111 &self,
112 repo_path: PathBuf,
113 task: Task,
114 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
115 let request = Request::AddTask {
116 repo_path,
117 task: Box::new(task),
118 };
119 let response = self.send_request(request).await?;
120
121 match response {
122 Response::Success { message } => {
123 println!("{message}");
124 Ok(())
125 }
126 Response::Error { message } => Err(message.into()),
127 _ => Err("Unexpected response from server".into()),
128 }
129 }
130
131 async fn list_tasks(&self) -> Result<Vec<Task>, Box<dyn std::error::Error + Send + Sync>> {
132 let request = Request::ListTasks;
133 let response = self.send_request(request).await?;
134
135 match response {
136 Response::TaskList { tasks } => Ok(tasks),
137 Response::Error { message } => Err(message.into()),
138 _ => Err("Unexpected response from server".into()),
139 }
140 }
141
142 async fn get_task_status(
143 &self,
144 task_id: String,
145 ) -> Result<TaskStatus, Box<dyn std::error::Error + Send + Sync>> {
146 let request = Request::GetStatus { task_id };
147 let response = self.send_request(request).await?;
148
149 match response {
150 Response::TaskStatus { status } => Ok(status),
151 Response::Error { message } => Err(message.into()),
152 _ => Err("Unexpected response from server".into()),
153 }
154 }
155
156 async fn shutdown_server(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
157 let request = Request::Shutdown;
158 let response = self.send_request(request).await?;
159
160 match response {
161 Response::Success { message } => {
162 println!("{message}");
163 Ok(())
164 }
165 Response::Error { message } => Err(message.into()),
166 _ => Err("Unexpected response from server".into()),
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use tempfile::TempDir;
175
176 #[tokio::test]
177 async fn test_client_creation() {
178 let temp_dir = TempDir::new().unwrap();
179 unsafe {
180 std::env::set_var("XDG_DATA_HOME", temp_dir.path().join("data"));
181 }
182 unsafe {
183 std::env::set_var("XDG_RUNTIME_DIR", temp_dir.path().join("runtime"));
184 }
185
186 let xdg = Arc::new(XdgDirectories::new().unwrap());
187 xdg.ensure_directories().unwrap();
188
189 let client = DefaultTskClient::new(xdg.clone());
190
191 assert!(!client.is_server_available().await);
193 }
194
195 #[tokio::test]
196 async fn test_response_parsing_validates_empty_responses() {
197 let temp_dir = TempDir::new().unwrap();
202 unsafe {
203 std::env::set_var("XDG_DATA_HOME", temp_dir.path().join("data"));
204 }
205 unsafe {
206 std::env::set_var("XDG_RUNTIME_DIR", temp_dir.path().join("runtime"));
207 }
208
209 let xdg = Arc::new(XdgDirectories::new().unwrap());
210 xdg.ensure_directories().unwrap();
211
212 let client = DefaultTskClient::new(xdg.clone());
213
214 let result = client.list_tasks().await;
216 assert!(result.is_err());
217
218 let error_msg = result.unwrap_err().to_string();
220 assert!(
221 error_msg.contains("Failed to connect to server")
222 || error_msg.contains("Connection refused"),
223 "Expected connection error, got: {}",
224 error_msg
225 );
226 }
227}