Skip to main content

zag_agent/
capability.rs

1use anyhow::{Result, bail};
2use serde::{Deserialize, Serialize};
3
4/// A feature that can be either natively supported by the provider or implemented by the wrapper.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct FeatureSupport {
7    pub supported: bool,
8    pub native: bool,
9}
10
11/// Session log support with completeness level.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SessionLogSupport {
14    pub supported: bool,
15    pub native: bool,
16    /// Completeness level: "full", "partial", or absent when unsupported.
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub completeness: Option<String>,
19}
20
21/// Size alias mappings for a provider.
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct SizeMappings {
24    pub small: String,
25    pub medium: String,
26    pub large: String,
27}
28
29/// All feature flags for a provider.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Features {
32    pub interactive: FeatureSupport,
33    pub non_interactive: FeatureSupport,
34    pub resume: FeatureSupport,
35    pub resume_with_prompt: FeatureSupport,
36    pub session_logs: SessionLogSupport,
37    pub json_output: FeatureSupport,
38    pub stream_json: FeatureSupport,
39    pub json_schema: FeatureSupport,
40    pub input_format: FeatureSupport,
41    pub streaming_input: FeatureSupport,
42    pub worktree: FeatureSupport,
43    pub sandbox: FeatureSupport,
44    pub system_prompt: FeatureSupport,
45    pub auto_approve: FeatureSupport,
46    pub review: FeatureSupport,
47    pub add_dirs: FeatureSupport,
48    pub max_turns: FeatureSupport,
49}
50
51/// Full capability declaration for a provider.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct ProviderCapability {
54    pub provider: String,
55    pub default_model: String,
56    pub available_models: Vec<String>,
57    pub size_mappings: SizeMappings,
58    pub features: Features,
59}
60
61impl FeatureSupport {
62    pub fn native() -> Self {
63        Self {
64            supported: true,
65            native: true,
66        }
67    }
68
69    pub fn wrapper() -> Self {
70        Self {
71            supported: true,
72            native: false,
73        }
74    }
75
76    pub fn unsupported() -> Self {
77        Self {
78            supported: false,
79            native: false,
80        }
81    }
82}
83
84impl SessionLogSupport {
85    pub fn full() -> Self {
86        Self {
87            supported: true,
88            native: true,
89            completeness: Some("full".to_string()),
90        }
91    }
92
93    pub fn partial() -> Self {
94        Self {
95            supported: true,
96            native: true,
97            completeness: Some("partial".to_string()),
98        }
99    }
100
101    pub fn unsupported() -> Self {
102        Self {
103            supported: false,
104            native: false,
105            completeness: None,
106        }
107    }
108}
109
110/// Get capability declarations for a provider.
111pub fn get_capability(provider: &str) -> Result<ProviderCapability> {
112    use crate::agent::{Agent, ModelSize};
113
114    match provider {
115        "claude" => {
116            use crate::providers::claude::{self, Claude};
117            Ok(ProviderCapability {
118                provider: "claude".to_string(),
119                default_model: claude::DEFAULT_MODEL.to_string(),
120                available_models: models_to_vec(claude::AVAILABLE_MODELS),
121                size_mappings: SizeMappings {
122                    small: Claude::model_for_size(ModelSize::Small).to_string(),
123                    medium: Claude::model_for_size(ModelSize::Medium).to_string(),
124                    large: Claude::model_for_size(ModelSize::Large).to_string(),
125                },
126                features: Features {
127                    interactive: FeatureSupport::native(),
128                    non_interactive: FeatureSupport::native(),
129                    resume: FeatureSupport::native(),
130                    resume_with_prompt: FeatureSupport::native(),
131                    session_logs: SessionLogSupport::full(),
132                    json_output: FeatureSupport::native(),
133                    stream_json: FeatureSupport::native(),
134                    json_schema: FeatureSupport::native(),
135                    input_format: FeatureSupport::native(),
136                    streaming_input: FeatureSupport::native(),
137                    worktree: FeatureSupport::wrapper(),
138                    sandbox: FeatureSupport::wrapper(),
139                    system_prompt: FeatureSupport::native(),
140                    auto_approve: FeatureSupport::native(),
141                    review: FeatureSupport::unsupported(),
142                    add_dirs: FeatureSupport::native(),
143                    max_turns: FeatureSupport::native(),
144                },
145            })
146        }
147        "codex" => {
148            use crate::providers::codex::{self, Codex};
149            Ok(ProviderCapability {
150                provider: "codex".to_string(),
151                default_model: codex::DEFAULT_MODEL.to_string(),
152                available_models: models_to_vec(codex::AVAILABLE_MODELS),
153                size_mappings: SizeMappings {
154                    small: Codex::model_for_size(ModelSize::Small).to_string(),
155                    medium: Codex::model_for_size(ModelSize::Medium).to_string(),
156                    large: Codex::model_for_size(ModelSize::Large).to_string(),
157                },
158                features: Features {
159                    interactive: FeatureSupport::native(),
160                    non_interactive: FeatureSupport::native(),
161                    resume: FeatureSupport::native(),
162                    resume_with_prompt: FeatureSupport::native(),
163                    session_logs: SessionLogSupport::partial(),
164                    json_output: FeatureSupport::native(),
165                    stream_json: FeatureSupport::unsupported(),
166                    json_schema: FeatureSupport::wrapper(),
167                    input_format: FeatureSupport::unsupported(),
168                    streaming_input: FeatureSupport::unsupported(),
169                    worktree: FeatureSupport::wrapper(),
170                    sandbox: FeatureSupport::wrapper(),
171                    system_prompt: FeatureSupport::wrapper(),
172                    auto_approve: FeatureSupport::native(),
173                    review: FeatureSupport::native(),
174                    add_dirs: FeatureSupport::native(),
175                    max_turns: FeatureSupport::native(),
176                },
177            })
178        }
179        "gemini" => {
180            use crate::providers::gemini::{self, Gemini};
181            Ok(ProviderCapability {
182                provider: "gemini".to_string(),
183                default_model: gemini::DEFAULT_MODEL.to_string(),
184                available_models: models_to_vec(gemini::AVAILABLE_MODELS),
185                size_mappings: SizeMappings {
186                    small: Gemini::model_for_size(ModelSize::Small).to_string(),
187                    medium: Gemini::model_for_size(ModelSize::Medium).to_string(),
188                    large: Gemini::model_for_size(ModelSize::Large).to_string(),
189                },
190                features: Features {
191                    interactive: FeatureSupport::native(),
192                    non_interactive: FeatureSupport::native(),
193                    resume: FeatureSupport::native(),
194                    resume_with_prompt: FeatureSupport::unsupported(),
195                    session_logs: SessionLogSupport::full(),
196                    json_output: FeatureSupport::wrapper(),
197                    stream_json: FeatureSupport::unsupported(),
198                    json_schema: FeatureSupport::wrapper(),
199                    input_format: FeatureSupport::unsupported(),
200                    streaming_input: FeatureSupport::unsupported(),
201                    worktree: FeatureSupport::wrapper(),
202                    sandbox: FeatureSupport::wrapper(),
203                    system_prompt: FeatureSupport::wrapper(),
204                    auto_approve: FeatureSupport::native(),
205                    review: FeatureSupport::unsupported(),
206                    add_dirs: FeatureSupport::native(),
207                    max_turns: FeatureSupport::native(),
208                },
209            })
210        }
211        "copilot" => {
212            use crate::providers::copilot::{self, Copilot};
213            Ok(ProviderCapability {
214                provider: "copilot".to_string(),
215                default_model: copilot::DEFAULT_MODEL.to_string(),
216                available_models: models_to_vec(copilot::AVAILABLE_MODELS),
217                size_mappings: SizeMappings {
218                    small: Copilot::model_for_size(ModelSize::Small).to_string(),
219                    medium: Copilot::model_for_size(ModelSize::Medium).to_string(),
220                    large: Copilot::model_for_size(ModelSize::Large).to_string(),
221                },
222                features: Features {
223                    interactive: FeatureSupport::native(),
224                    non_interactive: FeatureSupport::native(),
225                    resume: FeatureSupport::native(),
226                    resume_with_prompt: FeatureSupport::unsupported(),
227                    session_logs: SessionLogSupport::full(),
228                    json_output: FeatureSupport::unsupported(),
229                    stream_json: FeatureSupport::unsupported(),
230                    json_schema: FeatureSupport::unsupported(),
231                    input_format: FeatureSupport::unsupported(),
232                    streaming_input: FeatureSupport::unsupported(),
233                    worktree: FeatureSupport::wrapper(),
234                    sandbox: FeatureSupport::wrapper(),
235                    system_prompt: FeatureSupport::wrapper(),
236                    auto_approve: FeatureSupport::native(),
237                    review: FeatureSupport::unsupported(),
238                    add_dirs: FeatureSupport::native(),
239                    max_turns: FeatureSupport::native(),
240                },
241            })
242        }
243        "ollama" => {
244            use crate::providers::ollama;
245            Ok(ProviderCapability {
246                provider: "ollama".to_string(),
247                default_model: ollama::DEFAULT_MODEL.to_string(),
248                available_models: models_to_vec(ollama::AVAILABLE_SIZES),
249                size_mappings: SizeMappings {
250                    small: "2b".to_string(),
251                    medium: "9b".to_string(),
252                    large: "35b".to_string(),
253                },
254                features: Features {
255                    interactive: FeatureSupport::native(),
256                    non_interactive: FeatureSupport::native(),
257                    resume: FeatureSupport::unsupported(),
258                    resume_with_prompt: FeatureSupport::unsupported(),
259                    session_logs: SessionLogSupport::unsupported(),
260                    json_output: FeatureSupport::wrapper(),
261                    stream_json: FeatureSupport::unsupported(),
262                    json_schema: FeatureSupport::wrapper(),
263                    input_format: FeatureSupport::unsupported(),
264                    streaming_input: FeatureSupport::unsupported(),
265                    worktree: FeatureSupport::wrapper(),
266                    sandbox: FeatureSupport::wrapper(),
267                    system_prompt: FeatureSupport::wrapper(),
268                    auto_approve: FeatureSupport::native(),
269                    review: FeatureSupport::unsupported(),
270                    add_dirs: FeatureSupport::unsupported(),
271                    max_turns: FeatureSupport::unsupported(),
272                },
273            })
274        }
275        _ => bail!(
276            "No capabilities defined for provider '{}'. Available: claude, codex, gemini, copilot, ollama",
277            provider
278        ),
279    }
280}
281
282/// Format a capability struct into the requested output format.
283pub fn format_capability(cap: &ProviderCapability, format: &str, pretty: bool) -> Result<String> {
284    match format {
285        "json" => {
286            if pretty {
287                Ok(serde_json::to_string_pretty(cap)?)
288            } else {
289                Ok(serde_json::to_string(cap)?)
290            }
291        }
292        "yaml" => Ok(serde_yaml::to_string(cap)?),
293        "toml" => Ok(toml::to_string_pretty(cap)?),
294        _ => bail!(
295            "Unsupported format '{}'. Available: json, yaml, toml",
296            format
297        ),
298    }
299}
300
301/// Canonical list of provider names (excludes "auto" and "mock").
302pub const PROVIDERS: &[&str] = &["claude", "codex", "gemini", "copilot", "ollama"];
303
304/// List all available provider names.
305pub fn list_providers() -> Vec<String> {
306    PROVIDERS.iter().map(|s| s.to_string()).collect()
307}
308
309/// Get capabilities for all providers.
310pub fn get_all_capabilities() -> Vec<ProviderCapability> {
311    PROVIDERS
312        .iter()
313        .filter_map(|p| get_capability(p).ok())
314        .collect()
315}
316
317/// Result of resolving a model alias.
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct ResolvedModel {
320    pub input: String,
321    pub resolved: String,
322    pub is_alias: bool,
323    pub provider: String,
324}
325
326/// Resolve a model name or alias for a given provider.
327///
328/// Size aliases (`small`/`s`, `medium`/`m`/`default`, `large`/`l`/`max`) are
329/// resolved to the provider-specific model. Non-alias names pass through unchanged.
330pub fn resolve_model(provider: &str, model_input: &str) -> Result<ResolvedModel> {
331    use crate::agent::Agent;
332    use crate::providers::{
333        claude::Claude, codex::Codex, copilot::Copilot, gemini::Gemini, ollama::Ollama,
334    };
335
336    let resolved = match provider {
337        "claude" => Claude::resolve_model(model_input),
338        "codex" => Codex::resolve_model(model_input),
339        "gemini" => Gemini::resolve_model(model_input),
340        "copilot" => Copilot::resolve_model(model_input),
341        "ollama" => Ollama::resolve_model(model_input),
342        _ => bail!(
343            "Unknown provider '{}'. Available: {}",
344            provider,
345            PROVIDERS.join(", ")
346        ),
347    };
348
349    Ok(ResolvedModel {
350        input: model_input.to_string(),
351        is_alias: resolved != model_input,
352        resolved,
353        provider: provider.to_string(),
354    })
355}
356
357/// Format a resolved model into the requested output format.
358pub fn format_resolved_model(rm: &ResolvedModel, format: &str, pretty: bool) -> Result<String> {
359    match format {
360        "json" => {
361            if pretty {
362                Ok(serde_json::to_string_pretty(rm)?)
363            } else {
364                Ok(serde_json::to_string(rm)?)
365            }
366        }
367        "yaml" => Ok(serde_yaml::to_string(rm)?),
368        "toml" => Ok(toml::to_string_pretty(rm)?),
369        _ => bail!(
370            "Unsupported format '{}'. Available: json, yaml, toml",
371            format
372        ),
373    }
374}
375
376/// Format a list of capabilities into the requested output format.
377pub fn format_capabilities(
378    caps: &[ProviderCapability],
379    format: &str,
380    pretty: bool,
381) -> Result<String> {
382    match format {
383        "json" => {
384            if pretty {
385                Ok(serde_json::to_string_pretty(caps)?)
386            } else {
387                Ok(serde_json::to_string(caps)?)
388            }
389        }
390        "yaml" => Ok(serde_yaml::to_string(caps)?),
391        "toml" => {
392            #[derive(Serialize)]
393            struct Wrapper<'a> {
394                providers: &'a [ProviderCapability],
395            }
396            Ok(toml::to_string_pretty(&Wrapper { providers: caps })?)
397        }
398        _ => bail!(
399            "Unsupported format '{}'. Available: json, yaml, toml",
400            format
401        ),
402    }
403}
404
405/// Format a models listing into the requested output format.
406pub fn format_models(caps: &[ProviderCapability], format: &str, pretty: bool) -> Result<String> {
407    #[derive(Serialize)]
408    struct ModelEntry {
409        provider: String,
410        default_model: String,
411        models: Vec<String>,
412    }
413
414    let entries: Vec<ModelEntry> = caps
415        .iter()
416        .map(|c| ModelEntry {
417            provider: c.provider.clone(),
418            default_model: c.default_model.clone(),
419            models: c.available_models.clone(),
420        })
421        .collect();
422
423    match format {
424        "json" => {
425            if pretty {
426                Ok(serde_json::to_string_pretty(&entries)?)
427            } else {
428                Ok(serde_json::to_string(&entries)?)
429            }
430        }
431        "yaml" => Ok(serde_yaml::to_string(&entries)?),
432        "toml" => bail!("TOML does not support top-level arrays. Use json or yaml"),
433        _ => bail!(
434            "Unsupported format '{}'. Available: json, yaml, toml",
435            format
436        ),
437    }
438}
439
440/// Convert a slice of string references into a Vec of owned Strings.
441pub fn models_to_vec(models: &[&str]) -> Vec<String> {
442    models.iter().map(|s| s.to_string()).collect()
443}
444
445#[cfg(test)]
446#[path = "capability_tests.rs"]
447mod tests;