1mod error;
49mod paths;
50mod protocol;
51
52pub use error::Error;
53pub use paths::{PathError, SpnPaths};
54pub use protocol::{Request, Response};
55pub use secrecy::{ExposeSecret, SecretString};
56
57pub use spn_core::{
59 find_provider,
60 mask_key,
61 provider_to_env_var,
62 providers_by_category,
63 validate_key_format,
64 BackendError,
65 GpuInfo,
66 LoadConfig,
67 McpConfig,
68 McpServer,
70 McpServerType,
71 McpSource,
72 ModelInfo,
73 PackageManifest,
74 PackageRef,
76 PackageType,
77 Provider,
79 ProviderCategory,
80 PullProgress,
82 RunningModel,
83 Source,
84 ValidationResult,
86 KNOWN_PROVIDERS,
87};
88
89use std::path::PathBuf;
90use std::time::Duration;
91#[cfg(unix)]
92use tokio::io::{AsyncReadExt, AsyncWriteExt};
93#[cfg(unix)]
94use tokio::net::UnixStream;
95use tracing::debug;
96#[cfg(unix)]
97use tracing::warn;
98
99pub const DEFAULT_IPC_TIMEOUT: Duration = Duration::from_secs(30);
101
102pub fn socket_path() -> Result<PathBuf, Error> {
109 SpnPaths::new().map(|p| p.socket_file()).map_err(|_| {
110 Error::Configuration("HOME directory not found. Set HOME environment variable.".into())
111 })
112}
113
114pub fn daemon_socket_exists() -> bool {
118 socket_path().map(|p| p.exists()).unwrap_or(false)
119}
120
121#[derive(Debug)]
129pub struct SpnClient {
130 #[cfg(unix)]
131 stream: Option<UnixStream>,
132 fallback_mode: bool,
133 timeout: Duration,
135}
136
137impl SpnClient {
138 #[cfg(unix)]
144 pub async fn connect() -> Result<Self, Error> {
145 let path = socket_path()?;
146 Self::connect_to(&path).await
147 }
148
149 #[cfg(unix)]
153 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
154 debug!("Connecting to spn daemon at {:?}", socket_path);
155
156 let stream =
157 UnixStream::connect(socket_path)
158 .await
159 .map_err(|e| Error::ConnectionFailed {
160 path: socket_path.clone(),
161 source: e,
162 })?;
163
164 let mut client = Self {
166 stream: Some(stream),
167 fallback_mode: false,
168 timeout: DEFAULT_IPC_TIMEOUT,
169 };
170
171 client.ping().await?;
172 debug!("Connected to spn daemon");
173
174 Ok(client)
175 }
176
177 pub fn set_timeout(&mut self, timeout: Duration) {
181 self.timeout = timeout;
182 }
183
184 pub fn timeout(&self) -> Duration {
186 self.timeout
187 }
188
189 #[cfg(unix)]
196 pub async fn connect_with_fallback() -> Result<Self, Error> {
197 match Self::connect().await {
198 Ok(client) => Ok(client),
199 Err(e) => {
200 warn!("spn daemon not running, using env var fallback: {}", e);
201 Ok(Self {
202 stream: None,
203 fallback_mode: true,
204 timeout: DEFAULT_IPC_TIMEOUT,
205 })
206 }
207 }
208 }
209
210 #[cfg(not(unix))]
215 pub async fn connect_with_fallback() -> Result<Self, Error> {
216 debug!("Non-Unix platform: using env var fallback mode");
217 Ok(Self {
218 fallback_mode: true,
219 timeout: DEFAULT_IPC_TIMEOUT,
220 })
221 }
222
223 pub fn is_fallback_mode(&self) -> bool {
225 self.fallback_mode
226 }
227
228 #[cfg(unix)]
232 pub async fn ping(&mut self) -> Result<String, Error> {
233 let response = self.send_request(Request::Ping).await?;
234 match response {
235 Response::Pong { version } => Ok(version),
236 Response::Error { message } => Err(Error::DaemonError(message)),
237 _ => Err(Error::UnexpectedResponse),
238 }
239 }
240
241 #[cfg(unix)]
246 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
247 if self.fallback_mode {
248 return self.get_secret_from_env(provider);
249 }
250
251 let response = self
252 .send_request(Request::GetSecret {
253 provider: provider.to_string(),
254 })
255 .await?;
256
257 match response {
258 Response::Secret { value } => Ok(SecretString::from(value)),
259 Response::Error { message } => Err(Error::SecretNotFound {
260 provider: provider.to_string(),
261 details: message,
262 }),
263 _ => Err(Error::UnexpectedResponse),
264 }
265 }
266
267 #[cfg(not(unix))]
271 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
272 self.get_secret_from_env(provider)
273 }
274
275 #[cfg(unix)]
277 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
278 if self.fallback_mode {
279 return Ok(self.get_secret_from_env(provider).is_ok());
280 }
281
282 let response = self
283 .send_request(Request::HasSecret {
284 provider: provider.to_string(),
285 })
286 .await?;
287
288 match response {
289 Response::Exists { exists } => Ok(exists),
290 Response::Error { message } => Err(Error::DaemonError(message)),
291 _ => Err(Error::UnexpectedResponse),
292 }
293 }
294
295 #[cfg(not(unix))]
299 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
300 Ok(self.get_secret_from_env(provider).is_ok())
301 }
302
303 #[cfg(unix)]
305 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
306 if self.fallback_mode {
307 return Ok(self.list_env_providers());
308 }
309
310 let response = self.send_request(Request::ListProviders).await?;
311
312 match response {
313 Response::Providers { providers } => Ok(providers),
314 Response::Error { message } => Err(Error::DaemonError(message)),
315 _ => Err(Error::UnexpectedResponse),
316 }
317 }
318
319 #[cfg(not(unix))]
323 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
324 Ok(self.list_env_providers())
325 }
326
327 #[cfg(unix)]
334 pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
335 let timeout_duration = self.timeout;
336 let timeout_secs = timeout_duration.as_secs();
337
338 tokio::time::timeout(timeout_duration, self.send_request_inner(request))
340 .await
341 .map_err(|_| Error::Timeout(timeout_secs))?
342 }
343
344 #[cfg(unix)]
346 async fn send_request_inner(&mut self, request: Request) -> Result<Response, Error> {
347 let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
348
349 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
351
352 let len = request_json.len() as u32;
354 stream
355 .write_all(&len.to_be_bytes())
356 .await
357 .map_err(Error::IoError)?;
358 stream
359 .write_all(&request_json)
360 .await
361 .map_err(Error::IoError)?;
362
363 let mut len_buf = [0u8; 4];
365 stream
366 .read_exact(&mut len_buf)
367 .await
368 .map_err(Error::IoError)?;
369 let response_len = u32::from_be_bytes(len_buf) as usize;
370
371 if response_len > 1_048_576 {
373 return Err(Error::ResponseTooLarge(response_len));
374 }
375
376 let mut response_buf = vec![0u8; response_len];
378 stream
379 .read_exact(&mut response_buf)
380 .await
381 .map_err(Error::IoError)?;
382
383 let response: Response =
385 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
386
387 Ok(response)
388 }
389
390 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
393 let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
394 provider: provider.to_string(),
395 details: format!("Unknown provider: {provider}"),
396 })?;
397 std::env::var(env_var)
398 .map(SecretString::from)
399 .map_err(|_| Error::SecretNotFound {
400 provider: provider.to_string(),
401 details: format!("Environment variable {env_var} not set"),
402 })
403 }
404
405 fn list_env_providers(&self) -> Vec<String> {
406 KNOWN_PROVIDERS
407 .iter()
408 .filter(|p| std::env::var(p.env_var).is_ok())
409 .map(|p| p.id.to_string())
410 .collect()
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 #[test]
419 fn test_provider_to_env_var() {
420 assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
422 assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
423 assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
424 assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
425 assert_eq!(provider_to_env_var("unknown"), None);
426 }
427
428 #[test]
429 fn test_socket_path() {
430 if let Ok(path) = socket_path() {
432 assert!(path.to_string_lossy().contains(".spn"));
433 assert!(path.to_string_lossy().contains("daemon.sock"));
434 }
435 }
436
437 #[test]
438 fn test_daemon_socket_exists() {
439 let _exists = daemon_socket_exists();
442 }
443}