1mod error;
49mod paths;
50mod protocol;
51
52pub use error::Error;
53pub use paths::{PathError, SpnPaths};
54pub use protocol::{
55 IpcJobState, IpcJobStatus, IpcSchedulerStats, ModelProgress, Request, Response,
56 PROTOCOL_VERSION,
57};
58pub use secrecy::{ExposeSecret, SecretString};
59
60pub use spn_core::{
62 find_provider,
63 mask_key,
64 provider_to_env_var,
65 providers_by_category,
66 validate_key_format,
67 BackendError,
68 ChatMessage,
70 ChatOptions,
71 ChatResponse,
72 ChatRole,
73 GpuInfo,
74 LoadConfig,
75 McpConfig,
76 McpServer,
78 McpServerType,
79 McpSource,
80 ModelInfo,
81 PackageManifest,
82 PackageRef,
84 PackageType,
85 Provider,
87 ProviderCategory,
88 PullProgress,
90 RunningModel,
91 Source,
92 ValidationResult,
94 KNOWN_PROVIDERS,
95};
96
97use std::path::PathBuf;
98use std::time::Duration;
99#[cfg(unix)]
100use tokio::io::{AsyncReadExt, AsyncWriteExt};
101#[cfg(unix)]
102use tokio::net::UnixStream;
103use tracing::debug;
104#[cfg(unix)]
105use tracing::warn;
106
107pub const DEFAULT_IPC_TIMEOUT: Duration = Duration::from_secs(30);
109
110pub fn socket_path() -> Result<PathBuf, Error> {
117 SpnPaths::new().map(|p| p.socket_file()).map_err(|_| {
118 Error::Configuration("HOME directory not found. Set HOME environment variable.".into())
119 })
120}
121
122pub fn daemon_socket_exists() -> bool {
126 socket_path().map(|p| p.exists()).unwrap_or(false)
127}
128
129#[derive(Debug)]
137pub struct SpnClient {
138 #[cfg(unix)]
139 stream: Option<UnixStream>,
140 fallback_mode: bool,
141 timeout: Duration,
143}
144
145impl SpnClient {
146 #[cfg(unix)]
152 pub async fn connect() -> Result<Self, Error> {
153 let path = socket_path()?;
154 Self::connect_to(&path).await
155 }
156
157 #[cfg(unix)]
161 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
162 debug!("Connecting to spn daemon at {:?}", socket_path);
163
164 let stream =
165 UnixStream::connect(socket_path)
166 .await
167 .map_err(|e| Error::ConnectionFailed {
168 path: socket_path.clone(),
169 source: e,
170 })?;
171
172 let mut client = Self {
174 stream: Some(stream),
175 fallback_mode: false,
176 timeout: DEFAULT_IPC_TIMEOUT,
177 };
178
179 client.ping().await?;
180 debug!("Connected to spn daemon");
181
182 Ok(client)
183 }
184
185 pub fn set_timeout(&mut self, timeout: Duration) {
189 self.timeout = timeout;
190 }
191
192 pub fn timeout(&self) -> Duration {
194 self.timeout
195 }
196
197 #[cfg(unix)]
204 pub async fn connect_with_fallback() -> Result<Self, Error> {
205 match Self::connect().await {
206 Ok(client) => Ok(client),
207 Err(e) => {
208 warn!("spn daemon not running, using env var fallback: {}", e);
209 Ok(Self {
210 stream: None,
211 fallback_mode: true,
212 timeout: DEFAULT_IPC_TIMEOUT,
213 })
214 }
215 }
216 }
217
218 #[cfg(not(unix))]
223 pub async fn connect_with_fallback() -> Result<Self, Error> {
224 debug!("Non-Unix platform: using env var fallback mode");
225 Ok(Self {
226 fallback_mode: true,
227 timeout: DEFAULT_IPC_TIMEOUT,
228 })
229 }
230
231 pub fn is_fallback_mode(&self) -> bool {
233 self.fallback_mode
234 }
235
236 #[cfg(unix)]
243 pub async fn ping(&mut self) -> Result<String, Error> {
244 let response = self.send_request(Request::Ping).await?;
245 match response {
246 Response::Pong {
247 protocol_version,
248 version,
249 } => {
250 if protocol_version != protocol::PROTOCOL_VERSION {
252 warn!(
253 "Protocol version mismatch: client v{}, daemon v{}. \
254 Consider updating your daemon with 'spn daemon restart'.",
255 protocol::PROTOCOL_VERSION,
256 protocol_version
257 );
258 }
259 Ok(version)
260 }
261 Response::Error { message } => Err(Error::DaemonError(message)),
262 _ => Err(Error::UnexpectedResponse),
263 }
264 }
265
266 #[cfg(unix)]
271 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
272 if self.fallback_mode {
273 return self.get_secret_from_env(provider);
274 }
275
276 let response = self
277 .send_request(Request::GetSecret {
278 provider: provider.to_string(),
279 })
280 .await?;
281
282 match response {
283 Response::Secret { value } => Ok(SecretString::from(value)),
284 Response::Error { message } => Err(Error::SecretNotFound {
285 provider: provider.to_string(),
286 details: message,
287 }),
288 _ => Err(Error::UnexpectedResponse),
289 }
290 }
291
292 #[cfg(not(unix))]
296 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
297 self.get_secret_from_env(provider)
298 }
299
300 #[cfg(unix)]
302 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
303 if self.fallback_mode {
304 return Ok(self.get_secret_from_env(provider).is_ok());
305 }
306
307 let response = self
308 .send_request(Request::HasSecret {
309 provider: provider.to_string(),
310 })
311 .await?;
312
313 match response {
314 Response::Exists { exists } => Ok(exists),
315 Response::Error { message } => Err(Error::DaemonError(message)),
316 _ => Err(Error::UnexpectedResponse),
317 }
318 }
319
320 #[cfg(not(unix))]
324 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
325 Ok(self.get_secret_from_env(provider).is_ok())
326 }
327
328 #[cfg(unix)]
330 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
331 if self.fallback_mode {
332 return Ok(self.list_env_providers());
333 }
334
335 let response = self.send_request(Request::ListProviders).await?;
336
337 match response {
338 Response::Providers { providers } => Ok(providers),
339 Response::Error { message } => Err(Error::DaemonError(message)),
340 _ => Err(Error::UnexpectedResponse),
341 }
342 }
343
344 #[cfg(not(unix))]
348 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
349 Ok(self.list_env_providers())
350 }
351
352 #[cfg(unix)]
359 pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
360 let timeout_duration = self.timeout;
361 let timeout_secs = timeout_duration.as_secs();
362
363 tokio::time::timeout(timeout_duration, self.send_request_inner(request))
365 .await
366 .map_err(|_| Error::Timeout(timeout_secs))?
367 }
368
369 #[cfg(unix)]
371 async fn send_request_inner(&mut self, request: Request) -> Result<Response, Error> {
372 let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
373
374 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
376
377 let len = request_json.len() as u32;
379 stream
380 .write_all(&len.to_be_bytes())
381 .await
382 .map_err(Error::IoError)?;
383 stream
384 .write_all(&request_json)
385 .await
386 .map_err(Error::IoError)?;
387
388 let mut len_buf = [0u8; 4];
390 stream
391 .read_exact(&mut len_buf)
392 .await
393 .map_err(Error::IoError)?;
394 let response_len = u32::from_be_bytes(len_buf) as usize;
395
396 if response_len > 1_048_576 {
398 return Err(Error::ResponseTooLarge(response_len));
399 }
400
401 let mut response_buf = vec![0u8; response_len];
403 stream
404 .read_exact(&mut response_buf)
405 .await
406 .map_err(Error::IoError)?;
407
408 let response: Response =
410 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
411
412 Ok(response)
413 }
414
415 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
418 let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
419 provider: provider.to_string(),
420 details: format!("Unknown provider: {provider}"),
421 })?;
422 std::env::var(env_var)
423 .map(SecretString::from)
424 .map_err(|_| Error::SecretNotFound {
425 provider: provider.to_string(),
426 details: format!("Environment variable {env_var} not set"),
427 })
428 }
429
430 fn list_env_providers(&self) -> Vec<String> {
431 KNOWN_PROVIDERS
432 .iter()
433 .filter(|p| std::env::var(p.env_var).is_ok())
434 .map(|p| p.id.to_string())
435 .collect()
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_provider_to_env_var() {
445 assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
447 assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
448 assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
449 assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
450 assert_eq!(provider_to_env_var("unknown"), None);
451 }
452
453 #[test]
454 fn test_socket_path() {
455 if let Ok(path) = socket_path() {
457 assert!(path.to_string_lossy().contains(".spn"));
458 assert!(path.to_string_lossy().contains("daemon.sock"));
459 }
460 }
461
462 #[test]
463 fn test_daemon_socket_exists() {
464 let _exists = daemon_socket_exists();
467 }
468}