Skip to main content

tau_agent_base/
model_resolve.rs

1//! Model id / alias → concrete `Model` resolver.
2//!
3//! The resolver is intentionally pure (no I/O, no locks) so that callers
4//! can compose it with their own state ownership.  Both `create_session`
5//! and `set_model` flow through it.
6//!
7//! ## Resolution order
8//!
9//! 1. **Project alias map** (if provided): if `raw` matches a key here, the
10//!    target is taken from the project map.
11//! 2. **Global alias map**: same lookup.
12//! 3. **Literal model id**: `raw` is treated as a model id directly.
13//!
14//! At most one alias hop is performed: alias targets must be model ids
15//! (optionally `provider/model-id`), never another alias.  This makes
16//! cycles impossible by construction.
17//!
18//! ## Alias collisions
19//!
20//! If an alias has the same name as a real model id, the alias wins.
21//! This is documented in `docs/CONFIG.md`.
22//!
23//! ## `provider/model-id` parsing
24//!
25//! Alias targets may be prefixed with a provider name and a `/`.  We split
26//! on the **first** `/` only, so `"foo/bar/baz"` is parsed as
27//! `provider="foo"`, `id="bar/baz"`.  This lets unusual model ids that
28//! contain slashes still resolve.
29//!
30//! When the caller provides an explicit `provider_filter` (the
31//! `provider_name` argument from a request), it always takes precedence
32//! over a `provider/` prefix in the alias target — matching the existing
33//! behavior of the (Some(model_id), Some(provider)) request branch.
34
35use std::collections::HashMap;
36
37use crate::types::Model;
38
39// ---------------------------------------------------------------------------
40// Error type
41// ---------------------------------------------------------------------------
42
43/// Errors returned by [`resolve_model`].
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ResolveError {
46    /// `raw` matched an alias but the alias target does not point to any
47    /// known model.  Always an error: surfaces config bugs to the user
48    /// instead of silently falling back to the default.
49    UnknownAlias {
50        /// The alias name the caller passed in.
51        name: String,
52        /// The target string the alias was pointing at.
53        target: String,
54        /// Where the alias came from: `"project"` or `"global"`.
55        scope: &'static str,
56    },
57    /// `raw` was not an alias and was not a known model id.  Callers may
58    /// choose to surface this as an error or fall back to a default.
59    UnknownModel { name: String },
60}
61
62impl std::fmt::Display for ResolveError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            ResolveError::UnknownAlias {
66                name,
67                target,
68                scope,
69            } => write!(
70                f,
71                "{} alias '{}' points at unknown model '{}'",
72                scope, name, target
73            ),
74            ResolveError::UnknownModel { name } => {
75                write!(f, "unknown model or alias: {}", name)
76            }
77        }
78    }
79}
80
81impl std::error::Error for ResolveError {}
82
83// ---------------------------------------------------------------------------
84// Resolver
85// ---------------------------------------------------------------------------
86
87/// Resolve a `model_id`-shaped string (possibly an alias) to a concrete
88/// `Model` reference from `all_models`.
89///
90/// See the module docs for the lookup order and disambiguation rules.
91pub fn resolve_model<'a>(
92    raw: &str,
93    provider_filter: Option<&str>,
94    project_aliases: Option<&HashMap<String, String>>,
95    global_aliases: &HashMap<String, String>,
96    all_models: &'a [Model],
97) -> Result<&'a Model, ResolveError> {
98    // 1+2. Alias lookup (project takes precedence over global).
99    let (target, alias_scope): (&str, Option<&'static str>) = if let Some(map) =
100        project_aliases.and_then(|m| if m.is_empty() { None } else { Some(m) })
101        && let Some(t) = map.get(raw)
102    {
103        (t.as_str(), Some("project"))
104    } else if let Some(t) = global_aliases.get(raw) {
105        (t.as_str(), Some("global"))
106    } else {
107        (raw, None)
108    };
109
110    // 3. Parse `provider/id` form. Split on the FIRST `/` only so that
111    //    model ids containing slashes are preserved in the id half.
112    let (target_provider, target_id) = match target.split_once('/') {
113        Some((p, i)) if !p.is_empty() && !i.is_empty() => (Some(p), i),
114        _ => (None, target),
115    };
116
117    // 4. Combine with the explicit provider_filter from the request.
118    //    Explicit filter wins over the alias-target's prefix.
119    let effective_provider: Option<&str> = provider_filter.or(target_provider);
120
121    // 5. Look up in all_models.
122    let found = all_models.iter().find(|m| {
123        m.id == target_id
124            && match effective_provider {
125                Some(p) => m.provider == p,
126                None => true,
127            }
128    });
129
130    match found {
131        Some(m) => Ok(m),
132        None => match alias_scope {
133            Some(scope) => Err(ResolveError::UnknownAlias {
134                name: raw.to_string(),
135                target: target.to_string(),
136                scope,
137            }),
138            None => Err(ResolveError::UnknownModel {
139                name: raw.to_string(),
140            }),
141        },
142    }
143}
144
145// ---------------------------------------------------------------------------
146// Tests
147// ---------------------------------------------------------------------------
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use crate::types::{Model, ModelCost, ThinkingStyle};
153
154    fn mk_model(id: &str, provider: &str) -> Model {
155        Model {
156            id: id.into(),
157            name: id.into(),
158            api: "mock".into(),
159            provider: provider.into(),
160            base_url: "http://mock".into(),
161            thinking: ThinkingStyle::None,
162            cost: ModelCost::default(),
163            context_window: 100_000,
164            max_tokens: 4_096,
165            headers: std::collections::HashMap::new(),
166        }
167    }
168
169    fn make_models() -> Vec<Model> {
170        vec![
171            mk_model("opus-4", "anthropic"),
172            mk_model("haiku-4", "anthropic"),
173            mk_model("gpt-4", "openai"),
174            // Same id under two providers, used to test the provider filter.
175            mk_model("dual", "anthropic"),
176            mk_model("dual", "openai"),
177        ]
178    }
179
180    #[test]
181    fn plain_id_pass_through() {
182        let models = make_models();
183        let aliases = HashMap::new();
184        let m = resolve_model("opus-4", None, None, &aliases, &models).unwrap();
185        assert_eq!(m.id, "opus-4");
186        assert_eq!(m.provider, "anthropic");
187    }
188
189    #[test]
190    fn unknown_plain_id_returns_unknown_model() {
191        let models = make_models();
192        let aliases = HashMap::new();
193        let err = resolve_model("nope", None, None, &aliases, &models).unwrap_err();
194        assert_eq!(
195            err,
196            ResolveError::UnknownModel {
197                name: "nope".into()
198            }
199        );
200    }
201
202    #[test]
203    fn global_alias_hit() {
204        let models = make_models();
205        let mut aliases = HashMap::new();
206        aliases.insert("smart".into(), "opus-4".into());
207        let m = resolve_model("smart", None, None, &aliases, &models).unwrap();
208        assert_eq!(m.id, "opus-4");
209    }
210
211    #[test]
212    fn project_alias_hit() {
213        let models = make_models();
214        let global = HashMap::new();
215        let mut project = HashMap::new();
216        project.insert("smart".into(), "haiku-4".into());
217        let m = resolve_model("smart", None, Some(&project), &global, &models).unwrap();
218        assert_eq!(m.id, "haiku-4");
219    }
220
221    #[test]
222    fn project_overrides_global() {
223        let models = make_models();
224        let mut global = HashMap::new();
225        global.insert("smart".into(), "opus-4".into());
226        let mut project = HashMap::new();
227        project.insert("smart".into(), "haiku-4".into());
228        let m = resolve_model("smart", None, Some(&project), &global, &models).unwrap();
229        assert_eq!(m.id, "haiku-4");
230    }
231
232    #[test]
233    fn unknown_alias_target_is_error() {
234        let models = make_models();
235        let mut global = HashMap::new();
236        global.insert("smart".into(), "ghost".into());
237        let err = resolve_model("smart", None, None, &global, &models).unwrap_err();
238        assert_eq!(
239            err,
240            ResolveError::UnknownAlias {
241                name: "smart".into(),
242                target: "ghost".into(),
243                scope: "global",
244            }
245        );
246    }
247
248    #[test]
249    fn unknown_project_alias_target_reports_project_scope() {
250        let models = make_models();
251        let global = HashMap::new();
252        let mut project = HashMap::new();
253        project.insert("planner".into(), "ghost".into());
254        let err = resolve_model("planner", None, Some(&project), &global, &models).unwrap_err();
255        assert!(matches!(
256            err,
257            ResolveError::UnknownAlias {
258                ref name,
259                ref target,
260                scope: "project",
261            } if name == "planner" && target == "ghost"
262        ));
263    }
264
265    #[test]
266    fn provider_prefixed_alias_target() {
267        let models = make_models();
268        let mut global = HashMap::new();
269        global.insert("dual_a".into(), "anthropic/dual".into());
270        global.insert("dual_o".into(), "openai/dual".into());
271        let a = resolve_model("dual_a", None, None, &global, &models).unwrap();
272        assert_eq!(a.provider, "anthropic");
273        let o = resolve_model("dual_o", None, None, &global, &models).unwrap();
274        assert_eq!(o.provider, "openai");
275    }
276
277    #[test]
278    fn explicit_provider_filter_overrides_alias_prefix() {
279        let models = make_models();
280        let mut global = HashMap::new();
281        // Alias points at anthropic/dual, but request asks for openai.
282        global.insert("d".into(), "anthropic/dual".into());
283        let m = resolve_model("d", Some("openai"), None, &global, &models).unwrap();
284        assert_eq!(m.provider, "openai");
285        assert_eq!(m.id, "dual");
286    }
287
288    #[test]
289    fn explicit_provider_filter_on_plain_id() {
290        let models = make_models();
291        let global = HashMap::new();
292        let m = resolve_model("dual", Some("openai"), None, &global, &models).unwrap();
293        assert_eq!(m.provider, "openai");
294    }
295
296    #[test]
297    fn explicit_provider_filter_no_match_is_error() {
298        let models = make_models();
299        let global = HashMap::new();
300        // 'gpt-4' only exists under provider 'openai'; asking for 'anthropic'
301        // must fail rather than returning the openai one.
302        let err = resolve_model("gpt-4", Some("anthropic"), None, &global, &models).unwrap_err();
303        assert_eq!(
304            err,
305            ResolveError::UnknownModel {
306                name: "gpt-4".into()
307            }
308        );
309    }
310
311    #[test]
312    fn alias_name_matches_real_model_id_alias_wins() {
313        let models = make_models();
314        // 'opus-4' is a real model id. We add an alias of the same name
315        // that points elsewhere — alias must win.
316        let mut global = HashMap::new();
317        global.insert("opus-4".into(), "haiku-4".into());
318        let m = resolve_model("opus-4", None, None, &global, &models).unwrap();
319        assert_eq!(m.id, "haiku-4");
320    }
321
322    #[test]
323    fn split_on_first_slash_only() {
324        // A made-up model id containing a slash, registered under
325        // provider "foo".  Alias target is "foo/bar/baz", which should
326        // resolve to provider="foo", id="bar/baz".
327        let mut models = make_models();
328        models.push(mk_model("bar/baz", "foo"));
329        let mut global = HashMap::new();
330        global.insert("weird".into(), "foo/bar/baz".into());
331        let m = resolve_model("weird", None, None, &global, &models).unwrap();
332        assert_eq!(m.id, "bar/baz");
333        assert_eq!(m.provider, "foo");
334    }
335
336    #[test]
337    fn empty_alias_maps_match_plain_lookup() {
338        // Regression: with both maps empty the resolver behaves identically
339        // to a direct `all_models.iter().find` over (id, provider?).
340        let models = make_models();
341        let global = HashMap::new();
342        let project = HashMap::new();
343
344        // Plain id, no provider — finds first match.
345        let m = resolve_model("opus-4", None, Some(&project), &global, &models).unwrap();
346        assert_eq!(m.id, "opus-4");
347
348        // Plain id + explicit provider.
349        let m = resolve_model("dual", Some("openai"), Some(&project), &global, &models).unwrap();
350        assert_eq!(m.provider, "openai");
351
352        // Unknown id with empty maps still returns UnknownModel (not UnknownAlias).
353        let err = resolve_model("ghost", None, Some(&project), &global, &models).unwrap_err();
354        assert_eq!(
355            err,
356            ResolveError::UnknownModel {
357                name: "ghost".into()
358            }
359        );
360    }
361
362    #[test]
363    fn empty_project_map_falls_through_to_global() {
364        // An empty Some(&project) map should not shadow the global lookup.
365        let models = make_models();
366        let mut global = HashMap::new();
367        global.insert("smart".into(), "opus-4".into());
368        let project = HashMap::new();
369        let m = resolve_model("smart", None, Some(&project), &global, &models).unwrap();
370        assert_eq!(m.id, "opus-4");
371    }
372}