1use serde::{Deserialize, Serialize};
35use spn_core::{LoadConfig, ModelInfo, PullProgress, RunningModel};
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
43#[serde(rename_all = "lowercase")]
44pub enum IpcJobState {
45 Pending,
46 Running,
47 Completed,
48 Failed,
49 Cancelled,
50}
51
52impl std::fmt::Display for IpcJobState {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 IpcJobState::Pending => write!(f, "pending"),
56 IpcJobState::Running => write!(f, "running"),
57 IpcJobState::Completed => write!(f, "completed"),
58 IpcJobState::Failed => write!(f, "failed"),
59 IpcJobState::Cancelled => write!(f, "cancelled"),
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct IpcJobStatus {
67 pub id: String,
69 pub workflow: String,
71 pub state: IpcJobState,
73 pub name: Option<String>,
75 pub progress: u8,
77 pub error: Option<String>,
79 pub output: Option<String>,
81 pub created_at: u64,
83 pub started_at: Option<u64>,
85 pub ended_at: Option<u64>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct IpcSchedulerStats {
92 pub total: usize,
94 pub pending: usize,
96 pub running: usize,
98 pub completed: usize,
100 pub failed: usize,
102 pub cancelled: usize,
104 pub has_nika: bool,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct ModelProgress {
113 pub status: String,
115 pub completed: Option<u64>,
117 pub total: Option<u64>,
119 pub digest: Option<String>,
121}
122
123impl ModelProgress {
124 pub fn percentage(&self) -> Option<f64> {
127 match (self.completed, self.total) {
128 (Some(completed), Some(total)) if total > 0 => {
129 Some((completed as f64 / total as f64) * 100.0)
130 }
131 _ => None,
132 }
133 }
134
135 pub fn indeterminate(status: impl Into<String>) -> Self {
137 Self {
138 status: status.into(),
139 completed: None,
140 total: None,
141 digest: None,
142 }
143 }
144
145 pub fn determinate(status: impl Into<String>, completed: u64, total: u64) -> Self {
147 Self {
148 status: status.into(),
149 completed: Some(completed),
150 total: Some(total),
151 digest: None,
152 }
153 }
154
155 pub fn from_pull_progress(p: &PullProgress) -> Self {
157 Self {
158 status: p.status.clone(),
159 completed: Some(p.completed),
160 total: Some(p.total),
161 digest: None, }
163 }
164}
165
166pub const PROTOCOL_VERSION: u32 = 1;
175
176fn default_protocol_version() -> u32 {
179 0
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(tag = "cmd")]
185pub enum Request {
186 #[serde(rename = "PING")]
188 Ping,
189
190 #[serde(rename = "GET_SECRET")]
192 GetSecret { provider: String },
193
194 #[serde(rename = "HAS_SECRET")]
196 HasSecret { provider: String },
197
198 #[serde(rename = "LIST_PROVIDERS")]
200 ListProviders,
201
202 #[serde(rename = "MODEL_LIST")]
205 ModelList,
206
207 #[serde(rename = "MODEL_PULL")]
209 ModelPull { name: String },
210
211 #[serde(rename = "MODEL_LOAD")]
213 ModelLoad {
214 name: String,
215 #[serde(default)]
216 config: Option<LoadConfig>,
217 },
218
219 #[serde(rename = "MODEL_UNLOAD")]
221 ModelUnload { name: String },
222
223 #[serde(rename = "MODEL_STATUS")]
225 ModelStatus,
226
227 #[serde(rename = "MODEL_DELETE")]
229 ModelDelete { name: String },
230
231 #[serde(rename = "MODEL_RUN")]
233 ModelRun {
234 model: String,
236 prompt: String,
238 #[serde(default)]
240 system: Option<String>,
241 #[serde(default)]
243 temperature: Option<f32>,
244 #[serde(default)]
246 stream: bool,
247 },
248
249 #[serde(rename = "JOB_SUBMIT")]
252 JobSubmit {
253 workflow: String,
255 #[serde(default)]
257 args: Vec<String>,
258 #[serde(default)]
260 name: Option<String>,
261 #[serde(default)]
263 priority: i32,
264 },
265
266 #[serde(rename = "JOB_STATUS")]
268 JobStatus {
269 job_id: String,
271 },
272
273 #[serde(rename = "JOB_LIST")]
275 JobList {
276 #[serde(default)]
278 state: Option<String>,
279 },
280
281 #[serde(rename = "JOB_CANCEL")]
283 JobCancel {
284 job_id: String,
286 },
287
288 #[serde(rename = "JOB_STATS")]
290 JobStats,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295#[serde(untagged)]
296pub enum Response {
297 Pong {
299 #[serde(default = "default_protocol_version")]
302 protocol_version: u32,
303 version: String,
305 },
306
307 Secret { value: String },
316
317 Exists { exists: bool },
319
320 Providers { providers: Vec<String> },
322
323 Models { models: Vec<ModelInfo> },
326
327 RunningModels { running: Vec<RunningModel> },
329
330 Success { success: bool },
332
333 ModelRunResult {
335 content: String,
337 #[serde(default)]
339 stats: Option<serde_json::Value>,
340 },
341
342 Error { message: String },
344
345 Progress {
348 progress: ModelProgress,
350 },
351
352 StreamEnd {
354 success: bool,
356 #[serde(default)]
358 error: Option<String>,
359 },
360
361 JobSubmitted {
364 job: IpcJobStatus,
366 },
367
368 JobStatusResult {
370 job: Option<IpcJobStatus>,
372 },
373
374 JobListResult {
376 jobs: Vec<IpcJobStatus>,
378 },
379
380 JobCancelled {
382 cancelled: bool,
384 job_id: String,
386 },
387
388 JobStatsResult {
390 stats: IpcSchedulerStats,
392 },
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_request_serialization() {
401 let ping = Request::Ping;
402 let json = serde_json::to_string(&ping).unwrap();
403 assert_eq!(json, r#"{"cmd":"PING"}"#);
404
405 let get_secret = Request::GetSecret {
406 provider: "anthropic".to_string(),
407 };
408 let json = serde_json::to_string(&get_secret).unwrap();
409 assert_eq!(json, r#"{"cmd":"GET_SECRET","provider":"anthropic"}"#);
410
411 let has_secret = Request::HasSecret {
412 provider: "openai".to_string(),
413 };
414 let json = serde_json::to_string(&has_secret).unwrap();
415 assert_eq!(json, r#"{"cmd":"HAS_SECRET","provider":"openai"}"#);
416
417 let list = Request::ListProviders;
418 let json = serde_json::to_string(&list).unwrap();
419 assert_eq!(json, r#"{"cmd":"LIST_PROVIDERS"}"#);
420 }
421
422 #[test]
423 fn test_response_deserialization() {
424 let json = r#"{"protocol_version":1,"version":"0.14.2"}"#;
426 let response: Response = serde_json::from_str(json).unwrap();
427 assert!(
428 matches!(response, Response::Pong { protocol_version, version }
429 if protocol_version == 1 && version == "0.14.2")
430 );
431
432 let json = r#"{"version":"0.9.0"}"#;
434 let response: Response = serde_json::from_str(json).unwrap();
435 assert!(
436 matches!(response, Response::Pong { protocol_version, version }
437 if protocol_version == 0 && version == "0.9.0")
438 );
439
440 let json = r#"{"value":"sk-test-123"}"#;
442 let response: Response = serde_json::from_str(json).unwrap();
443 assert!(matches!(response, Response::Secret { value } if value == "sk-test-123"));
444
445 let json = r#"{"exists":true}"#;
447 let response: Response = serde_json::from_str(json).unwrap();
448 assert!(matches!(response, Response::Exists { exists } if exists));
449
450 let json = r#"{"providers":["anthropic","openai"]}"#;
452 let response: Response = serde_json::from_str(json).unwrap();
453 assert!(
454 matches!(response, Response::Providers { providers } if providers == vec!["anthropic", "openai"])
455 );
456
457 let json = r#"{"message":"Not found"}"#;
459 let response: Response = serde_json::from_str(json).unwrap();
460 assert!(matches!(response, Response::Error { message } if message == "Not found"));
461 }
462
463 #[test]
464 fn test_model_progress_serialization() {
465 let progress = ModelProgress {
466 status: "downloading".into(),
467 completed: Some(50),
468 total: Some(100),
469 digest: Some("sha256:abc123".into()),
470 };
471
472 let json = serde_json::to_string(&progress).unwrap();
473 let parsed: ModelProgress = serde_json::from_str(&json).unwrap();
474
475 assert_eq!(parsed.status, "downloading");
476 assert_eq!(parsed.completed, Some(50));
477 assert_eq!(parsed.total, Some(100));
478 }
479
480 #[test]
481 fn test_model_progress_percentage() {
482 let progress = ModelProgress {
483 status: "downloading".into(),
484 completed: Some(75),
485 total: Some(100),
486 digest: None,
487 };
488
489 assert_eq!(progress.percentage(), Some(75.0));
490
491 let no_total = ModelProgress {
492 status: "starting".into(),
493 completed: None,
494 total: None,
495 digest: None,
496 };
497
498 assert_eq!(no_total.percentage(), None);
499 }
500
501 #[test]
502 fn test_model_progress_constructors() {
503 let indeterminate = ModelProgress::indeterminate("loading");
504 assert_eq!(indeterminate.status, "loading");
505 assert!(indeterminate.percentage().is_none());
506
507 let determinate = ModelProgress::determinate("downloading", 50, 100);
508 assert_eq!(determinate.percentage(), Some(50.0));
509 }
510
511 #[test]
512 fn test_response_progress_variant() {
513 let progress = ModelProgress::determinate("downloading", 50, 100);
514 let response = Response::Progress { progress };
515
516 let json = serde_json::to_string(&response).unwrap();
517 assert!(json.contains("downloading"));
518 }
519
520 #[test]
521 fn test_response_stream_end_variant() {
522 let success_response = Response::StreamEnd {
523 success: true,
524 error: None,
525 };
526 let json = serde_json::to_string(&success_response).unwrap();
527 assert!(json.contains("success"));
528
529 let error_response = Response::StreamEnd {
530 success: false,
531 error: Some("Connection lost".into()),
532 };
533 let json = serde_json::to_string(&error_response).unwrap();
534 assert!(json.contains("Connection lost"));
535 }
536
537 #[test]
540 fn test_job_request_serialization() {
541 let submit = Request::JobSubmit {
542 workflow: "/path/to/workflow.yaml".into(),
543 args: vec!["--verbose".into()],
544 name: Some("Test Job".into()),
545 priority: 5,
546 };
547 let json = serde_json::to_string(&submit).unwrap();
548 assert!(json.contains("JOB_SUBMIT"));
549 assert!(json.contains("workflow.yaml"));
550
551 let status = Request::JobStatus {
552 job_id: "abc12345".into(),
553 };
554 let json = serde_json::to_string(&status).unwrap();
555 assert!(json.contains("JOB_STATUS"));
556 assert!(json.contains("abc12345"));
557
558 let list = Request::JobList { state: None };
559 let json = serde_json::to_string(&list).unwrap();
560 assert!(json.contains("JOB_LIST"));
561
562 let cancel = Request::JobCancel {
563 job_id: "def67890".into(),
564 };
565 let json = serde_json::to_string(&cancel).unwrap();
566 assert!(json.contains("JOB_CANCEL"));
567
568 let stats = Request::JobStats;
569 let json = serde_json::to_string(&stats).unwrap();
570 assert!(json.contains("JOB_STATS"));
571 }
572
573 #[test]
574 fn test_ipc_job_state_serialization() {
575 assert_eq!(
576 serde_json::to_string(&IpcJobState::Pending).unwrap(),
577 r#""pending""#
578 );
579 assert_eq!(
580 serde_json::to_string(&IpcJobState::Running).unwrap(),
581 r#""running""#
582 );
583 assert_eq!(
584 serde_json::to_string(&IpcJobState::Completed).unwrap(),
585 r#""completed""#
586 );
587 assert_eq!(
588 serde_json::to_string(&IpcJobState::Failed).unwrap(),
589 r#""failed""#
590 );
591 assert_eq!(
592 serde_json::to_string(&IpcJobState::Cancelled).unwrap(),
593 r#""cancelled""#
594 );
595 }
596
597 #[test]
598 fn test_ipc_job_status_serialization() {
599 let status = IpcJobStatus {
600 id: "abc12345".into(),
601 workflow: "/path/to/test.yaml".into(),
602 state: IpcJobState::Running,
603 name: Some("Test Job".into()),
604 progress: 50,
605 error: None,
606 output: None,
607 created_at: 1710000000000,
608 started_at: Some(1710000001000),
609 ended_at: None,
610 };
611
612 let json = serde_json::to_string(&status).unwrap();
613 assert!(json.contains("abc12345"));
614 assert!(json.contains("running"));
615 assert!(json.contains("Test Job"));
616 }
617
618 #[test]
619 fn test_ipc_scheduler_stats_serialization() {
620 let stats = IpcSchedulerStats {
621 total: 10,
622 pending: 2,
623 running: 3,
624 completed: 4,
625 failed: 1,
626 cancelled: 0,
627 has_nika: true,
628 };
629
630 let json = serde_json::to_string(&stats).unwrap();
631 let parsed: IpcSchedulerStats = serde_json::from_str(&json).unwrap();
632
633 assert_eq!(parsed.total, 10);
634 assert_eq!(parsed.running, 3);
635 assert!(parsed.has_nika);
636 }
637
638 #[test]
639 fn test_job_response_variants() {
640 let status = IpcJobStatus {
642 id: "abc12345".into(),
643 workflow: "/test.yaml".into(),
644 state: IpcJobState::Pending,
645 name: None,
646 progress: 0,
647 error: None,
648 output: None,
649 created_at: 1710000000000,
650 started_at: None,
651 ended_at: None,
652 };
653 let response = Response::JobSubmitted { job: status };
654 let json = serde_json::to_string(&response).unwrap();
655 assert!(json.contains("abc12345"));
656
657 let response = Response::JobCancelled {
659 cancelled: true,
660 job_id: "def67890".into(),
661 };
662 let json = serde_json::to_string(&response).unwrap();
663 assert!(json.contains("cancelled"));
664 assert!(json.contains("def67890"));
665 }
666}