Skip to main content

spn_client/
lib.rs

1//! # spn-client
2//!
3//! Client library for communicating with the spn daemon.
4//!
5//! This crate provides a simple interface for applications (like Nika) to securely
6//! retrieve secrets from the spn daemon without directly accessing the OS keychain.
7//!
8//! ## Usage
9//!
10//! ```rust,no_run
11//! use spn_client::{SpnClient, ExposeSecret};
12//!
13//! # async fn example() -> Result<(), spn_client::Error> {
14//! // Connect to the daemon
15//! let mut client = SpnClient::connect().await?;
16//!
17//! // Get a secret
18//! let api_key = client.get_secret("anthropic").await?;
19//! println!("Got key: {}", api_key.expose_secret());
20//!
21//! // Check if a secret exists
22//! if client.has_secret("openai").await? {
23//!     println!("OpenAI key available");
24//! }
25//!
26//! // List all providers
27//! let providers = client.list_providers().await?;
28//! println!("Available providers: {:?}", providers);
29//! # Ok(())
30//! # }
31//! ```
32//!
33//! ## Fallback Mode
34//!
35//! If the daemon is not running, the client can fall back to reading from
36//! environment variables:
37//!
38//! ```rust,no_run
39//! use spn_client::SpnClient;
40//!
41//! # async fn example() -> Result<(), spn_client::Error> {
42//! let mut client = SpnClient::connect_with_fallback().await?;
43//! // Works even if daemon is not running
44//! # Ok(())
45//! # }
46//! ```
47
48mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55use std::path::PathBuf;
56use tokio::io::{AsyncReadExt, AsyncWriteExt};
57use tokio::net::UnixStream;
58use tracing::{debug, warn};
59
60/// Default socket path for the spn daemon.
61pub fn default_socket_path() -> PathBuf {
62    dirs::home_dir()
63        .map(|h| h.join(".spn").join("daemon.sock"))
64        .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
65}
66
67/// Check if the daemon socket exists.
68pub fn daemon_socket_exists() -> bool {
69    default_socket_path().exists()
70}
71
72/// Client for communicating with the spn daemon.
73///
74/// The client uses Unix socket IPC to communicate with the daemon,
75/// which handles all keychain access to avoid repeated auth prompts.
76#[derive(Debug)]
77pub struct SpnClient {
78    stream: Option<UnixStream>,
79    fallback_mode: bool,
80}
81
82impl SpnClient {
83    /// Connect to the spn daemon.
84    ///
85    /// Returns an error if the daemon is not running.
86    pub async fn connect() -> Result<Self, Error> {
87        Self::connect_to(&default_socket_path()).await
88    }
89
90    /// Connect to the daemon at a specific socket path.
91    pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
92        debug!("Connecting to spn daemon at {:?}", socket_path);
93
94        let stream = UnixStream::connect(socket_path)
95            .await
96            .map_err(|e| Error::ConnectionFailed {
97                path: socket_path.clone(),
98                source: e,
99            })?;
100
101        // Verify connection with ping
102        let mut client = Self {
103            stream: Some(stream),
104            fallback_mode: false,
105        };
106
107        client.ping().await?;
108        debug!("Connected to spn daemon");
109
110        Ok(client)
111    }
112
113    /// Connect to the daemon, falling back to env vars if daemon is unavailable.
114    ///
115    /// This is the recommended way to connect in applications that should
116    /// work even without the daemon running.
117    pub async fn connect_with_fallback() -> Result<Self, Error> {
118        match Self::connect().await {
119            Ok(client) => Ok(client),
120            Err(e) => {
121                warn!("spn daemon not running, using env var fallback: {}", e);
122                Ok(Self {
123                    stream: None,
124                    fallback_mode: true,
125                })
126            }
127        }
128    }
129
130    /// Check if the client is in fallback mode (daemon not connected).
131    pub fn is_fallback_mode(&self) -> bool {
132        self.fallback_mode
133    }
134
135    /// Ping the daemon to verify the connection.
136    pub async fn ping(&mut self) -> Result<String, Error> {
137        let response = self.send_request(Request::Ping).await?;
138        match response {
139            Response::Pong { version } => Ok(version),
140            Response::Error { message } => Err(Error::DaemonError(message)),
141            _ => Err(Error::UnexpectedResponse),
142        }
143    }
144
145    /// Get a secret for the given provider.
146    ///
147    /// In fallback mode, attempts to read from the environment variable
148    /// associated with the provider (e.g., ANTHROPIC_API_KEY).
149    pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
150        if self.fallback_mode {
151            return self.get_secret_from_env(provider);
152        }
153
154        let response = self
155            .send_request(Request::GetSecret {
156                provider: provider.to_string(),
157            })
158            .await?;
159
160        match response {
161            Response::Secret { value } => Ok(SecretString::from(value)),
162            Response::Error { message } => Err(Error::SecretNotFound {
163                provider: provider.to_string(),
164                details: message,
165            }),
166            _ => Err(Error::UnexpectedResponse),
167        }
168    }
169
170    /// Check if a secret exists for the given provider.
171    pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
172        if self.fallback_mode {
173            return Ok(self.get_secret_from_env(provider).is_ok());
174        }
175
176        let response = self
177            .send_request(Request::HasSecret {
178                provider: provider.to_string(),
179            })
180            .await?;
181
182        match response {
183            Response::Exists { exists } => Ok(exists),
184            Response::Error { message } => Err(Error::DaemonError(message)),
185            _ => Err(Error::UnexpectedResponse),
186        }
187    }
188
189    /// List all available providers.
190    pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
191        if self.fallback_mode {
192            return Ok(self.list_env_providers());
193        }
194
195        let response = self.send_request(Request::ListProviders).await?;
196
197        match response {
198            Response::Providers { providers } => Ok(providers),
199            Response::Error { message } => Err(Error::DaemonError(message)),
200            _ => Err(Error::UnexpectedResponse),
201        }
202    }
203
204    /// Send a request to the daemon and receive a response.
205    async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
206        let stream = self
207            .stream
208            .as_mut()
209            .ok_or(Error::NotConnected)?;
210
211        // Serialize request
212        let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
213
214        // Send length-prefixed message
215        let len = request_json.len() as u32;
216        stream
217            .write_all(&len.to_be_bytes())
218            .await
219            .map_err(Error::IoError)?;
220        stream
221            .write_all(&request_json)
222            .await
223            .map_err(Error::IoError)?;
224
225        // Read response length
226        let mut len_buf = [0u8; 4];
227        stream
228            .read_exact(&mut len_buf)
229            .await
230            .map_err(Error::IoError)?;
231        let response_len = u32::from_be_bytes(len_buf) as usize;
232
233        // Sanity check response length (max 1MB)
234        if response_len > 1_048_576 {
235            return Err(Error::ResponseTooLarge(response_len));
236        }
237
238        // Read response
239        let mut response_buf = vec![0u8; response_len];
240        stream
241            .read_exact(&mut response_buf)
242            .await
243            .map_err(Error::IoError)?;
244
245        // Deserialize
246        let response: Response =
247            serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
248
249        Ok(response)
250    }
251
252    // Fallback helpers
253
254    fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
255        let env_var = provider_to_env_var(provider);
256        std::env::var(&env_var)
257            .map(SecretString::from)
258            .map_err(|_| Error::SecretNotFound {
259                provider: provider.to_string(),
260                details: format!("Environment variable {} not set", env_var),
261            })
262    }
263
264    fn list_env_providers(&self) -> Vec<String> {
265        KNOWN_PROVIDERS
266            .iter()
267            .filter(|p| std::env::var(provider_to_env_var(p)).is_ok())
268            .map(|s| s.to_string())
269            .collect()
270    }
271}
272
273/// Known provider names.
274const KNOWN_PROVIDERS: &[&str] = &[
275    "anthropic",
276    "openai",
277    "mistral",
278    "groq",
279    "deepseek",
280    "gemini",
281    "ollama",
282    "neo4j",
283    "github",
284    "slack",
285    "perplexity",
286    "firecrawl",
287    "supadata",
288];
289
290/// Convert a provider name to its environment variable.
291fn provider_to_env_var(provider: &str) -> String {
292    match provider.to_lowercase().as_str() {
293        "anthropic" => "ANTHROPIC_API_KEY".to_string(),
294        "openai" => "OPENAI_API_KEY".to_string(),
295        "mistral" => "MISTRAL_API_KEY".to_string(),
296        "groq" => "GROQ_API_KEY".to_string(),
297        "deepseek" => "DEEPSEEK_API_KEY".to_string(),
298        "gemini" => "GEMINI_API_KEY".to_string(),
299        "ollama" => "OLLAMA_HOST".to_string(),
300        "neo4j" => "NEO4J_PASSWORD".to_string(),
301        "github" => "GITHUB_TOKEN".to_string(),
302        "slack" => "SLACK_TOKEN".to_string(),
303        "perplexity" => "PERPLEXITY_API_KEY".to_string(),
304        "firecrawl" => "FIRECRAWL_API_KEY".to_string(),
305        "supadata" => "SUPADATA_API_KEY".to_string(),
306        other => format!("{}_API_KEY", other.to_uppercase()),
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_provider_to_env_var() {
316        assert_eq!(provider_to_env_var("anthropic"), "ANTHROPIC_API_KEY");
317        assert_eq!(provider_to_env_var("openai"), "OPENAI_API_KEY");
318        assert_eq!(provider_to_env_var("neo4j"), "NEO4J_PASSWORD");
319        assert_eq!(provider_to_env_var("github"), "GITHUB_TOKEN");
320        assert_eq!(provider_to_env_var("unknown"), "UNKNOWN_API_KEY");
321    }
322
323    #[test]
324    fn test_default_socket_path() {
325        let path = default_socket_path();
326        assert!(path.to_string_lossy().contains(".spn"));
327        assert!(path.to_string_lossy().contains("daemon.sock"));
328    }
329
330    #[test]
331    fn test_daemon_socket_exists() {
332        // Should return false since daemon isn't running in tests
333        assert!(!daemon_socket_exists());
334    }
335}