Skip to main content

vtcode_core/pods/
manager.rs

1use crate::pods::catalog::{PodCatalog, PodProfile};
2use crate::pods::state::{PodGpu, PodHealth, PodState, PodsState, RunningModel};
3use crate::pods::store::PodsStore;
4use crate::pods::transport::{PodTransport, SshTransport};
5use anyhow::{Context, Result, anyhow};
6use parking_lot::RwLock;
7use std::collections::BTreeMap;
8use std::fmt::Write;
9use std::sync::Arc;
10
11const DEFAULT_START_PORT: u16 = 8001;
12const DEFAULT_LOG_DIR: &str = ".vllm_logs";
13
14/// Request payload for starting a model on a pod.
15#[derive(Debug, Clone)]
16pub struct PodStartRequest {
17    pub pod_name: Option<String>,
18    pub ssh: Option<String>,
19    pub gpus: Vec<PodGpu>,
20    pub models_path: Option<String>,
21    pub name: String,
22    pub model: String,
23    pub profile: Option<String>,
24    pub requested_gpu_count: Option<usize>,
25    pub memory: Option<f32>,
26    pub context: Option<String>,
27}
28
29/// Result of a successful model launch.
30#[derive(Debug, Clone)]
31pub struct PodStartResult {
32    pub pod: PodState,
33    pub entry: RunningModel,
34    pub profile: PodProfile,
35    pub launch_command: String,
36}
37
38/// Row returned by `pods list`.
39#[derive(Debug, Clone)]
40pub struct PodListEntry {
41    pub name: String,
42    pub model: String,
43    pub port: u16,
44    pub pid: u32,
45    pub gpu_ids: Vec<u32>,
46    pub status: PodHealth,
47}
48
49/// Row returned by `pods known-models`.
50#[derive(Debug, Clone)]
51pub struct PodStatusDetail {
52    pub name: String,
53    pub model: String,
54    pub gpu_count: usize,
55}
56
57/// Split known models into compatible and incompatible groups.
58#[derive(Debug, Clone)]
59pub struct KnownModelsReport {
60    pub compatible: Vec<PodStatusDetail>,
61    pub incompatible: Vec<PodStatusDetail>,
62}
63
64/// `pods list` report.
65#[derive(Debug, Clone)]
66pub struct PodStatusReport {
67    pub pod_name: String,
68    pub entries: Vec<PodListEntry>,
69}
70
71/// Pod manager coordinating persisted state, catalog lookup, and SSH execution.
72#[derive(Clone)]
73pub struct PodManager {
74    store: PodsStore,
75    transport: Arc<dyn PodTransport>,
76    cached_state: Arc<RwLock<Option<PodsState>>>,
77    cached_catalog: Arc<RwLock<Option<PodCatalog>>>,
78}
79
80impl PodManager {
81    pub fn new() -> Result<Self> {
82        Ok(Self::with_transport(
83            PodsStore::default_store()?,
84            Arc::new(SshTransport),
85        ))
86    }
87
88    pub fn with_transport(store: PodsStore, transport: Arc<dyn PodTransport>) -> Self {
89        Self {
90            store,
91            transport,
92            cached_state: Arc::new(RwLock::new(None)),
93            cached_catalog: Arc::new(RwLock::new(None)),
94        }
95    }
96
97    pub async fn load_state(&self) -> Result<PodsState> {
98        if let Some(state) = self.cached_state.read().clone() {
99            return Ok(state);
100        }
101
102        let state = self.store.load_state().await?;
103        *self.cached_state.write() = Some(state.clone());
104        Ok(state)
105    }
106
107    pub async fn load_catalog(&self) -> Result<PodCatalog> {
108        if let Some(catalog) = self.cached_catalog.read().clone() {
109            return Ok(catalog);
110        }
111
112        let catalog = self.store.load_catalog().await?;
113        *self.cached_catalog.write() = Some(catalog.clone());
114        Ok(catalog)
115    }
116
117    pub async fn start_model(&self, request: PodStartRequest) -> Result<PodStartResult> {
118        let mut state = self.load_state().await?;
119        let catalog = self.load_catalog().await?;
120        let pod = self.resolve_active_pod(&mut state, &request).await?;
121        let profile = self.resolve_profile(&catalog, &pod, &request)?;
122        let gpu_count = request
123            .requested_gpu_count
124            .unwrap_or(profile.gpu_count)
125            .max(1);
126
127        if gpu_count > pod.gpu_count() {
128            return Err(anyhow!(
129                "requested {} GPUs but pod '{}' only has {}",
130                gpu_count,
131                pod.name,
132                pod.gpu_count()
133            ));
134        }
135
136        let selected_gpu_ids = select_gpus(&pod, gpu_count);
137        let port = next_port(&pod);
138        let sanitized_name = sanitize_component(&request.name);
139        let run_path = format!("/tmp/model_run_{sanitized_name}.sh");
140        let wrapper_path = format!("/tmp/model_wrapper_{sanitized_name}.sh");
141        let log_path = format!("~/{DEFAULT_LOG_DIR}/{sanitized_name}.log");
142        let vllm_args = render_args(
143            &profile.vllm_args,
144            request.memory,
145            request.context.as_deref(),
146        )?;
147        let run_script = render_run_script(
148            &profile,
149            &request.model,
150            &request.name,
151            port,
152            &vllm_args,
153            &selected_gpu_ids,
154            pod.models_path.as_deref(),
155        );
156        let wrapper_script = render_wrapper_script(&run_path, &log_path);
157
158        self.transport
159            .write_file(&pod.ssh, &run_path, &run_script)
160            .await?;
161        self.transport
162            .write_file(&pod.ssh, &wrapper_path, &wrapper_script)
163            .await?;
164
165        let chmod = self
166            .transport
167            .exec_capture(&pod.ssh, &format!("chmod +x {run_path} {wrapper_path}"))
168            .await?;
169        if !chmod.success {
170            return Err(anyhow!("failed to chmod remote scripts: {}", chmod.stderr));
171        }
172
173        let launch_command = format!(
174            "mkdir -p ~/{DEFAULT_LOG_DIR} && setsid {wrapper_path} >/dev/null 2>&1 < /dev/null & echo $!"
175        );
176        let launch = self
177            .transport
178            .exec_capture(&pod.ssh, &launch_command)
179            .await?;
180        if !launch.success {
181            return Err(anyhow!("failed to launch remote model: {}", launch.stderr));
182        }
183
184        let pid = parse_pid(&launch.stdout)?;
185        let entry = RunningModel {
186            model: request.model.clone(),
187            port,
188            gpu_ids: selected_gpu_ids.clone(),
189            pid,
190            profile: profile.name.clone(),
191        };
192
193        let mut updated_pod = pod.clone();
194        updated_pod
195            .models
196            .insert(request.name.clone(), entry.clone());
197        state.active_pod = Some(updated_pod.clone());
198        self.persist_state(&state).await?;
199
200        Ok(PodStartResult {
201            pod: updated_pod,
202            entry,
203            profile,
204            launch_command,
205        })
206    }
207
208    pub async fn stop_model(&self, name: &str) -> Result<Option<RunningModel>> {
209        let mut state = self.load_state().await?;
210        let Some(pod) = state.active_pod.as_mut() else {
211            return Ok(None);
212        };
213
214        let Some(entry) = pod.models.remove(name) else {
215            return Ok(None);
216        };
217
218        let command = format!(
219            "pkill -TERM -P {} || true; kill {} || true",
220            entry.pid, entry.pid
221        );
222        let output = self.transport.exec_capture(&pod.ssh, &command).await?;
223        if !output.success {
224            return Err(anyhow!(
225                "failed to stop model '{}': {}",
226                name,
227                output.stderr
228            ));
229        }
230
231        self.persist_state(&state).await?;
232        Ok(Some(entry))
233    }
234
235    pub async fn stop_all_models(&self) -> Result<usize> {
236        let mut state = self.load_state().await?;
237        let Some(pod) = state.active_pod.as_mut() else {
238            return Ok(0);
239        };
240
241        let pids = pod
242            .models
243            .values()
244            .map(|entry| entry.pid.to_string())
245            .collect::<Vec<_>>();
246
247        if pids.is_empty() {
248            return Ok(0);
249        }
250
251        let command = format!(
252            "for PID in {}; do pkill -TERM -P \"$PID\" || true; kill \"$PID\" || true; done",
253            pids.join(" ")
254        );
255        let output = self.transport.exec_capture(&pod.ssh, &command).await?;
256        if !output.success {
257            return Err(anyhow!("failed to stop models: {}", output.stderr));
258        }
259
260        let stopped = pod.models.len();
261        pod.models.clear();
262        self.persist_state(&state).await?;
263        Ok(stopped)
264    }
265
266    pub async fn list_models(&self) -> Result<PodStatusReport> {
267        let state = self.load_state().await?;
268        let Some(pod) = state.active_pod.as_ref() else {
269            return Err(anyhow!("no active pod configured"));
270        };
271
272        let mut entries = Vec::new();
273        for (name, model) in &pod.models {
274            let status = self.inspect_model(pod, name, model).await?;
275            entries.push(PodListEntry {
276                name: name.clone(),
277                model: model.model.clone(),
278                port: model.port,
279                pid: model.pid,
280                gpu_ids: model.gpu_ids.clone(),
281                status,
282            });
283        }
284
285        Ok(PodStatusReport {
286            pod_name: pod.name.clone(),
287            entries,
288        })
289    }
290
291    pub async fn stream_logs(&self, name: &str) -> Result<()> {
292        let state = self.load_state().await?;
293        let Some(pod) = state.active_pod.as_ref() else {
294            return Err(anyhow!("no active pod configured"));
295        };
296        let Some(entry) = pod.models.get(name) else {
297            return Err(anyhow!("unknown model '{}'", name));
298        };
299
300        let log_path = format!("~/{DEFAULT_LOG_DIR}/{}.log", sanitize_component(name));
301        let command = format!("tail -f {log_path}");
302        let _ = entry;
303        self.transport.exec_stream(&pod.ssh, &command).await
304    }
305
306    pub async fn known_models(&self) -> Result<KnownModelsReport> {
307        let state = self.load_state().await?;
308        let Some(pod) = state.active_pod.as_ref() else {
309            return Err(anyhow!("no active pod configured"));
310        };
311        let catalog = self.load_catalog().await?;
312        let (compatible, incompatible) = catalog.compatible_profiles(pod);
313
314        Ok(KnownModelsReport {
315            compatible: compatible
316                .into_iter()
317                .map(|profile| PodStatusDetail {
318                    name: profile.name.clone(),
319                    model: profile.model.clone(),
320                    gpu_count: profile.gpu_count,
321                })
322                .collect(),
323            incompatible: incompatible
324                .into_iter()
325                .map(|profile| PodStatusDetail {
326                    name: profile.name.clone(),
327                    model: profile.model.clone(),
328                    gpu_count: profile.gpu_count,
329                })
330                .collect(),
331        })
332    }
333
334    async fn persist_state(&self, state: &PodsState) -> Result<()> {
335        self.store.save_state(state).await?;
336        *self.cached_state.write() = Some(state.clone());
337        Ok(())
338    }
339
340    async fn resolve_active_pod(
341        &self,
342        state: &mut PodsState,
343        request: &PodStartRequest,
344    ) -> Result<PodState> {
345        let mut pod = state.active_pod.clone().unwrap_or_else(|| PodState {
346            name: request
347                .pod_name
348                .clone()
349                .unwrap_or_else(|| "active-pod".to_string()),
350            ssh: request.ssh.clone().unwrap_or_default(),
351            models_path: request.models_path.clone(),
352            gpus: Vec::new(),
353            models: BTreeMap::new(),
354        });
355
356        if let Some(name) = &request.pod_name {
357            pod.name = name.clone();
358        }
359        if let Some(ssh) = &request.ssh {
360            pod.ssh = ssh.clone();
361        }
362        if let Some(models_path) = &request.models_path {
363            pod.models_path = Some(models_path.clone());
364        }
365        if !request.gpus.is_empty() {
366            pod.gpus = request.gpus.clone();
367        }
368
369        if pod.ssh.is_empty() {
370            return Err(anyhow!(
371                "pod ssh command is required; pass --ssh or reuse the active pod"
372            ));
373        }
374        if pod.gpus.is_empty() {
375            return Err(anyhow!(
376                "pod gpu inventory is required; pass --gpu entries or reuse the active pod"
377            ));
378        }
379
380        state.active_pod = Some(pod.clone());
381        self.persist_state(state).await?;
382        Ok(pod)
383    }
384
385    fn resolve_profile(
386        &self,
387        catalog: &PodCatalog,
388        pod: &PodState,
389        request: &PodStartRequest,
390    ) -> Result<PodProfile> {
391        if let Some(profile_name) = request.profile.as_deref() {
392            let profile = catalog
393                .profiles
394                .iter()
395                .find(|profile| profile.name == profile_name)
396                .cloned()
397                .ok_or_else(|| anyhow!("unknown pod profile '{}'", profile_name))?;
398            return Ok(profile);
399        }
400
401        let mut candidates = catalog.profiles_for_model(&request.model);
402        candidates.retain(|profile| profile.matches_pod(pod));
403
404        if let Some(requested_gpu_count) = request.requested_gpu_count {
405            if let Some(profile) = candidates
406                .iter()
407                .copied()
408                .find(|profile| profile.matches_gpu_count(requested_gpu_count))
409            {
410                return Ok(profile.clone());
411            }
412
413            if !candidates.is_empty() {
414                let valid_counts = candidates
415                    .iter()
416                    .map(|profile| profile.gpu_count.to_string())
417                    .collect::<Vec<_>>();
418                return Err(anyhow!(
419                    "no profile for '{}' with {} GPUs; valid counts: {}",
420                    request.model,
421                    requested_gpu_count,
422                    valid_counts.join(", ")
423                ));
424            }
425        }
426
427        candidates
428            .into_iter()
429            .max_by_key(|profile| profile.gpu_count)
430            .cloned()
431            .or_else(|| {
432                if request.model.is_empty() {
433                    None
434                } else {
435                    Some(PodProfile {
436                        name: request.model.clone(),
437                        model: request.model.clone(),
438                        gpu_count: request.requested_gpu_count.unwrap_or(1),
439                        gpu_types: Vec::new(),
440                        command_template: default_command_template(),
441                        vllm_args: vec![
442                            "--trust-remote-code".to_string(),
443                            "--dtype".to_string(),
444                            "auto".to_string(),
445                        ],
446                        env: BTreeMap::new(),
447                    })
448                }
449            })
450            .ok_or_else(|| anyhow!("no profile found for model '{}'", request.model))
451    }
452
453    async fn inspect_model(
454        &self,
455        pod: &PodState,
456        name: &str,
457        entry: &RunningModel,
458    ) -> Result<PodHealth> {
459        let process = self
460            .transport
461            .exec_capture(&pod.ssh, &format!("ps -p {}", entry.pid))
462            .await?;
463        let health = self
464            .transport
465            .exec_capture(
466                &pod.ssh,
467                &format!("curl -s -f http://localhost:{}/health", entry.port),
468            )
469            .await?;
470        let log_tail = self
471            .transport
472            .exec_capture(
473                &pod.ssh,
474                &format!(
475                    "tail -n 20 ~/{DEFAULT_LOG_DIR}/{}.log",
476                    sanitize_component(name)
477                ),
478            )
479            .await?;
480
481        Ok(classify_status(
482            process.success,
483            health.success,
484            &log_tail.stdout,
485        ))
486    }
487}
488
489impl Default for PodManager {
490    fn default() -> Self {
491        Self::new().expect("pod manager should initialize with a home directory")
492    }
493}
494
495fn render_run_script(
496    profile: &PodProfile,
497    model: &str,
498    name: &str,
499    port: u16,
500    vllm_args: &[String],
501    gpu_ids: &[u32],
502    models_path: Option<&str>,
503) -> String {
504    let mut script = String::new();
505    script.push_str("#!/usr/bin/env bash\n");
506    script.push_str("set -euo pipefail\n");
507    script.push_str("export HF_HUB_ENABLE_HF_TRANSFER=1\n");
508    script.push_str("export VLLM_NO_USAGE_STATS=1\n");
509    script.push_str("export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True\n");
510    script.push_str("export FORCE_COLOR=1\n");
511    script.push_str("export TERM=xterm-256color\n");
512
513    for (key, value) in &profile.env {
514        let _ = writeln!(script, "export {key}={}", shell_quote(value));
515    }
516
517    if gpu_ids.len() == 1 {
518        let _ = writeln!(script, "export CUDA_VISIBLE_DEVICES={}", gpu_ids[0]);
519    }
520
521    if let Some(models_path) = models_path {
522        let _ = writeln!(script, "export MODELS_PATH={}", shell_quote(models_path));
523    }
524
525    let command = render_template(
526        &profile.command_template,
527        model,
528        name,
529        port,
530        &join_args(vllm_args),
531    );
532    let _ = writeln!(script, "exec {command}");
533    script
534}
535
536fn render_wrapper_script(run_path: &str, log_path: &str) -> String {
537    format!(
538        "#!/usr/bin/env bash\nset -euo pipefail\nmkdir -p ~/.vllm_logs\nscript -q -f -c {run_path} {log_path}\n"
539    )
540}
541
542fn render_template(template: &str, model: &str, name: &str, port: u16, vllm_args: &str) -> String {
543    template
544        .replace("{{MODEL_ID}}", model)
545        .replace("{{NAME}}", name)
546        .replace("{{PORT}}", &port.to_string())
547        .replace("{{VLLM_ARGS}}", vllm_args)
548}
549
550fn join_args(args: &[String]) -> String {
551    args.iter()
552        .map(|value| shell_quote(value))
553        .collect::<Vec<_>>()
554        .join(" ")
555}
556
557fn shell_quote(value: &str) -> String {
558    if value.is_empty() {
559        return "''".to_string();
560    }
561
562    if value.chars().all(|ch| {
563        ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | '.' | '/' | ':' | ',' | '=' | '@')
564    }) {
565        return value.to_string();
566    }
567
568    format!("'{}'", value.replace('\'', r#"'"'"'"#))
569}
570
571fn select_gpus(pod: &PodState, count: usize) -> Vec<u32> {
572    if count >= pod.gpus.len() {
573        return pod.gpus.iter().map(|gpu| gpu.id).collect();
574    }
575
576    let mut usage: BTreeMap<u32, usize> = pod.gpus.iter().map(|gpu| (gpu.id, 0)).collect();
577    for model in pod.models.values() {
578        for gpu_id in &model.gpu_ids {
579            if let Some(slot) = usage.get_mut(gpu_id) {
580                *slot += 1;
581            }
582        }
583    }
584
585    let mut gpus = pod.gpus.clone();
586    gpus.sort_by_key(|gpu| (usage.get(&gpu.id).copied().unwrap_or(0), gpu.id));
587    gpus.into_iter().take(count).map(|gpu| gpu.id).collect()
588}
589
590fn next_port(pod: &PodState) -> u16 {
591    let mut port = DEFAULT_START_PORT;
592    let occupied: std::collections::HashSet<u16> =
593        pod.models.values().map(|model| model.port).collect();
594    while occupied.contains(&port) {
595        port = port.saturating_add(1);
596    }
597    port
598}
599
600fn parse_pid(output: &str) -> Result<u32> {
601    let pid = output
602        .trim()
603        .lines()
604        .find_map(|line| line.trim().parse::<u32>().ok())
605        .ok_or_else(|| anyhow!("launch command did not return a pid"))?;
606    Ok(pid)
607}
608
609fn sanitize_component(name: &str) -> String {
610    let mut out = String::with_capacity(name.len());
611    for ch in name.chars() {
612        if ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-') {
613            out.push(ch);
614        } else {
615            out.push('_');
616        }
617    }
618    if out.is_empty() {
619        "model".to_string()
620    } else {
621        out
622    }
623}
624
625fn classify_status(process_alive: bool, health_ok: bool, log_tail: &str) -> PodHealth {
626    if process_alive && health_ok {
627        return PodHealth::Running;
628    }
629
630    let lower = log_tail.to_lowercase();
631    let failed = lower.contains("model runner exiting with code")
632        || lower.contains("script exited with code")
633        || lower.contains("torch.outofmemoryerror")
634        || lower.contains("cuda out of memory")
635        || lower.contains("runtimeerror: engine core initialization failed");
636
637    if failed {
638        PodHealth::Crashed
639    } else if process_alive {
640        PodHealth::Starting
641    } else {
642        PodHealth::Dead
643    }
644}
645
646fn render_args(args: &[String], memory: Option<f32>, context: Option<&str>) -> Result<Vec<String>> {
647    let mut rendered = Vec::new();
648    let mut index = 0;
649    while index < args.len() {
650        let arg = &args[index];
651        if arg == "--gpu-memory-utilization" {
652            index += 2;
653            continue;
654        }
655        if arg == "--max-model-len" {
656            index += 2;
657            continue;
658        }
659        rendered.push(arg.clone());
660        index += 1;
661    }
662
663    if let Some(memory) = memory {
664        let utilization = (memory / 100.0).clamp(0.0, 1.0);
665        rendered.push("--gpu-memory-utilization".to_string());
666        rendered.push(format!("{utilization:.2}"));
667    }
668
669    if let Some(context) = context {
670        rendered.push("--max-model-len".to_string());
671        rendered.push(parse_context_size(context)?.to_string());
672    }
673
674    Ok(rendered)
675}
676
677fn parse_context_size(value: &str) -> Result<u32> {
678    let trimmed = value.trim().to_lowercase();
679    if let Some(stripped) = trimmed.strip_suffix('k') {
680        return Ok(stripped
681            .parse::<u32>()
682            .with_context(|| format!("invalid context size '{value}'"))?
683            .saturating_mul(1024));
684    }
685
686    trimmed
687        .parse::<u32>()
688        .with_context(|| format!("invalid context size '{value}'"))
689}
690
691fn default_command_template() -> String {
692    "vllm serve {{MODEL_ID}} --served-model-name {{NAME}} --port {{PORT}} {{VLLM_ARGS}}".to_string()
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698    use crate::pods::catalog::PodProfile;
699    use crate::pods::state::PodGpu;
700    use crate::pods::transport::CommandOutput;
701    use anyhow::Result;
702    use async_trait::async_trait;
703    use parking_lot::Mutex;
704    use std::collections::VecDeque;
705
706    #[derive(Clone, Default)]
707    struct MockTransport {
708        commands: Arc<Mutex<Vec<String>>>,
709        writes: Arc<Mutex<Vec<(String, String)>>>,
710        responses: Arc<Mutex<VecDeque<CommandOutput>>>,
711    }
712
713    #[async_trait]
714    impl PodTransport for MockTransport {
715        async fn exec_capture(&self, _ssh_target: &str, command: &str) -> Result<CommandOutput> {
716            self.commands.lock().push(command.to_string());
717            Ok(self
718                .responses
719                .lock()
720                .pop_front()
721                .unwrap_or_else(|| CommandOutput {
722                    success: true,
723                    stdout: "12345\n".to_string(),
724                    stderr: String::new(),
725                }))
726        }
727
728        async fn write_file(
729            &self,
730            _ssh_target: &str,
731            remote_path: &str,
732            contents: &str,
733        ) -> Result<()> {
734            self.writes
735                .lock()
736                .push((remote_path.to_string(), contents.to_string()));
737            Ok(())
738        }
739
740        async fn exec_stream(&self, _ssh_target: &str, command: &str) -> Result<()> {
741            self.commands.lock().push(command.to_string());
742            Ok(())
743        }
744    }
745
746    #[test]
747    fn context_parser_handles_shorthand() {
748        assert_eq!(parse_context_size("32k").expect("parse"), 32_768);
749        assert_eq!(parse_context_size("131072").expect("parse"), 131_072);
750    }
751
752    #[test]
753    fn classify_status_detects_failure_patterns() {
754        assert_eq!(
755            classify_status(
756                true,
757                false,
758                "RuntimeError: Engine core initialization failed"
759            ),
760            PodHealth::Crashed
761        );
762        assert_eq!(classify_status(false, false, ""), PodHealth::Dead);
763        assert_eq!(classify_status(true, false, ""), PodHealth::Starting);
764    }
765
766    #[test]
767    fn select_gpus_prefers_less_loaded_devices() {
768        let pod = PodState {
769            name: "pod".to_string(),
770            ssh: "ssh root@example.com".to_string(),
771            models_path: None,
772            gpus: vec![
773                PodGpu {
774                    id: 0,
775                    name: "A100".to_string(),
776                },
777                PodGpu {
778                    id: 1,
779                    name: "A100".to_string(),
780                },
781            ],
782            models: BTreeMap::from([(
783                "existing".to_string(),
784                RunningModel {
785                    model: "model".to_string(),
786                    port: 8001,
787                    gpu_ids: vec![0],
788                    pid: 1,
789                    profile: "profile".to_string(),
790                },
791            )]),
792        };
793
794        assert_eq!(select_gpus(&pod, 1), vec![1]);
795    }
796
797    #[tokio::test]
798    async fn render_and_launch_flow_updates_state() {
799        let store = PodsStore::new(
800            std::env::temp_dir().join(format!("vtcode-pods-start-test-{}", std::process::id())),
801        );
802        let transport = Arc::new(MockTransport::default());
803        let manager = PodManager::with_transport(store, transport.clone());
804        let state = PodsState {
805            version: env!("CARGO_PKG_VERSION").to_string(),
806            active_pod: Some(PodState {
807                name: "gpu-box".to_string(),
808                ssh: "ssh root@example.com".to_string(),
809                models_path: Some("/models".to_string()),
810                gpus: vec![PodGpu {
811                    id: 0,
812                    name: "A100".to_string(),
813                }],
814                models: BTreeMap::new(),
815            }),
816        };
817        manager.store.save_state(&state).await.expect("save");
818        manager
819            .store
820            .save_catalog(&PodCatalog {
821                version: "1".to_string(),
822                profiles: vec![PodProfile {
823                    name: "test".to_string(),
824                    model: "test/model".to_string(),
825                    gpu_count: 1,
826                    gpu_types: vec!["A100".to_string()],
827                    command_template: default_command_template(),
828                    vllm_args: vec!["--max-model-len".to_string(), "4096".to_string()],
829                    env: BTreeMap::new(),
830                }],
831            })
832            .await
833            .expect("save catalog");
834
835        let result = manager
836            .start_model(PodStartRequest {
837                pod_name: None,
838                ssh: None,
839                gpus: Vec::new(),
840                models_path: None,
841                name: "local".to_string(),
842                model: "test/model".to_string(),
843                profile: None,
844                requested_gpu_count: None,
845                memory: Some(75.0),
846                context: Some("4k".to_string()),
847            })
848            .await
849            .expect("start");
850
851        assert_eq!(result.entry.pid, 12345);
852        assert!(
853            transport
854                .writes
855                .lock()
856                .iter()
857                .any(|(path, contents)| path.contains("model_run_local.sh")
858                    && contents.contains("vllm serve"))
859        );
860    }
861
862    #[test]
863    fn memory_override_rewrites_existing_argument() {
864        let args = render_args(
865            &[
866                "--dtype".to_string(),
867                "auto".to_string(),
868                "--gpu-memory-utilization".to_string(),
869                "0.80".to_string(),
870            ],
871            Some(90.0),
872            None,
873        )
874        .expect("render args");
875
876        assert!(
877            args.windows(2)
878                .any(|pair| pair == ["--gpu-memory-utilization", "0.90"])
879        );
880    }
881}