1use crate::pods::catalog::{PodCatalog, PodProfile};
2use crate::pods::state::{PodGpu, PodHealth, PodState, PodsState, RunningModel};
3use crate::pods::store::PodsStore;
4use crate::pods::transport::{PodTransport, SshTransport};
5use anyhow::{Context, Result, anyhow};
6use parking_lot::RwLock;
7use std::collections::BTreeMap;
8use std::fmt::Write;
9use std::sync::Arc;
10
11const DEFAULT_START_PORT: u16 = 8001;
12const DEFAULT_LOG_DIR: &str = ".vllm_logs";
13
14#[derive(Debug, Clone)]
16pub struct PodStartRequest {
17 pub pod_name: Option<String>,
18 pub ssh: Option<String>,
19 pub gpus: Vec<PodGpu>,
20 pub models_path: Option<String>,
21 pub name: String,
22 pub model: String,
23 pub profile: Option<String>,
24 pub requested_gpu_count: Option<usize>,
25 pub memory: Option<f32>,
26 pub context: Option<String>,
27}
28
29#[derive(Debug, Clone)]
31pub struct PodStartResult {
32 pub pod: PodState,
33 pub entry: RunningModel,
34 pub profile: PodProfile,
35 pub launch_command: String,
36}
37
38#[derive(Debug, Clone)]
40pub struct PodListEntry {
41 pub name: String,
42 pub model: String,
43 pub port: u16,
44 pub pid: u32,
45 pub gpu_ids: Vec<u32>,
46 pub status: PodHealth,
47}
48
49#[derive(Debug, Clone)]
51pub struct PodStatusDetail {
52 pub name: String,
53 pub model: String,
54 pub gpu_count: usize,
55}
56
57#[derive(Debug, Clone)]
59pub struct KnownModelsReport {
60 pub compatible: Vec<PodStatusDetail>,
61 pub incompatible: Vec<PodStatusDetail>,
62}
63
64#[derive(Debug, Clone)]
66pub struct PodStatusReport {
67 pub pod_name: String,
68 pub entries: Vec<PodListEntry>,
69}
70
71#[derive(Clone)]
73pub struct PodManager {
74 store: PodsStore,
75 transport: Arc<dyn PodTransport>,
76 cached_state: Arc<RwLock<Option<PodsState>>>,
77 cached_catalog: Arc<RwLock<Option<PodCatalog>>>,
78}
79
80impl PodManager {
81 pub fn new() -> Result<Self> {
82 Ok(Self::with_transport(
83 PodsStore::default_store()?,
84 Arc::new(SshTransport),
85 ))
86 }
87
88 pub fn with_transport(store: PodsStore, transport: Arc<dyn PodTransport>) -> Self {
89 Self {
90 store,
91 transport,
92 cached_state: Arc::new(RwLock::new(None)),
93 cached_catalog: Arc::new(RwLock::new(None)),
94 }
95 }
96
97 pub async fn load_state(&self) -> Result<PodsState> {
98 if let Some(state) = self.cached_state.read().clone() {
99 return Ok(state);
100 }
101
102 let state = self.store.load_state().await?;
103 *self.cached_state.write() = Some(state.clone());
104 Ok(state)
105 }
106
107 pub async fn load_catalog(&self) -> Result<PodCatalog> {
108 if let Some(catalog) = self.cached_catalog.read().clone() {
109 return Ok(catalog);
110 }
111
112 let catalog = self.store.load_catalog().await?;
113 *self.cached_catalog.write() = Some(catalog.clone());
114 Ok(catalog)
115 }
116
117 pub async fn start_model(&self, request: PodStartRequest) -> Result<PodStartResult> {
118 let mut state = self.load_state().await?;
119 let catalog = self.load_catalog().await?;
120 let pod = self.resolve_active_pod(&mut state, &request).await?;
121 let profile = self.resolve_profile(&catalog, &pod, &request)?;
122 let gpu_count = request
123 .requested_gpu_count
124 .unwrap_or(profile.gpu_count)
125 .max(1);
126
127 if gpu_count > pod.gpu_count() {
128 return Err(anyhow!(
129 "requested {} GPUs but pod '{}' only has {}",
130 gpu_count,
131 pod.name,
132 pod.gpu_count()
133 ));
134 }
135
136 let selected_gpu_ids = select_gpus(&pod, gpu_count);
137 let port = next_port(&pod);
138 let sanitized_name = sanitize_component(&request.name);
139 let run_path = format!("/tmp/model_run_{sanitized_name}.sh");
140 let wrapper_path = format!("/tmp/model_wrapper_{sanitized_name}.sh");
141 let log_path = format!("~/{DEFAULT_LOG_DIR}/{sanitized_name}.log");
142 let vllm_args = render_args(
143 &profile.vllm_args,
144 request.memory,
145 request.context.as_deref(),
146 )?;
147 let run_script = render_run_script(
148 &profile,
149 &request.model,
150 &request.name,
151 port,
152 &vllm_args,
153 &selected_gpu_ids,
154 pod.models_path.as_deref(),
155 );
156 let wrapper_script = render_wrapper_script(&run_path, &log_path);
157
158 self.transport
159 .write_file(&pod.ssh, &run_path, &run_script)
160 .await?;
161 self.transport
162 .write_file(&pod.ssh, &wrapper_path, &wrapper_script)
163 .await?;
164
165 let chmod = self
166 .transport
167 .exec_capture(&pod.ssh, &format!("chmod +x {run_path} {wrapper_path}"))
168 .await?;
169 if !chmod.success {
170 return Err(anyhow!("failed to chmod remote scripts: {}", chmod.stderr));
171 }
172
173 let launch_command = format!(
174 "mkdir -p ~/{DEFAULT_LOG_DIR} && setsid {wrapper_path} >/dev/null 2>&1 < /dev/null & echo $!"
175 );
176 let launch = self
177 .transport
178 .exec_capture(&pod.ssh, &launch_command)
179 .await?;
180 if !launch.success {
181 return Err(anyhow!("failed to launch remote model: {}", launch.stderr));
182 }
183
184 let pid = parse_pid(&launch.stdout)?;
185 let entry = RunningModel {
186 model: request.model.clone(),
187 port,
188 gpu_ids: selected_gpu_ids.clone(),
189 pid,
190 profile: profile.name.clone(),
191 };
192
193 let mut updated_pod = pod.clone();
194 updated_pod
195 .models
196 .insert(request.name.clone(), entry.clone());
197 state.active_pod = Some(updated_pod.clone());
198 self.persist_state(&state).await?;
199
200 Ok(PodStartResult {
201 pod: updated_pod,
202 entry,
203 profile,
204 launch_command,
205 })
206 }
207
208 pub async fn stop_model(&self, name: &str) -> Result<Option<RunningModel>> {
209 let mut state = self.load_state().await?;
210 let Some(pod) = state.active_pod.as_mut() else {
211 return Ok(None);
212 };
213
214 let Some(entry) = pod.models.remove(name) else {
215 return Ok(None);
216 };
217
218 let command = format!(
219 "pkill -TERM -P {} || true; kill {} || true",
220 entry.pid, entry.pid
221 );
222 let output = self.transport.exec_capture(&pod.ssh, &command).await?;
223 if !output.success {
224 return Err(anyhow!(
225 "failed to stop model '{}': {}",
226 name,
227 output.stderr
228 ));
229 }
230
231 self.persist_state(&state).await?;
232 Ok(Some(entry))
233 }
234
235 pub async fn stop_all_models(&self) -> Result<usize> {
236 let mut state = self.load_state().await?;
237 let Some(pod) = state.active_pod.as_mut() else {
238 return Ok(0);
239 };
240
241 let pids = pod
242 .models
243 .values()
244 .map(|entry| entry.pid.to_string())
245 .collect::<Vec<_>>();
246
247 if pids.is_empty() {
248 return Ok(0);
249 }
250
251 let command = format!(
252 "for PID in {}; do pkill -TERM -P \"$PID\" || true; kill \"$PID\" || true; done",
253 pids.join(" ")
254 );
255 let output = self.transport.exec_capture(&pod.ssh, &command).await?;
256 if !output.success {
257 return Err(anyhow!("failed to stop models: {}", output.stderr));
258 }
259
260 let stopped = pod.models.len();
261 pod.models.clear();
262 self.persist_state(&state).await?;
263 Ok(stopped)
264 }
265
266 pub async fn list_models(&self) -> Result<PodStatusReport> {
267 let state = self.load_state().await?;
268 let Some(pod) = state.active_pod.as_ref() else {
269 return Err(anyhow!("no active pod configured"));
270 };
271
272 let mut entries = Vec::new();
273 for (name, model) in &pod.models {
274 let status = self.inspect_model(pod, name, model).await?;
275 entries.push(PodListEntry {
276 name: name.clone(),
277 model: model.model.clone(),
278 port: model.port,
279 pid: model.pid,
280 gpu_ids: model.gpu_ids.clone(),
281 status,
282 });
283 }
284
285 Ok(PodStatusReport {
286 pod_name: pod.name.clone(),
287 entries,
288 })
289 }
290
291 pub async fn stream_logs(&self, name: &str) -> Result<()> {
292 let state = self.load_state().await?;
293 let Some(pod) = state.active_pod.as_ref() else {
294 return Err(anyhow!("no active pod configured"));
295 };
296 let Some(entry) = pod.models.get(name) else {
297 return Err(anyhow!("unknown model '{}'", name));
298 };
299
300 let log_path = format!("~/{DEFAULT_LOG_DIR}/{}.log", sanitize_component(name));
301 let command = format!("tail -f {log_path}");
302 let _ = entry;
303 self.transport.exec_stream(&pod.ssh, &command).await
304 }
305
306 pub async fn known_models(&self) -> Result<KnownModelsReport> {
307 let state = self.load_state().await?;
308 let Some(pod) = state.active_pod.as_ref() else {
309 return Err(anyhow!("no active pod configured"));
310 };
311 let catalog = self.load_catalog().await?;
312 let (compatible, incompatible) = catalog.compatible_profiles(pod);
313
314 Ok(KnownModelsReport {
315 compatible: compatible
316 .into_iter()
317 .map(|profile| PodStatusDetail {
318 name: profile.name.clone(),
319 model: profile.model.clone(),
320 gpu_count: profile.gpu_count,
321 })
322 .collect(),
323 incompatible: incompatible
324 .into_iter()
325 .map(|profile| PodStatusDetail {
326 name: profile.name.clone(),
327 model: profile.model.clone(),
328 gpu_count: profile.gpu_count,
329 })
330 .collect(),
331 })
332 }
333
334 async fn persist_state(&self, state: &PodsState) -> Result<()> {
335 self.store.save_state(state).await?;
336 *self.cached_state.write() = Some(state.clone());
337 Ok(())
338 }
339
340 async fn resolve_active_pod(
341 &self,
342 state: &mut PodsState,
343 request: &PodStartRequest,
344 ) -> Result<PodState> {
345 let mut pod = state.active_pod.clone().unwrap_or_else(|| PodState {
346 name: request
347 .pod_name
348 .clone()
349 .unwrap_or_else(|| "active-pod".to_string()),
350 ssh: request.ssh.clone().unwrap_or_default(),
351 models_path: request.models_path.clone(),
352 gpus: Vec::new(),
353 models: BTreeMap::new(),
354 });
355
356 if let Some(name) = &request.pod_name {
357 pod.name = name.clone();
358 }
359 if let Some(ssh) = &request.ssh {
360 pod.ssh = ssh.clone();
361 }
362 if let Some(models_path) = &request.models_path {
363 pod.models_path = Some(models_path.clone());
364 }
365 if !request.gpus.is_empty() {
366 pod.gpus = request.gpus.clone();
367 }
368
369 if pod.ssh.is_empty() {
370 return Err(anyhow!(
371 "pod ssh command is required; pass --ssh or reuse the active pod"
372 ));
373 }
374 if pod.gpus.is_empty() {
375 return Err(anyhow!(
376 "pod gpu inventory is required; pass --gpu entries or reuse the active pod"
377 ));
378 }
379
380 state.active_pod = Some(pod.clone());
381 self.persist_state(state).await?;
382 Ok(pod)
383 }
384
385 fn resolve_profile(
386 &self,
387 catalog: &PodCatalog,
388 pod: &PodState,
389 request: &PodStartRequest,
390 ) -> Result<PodProfile> {
391 if let Some(profile_name) = request.profile.as_deref() {
392 let profile = catalog
393 .profiles
394 .iter()
395 .find(|profile| profile.name == profile_name)
396 .cloned()
397 .ok_or_else(|| anyhow!("unknown pod profile '{}'", profile_name))?;
398 return Ok(profile);
399 }
400
401 let mut candidates = catalog.profiles_for_model(&request.model);
402 candidates.retain(|profile| profile.matches_pod(pod));
403
404 if let Some(requested_gpu_count) = request.requested_gpu_count {
405 if let Some(profile) = candidates
406 .iter()
407 .copied()
408 .find(|profile| profile.matches_gpu_count(requested_gpu_count))
409 {
410 return Ok(profile.clone());
411 }
412
413 if !candidates.is_empty() {
414 let valid_counts = candidates
415 .iter()
416 .map(|profile| profile.gpu_count.to_string())
417 .collect::<Vec<_>>();
418 return Err(anyhow!(
419 "no profile for '{}' with {} GPUs; valid counts: {}",
420 request.model,
421 requested_gpu_count,
422 valid_counts.join(", ")
423 ));
424 }
425 }
426
427 candidates
428 .into_iter()
429 .max_by_key(|profile| profile.gpu_count)
430 .cloned()
431 .or_else(|| {
432 if request.model.is_empty() {
433 None
434 } else {
435 Some(PodProfile {
436 name: request.model.clone(),
437 model: request.model.clone(),
438 gpu_count: request.requested_gpu_count.unwrap_or(1),
439 gpu_types: Vec::new(),
440 command_template: default_command_template(),
441 vllm_args: vec![
442 "--trust-remote-code".to_string(),
443 "--dtype".to_string(),
444 "auto".to_string(),
445 ],
446 env: BTreeMap::new(),
447 })
448 }
449 })
450 .ok_or_else(|| anyhow!("no profile found for model '{}'", request.model))
451 }
452
453 async fn inspect_model(
454 &self,
455 pod: &PodState,
456 name: &str,
457 entry: &RunningModel,
458 ) -> Result<PodHealth> {
459 let process = self
460 .transport
461 .exec_capture(&pod.ssh, &format!("ps -p {}", entry.pid))
462 .await?;
463 let health = self
464 .transport
465 .exec_capture(
466 &pod.ssh,
467 &format!("curl -s -f http://localhost:{}/health", entry.port),
468 )
469 .await?;
470 let log_tail = self
471 .transport
472 .exec_capture(
473 &pod.ssh,
474 &format!(
475 "tail -n 20 ~/{DEFAULT_LOG_DIR}/{}.log",
476 sanitize_component(name)
477 ),
478 )
479 .await?;
480
481 Ok(classify_status(
482 process.success,
483 health.success,
484 &log_tail.stdout,
485 ))
486 }
487}
488
489impl Default for PodManager {
490 fn default() -> Self {
491 Self::new().expect("pod manager should initialize with a home directory")
492 }
493}
494
495fn render_run_script(
496 profile: &PodProfile,
497 model: &str,
498 name: &str,
499 port: u16,
500 vllm_args: &[String],
501 gpu_ids: &[u32],
502 models_path: Option<&str>,
503) -> String {
504 let mut script = String::new();
505 script.push_str("#!/usr/bin/env bash\n");
506 script.push_str("set -euo pipefail\n");
507 script.push_str("export HF_HUB_ENABLE_HF_TRANSFER=1\n");
508 script.push_str("export VLLM_NO_USAGE_STATS=1\n");
509 script.push_str("export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True\n");
510 script.push_str("export FORCE_COLOR=1\n");
511 script.push_str("export TERM=xterm-256color\n");
512
513 for (key, value) in &profile.env {
514 let _ = writeln!(script, "export {key}={}", shell_quote(value));
515 }
516
517 if gpu_ids.len() == 1 {
518 let _ = writeln!(script, "export CUDA_VISIBLE_DEVICES={}", gpu_ids[0]);
519 }
520
521 if let Some(models_path) = models_path {
522 let _ = writeln!(script, "export MODELS_PATH={}", shell_quote(models_path));
523 }
524
525 let command = render_template(
526 &profile.command_template,
527 model,
528 name,
529 port,
530 &join_args(vllm_args),
531 );
532 let _ = writeln!(script, "exec {command}");
533 script
534}
535
536fn render_wrapper_script(run_path: &str, log_path: &str) -> String {
537 format!(
538 "#!/usr/bin/env bash\nset -euo pipefail\nmkdir -p ~/.vllm_logs\nscript -q -f -c {run_path} {log_path}\n"
539 )
540}
541
542fn render_template(template: &str, model: &str, name: &str, port: u16, vllm_args: &str) -> String {
543 template
544 .replace("{{MODEL_ID}}", model)
545 .replace("{{NAME}}", name)
546 .replace("{{PORT}}", &port.to_string())
547 .replace("{{VLLM_ARGS}}", vllm_args)
548}
549
550fn join_args(args: &[String]) -> String {
551 args.iter()
552 .map(|value| shell_quote(value))
553 .collect::<Vec<_>>()
554 .join(" ")
555}
556
557fn shell_quote(value: &str) -> String {
558 if value.is_empty() {
559 return "''".to_string();
560 }
561
562 if value.chars().all(|ch| {
563 ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | '.' | '/' | ':' | ',' | '=' | '@')
564 }) {
565 return value.to_string();
566 }
567
568 format!("'{}'", value.replace('\'', r#"'"'"'"#))
569}
570
571fn select_gpus(pod: &PodState, count: usize) -> Vec<u32> {
572 if count >= pod.gpus.len() {
573 return pod.gpus.iter().map(|gpu| gpu.id).collect();
574 }
575
576 let mut usage: BTreeMap<u32, usize> = pod.gpus.iter().map(|gpu| (gpu.id, 0)).collect();
577 for model in pod.models.values() {
578 for gpu_id in &model.gpu_ids {
579 if let Some(slot) = usage.get_mut(gpu_id) {
580 *slot += 1;
581 }
582 }
583 }
584
585 let mut gpus = pod.gpus.clone();
586 gpus.sort_by_key(|gpu| (usage.get(&gpu.id).copied().unwrap_or(0), gpu.id));
587 gpus.into_iter().take(count).map(|gpu| gpu.id).collect()
588}
589
590fn next_port(pod: &PodState) -> u16 {
591 let mut port = DEFAULT_START_PORT;
592 let occupied: std::collections::HashSet<u16> =
593 pod.models.values().map(|model| model.port).collect();
594 while occupied.contains(&port) {
595 port = port.saturating_add(1);
596 }
597 port
598}
599
600fn parse_pid(output: &str) -> Result<u32> {
601 let pid = output
602 .trim()
603 .lines()
604 .find_map(|line| line.trim().parse::<u32>().ok())
605 .ok_or_else(|| anyhow!("launch command did not return a pid"))?;
606 Ok(pid)
607}
608
609fn sanitize_component(name: &str) -> String {
610 let mut out = String::with_capacity(name.len());
611 for ch in name.chars() {
612 if ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-') {
613 out.push(ch);
614 } else {
615 out.push('_');
616 }
617 }
618 if out.is_empty() {
619 "model".to_string()
620 } else {
621 out
622 }
623}
624
625fn classify_status(process_alive: bool, health_ok: bool, log_tail: &str) -> PodHealth {
626 if process_alive && health_ok {
627 return PodHealth::Running;
628 }
629
630 let lower = log_tail.to_lowercase();
631 let failed = lower.contains("model runner exiting with code")
632 || lower.contains("script exited with code")
633 || lower.contains("torch.outofmemoryerror")
634 || lower.contains("cuda out of memory")
635 || lower.contains("runtimeerror: engine core initialization failed");
636
637 if failed {
638 PodHealth::Crashed
639 } else if process_alive {
640 PodHealth::Starting
641 } else {
642 PodHealth::Dead
643 }
644}
645
646fn render_args(args: &[String], memory: Option<f32>, context: Option<&str>) -> Result<Vec<String>> {
647 let mut rendered = Vec::new();
648 let mut index = 0;
649 while index < args.len() {
650 let arg = &args[index];
651 if arg == "--gpu-memory-utilization" {
652 index += 2;
653 continue;
654 }
655 if arg == "--max-model-len" {
656 index += 2;
657 continue;
658 }
659 rendered.push(arg.clone());
660 index += 1;
661 }
662
663 if let Some(memory) = memory {
664 let utilization = (memory / 100.0).clamp(0.0, 1.0);
665 rendered.push("--gpu-memory-utilization".to_string());
666 rendered.push(format!("{utilization:.2}"));
667 }
668
669 if let Some(context) = context {
670 rendered.push("--max-model-len".to_string());
671 rendered.push(parse_context_size(context)?.to_string());
672 }
673
674 Ok(rendered)
675}
676
677fn parse_context_size(value: &str) -> Result<u32> {
678 let trimmed = value.trim().to_lowercase();
679 if let Some(stripped) = trimmed.strip_suffix('k') {
680 return Ok(stripped
681 .parse::<u32>()
682 .with_context(|| format!("invalid context size '{value}'"))?
683 .saturating_mul(1024));
684 }
685
686 trimmed
687 .parse::<u32>()
688 .with_context(|| format!("invalid context size '{value}'"))
689}
690
691fn default_command_template() -> String {
692 "vllm serve {{MODEL_ID}} --served-model-name {{NAME}} --port {{PORT}} {{VLLM_ARGS}}".to_string()
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use crate::pods::catalog::PodProfile;
699 use crate::pods::state::PodGpu;
700 use crate::pods::transport::CommandOutput;
701 use anyhow::Result;
702 use async_trait::async_trait;
703 use parking_lot::Mutex;
704 use std::collections::VecDeque;
705
706 #[derive(Clone, Default)]
707 struct MockTransport {
708 commands: Arc<Mutex<Vec<String>>>,
709 writes: Arc<Mutex<Vec<(String, String)>>>,
710 responses: Arc<Mutex<VecDeque<CommandOutput>>>,
711 }
712
713 #[async_trait]
714 impl PodTransport for MockTransport {
715 async fn exec_capture(&self, _ssh_target: &str, command: &str) -> Result<CommandOutput> {
716 self.commands.lock().push(command.to_string());
717 Ok(self
718 .responses
719 .lock()
720 .pop_front()
721 .unwrap_or_else(|| CommandOutput {
722 success: true,
723 stdout: "12345\n".to_string(),
724 stderr: String::new(),
725 }))
726 }
727
728 async fn write_file(
729 &self,
730 _ssh_target: &str,
731 remote_path: &str,
732 contents: &str,
733 ) -> Result<()> {
734 self.writes
735 .lock()
736 .push((remote_path.to_string(), contents.to_string()));
737 Ok(())
738 }
739
740 async fn exec_stream(&self, _ssh_target: &str, command: &str) -> Result<()> {
741 self.commands.lock().push(command.to_string());
742 Ok(())
743 }
744 }
745
746 #[test]
747 fn context_parser_handles_shorthand() {
748 assert_eq!(parse_context_size("32k").expect("parse"), 32_768);
749 assert_eq!(parse_context_size("131072").expect("parse"), 131_072);
750 }
751
752 #[test]
753 fn classify_status_detects_failure_patterns() {
754 assert_eq!(
755 classify_status(
756 true,
757 false,
758 "RuntimeError: Engine core initialization failed"
759 ),
760 PodHealth::Crashed
761 );
762 assert_eq!(classify_status(false, false, ""), PodHealth::Dead);
763 assert_eq!(classify_status(true, false, ""), PodHealth::Starting);
764 }
765
766 #[test]
767 fn select_gpus_prefers_less_loaded_devices() {
768 let pod = PodState {
769 name: "pod".to_string(),
770 ssh: "ssh root@example.com".to_string(),
771 models_path: None,
772 gpus: vec![
773 PodGpu {
774 id: 0,
775 name: "A100".to_string(),
776 },
777 PodGpu {
778 id: 1,
779 name: "A100".to_string(),
780 },
781 ],
782 models: BTreeMap::from([(
783 "existing".to_string(),
784 RunningModel {
785 model: "model".to_string(),
786 port: 8001,
787 gpu_ids: vec![0],
788 pid: 1,
789 profile: "profile".to_string(),
790 },
791 )]),
792 };
793
794 assert_eq!(select_gpus(&pod, 1), vec![1]);
795 }
796
797 #[tokio::test]
798 async fn render_and_launch_flow_updates_state() {
799 let store = PodsStore::new(
800 std::env::temp_dir().join(format!("vtcode-pods-start-test-{}", std::process::id())),
801 );
802 let transport = Arc::new(MockTransport::default());
803 let manager = PodManager::with_transport(store, transport.clone());
804 let state = PodsState {
805 version: env!("CARGO_PKG_VERSION").to_string(),
806 active_pod: Some(PodState {
807 name: "gpu-box".to_string(),
808 ssh: "ssh root@example.com".to_string(),
809 models_path: Some("/models".to_string()),
810 gpus: vec![PodGpu {
811 id: 0,
812 name: "A100".to_string(),
813 }],
814 models: BTreeMap::new(),
815 }),
816 };
817 manager.store.save_state(&state).await.expect("save");
818 manager
819 .store
820 .save_catalog(&PodCatalog {
821 version: "1".to_string(),
822 profiles: vec![PodProfile {
823 name: "test".to_string(),
824 model: "test/model".to_string(),
825 gpu_count: 1,
826 gpu_types: vec!["A100".to_string()],
827 command_template: default_command_template(),
828 vllm_args: vec!["--max-model-len".to_string(), "4096".to_string()],
829 env: BTreeMap::new(),
830 }],
831 })
832 .await
833 .expect("save catalog");
834
835 let result = manager
836 .start_model(PodStartRequest {
837 pod_name: None,
838 ssh: None,
839 gpus: Vec::new(),
840 models_path: None,
841 name: "local".to_string(),
842 model: "test/model".to_string(),
843 profile: None,
844 requested_gpu_count: None,
845 memory: Some(75.0),
846 context: Some("4k".to_string()),
847 })
848 .await
849 .expect("start");
850
851 assert_eq!(result.entry.pid, 12345);
852 assert!(
853 transport
854 .writes
855 .lock()
856 .iter()
857 .any(|(path, contents)| path.contains("model_run_local.sh")
858 && contents.contains("vllm serve"))
859 );
860 }
861
862 #[test]
863 fn memory_override_rewrites_existing_argument() {
864 let args = render_args(
865 &[
866 "--dtype".to_string(),
867 "auto".to_string(),
868 "--gpu-memory-utilization".to_string(),
869 "0.80".to_string(),
870 ],
871 Some(90.0),
872 None,
873 )
874 .expect("render args");
875
876 assert!(
877 args.windows(2)
878 .any(|pair| pair == ["--gpu-memory-utilization", "0.90"])
879 );
880 }
881}