Skip to main content

spn_client/
protocol.rs

1//! Protocol types for daemon communication.
2//!
3//! The protocol uses length-prefixed JSON over Unix sockets.
4//!
5//! ## Wire Format
6//!
7//! ```text
8//! [4 bytes: message length (big-endian u32)][JSON payload]
9//! ```
10//!
11//! ## Protocol Versioning
12//!
13//! The protocol version is exchanged during the initial PING/PONG handshake.
14//! This allows clients and daemons to detect incompatible versions early.
15//!
16//! - `protocol_version`: Integer version for wire protocol changes
17//! - `version`: CLI version string for display purposes
18//!
19//! When the protocol version doesn't match, clients should warn and may
20//! fall back to environment variables.
21//!
22//! ## Example
23//!
24//! Request:
25//! ```json
26//! { "cmd": "GET_SECRET", "provider": "anthropic" }
27//! ```
28//!
29//! Response:
30//! ```json
31//! { "ok": true, "secret": "sk-ant-..." }
32//! ```
33
34use serde::{Deserialize, Serialize};
35use spn_core::{LoadConfig, ModelInfo, PullProgress, RunningModel};
36
37// ============================================================================
38// JOB TYPES (IPC-friendly versions)
39// ============================================================================
40
41/// Job state in the scheduler (IPC version).
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "lowercase")]
44pub enum IpcJobState {
45    Pending,
46    Running,
47    Completed,
48    Failed,
49    Cancelled,
50}
51
52impl std::fmt::Display for IpcJobState {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            IpcJobState::Pending => write!(f, "pending"),
56            IpcJobState::Running => write!(f, "running"),
57            IpcJobState::Completed => write!(f, "completed"),
58            IpcJobState::Failed => write!(f, "failed"),
59            IpcJobState::Cancelled => write!(f, "cancelled"),
60        }
61    }
62}
63
64/// Job status for IPC responses.
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct IpcJobStatus {
67    /// Job ID (8-char UUID prefix)
68    pub id: String,
69    /// Workflow path
70    pub workflow: String,
71    /// Current state
72    pub state: IpcJobState,
73    /// Optional job name
74    pub name: Option<String>,
75    /// Progress percentage (0-100)
76    pub progress: u8,
77    /// Error message (if failed)
78    pub error: Option<String>,
79    /// Output from the workflow (if completed)
80    pub output: Option<String>,
81    /// Creation timestamp (Unix epoch millis)
82    pub created_at: u64,
83    /// Start timestamp (Unix epoch millis, if started)
84    pub started_at: Option<u64>,
85    /// End timestamp (Unix epoch millis, if finished)
86    pub ended_at: Option<u64>,
87}
88
89/// Scheduler statistics for IPC responses.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct IpcSchedulerStats {
92    /// Total jobs (all states)
93    pub total: usize,
94    /// Pending jobs
95    pub pending: usize,
96    /// Currently running jobs
97    pub running: usize,
98    /// Completed jobs
99    pub completed: usize,
100    /// Failed jobs
101    pub failed: usize,
102    /// Cancelled jobs
103    pub cancelled: usize,
104    /// Whether nika binary is available
105    pub has_nika: bool,
106}
107
108/// Progress update for model operations (pull, load, delete).
109///
110/// Used for streaming progress from daemon to CLI during long-running operations.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct ModelProgress {
113    /// Current status message (e.g., "downloading", "verifying", "extracting")
114    pub status: String,
115    /// Bytes/units completed (optional for indeterminate operations)
116    pub completed: Option<u64>,
117    /// Total bytes/units (optional for indeterminate operations)
118    pub total: Option<u64>,
119    /// Model digest (for pull operations)
120    pub digest: Option<String>,
121}
122
123impl ModelProgress {
124    /// Calculate completion percentage (0.0 - 100.0).
125    /// Returns None if total is unknown or zero.
126    pub fn percentage(&self) -> Option<f64> {
127        match (self.completed, self.total) {
128            (Some(completed), Some(total)) if total > 0 => {
129                Some((completed as f64 / total as f64) * 100.0)
130            }
131            _ => None,
132        }
133    }
134
135    /// Create a new indeterminate progress (spinner mode).
136    pub fn indeterminate(status: impl Into<String>) -> Self {
137        Self {
138            status: status.into(),
139            completed: None,
140            total: None,
141            digest: None,
142        }
143    }
144
145    /// Create a determinate progress (progress bar mode).
146    pub fn determinate(status: impl Into<String>, completed: u64, total: u64) -> Self {
147        Self {
148            status: status.into(),
149            completed: Some(completed),
150            total: Some(total),
151            digest: None,
152        }
153    }
154
155    /// Create from PullProgress (from spn_core/spn_ollama).
156    pub fn from_pull_progress(p: &PullProgress) -> Self {
157        Self {
158            status: p.status.clone(),
159            completed: Some(p.completed),
160            total: Some(p.total),
161            digest: None, // PullProgress doesn't have digest field
162        }
163    }
164}
165
166/// Current protocol version.
167/// - Adding required fields to requests/responses
168/// - Changing the serialization format
169/// - Removing commands or response variants
170///
171/// Do NOT increment for:
172/// - Adding new optional fields
173/// - Adding new commands (backwards compatible)
174pub const PROTOCOL_VERSION: u32 = 1;
175
176/// Default protocol version for backwards compatibility.
177/// Old daemons that don't send protocol_version are assumed to be v0.
178fn default_protocol_version() -> u32 {
179    0
180}
181
182/// Request sent to the daemon.
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(tag = "cmd")]
185pub enum Request {
186    /// Ping the daemon to check it's alive.
187    #[serde(rename = "PING")]
188    Ping,
189
190    /// Get a secret for a provider.
191    #[serde(rename = "GET_SECRET")]
192    GetSecret { provider: String },
193
194    /// Check if a secret exists.
195    #[serde(rename = "HAS_SECRET")]
196    HasSecret { provider: String },
197
198    /// List all available providers.
199    #[serde(rename = "LIST_PROVIDERS")]
200    ListProviders,
201
202    // ==================== Model Commands ====================
203    /// List all installed models.
204    #[serde(rename = "MODEL_LIST")]
205    ModelList,
206
207    /// Pull/download a model.
208    #[serde(rename = "MODEL_PULL")]
209    ModelPull { name: String },
210
211    /// Load a model into memory.
212    #[serde(rename = "MODEL_LOAD")]
213    ModelLoad {
214        name: String,
215        #[serde(default)]
216        config: Option<LoadConfig>,
217    },
218
219    /// Unload a model from memory.
220    #[serde(rename = "MODEL_UNLOAD")]
221    ModelUnload { name: String },
222
223    /// Get status of running models.
224    #[serde(rename = "MODEL_STATUS")]
225    ModelStatus,
226
227    /// Delete a model.
228    #[serde(rename = "MODEL_DELETE")]
229    ModelDelete { name: String },
230
231    /// Run inference on a model.
232    #[serde(rename = "MODEL_RUN")]
233    ModelRun {
234        /// Model name (e.g., llama3.2)
235        model: String,
236        /// User prompt
237        prompt: String,
238        /// System prompt (optional)
239        #[serde(default)]
240        system: Option<String>,
241        /// Temperature (0.0 - 2.0)
242        #[serde(default)]
243        temperature: Option<f32>,
244        /// Enable streaming (not yet supported via IPC)
245        #[serde(default)]
246        stream: bool,
247    },
248
249    // ==================== Job Commands ====================
250    /// Submit a workflow job for background execution.
251    #[serde(rename = "JOB_SUBMIT")]
252    JobSubmit {
253        /// Path to workflow file
254        workflow: String,
255        /// Optional workflow arguments
256        #[serde(default)]
257        args: Vec<String>,
258        /// Optional job name for display
259        #[serde(default)]
260        name: Option<String>,
261        /// Job priority (higher = more urgent)
262        #[serde(default)]
263        priority: i32,
264    },
265
266    /// Get status of a specific job.
267    #[serde(rename = "JOB_STATUS")]
268    JobStatus {
269        /// Job ID (8-character short UUID)
270        job_id: String,
271    },
272
273    /// List all jobs (optionally filtered by state).
274    #[serde(rename = "JOB_LIST")]
275    JobList {
276        /// Filter by state (pending, running, completed, failed, cancelled)
277        #[serde(default)]
278        state: Option<String>,
279    },
280
281    /// Cancel a running or pending job.
282    #[serde(rename = "JOB_CANCEL")]
283    JobCancel {
284        /// Job ID to cancel
285        job_id: String,
286    },
287
288    /// Get scheduler statistics.
289    #[serde(rename = "JOB_STATS")]
290    JobStats,
291}
292
293/// Response from the daemon.
294#[derive(Debug, Clone, Serialize, Deserialize)]
295#[serde(untagged)]
296pub enum Response {
297    /// Successful ping response with version info.
298    Pong {
299        /// Protocol version for compatibility checking.
300        /// Clients should verify this matches PROTOCOL_VERSION.
301        #[serde(default = "default_protocol_version")]
302        protocol_version: u32,
303        /// CLI version string for display.
304        version: String,
305    },
306
307    /// Secret value response.
308    ///
309    /// # Security Note
310    ///
311    /// The secret is transmitted as plain JSON over the Unix socket. This is secure because:
312    /// - Unix socket requires peer credential verification (same UID only)
313    /// - Socket permissions are 0600 (owner-only)
314    /// - Connection is local-only (no network exposure)
315    Secret { value: String },
316
317    /// Secret existence check response.
318    Exists { exists: bool },
319
320    /// Provider list response.
321    Providers { providers: Vec<String> },
322
323    // ==================== Model Responses ====================
324    /// List of installed models.
325    Models { models: Vec<ModelInfo> },
326
327    /// List of currently running/loaded models.
328    RunningModels { running: Vec<RunningModel> },
329
330    /// Generic success response.
331    Success { success: bool },
332
333    /// Model run result with generated content.
334    ModelRunResult {
335        /// Generated content from the model.
336        content: String,
337        /// Optional stats (tokens_per_second, etc.)
338        #[serde(default)]
339        stats: Option<serde_json::Value>,
340    },
341
342    /// Error response.
343    Error { message: String },
344
345    // ==================== Streaming Responses ====================
346    /// Progress update for model operations (streaming).
347    Progress {
348        /// Progress details
349        progress: ModelProgress,
350    },
351
352    /// End of stream marker.
353    StreamEnd {
354        /// Whether the operation succeeded
355        success: bool,
356        /// Error message if failed
357        #[serde(default)]
358        error: Option<String>,
359    },
360
361    // ==================== Job Responses ====================
362    /// Job submitted response with initial status.
363    JobSubmitted {
364        /// The job status
365        job: IpcJobStatus,
366    },
367
368    /// Single job status response.
369    JobStatusResult {
370        /// The job status (None if job not found)
371        job: Option<IpcJobStatus>,
372    },
373
374    /// Job list response.
375    JobListResult {
376        /// List of jobs
377        jobs: Vec<IpcJobStatus>,
378    },
379
380    /// Job cancelled response.
381    JobCancelled {
382        /// Whether cancellation succeeded
383        cancelled: bool,
384        /// Job ID that was cancelled
385        job_id: String,
386    },
387
388    /// Scheduler statistics response.
389    JobStatsResult {
390        /// Scheduler stats
391        stats: IpcSchedulerStats,
392    },
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_request_serialization() {
401        let ping = Request::Ping;
402        let json = serde_json::to_string(&ping).unwrap();
403        assert_eq!(json, r#"{"cmd":"PING"}"#);
404
405        let get_secret = Request::GetSecret {
406            provider: "anthropic".to_string(),
407        };
408        let json = serde_json::to_string(&get_secret).unwrap();
409        assert_eq!(json, r#"{"cmd":"GET_SECRET","provider":"anthropic"}"#);
410
411        let has_secret = Request::HasSecret {
412            provider: "openai".to_string(),
413        };
414        let json = serde_json::to_string(&has_secret).unwrap();
415        assert_eq!(json, r#"{"cmd":"HAS_SECRET","provider":"openai"}"#);
416
417        let list = Request::ListProviders;
418        let json = serde_json::to_string(&list).unwrap();
419        assert_eq!(json, r#"{"cmd":"LIST_PROVIDERS"}"#);
420    }
421
422    #[test]
423    fn test_response_deserialization() {
424        // Pong with protocol version
425        let json = r#"{"protocol_version":1,"version":"0.14.2"}"#;
426        let response: Response = serde_json::from_str(json).unwrap();
427        assert!(
428            matches!(response, Response::Pong { protocol_version, version }
429                if protocol_version == 1 && version == "0.14.2")
430        );
431
432        // Pong without protocol version (backwards compatibility)
433        let json = r#"{"version":"0.9.0"}"#;
434        let response: Response = serde_json::from_str(json).unwrap();
435        assert!(
436            matches!(response, Response::Pong { protocol_version, version }
437                if protocol_version == 0 && version == "0.9.0")
438        );
439
440        // Secret
441        let json = r#"{"value":"sk-test-123"}"#;
442        let response: Response = serde_json::from_str(json).unwrap();
443        assert!(matches!(response, Response::Secret { value } if value == "sk-test-123"));
444
445        // Exists
446        let json = r#"{"exists":true}"#;
447        let response: Response = serde_json::from_str(json).unwrap();
448        assert!(matches!(response, Response::Exists { exists } if exists));
449
450        // Providers
451        let json = r#"{"providers":["anthropic","openai"]}"#;
452        let response: Response = serde_json::from_str(json).unwrap();
453        assert!(
454            matches!(response, Response::Providers { providers } if providers == vec!["anthropic", "openai"])
455        );
456
457        // Error
458        let json = r#"{"message":"Not found"}"#;
459        let response: Response = serde_json::from_str(json).unwrap();
460        assert!(matches!(response, Response::Error { message } if message == "Not found"));
461    }
462
463    #[test]
464    fn test_model_progress_serialization() {
465        let progress = ModelProgress {
466            status: "downloading".into(),
467            completed: Some(50),
468            total: Some(100),
469            digest: Some("sha256:abc123".into()),
470        };
471
472        let json = serde_json::to_string(&progress).unwrap();
473        let parsed: ModelProgress = serde_json::from_str(&json).unwrap();
474
475        assert_eq!(parsed.status, "downloading");
476        assert_eq!(parsed.completed, Some(50));
477        assert_eq!(parsed.total, Some(100));
478    }
479
480    #[test]
481    fn test_model_progress_percentage() {
482        let progress = ModelProgress {
483            status: "downloading".into(),
484            completed: Some(75),
485            total: Some(100),
486            digest: None,
487        };
488
489        assert_eq!(progress.percentage(), Some(75.0));
490
491        let no_total = ModelProgress {
492            status: "starting".into(),
493            completed: None,
494            total: None,
495            digest: None,
496        };
497
498        assert_eq!(no_total.percentage(), None);
499    }
500
501    #[test]
502    fn test_model_progress_constructors() {
503        let indeterminate = ModelProgress::indeterminate("loading");
504        assert_eq!(indeterminate.status, "loading");
505        assert!(indeterminate.percentage().is_none());
506
507        let determinate = ModelProgress::determinate("downloading", 50, 100);
508        assert_eq!(determinate.percentage(), Some(50.0));
509    }
510
511    #[test]
512    fn test_response_progress_variant() {
513        let progress = ModelProgress::determinate("downloading", 50, 100);
514        let response = Response::Progress { progress };
515
516        let json = serde_json::to_string(&response).unwrap();
517        assert!(json.contains("downloading"));
518    }
519
520    #[test]
521    fn test_response_stream_end_variant() {
522        let success_response = Response::StreamEnd {
523            success: true,
524            error: None,
525        };
526        let json = serde_json::to_string(&success_response).unwrap();
527        assert!(json.contains("success"));
528
529        let error_response = Response::StreamEnd {
530            success: false,
531            error: Some("Connection lost".into()),
532        };
533        let json = serde_json::to_string(&error_response).unwrap();
534        assert!(json.contains("Connection lost"));
535    }
536
537    // ==================== Job Protocol Tests ====================
538
539    #[test]
540    fn test_job_request_serialization() {
541        let submit = Request::JobSubmit {
542            workflow: "/path/to/workflow.yaml".into(),
543            args: vec!["--verbose".into()],
544            name: Some("Test Job".into()),
545            priority: 5,
546        };
547        let json = serde_json::to_string(&submit).unwrap();
548        assert!(json.contains("JOB_SUBMIT"));
549        assert!(json.contains("workflow.yaml"));
550
551        let status = Request::JobStatus {
552            job_id: "abc12345".into(),
553        };
554        let json = serde_json::to_string(&status).unwrap();
555        assert!(json.contains("JOB_STATUS"));
556        assert!(json.contains("abc12345"));
557
558        let list = Request::JobList { state: None };
559        let json = serde_json::to_string(&list).unwrap();
560        assert!(json.contains("JOB_LIST"));
561
562        let cancel = Request::JobCancel {
563            job_id: "def67890".into(),
564        };
565        let json = serde_json::to_string(&cancel).unwrap();
566        assert!(json.contains("JOB_CANCEL"));
567
568        let stats = Request::JobStats;
569        let json = serde_json::to_string(&stats).unwrap();
570        assert!(json.contains("JOB_STATS"));
571    }
572
573    #[test]
574    fn test_ipc_job_state_serialization() {
575        assert_eq!(
576            serde_json::to_string(&IpcJobState::Pending).unwrap(),
577            r#""pending""#
578        );
579        assert_eq!(
580            serde_json::to_string(&IpcJobState::Running).unwrap(),
581            r#""running""#
582        );
583        assert_eq!(
584            serde_json::to_string(&IpcJobState::Completed).unwrap(),
585            r#""completed""#
586        );
587        assert_eq!(
588            serde_json::to_string(&IpcJobState::Failed).unwrap(),
589            r#""failed""#
590        );
591        assert_eq!(
592            serde_json::to_string(&IpcJobState::Cancelled).unwrap(),
593            r#""cancelled""#
594        );
595    }
596
597    #[test]
598    fn test_ipc_job_status_serialization() {
599        let status = IpcJobStatus {
600            id: "abc12345".into(),
601            workflow: "/path/to/test.yaml".into(),
602            state: IpcJobState::Running,
603            name: Some("Test Job".into()),
604            progress: 50,
605            error: None,
606            output: None,
607            created_at: 1710000000000,
608            started_at: Some(1710000001000),
609            ended_at: None,
610        };
611
612        let json = serde_json::to_string(&status).unwrap();
613        assert!(json.contains("abc12345"));
614        assert!(json.contains("running"));
615        assert!(json.contains("Test Job"));
616    }
617
618    #[test]
619    fn test_ipc_scheduler_stats_serialization() {
620        let stats = IpcSchedulerStats {
621            total: 10,
622            pending: 2,
623            running: 3,
624            completed: 4,
625            failed: 1,
626            cancelled: 0,
627            has_nika: true,
628        };
629
630        let json = serde_json::to_string(&stats).unwrap();
631        let parsed: IpcSchedulerStats = serde_json::from_str(&json).unwrap();
632
633        assert_eq!(parsed.total, 10);
634        assert_eq!(parsed.running, 3);
635        assert!(parsed.has_nika);
636    }
637
638    #[test]
639    fn test_job_response_variants() {
640        // JobSubmitted
641        let status = IpcJobStatus {
642            id: "abc12345".into(),
643            workflow: "/test.yaml".into(),
644            state: IpcJobState::Pending,
645            name: None,
646            progress: 0,
647            error: None,
648            output: None,
649            created_at: 1710000000000,
650            started_at: None,
651            ended_at: None,
652        };
653        let response = Response::JobSubmitted { job: status };
654        let json = serde_json::to_string(&response).unwrap();
655        assert!(json.contains("abc12345"));
656
657        // JobCancelled
658        let response = Response::JobCancelled {
659            cancelled: true,
660            job_id: "def67890".into(),
661        };
662        let json = serde_json::to_string(&response).unwrap();
663        assert!(json.contains("cancelled"));
664        assert!(json.contains("def67890"));
665    }
666}