Skip to main content

spn_client/
lib.rs

1//! # spn-client
2//!
3//! Client library for communicating with the spn daemon.
4//!
5//! This crate provides a simple interface for applications (like Nika) to securely
6//! retrieve secrets from the spn daemon without directly accessing the OS keychain.
7//!
8//! ## Usage
9//!
10//! ```rust,no_run
11//! use spn_client::{SpnClient, ExposeSecret};
12//!
13//! # async fn example() -> Result<(), spn_client::Error> {
14//! // Connect to the daemon
15//! let mut client = SpnClient::connect().await?;
16//!
17//! // Get a secret
18//! let api_key = client.get_secret("anthropic").await?;
19//! println!("Got key: {}", api_key.expose_secret());
20//!
21//! // Check if a secret exists
22//! if client.has_secret("openai").await? {
23//!     println!("OpenAI key available");
24//! }
25//!
26//! // List all providers
27//! let providers = client.list_providers().await?;
28//! println!("Available providers: {:?}", providers);
29//! # Ok(())
30//! # }
31//! ```
32//!
33//! ## Fallback Mode
34//!
35//! If the daemon is not running, the client can fall back to reading from
36//! environment variables:
37//!
38//! ```rust,no_run
39//! use spn_client::SpnClient;
40//!
41//! # async fn example() -> Result<(), spn_client::Error> {
42//! let mut client = SpnClient::connect_with_fallback().await?;
43//! // Works even if daemon is not running
44//! # Ok(())
45//! # }
46//! ```
47
48mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55// Re-export all spn-core types for convenience
56pub use spn_core::{
57    find_provider,
58    mask_key,
59    provider_to_env_var,
60    providers_by_category,
61    validate_key_format,
62    BackendError,
63    GpuInfo,
64    LoadConfig,
65    McpConfig,
66    // MCP
67    McpServer,
68    McpServerType,
69    McpSource,
70    ModelInfo,
71    PackageManifest,
72    // Registry
73    PackageRef,
74    PackageType,
75    // Providers
76    Provider,
77    ProviderCategory,
78    // Backend
79    PullProgress,
80    RunningModel,
81    // Validation
82    ValidationResult,
83    KNOWN_PROVIDERS,
84};
85
86use std::path::PathBuf;
87#[cfg(unix)]
88use tokio::io::{AsyncReadExt, AsyncWriteExt};
89#[cfg(unix)]
90use tokio::net::UnixStream;
91use tracing::debug;
92#[cfg(unix)]
93use tracing::warn;
94
95/// Default socket path for the spn daemon.
96pub fn default_socket_path() -> PathBuf {
97    dirs::home_dir()
98        .map(|h| h.join(".spn").join("daemon.sock"))
99        .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
100}
101
102/// Check if the daemon socket exists.
103pub fn daemon_socket_exists() -> bool {
104    default_socket_path().exists()
105}
106
107/// Client for communicating with the spn daemon.
108///
109/// The client uses Unix socket IPC to communicate with the daemon,
110/// which handles all keychain access to avoid repeated auth prompts.
111///
112/// On non-Unix platforms (Windows), the client always operates in fallback mode,
113/// reading secrets from environment variables.
114#[derive(Debug)]
115pub struct SpnClient {
116    #[cfg(unix)]
117    stream: Option<UnixStream>,
118    fallback_mode: bool,
119}
120
121impl SpnClient {
122    /// Connect to the spn daemon.
123    ///
124    /// Returns an error if the daemon is not running.
125    ///
126    /// This method is only available on Unix platforms.
127    #[cfg(unix)]
128    pub async fn connect() -> Result<Self, Error> {
129        Self::connect_to(&default_socket_path()).await
130    }
131
132    /// Connect to the daemon at a specific socket path.
133    ///
134    /// This method is only available on Unix platforms.
135    #[cfg(unix)]
136    pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
137        debug!("Connecting to spn daemon at {:?}", socket_path);
138
139        let stream =
140            UnixStream::connect(socket_path)
141                .await
142                .map_err(|e| Error::ConnectionFailed {
143                    path: socket_path.clone(),
144                    source: e,
145                })?;
146
147        // Verify connection with ping
148        let mut client = Self {
149            stream: Some(stream),
150            fallback_mode: false,
151        };
152
153        client.ping().await?;
154        debug!("Connected to spn daemon");
155
156        Ok(client)
157    }
158
159    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
160    ///
161    /// This is the recommended way to connect in applications that should
162    /// work even without the daemon running.
163    ///
164    /// On non-Unix platforms (Windows), this always returns a fallback client.
165    #[cfg(unix)]
166    pub async fn connect_with_fallback() -> Result<Self, Error> {
167        match Self::connect().await {
168            Ok(client) => Ok(client),
169            Err(e) => {
170                warn!("spn daemon not running, using env var fallback: {}", e);
171                Ok(Self {
172                    stream: None,
173                    fallback_mode: true,
174                })
175            }
176        }
177    }
178
179    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
180    ///
181    /// On non-Unix platforms (Windows), this always returns a fallback client
182    /// since Unix sockets are not available.
183    #[cfg(not(unix))]
184    pub async fn connect_with_fallback() -> Result<Self, Error> {
185        debug!("Non-Unix platform: using env var fallback mode");
186        Ok(Self {
187            fallback_mode: true,
188        })
189    }
190
191    /// Check if the client is in fallback mode (daemon not connected).
192    pub fn is_fallback_mode(&self) -> bool {
193        self.fallback_mode
194    }
195
196    /// Ping the daemon to verify the connection.
197    ///
198    /// This method is only available on Unix platforms.
199    #[cfg(unix)]
200    pub async fn ping(&mut self) -> Result<String, Error> {
201        let response = self.send_request(Request::Ping).await?;
202        match response {
203            Response::Pong { version } => Ok(version),
204            Response::Error { message } => Err(Error::DaemonError(message)),
205            _ => Err(Error::UnexpectedResponse),
206        }
207    }
208
209    /// Get a secret for the given provider.
210    ///
211    /// In fallback mode, attempts to read from the environment variable
212    /// associated with the provider (e.g., ANTHROPIC_API_KEY).
213    #[cfg(unix)]
214    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
215        if self.fallback_mode {
216            return self.get_secret_from_env(provider);
217        }
218
219        let response = self
220            .send_request(Request::GetSecret {
221                provider: provider.to_string(),
222            })
223            .await?;
224
225        match response {
226            Response::Secret { value } => Ok(SecretString::from(value)),
227            Response::Error { message } => Err(Error::SecretNotFound {
228                provider: provider.to_string(),
229                details: message,
230            }),
231            _ => Err(Error::UnexpectedResponse),
232        }
233    }
234
235    /// Get a secret for the given provider.
236    ///
237    /// On non-Unix platforms, always reads from environment variables.
238    #[cfg(not(unix))]
239    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
240        self.get_secret_from_env(provider)
241    }
242
243    /// Check if a secret exists for the given provider.
244    #[cfg(unix)]
245    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
246        if self.fallback_mode {
247            return Ok(self.get_secret_from_env(provider).is_ok());
248        }
249
250        let response = self
251            .send_request(Request::HasSecret {
252                provider: provider.to_string(),
253            })
254            .await?;
255
256        match response {
257            Response::Exists { exists } => Ok(exists),
258            Response::Error { message } => Err(Error::DaemonError(message)),
259            _ => Err(Error::UnexpectedResponse),
260        }
261    }
262
263    /// Check if a secret exists for the given provider.
264    ///
265    /// On non-Unix platforms, checks environment variables.
266    #[cfg(not(unix))]
267    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
268        Ok(self.get_secret_from_env(provider).is_ok())
269    }
270
271    /// List all available providers.
272    #[cfg(unix)]
273    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
274        if self.fallback_mode {
275            return Ok(self.list_env_providers());
276        }
277
278        let response = self.send_request(Request::ListProviders).await?;
279
280        match response {
281            Response::Providers { providers } => Ok(providers),
282            Response::Error { message } => Err(Error::DaemonError(message)),
283            _ => Err(Error::UnexpectedResponse),
284        }
285    }
286
287    /// List all available providers.
288    ///
289    /// On non-Unix platforms, lists providers from environment variables.
290    #[cfg(not(unix))]
291    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
292        Ok(self.list_env_providers())
293    }
294
295    /// Send a request to the daemon and receive a response.
296    ///
297    /// This is a low-level method for sending arbitrary requests.
298    /// For common operations, use the convenience methods like `get_secret()`.
299    #[cfg(unix)]
300    pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
301        let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
302
303        // Serialize request
304        let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
305
306        // Send length-prefixed message
307        let len = request_json.len() as u32;
308        stream
309            .write_all(&len.to_be_bytes())
310            .await
311            .map_err(Error::IoError)?;
312        stream
313            .write_all(&request_json)
314            .await
315            .map_err(Error::IoError)?;
316
317        // Read response length
318        let mut len_buf = [0u8; 4];
319        stream
320            .read_exact(&mut len_buf)
321            .await
322            .map_err(Error::IoError)?;
323        let response_len = u32::from_be_bytes(len_buf) as usize;
324
325        // Sanity check response length (max 1MB)
326        if response_len > 1_048_576 {
327            return Err(Error::ResponseTooLarge(response_len));
328        }
329
330        // Read response
331        let mut response_buf = vec![0u8; response_len];
332        stream
333            .read_exact(&mut response_buf)
334            .await
335            .map_err(Error::IoError)?;
336
337        // Deserialize
338        let response: Response =
339            serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
340
341        Ok(response)
342    }
343
344    // Fallback helpers
345
346    fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
347        let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
348            provider: provider.to_string(),
349            details: format!("Unknown provider: {provider}"),
350        })?;
351        std::env::var(env_var)
352            .map(SecretString::from)
353            .map_err(|_| Error::SecretNotFound {
354                provider: provider.to_string(),
355                details: format!("Environment variable {env_var} not set"),
356            })
357    }
358
359    fn list_env_providers(&self) -> Vec<String> {
360        KNOWN_PROVIDERS
361            .iter()
362            .filter(|p| std::env::var(p.env_var).is_ok())
363            .map(|p| p.id.to_string())
364            .collect()
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_provider_to_env_var() {
374        // These now use spn_core::provider_to_env_var which returns Option
375        assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
376        assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
377        assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
378        assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
379        assert_eq!(provider_to_env_var("unknown"), None);
380    }
381
382    #[test]
383    fn test_default_socket_path() {
384        let path = default_socket_path();
385        assert!(path.to_string_lossy().contains(".spn"));
386        assert!(path.to_string_lossy().contains("daemon.sock"));
387    }
388
389    #[test]
390    fn test_daemon_socket_exists() {
391        // Just verify the function runs without panicking
392        // The result depends on whether daemon is actually running
393        let _exists = daemon_socket_exists();
394    }
395}