Skip to main content

zag_agent/
capability.rs

1use anyhow::{Result, bail};
2use serde::{Deserialize, Serialize};
3
4/// A feature that can be either natively supported by the provider or implemented by the wrapper.
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct FeatureSupport {
7    pub supported: bool,
8    pub native: bool,
9}
10
11/// Session log support with completeness level.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct SessionLogSupport {
14    pub supported: bool,
15    pub native: bool,
16    /// Completeness level: "full", "partial", or absent when unsupported.
17    #[serde(skip_serializing_if = "Option::is_none")]
18    pub completeness: Option<String>,
19}
20
21/// Streaming input support with mid-turn injection semantics.
22///
23/// Describes what happens when `StreamingSession::send_user_message` is called
24/// while the agent is already producing a response on the current turn.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct StreamingInputSupport {
27    pub supported: bool,
28    pub native: bool,
29    /// Mid-turn semantics when `send_user_message` is called while the agent
30    /// is already producing a response. One of:
31    /// - `"queue"` — message is buffered and delivered at the next turn boundary
32    ///   (the current turn runs to completion before the new message is processed).
33    /// - `"interrupt"` — message cancels the current turn and starts a new one
34    ///   with the new input.
35    /// - `"between-turns-only"` — calling mid-turn is an error or no-op; callers
36    ///   must wait for the current turn to finish before sending.
37    ///
38    /// Absent when `supported == false`.
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub semantics: Option<String>,
41}
42
43/// Size alias mappings for a provider.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SizeMappings {
46    pub small: String,
47    pub medium: String,
48    pub large: String,
49}
50
51/// All feature flags for a provider.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct Features {
54    pub interactive: FeatureSupport,
55    pub non_interactive: FeatureSupport,
56    pub resume: FeatureSupport,
57    pub resume_with_prompt: FeatureSupport,
58    pub session_logs: SessionLogSupport,
59    pub json_output: FeatureSupport,
60    pub stream_json: FeatureSupport,
61    pub json_schema: FeatureSupport,
62    pub input_format: FeatureSupport,
63    pub streaming_input: StreamingInputSupport,
64    pub worktree: FeatureSupport,
65    pub sandbox: FeatureSupport,
66    pub system_prompt: FeatureSupport,
67    pub auto_approve: FeatureSupport,
68    pub review: FeatureSupport,
69    pub add_dirs: FeatureSupport,
70    pub max_turns: FeatureSupport,
71}
72
73/// Full capability declaration for a provider.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ProviderCapability {
76    pub provider: String,
77    pub default_model: String,
78    pub available_models: Vec<String>,
79    pub size_mappings: SizeMappings,
80    pub features: Features,
81}
82
83impl FeatureSupport {
84    pub fn native() -> Self {
85        Self {
86            supported: true,
87            native: true,
88        }
89    }
90
91    pub fn wrapper() -> Self {
92        Self {
93            supported: true,
94            native: false,
95        }
96    }
97
98    pub fn unsupported() -> Self {
99        Self {
100            supported: false,
101            native: false,
102        }
103    }
104}
105
106impl SessionLogSupport {
107    pub fn full() -> Self {
108        Self {
109            supported: true,
110            native: true,
111            completeness: Some("full".to_string()),
112        }
113    }
114
115    pub fn partial() -> Self {
116        Self {
117            supported: true,
118            native: true,
119            completeness: Some("partial".to_string()),
120        }
121    }
122
123    pub fn unsupported() -> Self {
124        Self {
125            supported: false,
126            native: false,
127            completeness: None,
128        }
129    }
130}
131
132impl StreamingInputSupport {
133    /// Mid-turn messages are queued and delivered at the next turn boundary.
134    /// The currently running turn is not interrupted.
135    pub fn queue() -> Self {
136        Self {
137            supported: true,
138            native: true,
139            semantics: Some("queue".to_string()),
140        }
141    }
142
143    /// Mid-turn messages cancel the current turn and start a new one.
144    pub fn interrupt() -> Self {
145        Self {
146            supported: true,
147            native: true,
148            semantics: Some("interrupt".to_string()),
149        }
150    }
151
152    /// Messages may only be sent between turns; mid-turn sends are an error.
153    pub fn between_turns_only() -> Self {
154        Self {
155            supported: true,
156            native: true,
157            semantics: Some("between-turns-only".to_string()),
158        }
159    }
160
161    /// The provider does not support streaming input at all.
162    pub fn unsupported() -> Self {
163        Self {
164            supported: false,
165            native: false,
166            semantics: None,
167        }
168    }
169}
170
171/// Get capability declarations for a provider.
172pub fn get_capability(provider: &str) -> Result<ProviderCapability> {
173    use crate::agent::{Agent, ModelSize};
174
175    match provider {
176        "claude" => {
177            use crate::providers::claude::{self, Claude};
178            Ok(ProviderCapability {
179                provider: "claude".to_string(),
180                default_model: claude::DEFAULT_MODEL.to_string(),
181                available_models: models_to_vec(claude::AVAILABLE_MODELS),
182                size_mappings: SizeMappings {
183                    small: Claude::model_for_size(ModelSize::Small).to_string(),
184                    medium: Claude::model_for_size(ModelSize::Medium).to_string(),
185                    large: Claude::model_for_size(ModelSize::Large).to_string(),
186                },
187                features: Features {
188                    interactive: FeatureSupport::native(),
189                    non_interactive: FeatureSupport::native(),
190                    resume: FeatureSupport::native(),
191                    resume_with_prompt: FeatureSupport::native(),
192                    session_logs: SessionLogSupport::full(),
193                    json_output: FeatureSupport::native(),
194                    stream_json: FeatureSupport::native(),
195                    json_schema: FeatureSupport::native(),
196                    input_format: FeatureSupport::native(),
197                    streaming_input: StreamingInputSupport::queue(),
198                    worktree: FeatureSupport::wrapper(),
199                    sandbox: FeatureSupport::wrapper(),
200                    system_prompt: FeatureSupport::native(),
201                    auto_approve: FeatureSupport::native(),
202                    review: FeatureSupport::unsupported(),
203                    add_dirs: FeatureSupport::native(),
204                    max_turns: FeatureSupport::native(),
205                },
206            })
207        }
208        "codex" => {
209            use crate::providers::codex::{self, Codex};
210            Ok(ProviderCapability {
211                provider: "codex".to_string(),
212                default_model: codex::DEFAULT_MODEL.to_string(),
213                available_models: models_to_vec(codex::AVAILABLE_MODELS),
214                size_mappings: SizeMappings {
215                    small: Codex::model_for_size(ModelSize::Small).to_string(),
216                    medium: Codex::model_for_size(ModelSize::Medium).to_string(),
217                    large: Codex::model_for_size(ModelSize::Large).to_string(),
218                },
219                features: Features {
220                    interactive: FeatureSupport::native(),
221                    non_interactive: FeatureSupport::native(),
222                    resume: FeatureSupport::native(),
223                    resume_with_prompt: FeatureSupport::native(),
224                    session_logs: SessionLogSupport::partial(),
225                    json_output: FeatureSupport::native(),
226                    stream_json: FeatureSupport::unsupported(),
227                    json_schema: FeatureSupport::wrapper(),
228                    input_format: FeatureSupport::unsupported(),
229                    streaming_input: StreamingInputSupport::unsupported(),
230                    worktree: FeatureSupport::wrapper(),
231                    sandbox: FeatureSupport::wrapper(),
232                    system_prompt: FeatureSupport::wrapper(),
233                    auto_approve: FeatureSupport::native(),
234                    review: FeatureSupport::native(),
235                    add_dirs: FeatureSupport::native(),
236                    max_turns: FeatureSupport::native(),
237                },
238            })
239        }
240        "gemini" => {
241            use crate::providers::gemini::{self, Gemini};
242            Ok(ProviderCapability {
243                provider: "gemini".to_string(),
244                default_model: gemini::DEFAULT_MODEL.to_string(),
245                available_models: models_to_vec(gemini::AVAILABLE_MODELS),
246                size_mappings: SizeMappings {
247                    small: Gemini::model_for_size(ModelSize::Small).to_string(),
248                    medium: Gemini::model_for_size(ModelSize::Medium).to_string(),
249                    large: Gemini::model_for_size(ModelSize::Large).to_string(),
250                },
251                features: Features {
252                    interactive: FeatureSupport::native(),
253                    non_interactive: FeatureSupport::native(),
254                    resume: FeatureSupport::native(),
255                    resume_with_prompt: FeatureSupport::unsupported(),
256                    session_logs: SessionLogSupport::full(),
257                    json_output: FeatureSupport::wrapper(),
258                    stream_json: FeatureSupport::unsupported(),
259                    json_schema: FeatureSupport::wrapper(),
260                    input_format: FeatureSupport::unsupported(),
261                    streaming_input: StreamingInputSupport::unsupported(),
262                    worktree: FeatureSupport::wrapper(),
263                    sandbox: FeatureSupport::wrapper(),
264                    system_prompt: FeatureSupport::wrapper(),
265                    auto_approve: FeatureSupport::native(),
266                    review: FeatureSupport::unsupported(),
267                    add_dirs: FeatureSupport::native(),
268                    max_turns: FeatureSupport::native(),
269                },
270            })
271        }
272        "copilot" => {
273            use crate::providers::copilot::{self, Copilot};
274            Ok(ProviderCapability {
275                provider: "copilot".to_string(),
276                default_model: copilot::DEFAULT_MODEL.to_string(),
277                available_models: models_to_vec(copilot::AVAILABLE_MODELS),
278                size_mappings: SizeMappings {
279                    small: Copilot::model_for_size(ModelSize::Small).to_string(),
280                    medium: Copilot::model_for_size(ModelSize::Medium).to_string(),
281                    large: Copilot::model_for_size(ModelSize::Large).to_string(),
282                },
283                features: Features {
284                    interactive: FeatureSupport::native(),
285                    non_interactive: FeatureSupport::native(),
286                    resume: FeatureSupport::native(),
287                    resume_with_prompt: FeatureSupport::unsupported(),
288                    session_logs: SessionLogSupport::full(),
289                    json_output: FeatureSupport::unsupported(),
290                    stream_json: FeatureSupport::unsupported(),
291                    json_schema: FeatureSupport::unsupported(),
292                    input_format: FeatureSupport::unsupported(),
293                    streaming_input: StreamingInputSupport::unsupported(),
294                    worktree: FeatureSupport::wrapper(),
295                    sandbox: FeatureSupport::wrapper(),
296                    system_prompt: FeatureSupport::wrapper(),
297                    auto_approve: FeatureSupport::native(),
298                    review: FeatureSupport::unsupported(),
299                    add_dirs: FeatureSupport::native(),
300                    max_turns: FeatureSupport::native(),
301                },
302            })
303        }
304        "ollama" => {
305            use crate::providers::ollama;
306            Ok(ProviderCapability {
307                provider: "ollama".to_string(),
308                default_model: ollama::DEFAULT_MODEL.to_string(),
309                available_models: models_to_vec(ollama::AVAILABLE_SIZES),
310                size_mappings: SizeMappings {
311                    small: "2b".to_string(),
312                    medium: "9b".to_string(),
313                    large: "35b".to_string(),
314                },
315                features: Features {
316                    interactive: FeatureSupport::native(),
317                    non_interactive: FeatureSupport::native(),
318                    resume: FeatureSupport::unsupported(),
319                    resume_with_prompt: FeatureSupport::unsupported(),
320                    session_logs: SessionLogSupport::unsupported(),
321                    json_output: FeatureSupport::wrapper(),
322                    stream_json: FeatureSupport::unsupported(),
323                    json_schema: FeatureSupport::wrapper(),
324                    input_format: FeatureSupport::unsupported(),
325                    streaming_input: StreamingInputSupport::unsupported(),
326                    worktree: FeatureSupport::wrapper(),
327                    sandbox: FeatureSupport::wrapper(),
328                    system_prompt: FeatureSupport::wrapper(),
329                    auto_approve: FeatureSupport::native(),
330                    review: FeatureSupport::unsupported(),
331                    add_dirs: FeatureSupport::unsupported(),
332                    max_turns: FeatureSupport::unsupported(),
333                },
334            })
335        }
336        _ => bail!(
337            "No capabilities defined for provider '{provider}'. Available: claude, codex, gemini, copilot, ollama"
338        ),
339    }
340}
341
342/// Format a capability struct into the requested output format.
343pub fn format_capability(cap: &ProviderCapability, format: &str, pretty: bool) -> Result<String> {
344    match format {
345        "json" => {
346            if pretty {
347                Ok(serde_json::to_string_pretty(cap)?)
348            } else {
349                Ok(serde_json::to_string(cap)?)
350            }
351        }
352        "yaml" => Ok(serde_yaml::to_string(cap)?),
353        "toml" => Ok(toml::to_string_pretty(cap)?),
354        _ => bail!("Unsupported format '{format}'. Available: json, yaml, toml"),
355    }
356}
357
358/// Canonical list of provider names (excludes "auto" and "mock").
359pub const PROVIDERS: &[&str] = &["claude", "codex", "gemini", "copilot", "ollama"];
360
361/// List all available provider names.
362pub fn list_providers() -> Vec<String> {
363    PROVIDERS.iter().map(|s| s.to_string()).collect()
364}
365
366/// Get capabilities for all providers.
367pub fn get_all_capabilities() -> Vec<ProviderCapability> {
368    PROVIDERS
369        .iter()
370        .filter_map(|p| get_capability(p).ok())
371        .collect()
372}
373
374/// Result of resolving a model alias.
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct ResolvedModel {
377    pub input: String,
378    pub resolved: String,
379    pub is_alias: bool,
380    pub provider: String,
381}
382
383/// Resolve a model name or alias for a given provider.
384///
385/// Size aliases (`small`/`s`, `medium`/`m`/`default`, `large`/`l`/`max`) are
386/// resolved to the provider-specific model. Non-alias names pass through unchanged.
387pub fn resolve_model(provider: &str, model_input: &str) -> Result<ResolvedModel> {
388    use crate::agent::Agent;
389    use crate::providers::{
390        claude::Claude, codex::Codex, copilot::Copilot, gemini::Gemini, ollama::Ollama,
391    };
392
393    let resolved = match provider {
394        "claude" => Claude::resolve_model(model_input),
395        "codex" => Codex::resolve_model(model_input),
396        "gemini" => Gemini::resolve_model(model_input),
397        "copilot" => Copilot::resolve_model(model_input),
398        "ollama" => Ollama::resolve_model(model_input),
399        _ => bail!(
400            "Unknown provider '{}'. Available: {}",
401            provider,
402            PROVIDERS.join(", ")
403        ),
404    };
405
406    Ok(ResolvedModel {
407        input: model_input.to_string(),
408        is_alias: resolved != model_input,
409        resolved,
410        provider: provider.to_string(),
411    })
412}
413
414/// Format a resolved model into the requested output format.
415pub fn format_resolved_model(rm: &ResolvedModel, format: &str, pretty: bool) -> Result<String> {
416    match format {
417        "json" => {
418            if pretty {
419                Ok(serde_json::to_string_pretty(rm)?)
420            } else {
421                Ok(serde_json::to_string(rm)?)
422            }
423        }
424        "yaml" => Ok(serde_yaml::to_string(rm)?),
425        "toml" => Ok(toml::to_string_pretty(rm)?),
426        _ => bail!("Unsupported format '{format}'. Available: json, yaml, toml"),
427    }
428}
429
430/// Format a list of capabilities into the requested output format.
431pub fn format_capabilities(
432    caps: &[ProviderCapability],
433    format: &str,
434    pretty: bool,
435) -> Result<String> {
436    match format {
437        "json" => {
438            if pretty {
439                Ok(serde_json::to_string_pretty(caps)?)
440            } else {
441                Ok(serde_json::to_string(caps)?)
442            }
443        }
444        "yaml" => Ok(serde_yaml::to_string(caps)?),
445        "toml" => {
446            #[derive(Serialize)]
447            struct Wrapper<'a> {
448                providers: &'a [ProviderCapability],
449            }
450            Ok(toml::to_string_pretty(&Wrapper { providers: caps })?)
451        }
452        _ => bail!("Unsupported format '{format}'. Available: json, yaml, toml"),
453    }
454}
455
456/// Format a models listing into the requested output format.
457pub fn format_models(caps: &[ProviderCapability], format: &str, pretty: bool) -> Result<String> {
458    #[derive(Serialize)]
459    struct ModelEntry {
460        provider: String,
461        default_model: String,
462        models: Vec<String>,
463    }
464
465    let entries: Vec<ModelEntry> = caps
466        .iter()
467        .map(|c| ModelEntry {
468            provider: c.provider.clone(),
469            default_model: c.default_model.clone(),
470            models: c.available_models.clone(),
471        })
472        .collect();
473
474    match format {
475        "json" => {
476            if pretty {
477                Ok(serde_json::to_string_pretty(&entries)?)
478            } else {
479                Ok(serde_json::to_string(&entries)?)
480            }
481        }
482        "yaml" => Ok(serde_yaml::to_string(&entries)?),
483        "toml" => bail!("TOML does not support top-level arrays. Use json or yaml"),
484        _ => bail!("Unsupported format '{format}'. Available: json, yaml, toml"),
485    }
486}
487
488/// Convert a slice of string references into a Vec of owned Strings.
489pub fn models_to_vec(models: &[&str]) -> Vec<String> {
490    models.iter().map(|s| s.to_string()).collect()
491}
492
493/// Render a provider-summary text table identical to the one printed by
494/// `zag discover`. Columns: provider, default model, model count, resume,
495/// json output, session-log completeness.
496pub fn format_summary_table(caps: &[ProviderCapability]) -> String {
497    use std::fmt::Write;
498    let mut out = String::new();
499    let _ = writeln!(
500        out,
501        "{:<10} {:<28} {:>6}  {:<6} {:<6} {:<7}",
502        "PROVIDER", "DEFAULT MODEL", "MODELS", "RESUME", "JSON", "LOGS"
503    );
504    let _ = writeln!(out, "{}", "-".repeat(70));
505    for cap in caps {
506        let resume = if cap.features.resume.supported {
507            "yes"
508        } else {
509            "no"
510        };
511        let json_out = if cap.features.json_output.supported {
512            "yes"
513        } else {
514            "no"
515        };
516        let logs = cap
517            .features
518            .session_logs
519            .completeness
520            .as_deref()
521            .unwrap_or("-");
522        let _ = writeln!(
523            out,
524            "{:<10} {:<28} {:>6}  {:<6} {:<6} {:<7}",
525            cap.provider,
526            cap.default_model,
527            cap.available_models.len(),
528            resume,
529            json_out,
530            logs,
531        );
532    }
533    out
534}
535
536/// Render the per-provider detail block printed by `zag discover -p <name>`.
537pub fn format_provider_detail(cap: &ProviderCapability) -> String {
538    use std::fmt::Write;
539    let mut out = String::new();
540    let _ = writeln!(out, "Provider: {}", cap.provider);
541    let _ = writeln!(out, "Default model: {}", cap.default_model);
542    let _ = writeln!(
543        out,
544        "Size mappings: small={}, medium={}, large={}",
545        cap.size_mappings.small, cap.size_mappings.medium, cap.size_mappings.large
546    );
547    let _ = writeln!(out, "Available models:");
548    for m in &cap.available_models {
549        let _ = writeln!(out, "  - {m}");
550    }
551    let _ = writeln!(out);
552    let _ = writeln!(out, "Features:");
553    format_feature(&mut out, "  interactive", &cap.features.interactive);
554    format_feature(&mut out, "  non-interactive", &cap.features.non_interactive);
555    format_feature(&mut out, "  resume", &cap.features.resume);
556    format_feature(
557        &mut out,
558        "  resume-with-prompt",
559        &cap.features.resume_with_prompt,
560    );
561    format_session_log(&mut out, "  session-logs", &cap.features.session_logs);
562    format_feature(&mut out, "  json-output", &cap.features.json_output);
563    format_feature(&mut out, "  stream-json", &cap.features.stream_json);
564    format_feature(&mut out, "  json-schema", &cap.features.json_schema);
565    format_feature(&mut out, "  input-format", &cap.features.input_format);
566    format_streaming_input(&mut out, "  streaming-input", &cap.features.streaming_input);
567    format_feature(&mut out, "  worktree", &cap.features.worktree);
568    format_feature(&mut out, "  sandbox", &cap.features.sandbox);
569    format_feature(&mut out, "  system-prompt", &cap.features.system_prompt);
570    format_feature(&mut out, "  auto-approve", &cap.features.auto_approve);
571    format_feature(&mut out, "  review", &cap.features.review);
572    format_feature(&mut out, "  add-dirs", &cap.features.add_dirs);
573    format_feature(&mut out, "  max-turns", &cap.features.max_turns);
574    out
575}
576
577/// Render a plain text listing of models grouped by provider, as printed by
578/// `zag discover --models`.
579pub fn format_models_text(caps: &[ProviderCapability]) -> String {
580    use std::fmt::Write;
581    let mut out = String::new();
582    for cap in caps {
583        let _ = writeln!(out, "{}:", cap.provider);
584        for m in &cap.available_models {
585            let _ = writeln!(out, "  {m}");
586        }
587    }
588    out
589}
590
591fn format_feature(out: &mut String, label: &str, f: &FeatureSupport) {
592    use std::fmt::Write;
593    let status = if f.supported {
594        if f.native { "native" } else { "wrapper" }
595    } else {
596        "no"
597    };
598    let _ = writeln!(out, "{label:<24} {status}");
599}
600
601fn format_streaming_input(out: &mut String, label: &str, f: &StreamingInputSupport) {
602    use std::fmt::Write;
603    let status = if f.supported {
604        let base = if f.native { "native" } else { "wrapper" };
605        match f.semantics.as_deref() {
606            Some(s) => format!("{base} ({s})"),
607            None => base.to_string(),
608        }
609    } else {
610        "no".to_string()
611    };
612    let _ = writeln!(out, "{label:<24} {status}");
613}
614
615fn format_session_log(out: &mut String, label: &str, f: &SessionLogSupport) {
616    use std::fmt::Write;
617    let status = if f.supported {
618        match f.completeness.as_deref() {
619            Some(c) => {
620                if f.native {
621                    c.to_string()
622                } else {
623                    format!("{c} (wrapper)")
624                }
625            }
626            None => "yes".to_string(),
627        }
628    } else {
629        "no".to_string()
630    };
631    let _ = writeln!(out, "{label:<24} {status}");
632}
633
634#[cfg(test)]
635#[path = "capability_tests.rs"]
636mod tests;