1use std::collections::HashMap;
36
37use crate::types::Model;
38
39#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ResolveError {
46 UnknownAlias {
50 name: String,
52 target: String,
54 scope: &'static str,
56 },
57 UnknownModel { name: String },
60}
61
62impl std::fmt::Display for ResolveError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 ResolveError::UnknownAlias {
66 name,
67 target,
68 scope,
69 } => write!(
70 f,
71 "{} alias '{}' points at unknown model '{}'",
72 scope, name, target
73 ),
74 ResolveError::UnknownModel { name } => {
75 write!(f, "unknown model or alias: {}", name)
76 }
77 }
78 }
79}
80
81impl std::error::Error for ResolveError {}
82
83pub fn resolve_model<'a>(
92 raw: &str,
93 provider_filter: Option<&str>,
94 project_aliases: Option<&HashMap<String, String>>,
95 global_aliases: &HashMap<String, String>,
96 all_models: &'a [Model],
97) -> Result<&'a Model, ResolveError> {
98 let (target, alias_scope): (&str, Option<&'static str>) = if let Some(map) =
100 project_aliases.and_then(|m| if m.is_empty() { None } else { Some(m) })
101 && let Some(t) = map.get(raw)
102 {
103 (t.as_str(), Some("project"))
104 } else if let Some(t) = global_aliases.get(raw) {
105 (t.as_str(), Some("global"))
106 } else {
107 (raw, None)
108 };
109
110 let (target_provider, target_id) = match target.split_once('/') {
113 Some((p, i)) if !p.is_empty() && !i.is_empty() => (Some(p), i),
114 _ => (None, target),
115 };
116
117 let effective_provider: Option<&str> = provider_filter.or(target_provider);
120
121 let found = all_models.iter().find(|m| {
123 m.id == target_id
124 && match effective_provider {
125 Some(p) => m.provider == p,
126 None => true,
127 }
128 });
129
130 match found {
131 Some(m) => Ok(m),
132 None => match alias_scope {
133 Some(scope) => Err(ResolveError::UnknownAlias {
134 name: raw.to_string(),
135 target: target.to_string(),
136 scope,
137 }),
138 None => Err(ResolveError::UnknownModel {
139 name: raw.to_string(),
140 }),
141 },
142 }
143}
144
145#[cfg(test)]
150mod tests {
151 use super::*;
152 use crate::types::{Model, ModelCost, ThinkingStyle};
153
154 fn mk_model(id: &str, provider: &str) -> Model {
155 Model {
156 id: id.into(),
157 name: id.into(),
158 api: "mock".into(),
159 provider: provider.into(),
160 base_url: "http://mock".into(),
161 thinking: ThinkingStyle::None,
162 cost: ModelCost::default(),
163 context_window: 100_000,
164 max_tokens: 4_096,
165 headers: std::collections::HashMap::new(),
166 }
167 }
168
169 fn make_models() -> Vec<Model> {
170 vec![
171 mk_model("opus-4", "anthropic"),
172 mk_model("haiku-4", "anthropic"),
173 mk_model("gpt-4", "openai"),
174 mk_model("dual", "anthropic"),
176 mk_model("dual", "openai"),
177 ]
178 }
179
180 #[test]
181 fn plain_id_pass_through() {
182 let models = make_models();
183 let aliases = HashMap::new();
184 let m = resolve_model("opus-4", None, None, &aliases, &models).unwrap();
185 assert_eq!(m.id, "opus-4");
186 assert_eq!(m.provider, "anthropic");
187 }
188
189 #[test]
190 fn unknown_plain_id_returns_unknown_model() {
191 let models = make_models();
192 let aliases = HashMap::new();
193 let err = resolve_model("nope", None, None, &aliases, &models).unwrap_err();
194 assert_eq!(
195 err,
196 ResolveError::UnknownModel {
197 name: "nope".into()
198 }
199 );
200 }
201
202 #[test]
203 fn global_alias_hit() {
204 let models = make_models();
205 let mut aliases = HashMap::new();
206 aliases.insert("smart".into(), "opus-4".into());
207 let m = resolve_model("smart", None, None, &aliases, &models).unwrap();
208 assert_eq!(m.id, "opus-4");
209 }
210
211 #[test]
212 fn project_alias_hit() {
213 let models = make_models();
214 let global = HashMap::new();
215 let mut project = HashMap::new();
216 project.insert("smart".into(), "haiku-4".into());
217 let m = resolve_model("smart", None, Some(&project), &global, &models).unwrap();
218 assert_eq!(m.id, "haiku-4");
219 }
220
221 #[test]
222 fn project_overrides_global() {
223 let models = make_models();
224 let mut global = HashMap::new();
225 global.insert("smart".into(), "opus-4".into());
226 let mut project = HashMap::new();
227 project.insert("smart".into(), "haiku-4".into());
228 let m = resolve_model("smart", None, Some(&project), &global, &models).unwrap();
229 assert_eq!(m.id, "haiku-4");
230 }
231
232 #[test]
233 fn unknown_alias_target_is_error() {
234 let models = make_models();
235 let mut global = HashMap::new();
236 global.insert("smart".into(), "ghost".into());
237 let err = resolve_model("smart", None, None, &global, &models).unwrap_err();
238 assert_eq!(
239 err,
240 ResolveError::UnknownAlias {
241 name: "smart".into(),
242 target: "ghost".into(),
243 scope: "global",
244 }
245 );
246 }
247
248 #[test]
249 fn unknown_project_alias_target_reports_project_scope() {
250 let models = make_models();
251 let global = HashMap::new();
252 let mut project = HashMap::new();
253 project.insert("planner".into(), "ghost".into());
254 let err = resolve_model("planner", None, Some(&project), &global, &models).unwrap_err();
255 assert!(matches!(
256 err,
257 ResolveError::UnknownAlias {
258 ref name,
259 ref target,
260 scope: "project",
261 } if name == "planner" && target == "ghost"
262 ));
263 }
264
265 #[test]
266 fn provider_prefixed_alias_target() {
267 let models = make_models();
268 let mut global = HashMap::new();
269 global.insert("dual_a".into(), "anthropic/dual".into());
270 global.insert("dual_o".into(), "openai/dual".into());
271 let a = resolve_model("dual_a", None, None, &global, &models).unwrap();
272 assert_eq!(a.provider, "anthropic");
273 let o = resolve_model("dual_o", None, None, &global, &models).unwrap();
274 assert_eq!(o.provider, "openai");
275 }
276
277 #[test]
278 fn explicit_provider_filter_overrides_alias_prefix() {
279 let models = make_models();
280 let mut global = HashMap::new();
281 global.insert("d".into(), "anthropic/dual".into());
283 let m = resolve_model("d", Some("openai"), None, &global, &models).unwrap();
284 assert_eq!(m.provider, "openai");
285 assert_eq!(m.id, "dual");
286 }
287
288 #[test]
289 fn explicit_provider_filter_on_plain_id() {
290 let models = make_models();
291 let global = HashMap::new();
292 let m = resolve_model("dual", Some("openai"), None, &global, &models).unwrap();
293 assert_eq!(m.provider, "openai");
294 }
295
296 #[test]
297 fn explicit_provider_filter_no_match_is_error() {
298 let models = make_models();
299 let global = HashMap::new();
300 let err = resolve_model("gpt-4", Some("anthropic"), None, &global, &models).unwrap_err();
303 assert_eq!(
304 err,
305 ResolveError::UnknownModel {
306 name: "gpt-4".into()
307 }
308 );
309 }
310
311 #[test]
312 fn alias_name_matches_real_model_id_alias_wins() {
313 let models = make_models();
314 let mut global = HashMap::new();
317 global.insert("opus-4".into(), "haiku-4".into());
318 let m = resolve_model("opus-4", None, None, &global, &models).unwrap();
319 assert_eq!(m.id, "haiku-4");
320 }
321
322 #[test]
323 fn split_on_first_slash_only() {
324 let mut models = make_models();
328 models.push(mk_model("bar/baz", "foo"));
329 let mut global = HashMap::new();
330 global.insert("weird".into(), "foo/bar/baz".into());
331 let m = resolve_model("weird", None, None, &global, &models).unwrap();
332 assert_eq!(m.id, "bar/baz");
333 assert_eq!(m.provider, "foo");
334 }
335
336 #[test]
337 fn empty_alias_maps_match_plain_lookup() {
338 let models = make_models();
341 let global = HashMap::new();
342 let project = HashMap::new();
343
344 let m = resolve_model("opus-4", None, Some(&project), &global, &models).unwrap();
346 assert_eq!(m.id, "opus-4");
347
348 let m = resolve_model("dual", Some("openai"), Some(&project), &global, &models).unwrap();
350 assert_eq!(m.provider, "openai");
351
352 let err = resolve_model("ghost", None, Some(&project), &global, &models).unwrap_err();
354 assert_eq!(
355 err,
356 ResolveError::UnknownModel {
357 name: "ghost".into()
358 }
359 );
360 }
361
362 #[test]
363 fn empty_project_map_falls_through_to_global() {
364 let models = make_models();
366 let mut global = HashMap::new();
367 global.insert("smart".into(), "opus-4".into());
368 let project = HashMap::new();
369 let m = resolve_model("smart", None, Some(&project), &global, &models).unwrap();
370 assert_eq!(m.id, "opus-4");
371 }
372}