Skip to main content

spn_client/
lib.rs

1//! # spn-client
2//!
3//! Client library for communicating with the spn daemon.
4//!
5//! This crate provides a simple interface for applications (like Nika) to securely
6//! retrieve secrets from the spn daemon without directly accessing the OS keychain.
7//!
8//! ## Usage
9//!
10//! ```rust,no_run
11//! use spn_client::{SpnClient, ExposeSecret};
12//!
13//! # async fn example() -> Result<(), spn_client::Error> {
14//! // Connect to the daemon
15//! let mut client = SpnClient::connect().await?;
16//!
17//! // Get a secret
18//! let api_key = client.get_secret("anthropic").await?;
19//! println!("Got key: {}", api_key.expose_secret());
20//!
21//! // Check if a secret exists
22//! if client.has_secret("openai").await? {
23//!     println!("OpenAI key available");
24//! }
25//!
26//! // List all providers
27//! let providers = client.list_providers().await?;
28//! println!("Available providers: {:?}", providers);
29//! # Ok(())
30//! # }
31//! ```
32//!
33//! ## Fallback Mode
34//!
35//! If the daemon is not running, the client can fall back to reading from
36//! environment variables:
37//!
38//! ```rust,no_run
39//! use spn_client::SpnClient;
40//!
41//! # async fn example() -> Result<(), spn_client::Error> {
42//! let mut client = SpnClient::connect_with_fallback().await?;
43//! // Works even if daemon is not running
44//! # Ok(())
45//! # }
46//! ```
47
48mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55// Re-export all spn-core types for convenience
56pub use spn_core::{
57    // Providers
58    Provider, ProviderCategory, KNOWN_PROVIDERS,
59    find_provider, provider_to_env_var, providers_by_category,
60    // Validation
61    ValidationResult, validate_key_format, mask_key,
62    // MCP
63    McpServer, McpServerType, McpConfig, McpSource,
64    // Registry
65    PackageRef, PackageManifest, PackageType,
66    // Backend
67    PullProgress, ModelInfo, RunningModel, GpuInfo, LoadConfig, BackendError,
68};
69
70use std::path::PathBuf;
71#[cfg(unix)]
72use tokio::io::{AsyncReadExt, AsyncWriteExt};
73#[cfg(unix)]
74use tokio::net::UnixStream;
75use tracing::debug;
76#[cfg(unix)]
77use tracing::warn;
78
79/// Default socket path for the spn daemon.
80pub fn default_socket_path() -> PathBuf {
81    dirs::home_dir()
82        .map(|h| h.join(".spn").join("daemon.sock"))
83        .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
84}
85
86/// Check if the daemon socket exists.
87pub fn daemon_socket_exists() -> bool {
88    default_socket_path().exists()
89}
90
91/// Client for communicating with the spn daemon.
92///
93/// The client uses Unix socket IPC to communicate with the daemon,
94/// which handles all keychain access to avoid repeated auth prompts.
95///
96/// On non-Unix platforms (Windows), the client always operates in fallback mode,
97/// reading secrets from environment variables.
98#[derive(Debug)]
99pub struct SpnClient {
100    #[cfg(unix)]
101    stream: Option<UnixStream>,
102    fallback_mode: bool,
103}
104
105impl SpnClient {
106    /// Connect to the spn daemon.
107    ///
108    /// Returns an error if the daemon is not running.
109    ///
110    /// This method is only available on Unix platforms.
111    #[cfg(unix)]
112    pub async fn connect() -> Result<Self, Error> {
113        Self::connect_to(&default_socket_path()).await
114    }
115
116    /// Connect to the daemon at a specific socket path.
117    ///
118    /// This method is only available on Unix platforms.
119    #[cfg(unix)]
120    pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
121        debug!("Connecting to spn daemon at {:?}", socket_path);
122
123        let stream = UnixStream::connect(socket_path)
124            .await
125            .map_err(|e| Error::ConnectionFailed {
126                path: socket_path.clone(),
127                source: e,
128            })?;
129
130        // Verify connection with ping
131        let mut client = Self {
132            stream: Some(stream),
133            fallback_mode: false,
134        };
135
136        client.ping().await?;
137        debug!("Connected to spn daemon");
138
139        Ok(client)
140    }
141
142    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
143    ///
144    /// This is the recommended way to connect in applications that should
145    /// work even without the daemon running.
146    ///
147    /// On non-Unix platforms (Windows), this always returns a fallback client.
148    #[cfg(unix)]
149    pub async fn connect_with_fallback() -> Result<Self, Error> {
150        match Self::connect().await {
151            Ok(client) => Ok(client),
152            Err(e) => {
153                warn!("spn daemon not running, using env var fallback: {}", e);
154                Ok(Self {
155                    stream: None,
156                    fallback_mode: true,
157                })
158            }
159        }
160    }
161
162    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
163    ///
164    /// On non-Unix platforms (Windows), this always returns a fallback client
165    /// since Unix sockets are not available.
166    #[cfg(not(unix))]
167    pub async fn connect_with_fallback() -> Result<Self, Error> {
168        debug!("Non-Unix platform: using env var fallback mode");
169        Ok(Self {
170            fallback_mode: true,
171        })
172    }
173
174    /// Check if the client is in fallback mode (daemon not connected).
175    pub fn is_fallback_mode(&self) -> bool {
176        self.fallback_mode
177    }
178
179    /// Ping the daemon to verify the connection.
180    ///
181    /// This method is only available on Unix platforms.
182    #[cfg(unix)]
183    pub async fn ping(&mut self) -> Result<String, Error> {
184        let response = self.send_request(Request::Ping).await?;
185        match response {
186            Response::Pong { version } => Ok(version),
187            Response::Error { message } => Err(Error::DaemonError(message)),
188            _ => Err(Error::UnexpectedResponse),
189        }
190    }
191
192    /// Get a secret for the given provider.
193    ///
194    /// In fallback mode, attempts to read from the environment variable
195    /// associated with the provider (e.g., ANTHROPIC_API_KEY).
196    #[cfg(unix)]
197    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
198        if self.fallback_mode {
199            return self.get_secret_from_env(provider);
200        }
201
202        let response = self
203            .send_request(Request::GetSecret {
204                provider: provider.to_string(),
205            })
206            .await?;
207
208        match response {
209            Response::Secret { value } => Ok(SecretString::from(value)),
210            Response::Error { message } => Err(Error::SecretNotFound {
211                provider: provider.to_string(),
212                details: message,
213            }),
214            _ => Err(Error::UnexpectedResponse),
215        }
216    }
217
218    /// Get a secret for the given provider.
219    ///
220    /// On non-Unix platforms, always reads from environment variables.
221    #[cfg(not(unix))]
222    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
223        self.get_secret_from_env(provider)
224    }
225
226    /// Check if a secret exists for the given provider.
227    #[cfg(unix)]
228    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
229        if self.fallback_mode {
230            return Ok(self.get_secret_from_env(provider).is_ok());
231        }
232
233        let response = self
234            .send_request(Request::HasSecret {
235                provider: provider.to_string(),
236            })
237            .await?;
238
239        match response {
240            Response::Exists { exists } => Ok(exists),
241            Response::Error { message } => Err(Error::DaemonError(message)),
242            _ => Err(Error::UnexpectedResponse),
243        }
244    }
245
246    /// Check if a secret exists for the given provider.
247    ///
248    /// On non-Unix platforms, checks environment variables.
249    #[cfg(not(unix))]
250    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
251        Ok(self.get_secret_from_env(provider).is_ok())
252    }
253
254    /// List all available providers.
255    #[cfg(unix)]
256    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
257        if self.fallback_mode {
258            return Ok(self.list_env_providers());
259        }
260
261        let response = self.send_request(Request::ListProviders).await?;
262
263        match response {
264            Response::Providers { providers } => Ok(providers),
265            Response::Error { message } => Err(Error::DaemonError(message)),
266            _ => Err(Error::UnexpectedResponse),
267        }
268    }
269
270    /// List all available providers.
271    ///
272    /// On non-Unix platforms, lists providers from environment variables.
273    #[cfg(not(unix))]
274    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
275        Ok(self.list_env_providers())
276    }
277
278    /// Send a request to the daemon and receive a response.
279    #[cfg(unix)]
280    async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
281        let stream = self
282            .stream
283            .as_mut()
284            .ok_or(Error::NotConnected)?;
285
286        // Serialize request
287        let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
288
289        // Send length-prefixed message
290        let len = request_json.len() as u32;
291        stream
292            .write_all(&len.to_be_bytes())
293            .await
294            .map_err(Error::IoError)?;
295        stream
296            .write_all(&request_json)
297            .await
298            .map_err(Error::IoError)?;
299
300        // Read response length
301        let mut len_buf = [0u8; 4];
302        stream
303            .read_exact(&mut len_buf)
304            .await
305            .map_err(Error::IoError)?;
306        let response_len = u32::from_be_bytes(len_buf) as usize;
307
308        // Sanity check response length (max 1MB)
309        if response_len > 1_048_576 {
310            return Err(Error::ResponseTooLarge(response_len));
311        }
312
313        // Read response
314        let mut response_buf = vec![0u8; response_len];
315        stream
316            .read_exact(&mut response_buf)
317            .await
318            .map_err(Error::IoError)?;
319
320        // Deserialize
321        let response: Response =
322            serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
323
324        Ok(response)
325    }
326
327    // Fallback helpers
328
329    fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
330        let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
331            provider: provider.to_string(),
332            details: format!("Unknown provider: {provider}"),
333        })?;
334        std::env::var(env_var)
335            .map(SecretString::from)
336            .map_err(|_| Error::SecretNotFound {
337                provider: provider.to_string(),
338                details: format!("Environment variable {env_var} not set"),
339            })
340    }
341
342    fn list_env_providers(&self) -> Vec<String> {
343        KNOWN_PROVIDERS
344            .iter()
345            .filter(|p| std::env::var(p.env_var).is_ok())
346            .map(|p| p.id.to_string())
347            .collect()
348    }
349}
350
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_provider_to_env_var() {
358        // These now use spn_core::provider_to_env_var which returns Option
359        assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
360        assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
361        assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
362        assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
363        assert_eq!(provider_to_env_var("unknown"), None);
364    }
365
366    #[test]
367    fn test_default_socket_path() {
368        let path = default_socket_path();
369        assert!(path.to_string_lossy().contains(".spn"));
370        assert!(path.to_string_lossy().contains("daemon.sock"));
371    }
372
373    #[test]
374    fn test_daemon_socket_exists() {
375        // Should return false since daemon isn't running in tests
376        assert!(!daemon_socket_exists());
377    }
378}