Skip to main content

swink_agent/types/
model.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5// ─── Model Capabilities ─────────────────────────────────────────────────────
6
7/// Per-model capability flags and limits.
8///
9/// Populated from the model catalog or set manually. The agent loop can
10/// inspect these before enabling provider-specific features (e.g. skip
11/// thinking blocks when `supports_thinking` is false).
12#[allow(clippy::struct_excessive_bools)]
13#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
14pub struct ModelCapabilities {
15    /// Whether the model supports extended-thinking / chain-of-thought blocks.
16    pub supports_thinking: bool,
17    /// Whether the model accepts image content blocks.
18    pub supports_vision: bool,
19    /// Whether the model can invoke tools.
20    pub supports_tool_use: bool,
21    /// Whether the model supports streaming responses.
22    pub supports_streaming: bool,
23    /// Whether the model supports structured (JSON schema) output.
24    pub supports_structured_output: bool,
25    /// Maximum input context window in tokens, if known.
26    pub max_context_window: Option<u64>,
27    /// Maximum output tokens per response, if known.
28    pub max_output_tokens: Option<u64>,
29}
30
31impl ModelCapabilities {
32    /// Create capabilities with all flags set to false and no limits.
33    #[must_use]
34    pub fn none() -> Self {
35        Self::default()
36    }
37
38    #[must_use]
39    pub const fn with_thinking(mut self, val: bool) -> Self {
40        self.supports_thinking = val;
41        self
42    }
43
44    #[must_use]
45    pub const fn with_vision(mut self, val: bool) -> Self {
46        self.supports_vision = val;
47        self
48    }
49
50    #[must_use]
51    pub const fn with_tool_use(mut self, val: bool) -> Self {
52        self.supports_tool_use = val;
53        self
54    }
55
56    #[must_use]
57    pub const fn with_streaming(mut self, val: bool) -> Self {
58        self.supports_streaming = val;
59        self
60    }
61
62    #[must_use]
63    pub const fn with_structured_output(mut self, val: bool) -> Self {
64        self.supports_structured_output = val;
65        self
66    }
67
68    #[must_use]
69    pub const fn with_max_context_window(mut self, tokens: u64) -> Self {
70        self.max_context_window = Some(tokens);
71        self
72    }
73
74    #[must_use]
75    pub const fn with_max_output_tokens(mut self, tokens: u64) -> Self {
76        self.max_output_tokens = Some(tokens);
77        self
78    }
79}
80
81// ─── Model Specification ────────────────────────────────────────────────────
82
83/// Reasoning depth for models that support configurable thinking.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
85#[serde(rename_all = "snake_case")]
86pub enum ThinkingLevel {
87    #[default]
88    Off,
89    Minimal,
90    Low,
91    Medium,
92    High,
93    ExtraHigh,
94}
95
96/// Optional per-level token budget overrides for providers that support
97/// token-based reasoning control.
98#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
99pub struct ThinkingBudgets {
100    pub budgets: HashMap<ThinkingLevel, u64>,
101}
102
103impl ThinkingBudgets {
104    /// Create a new `ThinkingBudgets` from a map.
105    #[must_use]
106    pub const fn new(budgets: HashMap<ThinkingLevel, u64>) -> Self {
107        Self { budgets }
108    }
109
110    /// Look up the token budget for a given thinking level.
111    #[must_use]
112    pub fn get(&self, level: &ThinkingLevel) -> Option<u64> {
113        self.budgets.get(level).copied()
114    }
115}
116
117/// Identifies the target model for a request.
118#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
119#[allow(clippy::derive_partial_eq_without_eq)]
120pub struct ModelSpec {
121    pub provider: String,
122    pub model_id: String,
123    pub thinking_level: ThinkingLevel,
124    pub thinking_budgets: Option<ThinkingBudgets>,
125    /// Provider-specific configuration (thinking, parameters, etc.).
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub provider_config: Option<serde_json::Value>,
128    /// Per-model capability flags and limits.
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub capabilities: Option<ModelCapabilities>,
131}
132
133impl ModelSpec {
134    /// Create a new `ModelSpec` with thinking disabled and no budgets.
135    #[must_use]
136    pub fn new(provider: impl Into<String>, model_id: impl Into<String>) -> Self {
137        Self {
138            provider: provider.into(),
139            model_id: model_id.into(),
140            thinking_level: ThinkingLevel::Off,
141            thinking_budgets: None,
142            provider_config: None,
143            capabilities: None,
144        }
145    }
146
147    #[must_use]
148    pub const fn with_thinking_level(mut self, level: ThinkingLevel) -> Self {
149        self.thinking_level = level;
150        self
151    }
152
153    #[must_use]
154    pub fn with_thinking_budgets(mut self, budgets: ThinkingBudgets) -> Self {
155        self.thinking_budgets = Some(budgets);
156        self
157    }
158
159    #[must_use]
160    pub fn with_provider_config(mut self, config: serde_json::Value) -> Self {
161        self.provider_config = Some(config);
162        self
163    }
164
165    #[must_use]
166    pub const fn with_capabilities(mut self, capabilities: ModelCapabilities) -> Self {
167        self.capabilities = Some(capabilities);
168        self
169    }
170
171    /// Returns the model capabilities, or a default (all-false) set if none
172    /// were provided.
173    #[must_use]
174    pub fn capabilities(&self) -> ModelCapabilities {
175        self.capabilities.clone().unwrap_or_default()
176    }
177
178    /// Get a typed provider config, deserializing from the stored JSON.
179    #[must_use]
180    pub fn provider_config_as<T: serde::de::DeserializeOwned>(&self) -> Option<T> {
181        self.provider_config
182            .as_ref()
183            .and_then(|v| serde_json::from_value(v.clone()).ok())
184    }
185}