Skip to main content

spn_client/
lib.rs

1//! # spn-client
2//!
3//! Client library for communicating with the spn daemon.
4//!
5//! This crate provides a simple interface for applications (like Nika) to securely
6//! retrieve secrets from the spn daemon without directly accessing the OS keychain.
7//!
8//! ## Usage
9//!
10//! ```rust,no_run
11//! use spn_client::{SpnClient, ExposeSecret};
12//!
13//! # async fn example() -> Result<(), spn_client::Error> {
14//! // Connect to the daemon
15//! let mut client = SpnClient::connect().await?;
16//!
17//! // Get a secret
18//! let api_key = client.get_secret("anthropic").await?;
19//! println!("Got key: {}", api_key.expose_secret());
20//!
21//! // Check if a secret exists
22//! if client.has_secret("openai").await? {
23//!     println!("OpenAI key available");
24//! }
25//!
26//! // List all providers
27//! let providers = client.list_providers().await?;
28//! println!("Available providers: {:?}", providers);
29//! # Ok(())
30//! # }
31//! ```
32//!
33//! ## Fallback Mode
34//!
35//! If the daemon is not running, the client can fall back to reading from
36//! environment variables:
37//!
38//! ```rust,no_run
39//! use spn_client::SpnClient;
40//!
41//! # async fn example() -> Result<(), spn_client::Error> {
42//! let mut client = SpnClient::connect_with_fallback().await?;
43//! // Works even if daemon is not running
44//! # Ok(())
45//! # }
46//! ```
47
48mod error;
49mod paths;
50mod protocol;
51
52pub use error::Error;
53pub use paths::{PathError, SpnPaths};
54pub use protocol::{
55    ForeignMcpInfo, IpcJobState, IpcJobStatus, IpcSchedulerStats, ModelProgress, RecentProjectInfo,
56    Request, Response, WatcherStatusInfo, PROTOCOL_VERSION,
57};
58pub use secrecy::{ExposeSecret, SecretString};
59
60// Re-export all spn-core types for convenience
61pub use spn_core::{
62    find_provider,
63    mask_key,
64    provider_to_env_var,
65    providers_by_category,
66    validate_key_format,
67    BackendError,
68    // Chat types
69    ChatMessage,
70    ChatOptions,
71    ChatResponse,
72    ChatRole,
73    GpuInfo,
74    LoadConfig,
75    McpConfig,
76    // MCP
77    McpServer,
78    McpServerType,
79    McpSource,
80    ModelInfo,
81    PackageManifest,
82    // Registry
83    PackageRef,
84    PackageType,
85    // Providers
86    Provider,
87    ProviderCategory,
88    // Backend
89    PullProgress,
90    RunningModel,
91    Source,
92    // Validation
93    ValidationResult,
94    KNOWN_PROVIDERS,
95};
96
97use std::path::PathBuf;
98use std::time::Duration;
99#[cfg(unix)]
100use tokio::io::{AsyncReadExt, AsyncWriteExt};
101#[cfg(unix)]
102use tokio::net::UnixStream;
103use tracing::debug;
104#[cfg(unix)]
105use tracing::warn;
106
107/// Default timeout for IPC operations (30 seconds).
108pub const DEFAULT_IPC_TIMEOUT: Duration = Duration::from_secs(30);
109
110/// Get socket path for the spn daemon, returning an error if HOME is unavailable.
111///
112/// Use this function when you need to ensure a secure socket path.
113/// Returns an error instead of falling back to `/tmp`.
114///
115/// This is a convenience wrapper around `SpnPaths::new()?.socket_file()`.
116pub fn socket_path() -> Result<PathBuf, Error> {
117    SpnPaths::new().map(|p| p.socket_file()).map_err(|_| {
118        Error::Configuration("HOME directory not found. Set HOME environment variable.".into())
119    })
120}
121
122/// Check if the daemon socket exists.
123///
124/// Returns `false` if HOME directory is unavailable.
125pub fn daemon_socket_exists() -> bool {
126    socket_path().map(|p| p.exists()).unwrap_or(false)
127}
128
129/// Client for communicating with the spn daemon.
130///
131/// The client uses Unix socket IPC to communicate with the daemon,
132/// which handles all keychain access to avoid repeated auth prompts.
133///
134/// On non-Unix platforms (Windows), the client always operates in fallback mode,
135/// reading secrets from environment variables.
136#[derive(Debug)]
137pub struct SpnClient {
138    #[cfg(unix)]
139    stream: Option<UnixStream>,
140    fallback_mode: bool,
141    /// Timeout for IPC operations.
142    timeout: Duration,
143}
144
145impl SpnClient {
146    /// Connect to the spn daemon.
147    ///
148    /// Returns an error if the daemon is not running.
149    ///
150    /// This method is only available on Unix platforms.
151    #[cfg(unix)]
152    pub async fn connect() -> Result<Self, Error> {
153        let path = socket_path()?;
154        Self::connect_to(&path).await
155    }
156
157    /// Connect to the daemon at a specific socket path.
158    ///
159    /// This method is only available on Unix platforms.
160    #[cfg(unix)]
161    pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
162        debug!("Connecting to spn daemon at {:?}", socket_path);
163
164        let stream =
165            UnixStream::connect(socket_path)
166                .await
167                .map_err(|e| Error::ConnectionFailed {
168                    path: socket_path.clone(),
169                    source: e,
170                })?;
171
172        // Verify connection with ping
173        let mut client = Self {
174            stream: Some(stream),
175            fallback_mode: false,
176            timeout: DEFAULT_IPC_TIMEOUT,
177        };
178
179        client.ping().await?;
180        debug!("Connected to spn daemon");
181
182        Ok(client)
183    }
184
185    /// Set the timeout for IPC operations.
186    ///
187    /// The default timeout is 30 seconds.
188    pub fn set_timeout(&mut self, timeout: Duration) {
189        self.timeout = timeout;
190    }
191
192    /// Get the current timeout for IPC operations.
193    pub fn timeout(&self) -> Duration {
194        self.timeout
195    }
196
197    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
198    ///
199    /// This is the recommended way to connect in applications that should
200    /// work even without the daemon running.
201    ///
202    /// On non-Unix platforms (Windows), this always returns a fallback client.
203    #[cfg(unix)]
204    pub async fn connect_with_fallback() -> Result<Self, Error> {
205        match Self::connect().await {
206            Ok(client) => Ok(client),
207            Err(e) => {
208                warn!("spn daemon not running, using env var fallback: {}", e);
209                Ok(Self {
210                    stream: None,
211                    fallback_mode: true,
212                    timeout: DEFAULT_IPC_TIMEOUT,
213                })
214            }
215        }
216    }
217
218    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
219    ///
220    /// On non-Unix platforms (Windows), this always returns a fallback client
221    /// since Unix sockets are not available.
222    #[cfg(not(unix))]
223    pub async fn connect_with_fallback() -> Result<Self, Error> {
224        debug!("Non-Unix platform: using env var fallback mode");
225        Ok(Self {
226            fallback_mode: true,
227            timeout: DEFAULT_IPC_TIMEOUT,
228        })
229    }
230
231    /// Check if the client is in fallback mode (daemon not connected).
232    pub fn is_fallback_mode(&self) -> bool {
233        self.fallback_mode
234    }
235
236    /// Ping the daemon to verify the connection and check protocol compatibility.
237    ///
238    /// Returns the daemon's CLI version string on success.
239    /// Warns if the protocol version doesn't match but allows connection.
240    ///
241    /// This method is only available on Unix platforms.
242    #[cfg(unix)]
243    pub async fn ping(&mut self) -> Result<String, Error> {
244        let response = self.send_request(Request::Ping).await?;
245        match response {
246            Response::Pong {
247                protocol_version,
248                version,
249            } => {
250                // Check protocol version compatibility
251                if protocol_version != protocol::PROTOCOL_VERSION {
252                    warn!(
253                        "Protocol version mismatch: client v{}, daemon v{}. \
254                        Consider updating your daemon with 'spn daemon restart'.",
255                        protocol::PROTOCOL_VERSION,
256                        protocol_version
257                    );
258                }
259                Ok(version)
260            }
261            Response::Error { message } => Err(Error::DaemonError(message)),
262            _ => Err(Error::UnexpectedResponse),
263        }
264    }
265
266    /// Get a secret for the given provider.
267    ///
268    /// In fallback mode, attempts to read from the environment variable
269    /// associated with the provider (e.g., ANTHROPIC_API_KEY).
270    #[cfg(unix)]
271    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
272        if self.fallback_mode {
273            return self.get_secret_from_env(provider);
274        }
275
276        let response = self
277            .send_request(Request::GetSecret {
278                provider: provider.to_string(),
279            })
280            .await?;
281
282        match response {
283            Response::Secret { value } => Ok(SecretString::from(value)),
284            Response::Error { message } => Err(Error::SecretNotFound {
285                provider: provider.to_string(),
286                details: message,
287            }),
288            _ => Err(Error::UnexpectedResponse),
289        }
290    }
291
292    /// Get a secret for the given provider.
293    ///
294    /// On non-Unix platforms, always reads from environment variables.
295    #[cfg(not(unix))]
296    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
297        self.get_secret_from_env(provider)
298    }
299
300    /// Check if a secret exists for the given provider.
301    #[cfg(unix)]
302    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
303        if self.fallback_mode {
304            return Ok(self.get_secret_from_env(provider).is_ok());
305        }
306
307        let response = self
308            .send_request(Request::HasSecret {
309                provider: provider.to_string(),
310            })
311            .await?;
312
313        match response {
314            Response::Exists { exists } => Ok(exists),
315            Response::Error { message } => Err(Error::DaemonError(message)),
316            _ => Err(Error::UnexpectedResponse),
317        }
318    }
319
320    /// Check if a secret exists for the given provider.
321    ///
322    /// On non-Unix platforms, checks environment variables.
323    #[cfg(not(unix))]
324    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
325        Ok(self.get_secret_from_env(provider).is_ok())
326    }
327
328    /// List all available providers.
329    #[cfg(unix)]
330    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
331        if self.fallback_mode {
332            return Ok(self.list_env_providers());
333        }
334
335        let response = self.send_request(Request::ListProviders).await?;
336
337        match response {
338            Response::Providers { providers } => Ok(providers),
339            Response::Error { message } => Err(Error::DaemonError(message)),
340            _ => Err(Error::UnexpectedResponse),
341        }
342    }
343
344    /// List all available providers.
345    ///
346    /// On non-Unix platforms, lists providers from environment variables.
347    #[cfg(not(unix))]
348    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
349        Ok(self.list_env_providers())
350    }
351
352    /// Refresh a secret in the daemon cache.
353    ///
354    /// This should be called after `spn provider set` to invalidate stale values.
355    /// Returns `true` if the secret was found and reloaded.
356    #[cfg(unix)]
357    pub async fn refresh_secret(&mut self, provider: &str) -> Result<bool, Error> {
358        if self.fallback_mode {
359            // In fallback mode, there's no cache to refresh
360            return Ok(true);
361        }
362
363        let response = self
364            .send_request(Request::RefreshSecret {
365                provider: provider.to_string(),
366            })
367            .await?;
368
369        match response {
370            Response::Refreshed { refreshed, .. } => Ok(refreshed),
371            Response::Error { message } => Err(Error::DaemonError(message)),
372            _ => Err(Error::UnexpectedResponse),
373        }
374    }
375
376    /// Refresh a secret in the daemon cache.
377    ///
378    /// On non-Unix platforms, this is a no-op since there's no daemon.
379    #[cfg(not(unix))]
380    pub async fn refresh_secret(&mut self, _provider: &str) -> Result<bool, Error> {
381        Ok(true) // No daemon cache to refresh
382    }
383
384    /// Get watcher status (recent projects, foreign MCPs, watched paths).
385    ///
386    /// Returns status information about the MCP config watcher.
387    #[cfg(unix)]
388    pub async fn watcher_status(&mut self) -> Result<WatcherStatusInfo, Error> {
389        if self.fallback_mode {
390            return Err(Error::DaemonError(
391                "Watcher status not available in fallback mode".into(),
392            ));
393        }
394
395        let response = self.send_request(Request::WatcherStatus).await?;
396
397        match response {
398            Response::WatcherStatusResult { status } => Ok(status),
399            Response::Error { message } => Err(Error::DaemonError(message)),
400            _ => Err(Error::UnexpectedResponse),
401        }
402    }
403
404    /// Get watcher status.
405    ///
406    /// On non-Unix platforms, returns an error since the daemon is not available.
407    #[cfg(not(unix))]
408    pub async fn watcher_status(&mut self) -> Result<WatcherStatusInfo, Error> {
409        Err(Error::DaemonError(
410            "Watcher status not available on non-Unix platforms".into(),
411        ))
412    }
413
414    /// Send a request to the daemon and receive a response.
415    ///
416    /// This is a low-level method for sending arbitrary requests.
417    /// For common operations, use the convenience methods like `get_secret()`.
418    ///
419    /// The request will time out after the configured timeout (default 30 seconds).
420    #[cfg(unix)]
421    pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
422        let timeout_duration = self.timeout;
423        let timeout_secs = timeout_duration.as_secs();
424
425        // Wrap the entire operation in a timeout
426        tokio::time::timeout(timeout_duration, self.send_request_inner(request))
427            .await
428            .map_err(|_| Error::Timeout(timeout_secs))?
429    }
430
431    /// Inner implementation of send_request without timeout.
432    #[cfg(unix)]
433    async fn send_request_inner(&mut self, request: Request) -> Result<Response, Error> {
434        let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
435
436        // Serialize request
437        let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
438
439        // Send length-prefixed message
440        let len = request_json.len() as u32;
441        stream
442            .write_all(&len.to_be_bytes())
443            .await
444            .map_err(Error::IoError)?;
445        stream
446            .write_all(&request_json)
447            .await
448            .map_err(Error::IoError)?;
449
450        // Read response length
451        let mut len_buf = [0u8; 4];
452        stream
453            .read_exact(&mut len_buf)
454            .await
455            .map_err(Error::IoError)?;
456        let response_len = u32::from_be_bytes(len_buf) as usize;
457
458        // Sanity check response length (max 1MB)
459        if response_len > 1_048_576 {
460            return Err(Error::ResponseTooLarge(response_len));
461        }
462
463        // Read response
464        let mut response_buf = vec![0u8; response_len];
465        stream
466            .read_exact(&mut response_buf)
467            .await
468            .map_err(Error::IoError)?;
469
470        // Deserialize
471        let response: Response =
472            serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
473
474        Ok(response)
475    }
476
477    // Fallback helpers
478
479    fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
480        let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
481            provider: provider.to_string(),
482            details: format!("Unknown provider: {provider}"),
483        })?;
484        std::env::var(env_var)
485            .map(SecretString::from)
486            .map_err(|_| Error::SecretNotFound {
487                provider: provider.to_string(),
488                details: format!("Environment variable {env_var} not set"),
489            })
490    }
491
492    fn list_env_providers(&self) -> Vec<String> {
493        KNOWN_PROVIDERS
494            .iter()
495            .filter(|p| std::env::var(p.env_var).is_ok())
496            .map(|p| p.id.to_string())
497            .collect()
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_provider_to_env_var() {
507        // These now use spn_core::provider_to_env_var which returns Option
508        assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
509        assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
510        assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
511        assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
512        assert_eq!(provider_to_env_var("unknown"), None);
513    }
514
515    #[test]
516    fn test_socket_path() {
517        // socket_path() returns Result, verify it works when HOME is set
518        if let Ok(path) = socket_path() {
519            assert!(path.to_string_lossy().contains(".spn"));
520            assert!(path.to_string_lossy().contains("daemon.sock"));
521        }
522    }
523
524    #[test]
525    fn test_daemon_socket_exists() {
526        // Just verify the function runs without panicking
527        // The result depends on whether daemon is actually running
528        let _exists = daemon_socket_exists();
529    }
530}