Skip to main content

vtcode_core/pods/
catalog.rs

1use crate::pods::state::PodState;
2use serde::{Deserialize, Serialize};
3use std::collections::BTreeMap;
4
5/// Root catalog describing known deployment profiles.
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct PodCatalog {
8    pub version: String,
9    #[serde(default)]
10    pub profiles: Vec<PodProfile>,
11}
12
13impl Default for PodCatalog {
14    fn default() -> Self {
15        Self::embedded_default()
16    }
17}
18
19/// A single deployment profile for a model.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct PodProfile {
22    pub name: String,
23    pub model: String,
24    pub gpu_count: usize,
25    #[serde(default)]
26    pub gpu_types: Vec<String>,
27    #[serde(default = "default_command_template")]
28    pub command_template: String,
29    #[serde(default)]
30    pub vllm_args: Vec<String>,
31    #[serde(default)]
32    pub env: BTreeMap<String, String>,
33}
34
35impl PodCatalog {
36    pub fn embedded_default() -> Self {
37        match serde_json::from_str(include_str!("default_catalog.json")) {
38            Ok(catalog) => catalog,
39            Err(_) => Self {
40                version: "1".to_string(),
41                profiles: vec![PodProfile {
42                    name: "generic-8b".to_string(),
43                    model: "meta-llama/Llama-3.1-8B-Instruct".to_string(),
44                    gpu_count: 1,
45                    gpu_types: vec![],
46                    command_template: default_command_template(),
47                    vllm_args: vec![
48                        "--trust-remote-code".to_string(),
49                        "--dtype".to_string(),
50                        "auto".to_string(),
51                        "--gpu-memory-utilization".to_string(),
52                        "0.90".to_string(),
53                        "--max-model-len".to_string(),
54                        "8192".to_string(),
55                    ],
56                    env: BTreeMap::new(),
57                }],
58            },
59        }
60    }
61
62    pub fn profiles_for_model(&self, model: &str) -> Vec<&PodProfile> {
63        self.profiles
64            .iter()
65            .filter(|profile| profile.name == model || profile.model == model)
66            .collect()
67    }
68
69    pub fn compatible_profiles<'a>(
70        &'a self,
71        pod: &PodState,
72    ) -> (Vec<&'a PodProfile>, Vec<&'a PodProfile>) {
73        let mut compatible = Vec::new();
74        let mut incompatible = Vec::new();
75
76        for profile in &self.profiles {
77            if profile.matches_pod(pod) {
78                compatible.push(profile);
79            } else {
80                incompatible.push(profile);
81            }
82        }
83
84        (compatible, incompatible)
85    }
86}
87
88impl PodProfile {
89    pub fn matches_pod(&self, pod: &PodState) -> bool {
90        if self.gpu_count > pod.gpu_count() {
91            return false;
92        }
93
94        if self.gpu_types.is_empty() {
95            return true;
96        }
97
98        let gpu_types = self
99            .gpu_types
100            .iter()
101            .map(|gpu_type| gpu_type.to_lowercase())
102            .collect::<Vec<_>>();
103
104        pod.gpus
105            .iter()
106            .filter(|gpu| {
107                let gpu_name = gpu.name.to_lowercase();
108                gpu_types.iter().any(|gpu_type| gpu_name.contains(gpu_type))
109            })
110            .take(self.gpu_count)
111            .count()
112            >= self.gpu_count
113    }
114
115    pub fn matches_gpu_count(&self, count: usize) -> bool {
116        self.gpu_count == count
117    }
118}
119
120fn default_command_template() -> String {
121    "vllm serve {{MODEL_ID}} --served-model-name {{NAME}} --port {{PORT}} {{VLLM_ARGS}}".to_string()
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::pods::state::{PodGpu, PodState};
128
129    #[test]
130    fn profile_matches_gpu_types_by_substring() {
131        let profile = PodProfile {
132            name: "test".to_string(),
133            model: "model".to_string(),
134            gpu_count: 1,
135            gpu_types: vec!["A100".to_string()],
136            command_template: default_command_template(),
137            vllm_args: vec![],
138            env: BTreeMap::new(),
139        };
140        let pod = PodState {
141            name: "pod".to_string(),
142            ssh: "ssh root@example.com".to_string(),
143            models_path: None,
144            gpus: vec![PodGpu {
145                id: 0,
146                name: "NVIDIA A100-SXM4-80GB".to_string(),
147            }],
148            models: BTreeMap::new(),
149        };
150
151        assert!(profile.matches_pod(&pod));
152    }
153
154    #[test]
155    fn profile_requires_enough_matching_gpu_types() {
156        let profile = PodProfile {
157            name: "dual-a100".to_string(),
158            model: "model".to_string(),
159            gpu_count: 2,
160            gpu_types: vec!["A100".to_string()],
161            command_template: default_command_template(),
162            vllm_args: vec![],
163            env: BTreeMap::new(),
164        };
165        let pod = PodState {
166            name: "pod".to_string(),
167            ssh: "ssh root@example.com".to_string(),
168            models_path: None,
169            gpus: vec![
170                PodGpu {
171                    id: 0,
172                    name: "NVIDIA A100-SXM4-80GB".to_string(),
173                },
174                PodGpu {
175                    id: 1,
176                    name: "NVIDIA RTX 4090".to_string(),
177                },
178            ],
179            models: BTreeMap::new(),
180        };
181
182        assert!(!profile.matches_pod(&pod));
183    }
184}