1mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55use std::path::PathBuf;
56use tokio::io::{AsyncReadExt, AsyncWriteExt};
57use tokio::net::UnixStream;
58use tracing::{debug, warn};
59
60pub fn default_socket_path() -> PathBuf {
62 dirs::home_dir()
63 .map(|h| h.join(".spn").join("daemon.sock"))
64 .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
65}
66
67pub fn daemon_socket_exists() -> bool {
69 default_socket_path().exists()
70}
71
72#[derive(Debug)]
77pub struct SpnClient {
78 stream: Option<UnixStream>,
79 fallback_mode: bool,
80}
81
82impl SpnClient {
83 pub async fn connect() -> Result<Self, Error> {
87 Self::connect_to(&default_socket_path()).await
88 }
89
90 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
92 debug!("Connecting to spn daemon at {:?}", socket_path);
93
94 let stream = UnixStream::connect(socket_path)
95 .await
96 .map_err(|e| Error::ConnectionFailed {
97 path: socket_path.clone(),
98 source: e,
99 })?;
100
101 let mut client = Self {
103 stream: Some(stream),
104 fallback_mode: false,
105 };
106
107 client.ping().await?;
108 debug!("Connected to spn daemon");
109
110 Ok(client)
111 }
112
113 pub async fn connect_with_fallback() -> Result<Self, Error> {
118 match Self::connect().await {
119 Ok(client) => Ok(client),
120 Err(e) => {
121 warn!("spn daemon not running, using env var fallback: {}", e);
122 Ok(Self {
123 stream: None,
124 fallback_mode: true,
125 })
126 }
127 }
128 }
129
130 pub fn is_fallback_mode(&self) -> bool {
132 self.fallback_mode
133 }
134
135 pub async fn ping(&mut self) -> Result<String, Error> {
137 let response = self.send_request(Request::Ping).await?;
138 match response {
139 Response::Pong { version } => Ok(version),
140 Response::Error { message } => Err(Error::DaemonError(message)),
141 _ => Err(Error::UnexpectedResponse),
142 }
143 }
144
145 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
150 if self.fallback_mode {
151 return self.get_secret_from_env(provider);
152 }
153
154 let response = self
155 .send_request(Request::GetSecret {
156 provider: provider.to_string(),
157 })
158 .await?;
159
160 match response {
161 Response::Secret { value } => Ok(SecretString::from(value)),
162 Response::Error { message } => Err(Error::SecretNotFound {
163 provider: provider.to_string(),
164 details: message,
165 }),
166 _ => Err(Error::UnexpectedResponse),
167 }
168 }
169
170 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
172 if self.fallback_mode {
173 return Ok(self.get_secret_from_env(provider).is_ok());
174 }
175
176 let response = self
177 .send_request(Request::HasSecret {
178 provider: provider.to_string(),
179 })
180 .await?;
181
182 match response {
183 Response::Exists { exists } => Ok(exists),
184 Response::Error { message } => Err(Error::DaemonError(message)),
185 _ => Err(Error::UnexpectedResponse),
186 }
187 }
188
189 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
191 if self.fallback_mode {
192 return Ok(self.list_env_providers());
193 }
194
195 let response = self.send_request(Request::ListProviders).await?;
196
197 match response {
198 Response::Providers { providers } => Ok(providers),
199 Response::Error { message } => Err(Error::DaemonError(message)),
200 _ => Err(Error::UnexpectedResponse),
201 }
202 }
203
204 async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
206 let stream = self
207 .stream
208 .as_mut()
209 .ok_or(Error::NotConnected)?;
210
211 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
213
214 let len = request_json.len() as u32;
216 stream
217 .write_all(&len.to_be_bytes())
218 .await
219 .map_err(Error::IoError)?;
220 stream
221 .write_all(&request_json)
222 .await
223 .map_err(Error::IoError)?;
224
225 let mut len_buf = [0u8; 4];
227 stream
228 .read_exact(&mut len_buf)
229 .await
230 .map_err(Error::IoError)?;
231 let response_len = u32::from_be_bytes(len_buf) as usize;
232
233 if response_len > 1_048_576 {
235 return Err(Error::ResponseTooLarge(response_len));
236 }
237
238 let mut response_buf = vec![0u8; response_len];
240 stream
241 .read_exact(&mut response_buf)
242 .await
243 .map_err(Error::IoError)?;
244
245 let response: Response =
247 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
248
249 Ok(response)
250 }
251
252 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
255 let env_var = provider_to_env_var(provider);
256 std::env::var(&env_var)
257 .map(SecretString::from)
258 .map_err(|_| Error::SecretNotFound {
259 provider: provider.to_string(),
260 details: format!("Environment variable {} not set", env_var),
261 })
262 }
263
264 fn list_env_providers(&self) -> Vec<String> {
265 KNOWN_PROVIDERS
266 .iter()
267 .filter(|p| std::env::var(provider_to_env_var(p)).is_ok())
268 .map(|s| s.to_string())
269 .collect()
270 }
271}
272
273const KNOWN_PROVIDERS: &[&str] = &[
275 "anthropic",
276 "openai",
277 "mistral",
278 "groq",
279 "deepseek",
280 "gemini",
281 "ollama",
282 "neo4j",
283 "github",
284 "slack",
285 "perplexity",
286 "firecrawl",
287 "supadata",
288];
289
290fn provider_to_env_var(provider: &str) -> String {
292 match provider.to_lowercase().as_str() {
293 "anthropic" => "ANTHROPIC_API_KEY".to_string(),
294 "openai" => "OPENAI_API_KEY".to_string(),
295 "mistral" => "MISTRAL_API_KEY".to_string(),
296 "groq" => "GROQ_API_KEY".to_string(),
297 "deepseek" => "DEEPSEEK_API_KEY".to_string(),
298 "gemini" => "GEMINI_API_KEY".to_string(),
299 "ollama" => "OLLAMA_HOST".to_string(),
300 "neo4j" => "NEO4J_PASSWORD".to_string(),
301 "github" => "GITHUB_TOKEN".to_string(),
302 "slack" => "SLACK_TOKEN".to_string(),
303 "perplexity" => "PERPLEXITY_API_KEY".to_string(),
304 "firecrawl" => "FIRECRAWL_API_KEY".to_string(),
305 "supadata" => "SUPADATA_API_KEY".to_string(),
306 other => format!("{}_API_KEY", other.to_uppercase()),
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_provider_to_env_var() {
316 assert_eq!(provider_to_env_var("anthropic"), "ANTHROPIC_API_KEY");
317 assert_eq!(provider_to_env_var("openai"), "OPENAI_API_KEY");
318 assert_eq!(provider_to_env_var("neo4j"), "NEO4J_PASSWORD");
319 assert_eq!(provider_to_env_var("github"), "GITHUB_TOKEN");
320 assert_eq!(provider_to_env_var("unknown"), "UNKNOWN_API_KEY");
321 }
322
323 #[test]
324 fn test_default_socket_path() {
325 let path = default_socket_path();
326 assert!(path.to_string_lossy().contains(".spn"));
327 assert!(path.to_string_lossy().contains("daemon.sock"));
328 }
329
330 #[test]
331 fn test_daemon_socket_exists() {
332 assert!(!daemon_socket_exists());
334 }
335}