Skip to main content

sgr_agent/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Inline image data for multimodal messages.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6    /// Base64-encoded image data.
7    pub data: String,
8    /// MIME type (e.g. "image/jpeg", "image/png").
9    pub mime_type: String,
10}
11
12impl ImagePart {
13    /// `data:<mime>;base64,<data>` URL form consumed by OpenAI vision APIs.
14    pub fn data_url(&self) -> String {
15        format!("data:{};base64,{}", self.mime_type, self.data)
16    }
17}
18
19/// A chat message in the conversation history.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22    pub role: Role,
23    pub content: String,
24    /// Tool call results (only for Role::Tool).
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tool_call_id: Option<String>,
27    /// Tool calls made by the assistant (only for Role::Assistant with function calling).
28    /// Gemini API requires model turns to include functionCall parts.
29    #[serde(default, skip_serializing_if = "Vec::is_empty")]
30    pub tool_calls: Vec<ToolCall>,
31    /// Inline images (for multimodal VLM input).
32    #[serde(default, skip_serializing_if = "Vec::is_empty")]
33    pub images: Vec<ImagePart>,
34    /// Whether this message can be dropped during context compaction.
35    /// false (default) = critical — never remove (inbox, instruction, system).
36    /// true = compactable — can be summarized or dropped when context overflows.
37    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
38    pub compactable: bool,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum Role {
44    System,
45    User,
46    Assistant,
47    Tool,
48}
49
50impl Message {
51    pub fn system(content: impl Into<String>) -> Self {
52        Self {
53            role: Role::System,
54            content: content.into(),
55            tool_call_id: None,
56            tool_calls: vec![],
57            images: vec![],
58            compactable: false,
59        }
60    }
61    pub fn user(content: impl Into<String>) -> Self {
62        Self {
63            role: Role::User,
64            content: content.into(),
65            tool_call_id: None,
66            tool_calls: vec![],
67            images: vec![],
68            compactable: false,
69        }
70    }
71    pub fn assistant(content: impl Into<String>) -> Self {
72        Self {
73            role: Role::Assistant,
74            content: content.into(),
75            tool_call_id: None,
76            tool_calls: vec![],
77            images: vec![],
78            compactable: false,
79        }
80    }
81    /// Create an assistant message that includes function calls (for Gemini FC protocol).
82    pub fn assistant_with_tool_calls(
83        content: impl Into<String>,
84        tool_calls: Vec<ToolCall>,
85    ) -> Self {
86        Self {
87            role: Role::Assistant,
88            content: content.into(),
89            tool_call_id: None,
90            tool_calls,
91            images: vec![],
92            compactable: false,
93        }
94    }
95    pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
96        Self {
97            role: Role::Tool,
98            content: content.into(),
99            tool_call_id: Some(call_id.into()),
100            tool_calls: vec![],
101            images: vec![],
102            compactable: false,
103        }
104    }
105    /// Tool result with inline images (for VLM — Gemini sees the images).
106    pub fn tool_with_images(
107        call_id: impl Into<String>,
108        content: impl Into<String>,
109        images: Vec<ImagePart>,
110    ) -> Self {
111        Self {
112            role: Role::Tool,
113            content: content.into(),
114            tool_call_id: Some(call_id.into()),
115            tool_calls: vec![],
116            images,
117            compactable: false,
118        }
119    }
120    /// User message with inline images.
121    pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
122        Self {
123            role: Role::User,
124            content: content.into(),
125            tool_call_id: None,
126            tool_calls: vec![],
127            images,
128            compactable: false,
129        }
130    }
131    /// Mark this message as compactable (safe to drop during context overflow).
132    pub fn compactable(mut self) -> Self {
133        self.compactable = true;
134        self
135    }
136}
137
138/// A tool call returned by the model.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolCall {
141    /// Unique ID for matching with tool results.
142    pub id: String,
143    /// Tool/function name.
144    pub name: String,
145    /// JSON-encoded arguments.
146    pub arguments: serde_json::Value,
147}
148
149/// Response from an SGR call — structured output + optional tool calls.
150#[derive(Debug, Clone)]
151pub struct SgrResponse<T> {
152    /// Parsed structured output (SGR envelope).
153    /// `None` if the model only returned tool calls without structured content.
154    pub output: Option<T>,
155    /// Tool calls the model wants to execute.
156    pub tool_calls: Vec<ToolCall>,
157    /// Raw text (for streaming / debugging).
158    pub raw_text: String,
159    /// Token usage.
160    pub usage: Option<Usage>,
161    /// Rate limit info from response headers (if provider sends them).
162    pub rate_limit: Option<RateLimitInfo>,
163}
164
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct Usage {
167    pub prompt_tokens: u32,
168    pub completion_tokens: u32,
169    pub total_tokens: u32,
170}
171
172/// Rate limit info extracted from response headers and/or error body.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct RateLimitInfo {
175    /// Requests remaining in current window.
176    pub requests_remaining: Option<u32>,
177    /// Tokens remaining in current window.
178    pub tokens_remaining: Option<u32>,
179    /// Seconds until limit resets.
180    pub retry_after_secs: Option<u64>,
181    /// Unix timestamp when limit resets.
182    pub resets_at: Option<u64>,
183    /// Provider error type (e.g. "usage_limit_reached", "rate_limit_exceeded").
184    pub error_type: Option<String>,
185    /// Human-readable message from provider.
186    pub message: Option<String>,
187}
188
189impl RateLimitInfo {
190    /// Parse from HTTP response headers (OpenAI/Gemini/OpenRouter standard).
191    pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
192        let get_u32 =
193            |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
194        let get_u64 =
195            |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
196
197        let requests_remaining = get_u32("x-ratelimit-remaining-requests");
198        let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
199        let retry_after_secs =
200            get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
201        let resets_at = get_u64("x-ratelimit-reset-tokens");
202
203        if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
204        {
205            Some(Self {
206                requests_remaining,
207                tokens_remaining,
208                retry_after_secs,
209                resets_at,
210                error_type: None,
211                message: None,
212            })
213        } else {
214            None
215        }
216    }
217
218    /// Parse from JSON error body (OpenAI, Codex, Gemini error responses).
219    pub fn from_error_body(body: &str) -> Option<Self> {
220        let json: serde_json::Value = serde_json::from_str(body).ok()?;
221        let err = json.get("error")?;
222
223        let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
224        let message = err
225            .get("message")
226            .and_then(|v| v.as_str())
227            .map(String::from);
228        let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
229        let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
230
231        Some(Self {
232            requests_remaining: None,
233            tokens_remaining: None,
234            retry_after_secs,
235            resets_at,
236            error_type,
237            message,
238        })
239    }
240
241    /// Human-readable description of when limit resets.
242    pub fn reset_display(&self) -> String {
243        if let Some(secs) = self.retry_after_secs {
244            let hours = secs / 3600;
245            let mins = (secs % 3600) / 60;
246            if hours >= 24 {
247                format!("{}d {}h", hours / 24, hours % 24)
248            } else if hours > 0 {
249                format!("{}h {}m", hours, mins)
250            } else {
251                format!("{}m", mins)
252            }
253        } else {
254            "unknown".into()
255        }
256    }
257
258    /// One-line status for status bar.
259    pub fn status_line(&self) -> String {
260        let mut parts = Vec::new();
261        if let Some(r) = self.requests_remaining {
262            parts.push(format!("req:{}", r));
263        }
264        if let Some(t) = self.tokens_remaining {
265            parts.push(format!("tok:{}", t));
266        }
267        if self.retry_after_secs.is_some() {
268            parts.push(format!("reset:{}", self.reset_display()));
269        }
270        if parts.is_empty() {
271            self.message
272                .clone()
273                .unwrap_or_else(|| "rate limited".into())
274        } else {
275            parts.join(" | ")
276        }
277    }
278}
279
280/// LLM provider configuration — single config for any provider.
281///
282/// Two optional fields control routing:
283/// - `api_key`: None → auto from env vars (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
284/// - `base_url`: None → auto-detect provider from model name; Some → custom endpoint
285///
286/// ```no_run
287/// use sgr_agent::LlmConfig;
288///
289/// let c = LlmConfig::auto("gpt-4o");                                          // env vars
290/// let c = LlmConfig::with_key("sk-...", "claude-3-haiku");                    // explicit key
291/// let c = LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"); // custom
292/// let c = LlmConfig::auto("gpt-4o").temperature(0.9).max_tokens(2048);        // builder
293/// ```
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct LlmConfig {
296    pub model: String,
297    #[serde(default, skip_serializing_if = "Option::is_none")]
298    pub api_key: Option<String>,
299    #[serde(default, skip_serializing_if = "Option::is_none")]
300    pub base_url: Option<String>,
301    #[serde(default = "default_temperature")]
302    pub temp: f64,
303    #[serde(default, skip_serializing_if = "Option::is_none")]
304    pub max_tokens: Option<u32>,
305    /// OpenAI prompt cache key — caches system prompt prefix server-side.
306    #[serde(default, skip_serializing_if = "Option::is_none")]
307    pub prompt_cache_key: Option<String>,
308    /// Vertex AI project ID (enables Vertex routing when set).
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub project_id: Option<String>,
311    /// Vertex AI location (default: "global").
312    #[serde(default, skip_serializing_if = "Option::is_none")]
313    pub location: Option<String>,
314    /// Force Chat Completions API instead of Responses API.
315    /// Needed for OpenAI-compatible endpoints that don't support /responses
316    /// (e.g. Cloudflare AI Gateway compat, OpenRouter, local models).
317    #[serde(default)]
318    pub use_chat_api: bool,
319    /// Extra HTTP headers to include in LLM API requests.
320    /// E.g. `cf-aig-request-timeout: 300000` for Cloudflare AI Gateway.
321    #[serde(default, skip_serializing_if = "Vec::is_empty")]
322    pub extra_headers: Vec<(String, String)>,
323    /// Reasoning effort for reasoning models. "none" disables reasoning for FC.
324    /// E.g. DeepInfra Nemotron Super needs "none" for function calling.
325    #[serde(default, skip_serializing_if = "Option::is_none")]
326    pub reasoning_effort: Option<String>,
327    /// `text.verbosity` for OpenAI Responses API: "low" | "medium" | "high".
328    /// "low" suppresses narrative text emitted alongside forced tool calls —
329    /// useful when only the tool args are consumed and the assistant.text is
330    /// wasted output tokens. Currently honored by the oxide (Responses) backend.
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub verbosity: Option<String>,
333    /// Force genai backend (for providers with native API: Anthropic, Gemini).
334    /// When false, oxide (OpenAI Responses API) is used by default.
335    #[serde(default)]
336    pub use_genai: bool,
337    /// Use CLI subprocess backend (claude/gemini/codex -p).
338    /// Tool calls emulated via text prompt + flexible parsing.
339    /// Uses CLI's own auth (subscription credits, no API key).
340    #[serde(default)]
341    pub use_cli: bool,
342    /// Session ID for request grouping (sticky routing, trace correlation).
343    /// Set per-trial to group all LLM calls in the same session.
344    #[serde(default, skip_serializing_if = "Option::is_none")]
345    pub session_id: Option<String>,
346
347    // ── Provider capabilities (auto-detected from model name, overridable in TOML) ──
348    /// Reject assistant message as last in conversation. Auto: true for Anthropic Opus/Sonnet.
349    #[serde(default, skip_serializing_if = "Option::is_none")]
350    pub no_assistant_prefill: Option<bool>,
351    /// Prompt cache TTL (e.g. "5m", "1h"). Auto: "1h" for Anthropic models.
352    #[serde(default, skip_serializing_if = "Option::is_none")]
353    pub cache_ttl: Option<String>,
354    /// Pin to specific provider on OpenRouter (e.g. "Anthropic"). Auto-detected.
355    #[serde(default, skip_serializing_if = "Option::is_none")]
356    pub pin_provider: Option<String>,
357    /// Enable WebSocket for Responses API (lower latency, persistent connection).
358    /// Default: true for Responses API (Oxide backend), ignored for Chat/genai/CLI.
359    #[serde(default = "default_websocket")]
360    pub websocket: bool,
361}
362
363fn default_websocket() -> bool {
364    true
365}
366
367fn default_temperature() -> f64 {
368    0.7
369}
370
371impl Default for LlmConfig {
372    fn default() -> Self {
373        Self {
374            model: String::new(),
375            api_key: None,
376            base_url: None,
377            temp: default_temperature(),
378            max_tokens: None,
379            prompt_cache_key: None,
380            project_id: None,
381            location: None,
382            use_chat_api: false,
383            extra_headers: Vec::new(),
384            reasoning_effort: None,
385            verbosity: None,
386            use_genai: false,
387            use_cli: false,
388            session_id: None,
389            no_assistant_prefill: None,
390            cache_ttl: None,
391            pin_provider: None,
392            websocket: default_websocket(),
393        }
394    }
395}
396
397impl LlmConfig {
398    /// Auto-detect provider from model name, use env vars for auth.
399    pub fn auto(model: impl Into<String>) -> Self {
400        Self {
401            model: model.into(),
402            ..Default::default()
403        }
404    }
405
406    /// Explicit API key, auto-detect provider from model name.
407    pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
408        Self {
409            model: model.into(),
410            api_key: Some(api_key.into()),
411            ..Default::default()
412        }
413    }
414
415    /// Custom OpenAI-compatible endpoint (OpenRouter, Ollama, LiteLLM, etc.).
416    pub fn endpoint(
417        api_key: impl Into<String>,
418        base_url: impl Into<String>,
419        model: impl Into<String>,
420    ) -> Self {
421        Self {
422            model: model.into(),
423            api_key: Some(api_key.into()),
424            base_url: Some(base_url.into()),
425            ..Default::default()
426        }
427    }
428
429    /// Vertex AI — uses gcloud ADC for auth (no API key needed).
430    pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
431        Self {
432            model: model.into(),
433            project_id: Some(project_id.into()),
434            location: Some("global".into()),
435            ..Default::default()
436        }
437    }
438
439    /// Set Vertex AI location.
440    pub fn location(mut self, loc: impl Into<String>) -> Self {
441        self.location = Some(loc.into());
442        self
443    }
444
445    /// Set temperature.
446    pub fn temperature(mut self, t: f64) -> Self {
447        self.temp = t;
448        self
449    }
450
451    /// Set max output tokens.
452    pub fn max_tokens(mut self, m: u32) -> Self {
453        self.max_tokens = Some(m);
454        self
455    }
456
457    /// Set `text.verbosity` for the OpenAI Responses API.
458    /// Accepts "low" | "medium" | "high". "low" minimizes narrative text the
459    /// model emits alongside a forced tool call.
460    pub fn verbosity(mut self, v: impl Into<String>) -> Self {
461        self.verbosity = Some(v.into());
462        self
463    }
464
465    /// Set OpenAI prompt cache key for server-side system prompt caching.
466    pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
467        self.prompt_cache_key = Some(key.into());
468        self
469    }
470
471    /// True if model targets Anthropic (auto-detect from model name).
472    pub fn is_anthropic(&self) -> bool {
473        self.model.starts_with("anthropic/") || self.model.starts_with("claude")
474    }
475
476    /// Whether assistant prefill is rejected. TOML override > auto-detect.
477    pub fn rejects_prefill(&self) -> bool {
478        self.no_assistant_prefill.unwrap_or_else(|| {
479            // Anthropic Opus/Sonnet reject prefill; Haiku via Bedrock allows it
480            self.is_anthropic() && !self.model.contains("haiku")
481        })
482    }
483
484    /// Prompt cache TTL, if any. TOML override > auto-detect.
485    pub fn resolved_cache_ttl(&self) -> Option<&str> {
486        if self.cache_ttl.is_some() {
487            return self.cache_ttl.as_deref();
488        }
489        if self.is_anthropic() {
490            Some("1h")
491        } else {
492            None
493        }
494    }
495
496    /// Provider to pin on OpenRouter, if any. TOML override > auto-detect.
497    pub fn resolved_pin_provider(&self) -> Option<&str> {
498        if self.pin_provider.is_some() {
499            return self.pin_provider.as_deref();
500        }
501        if self.is_anthropic() {
502            Some("Anthropic")
503        } else {
504            None
505        }
506    }
507
508    /// Apply extra_headers to an openai-oxide ClientConfig.
509    /// Used by both OxideClient and OxideChatClient.
510    pub fn apply_headers(&self, config: &mut openai_oxide::config::ClientConfig) {
511        if !self.extra_headers.is_empty() {
512            let mut hm = reqwest::header::HeaderMap::new();
513            for (k, v) in &self.extra_headers {
514                if let (Ok(name), Ok(val)) = (
515                    reqwest::header::HeaderName::from_bytes(k.as_bytes()),
516                    reqwest::header::HeaderValue::from_str(v),
517                ) {
518                    hm.insert(name, val);
519                }
520            }
521            config.default_headers = Some(hm);
522        }
523    }
524
525    /// CLI subprocess backend — uses `claude -p` / `gemini -p` / `codex exec`.
526    /// No API key needed, uses CLI's own auth (subscription credits).
527    /// Optional `model` overrides the CLI's default model via `--model` flag.
528    pub fn cli(cli_model: impl Into<String>) -> Self {
529        Self {
530            model: cli_model.into(),
531            use_cli: true,
532            ..Default::default()
533        }
534    }
535
536    /// Human-readable label for display.
537    pub fn label(&self) -> String {
538        if self.use_cli {
539            format!("CLI ({})", self.model)
540        } else if self.project_id.is_some() {
541            format!("Vertex ({})", self.model)
542        } else if self.base_url.is_some() {
543            format!("Custom ({})", self.model)
544        } else {
545            self.model.clone()
546        }
547    }
548
549    /// Infer a cheap/fast model for compaction based on the primary model.
550    pub fn compaction_model(&self) -> String {
551        if self.model.starts_with("gemini") {
552            "gemini-2.0-flash-lite".into()
553        } else if self.model.starts_with("gpt") {
554            "gpt-4o-mini".into()
555        } else if self.model.starts_with("claude") {
556            "claude-3-haiku-20240307".into()
557        } else {
558            // Unknown provider — use the same model
559            self.model.clone()
560        }
561    }
562
563    /// Create a compaction config — cheap model, low max_tokens.
564    pub fn for_compaction(&self) -> Self {
565        let mut cfg = self.clone();
566        cfg.model = self.compaction_model();
567        cfg.max_tokens = Some(2048);
568        cfg
569    }
570}
571
572/// Legacy provider configuration (used by OpenAIClient/GeminiClient).
573#[derive(Debug, Clone)]
574pub struct ProviderConfig {
575    pub api_key: String,
576    pub model: String,
577    pub base_url: Option<String>,
578    pub project_id: Option<String>,
579    pub location: Option<String>,
580    pub temperature: f32,
581    pub max_tokens: Option<u32>,
582}
583
584impl ProviderConfig {
585    pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
586        Self {
587            api_key: api_key.into(),
588            model: model.into(),
589            base_url: None,
590            project_id: None,
591            location: None,
592            temperature: 0.3,
593            max_tokens: None,
594        }
595    }
596
597    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
598        Self {
599            api_key: api_key.into(),
600            model: model.into(),
601            base_url: None,
602            project_id: None,
603            location: None,
604            temperature: 0.3,
605            max_tokens: None,
606        }
607    }
608
609    pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
610        Self {
611            api_key: api_key.into(),
612            model: model.into(),
613            base_url: Some("https://openrouter.ai/api/v1".into()),
614            project_id: None,
615            location: None,
616            temperature: 0.3,
617            max_tokens: None,
618        }
619    }
620
621    pub fn vertex(
622        access_token: impl Into<String>,
623        project_id: impl Into<String>,
624        model: impl Into<String>,
625    ) -> Self {
626        Self {
627            api_key: access_token.into(),
628            model: model.into(),
629            base_url: None,
630            project_id: Some(project_id.into()),
631            location: Some("global".to_string()),
632            temperature: 0.3,
633            max_tokens: None,
634        }
635    }
636
637    pub fn ollama(model: impl Into<String>) -> Self {
638        Self {
639            api_key: String::new(),
640            model: model.into(),
641            base_url: Some("http://localhost:11434/v1".into()),
642            project_id: None,
643            location: None,
644            temperature: 0.3,
645            max_tokens: None,
646        }
647    }
648}
649
650/// Errors from SGR calls.
651#[derive(Debug, thiserror::Error)]
652pub enum SgrError {
653    #[error("HTTP error: {0}")]
654    Http(#[from] reqwest::Error),
655    #[error("API error {status}: {body}")]
656    Api { status: u16, body: String },
657    #[error("Rate limit: {}", info.status_line())]
658    RateLimit { status: u16, info: RateLimitInfo },
659    #[error("JSON parse error: {0}")]
660    Json(#[from] serde_json::Error),
661    #[error("Schema error: {0}")]
662    Schema(String),
663    #[error("No content in response")]
664    EmptyResponse,
665    /// Model response was truncated due to max_output_tokens limit.
666    /// Contains the partial content that was generated before truncation.
667    #[error("Response truncated (max_output_tokens): {partial_content}")]
668    MaxOutputTokens { partial_content: String },
669    /// Prompt too long — context exceeds model's input limit.
670    #[error("Prompt too long: {0}")]
671    PromptTooLong(String),
672}
673
674impl SgrError {
675    /// Build error from HTTP status + body, auto-detecting rate limits.
676    pub fn from_api_response(status: u16, body: String) -> Self {
677        if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
678            && let Some(mut info) = RateLimitInfo::from_error_body(&body)
679        {
680            if info.message.is_none() {
681                info.message = Some(body.chars().take(200).collect());
682            }
683            return SgrError::RateLimit { status, info };
684        }
685        SgrError::Api { status, body }
686    }
687
688    /// Build error from HTTP status + body + headers, auto-detecting rate limits.
689    pub fn from_response_parts(
690        status: u16,
691        body: String,
692        headers: &reqwest::header::HeaderMap,
693    ) -> Self {
694        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
695            let mut info = RateLimitInfo::from_error_body(&body)
696                .or_else(|| RateLimitInfo::from_headers(headers))
697                .unwrap_or(RateLimitInfo {
698                    requests_remaining: None,
699                    tokens_remaining: None,
700                    retry_after_secs: None,
701                    resets_at: None,
702                    error_type: Some("rate_limit".into()),
703                    message: Some(body.chars().take(200).collect()),
704                });
705            // Merge header info into body info
706            if let Some(header_info) = RateLimitInfo::from_headers(headers) {
707                if info.requests_remaining.is_none() {
708                    info.requests_remaining = header_info.requests_remaining;
709                }
710                if info.tokens_remaining.is_none() {
711                    info.tokens_remaining = header_info.tokens_remaining;
712                }
713            }
714            return SgrError::RateLimit { status, info };
715        }
716        SgrError::Api { status, body }
717    }
718
719    /// Is this a rate limit error?
720    pub fn is_rate_limit(&self) -> bool {
721        matches!(self, SgrError::RateLimit { .. })
722    }
723
724    /// Get rate limit info if this is a rate limit error.
725    pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
726        match self {
727            SgrError::RateLimit { info, .. } => Some(info),
728            _ => None,
729        }
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736
737    #[test]
738    fn parse_codex_rate_limit_error() {
739        let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
740        let err = SgrError::from_api_response(429, body.to_string());
741        assert!(err.is_rate_limit());
742        let info = err.rate_limit_info().unwrap();
743        assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
744        assert_eq!(info.retry_after_secs, Some(442787));
745        assert_eq!(info.resets_at, Some(1773534007));
746        assert_eq!(info.reset_display(), "5d 2h");
747    }
748
749    #[test]
750    fn parse_openai_rate_limit_error() {
751        let body =
752            r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
753        let err = SgrError::from_api_response(429, body.to_string());
754        assert!(err.is_rate_limit());
755        let info = err.rate_limit_info().unwrap();
756        assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
757    }
758
759    #[test]
760    fn non_rate_limit_stays_api_error() {
761        let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
762        let err = SgrError::from_api_response(400, body.to_string());
763        assert!(!err.is_rate_limit());
764        assert!(matches!(err, SgrError::Api { status: 400, .. }));
765    }
766
767    #[test]
768    fn status_line_with_all_fields() {
769        let info = RateLimitInfo {
770            requests_remaining: Some(5),
771            tokens_remaining: Some(10000),
772            retry_after_secs: Some(3600),
773            resets_at: None,
774            error_type: None,
775            message: None,
776        };
777        assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
778    }
779
780    #[test]
781    fn status_line_fallback_to_message() {
782        let info = RateLimitInfo {
783            requests_remaining: None,
784            tokens_remaining: None,
785            retry_after_secs: None,
786            resets_at: None,
787            error_type: None,
788            message: Some("custom message".into()),
789        };
790        assert_eq!(info.status_line(), "custom message");
791    }
792
793    #[test]
794    fn reset_display_formats() {
795        let make = |secs| RateLimitInfo {
796            requests_remaining: None,
797            tokens_remaining: None,
798            retry_after_secs: Some(secs),
799            resets_at: None,
800            error_type: None,
801            message: None,
802        };
803        assert_eq!(make(90).reset_display(), "1m");
804        assert_eq!(make(3661).reset_display(), "1h 1m");
805        assert_eq!(make(90000).reset_display(), "1d 1h");
806    }
807}