1mod error;
49mod protocol;
50
51pub use error::Error;
52pub use protocol::{Request, Response};
53pub use secrecy::{ExposeSecret, SecretString};
54
55pub use spn_core::{
57 find_provider,
58 mask_key,
59 provider_to_env_var,
60 providers_by_category,
61 validate_key_format,
62 BackendError,
63 GpuInfo,
64 LoadConfig,
65 McpConfig,
66 McpServer,
68 McpServerType,
69 McpSource,
70 ModelInfo,
71 PackageManifest,
72 PackageRef,
74 PackageType,
75 Provider,
77 ProviderCategory,
78 PullProgress,
80 RunningModel,
81 ValidationResult,
83 KNOWN_PROVIDERS,
84};
85
86use std::path::PathBuf;
87#[cfg(unix)]
88use tokio::io::{AsyncReadExt, AsyncWriteExt};
89#[cfg(unix)]
90use tokio::net::UnixStream;
91use tracing::debug;
92#[cfg(unix)]
93use tracing::warn;
94
95pub fn default_socket_path() -> PathBuf {
97 dirs::home_dir()
98 .map(|h| h.join(".spn").join("daemon.sock"))
99 .unwrap_or_else(|| PathBuf::from("/tmp/spn-daemon.sock"))
100}
101
102pub fn daemon_socket_exists() -> bool {
104 default_socket_path().exists()
105}
106
107#[derive(Debug)]
115pub struct SpnClient {
116 #[cfg(unix)]
117 stream: Option<UnixStream>,
118 fallback_mode: bool,
119}
120
121impl SpnClient {
122 #[cfg(unix)]
128 pub async fn connect() -> Result<Self, Error> {
129 Self::connect_to(&default_socket_path()).await
130 }
131
132 #[cfg(unix)]
136 pub async fn connect_to(socket_path: &PathBuf) -> Result<Self, Error> {
137 debug!("Connecting to spn daemon at {:?}", socket_path);
138
139 let stream =
140 UnixStream::connect(socket_path)
141 .await
142 .map_err(|e| Error::ConnectionFailed {
143 path: socket_path.clone(),
144 source: e,
145 })?;
146
147 let mut client = Self {
149 stream: Some(stream),
150 fallback_mode: false,
151 };
152
153 client.ping().await?;
154 debug!("Connected to spn daemon");
155
156 Ok(client)
157 }
158
159 #[cfg(unix)]
166 pub async fn connect_with_fallback() -> Result<Self, Error> {
167 match Self::connect().await {
168 Ok(client) => Ok(client),
169 Err(e) => {
170 warn!("spn daemon not running, using env var fallback: {}", e);
171 Ok(Self {
172 stream: None,
173 fallback_mode: true,
174 })
175 }
176 }
177 }
178
179 #[cfg(not(unix))]
184 pub async fn connect_with_fallback() -> Result<Self, Error> {
185 debug!("Non-Unix platform: using env var fallback mode");
186 Ok(Self {
187 fallback_mode: true,
188 })
189 }
190
191 pub fn is_fallback_mode(&self) -> bool {
193 self.fallback_mode
194 }
195
196 #[cfg(unix)]
200 pub async fn ping(&mut self) -> Result<String, Error> {
201 let response = self.send_request(Request::Ping).await?;
202 match response {
203 Response::Pong { version } => Ok(version),
204 Response::Error { message } => Err(Error::DaemonError(message)),
205 _ => Err(Error::UnexpectedResponse),
206 }
207 }
208
209 #[cfg(unix)]
214 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
215 if self.fallback_mode {
216 return self.get_secret_from_env(provider);
217 }
218
219 let response = self
220 .send_request(Request::GetSecret {
221 provider: provider.to_string(),
222 })
223 .await?;
224
225 match response {
226 Response::Secret { value } => Ok(SecretString::from(value)),
227 Response::Error { message } => Err(Error::SecretNotFound {
228 provider: provider.to_string(),
229 details: message,
230 }),
231 _ => Err(Error::UnexpectedResponse),
232 }
233 }
234
235 #[cfg(not(unix))]
239 pub async fn get_secret(&mut self, provider: &str) -> Result<SecretString, Error> {
240 self.get_secret_from_env(provider)
241 }
242
243 #[cfg(unix)]
245 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
246 if self.fallback_mode {
247 return Ok(self.get_secret_from_env(provider).is_ok());
248 }
249
250 let response = self
251 .send_request(Request::HasSecret {
252 provider: provider.to_string(),
253 })
254 .await?;
255
256 match response {
257 Response::Exists { exists } => Ok(exists),
258 Response::Error { message } => Err(Error::DaemonError(message)),
259 _ => Err(Error::UnexpectedResponse),
260 }
261 }
262
263 #[cfg(not(unix))]
267 pub async fn has_secret(&mut self, provider: &str) -> Result<bool, Error> {
268 Ok(self.get_secret_from_env(provider).is_ok())
269 }
270
271 #[cfg(unix)]
273 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
274 if self.fallback_mode {
275 return Ok(self.list_env_providers());
276 }
277
278 let response = self.send_request(Request::ListProviders).await?;
279
280 match response {
281 Response::Providers { providers } => Ok(providers),
282 Response::Error { message } => Err(Error::DaemonError(message)),
283 _ => Err(Error::UnexpectedResponse),
284 }
285 }
286
287 #[cfg(not(unix))]
291 pub async fn list_providers(&mut self) -> Result<Vec<String>, Error> {
292 Ok(self.list_env_providers())
293 }
294
295 #[cfg(unix)]
300 pub async fn send_request(&mut self, request: Request) -> Result<Response, Error> {
301 let stream = self.stream.as_mut().ok_or(Error::NotConnected)?;
302
303 let request_json = serde_json::to_vec(&request).map_err(Error::SerializationError)?;
305
306 let len = request_json.len() as u32;
308 stream
309 .write_all(&len.to_be_bytes())
310 .await
311 .map_err(Error::IoError)?;
312 stream
313 .write_all(&request_json)
314 .await
315 .map_err(Error::IoError)?;
316
317 let mut len_buf = [0u8; 4];
319 stream
320 .read_exact(&mut len_buf)
321 .await
322 .map_err(Error::IoError)?;
323 let response_len = u32::from_be_bytes(len_buf) as usize;
324
325 if response_len > 1_048_576 {
327 return Err(Error::ResponseTooLarge(response_len));
328 }
329
330 let mut response_buf = vec![0u8; response_len];
332 stream
333 .read_exact(&mut response_buf)
334 .await
335 .map_err(Error::IoError)?;
336
337 let response: Response =
339 serde_json::from_slice(&response_buf).map_err(Error::DeserializationError)?;
340
341 Ok(response)
342 }
343
344 fn get_secret_from_env(&self, provider: &str) -> Result<SecretString, Error> {
347 let env_var = provider_to_env_var(provider).ok_or_else(|| Error::SecretNotFound {
348 provider: provider.to_string(),
349 details: format!("Unknown provider: {provider}"),
350 })?;
351 std::env::var(env_var)
352 .map(SecretString::from)
353 .map_err(|_| Error::SecretNotFound {
354 provider: provider.to_string(),
355 details: format!("Environment variable {env_var} not set"),
356 })
357 }
358
359 fn list_env_providers(&self) -> Vec<String> {
360 KNOWN_PROVIDERS
361 .iter()
362 .filter(|p| std::env::var(p.env_var).is_ok())
363 .map(|p| p.id.to_string())
364 .collect()
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_provider_to_env_var() {
374 assert_eq!(provider_to_env_var("anthropic"), Some("ANTHROPIC_API_KEY"));
376 assert_eq!(provider_to_env_var("openai"), Some("OPENAI_API_KEY"));
377 assert_eq!(provider_to_env_var("neo4j"), Some("NEO4J_PASSWORD"));
378 assert_eq!(provider_to_env_var("github"), Some("GITHUB_TOKEN"));
379 assert_eq!(provider_to_env_var("unknown"), None);
380 }
381
382 #[test]
383 fn test_default_socket_path() {
384 let path = default_socket_path();
385 assert!(path.to_string_lossy().contains(".spn"));
386 assert!(path.to_string_lossy().contains("daemon.sock"));
387 }
388
389 #[test]
390 fn test_daemon_socket_exists() {
391 let _exists = daemon_socket_exists();
394 }
395}