Skip to main content

sgr_agent/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Inline image data for multimodal messages.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6    /// Base64-encoded image data.
7    pub data: String,
8    /// MIME type (e.g. "image/jpeg", "image/png").
9    pub mime_type: String,
10}
11
12impl ImagePart {
13    /// `data:<mime>;base64,<data>` URL form consumed by OpenAI vision APIs.
14    pub fn data_url(&self) -> String {
15        format!("data:{};base64,{}", self.mime_type, self.data)
16    }
17}
18
19/// A chat message in the conversation history.
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22    pub role: Role,
23    pub content: String,
24    /// Tool call results (only for Role::Tool).
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tool_call_id: Option<String>,
27    /// Tool calls made by the assistant (only for Role::Assistant with function calling).
28    /// Gemini API requires model turns to include functionCall parts.
29    #[serde(default, skip_serializing_if = "Vec::is_empty")]
30    pub tool_calls: Vec<ToolCall>,
31    /// Inline images (for multimodal VLM input).
32    #[serde(default, skip_serializing_if = "Vec::is_empty")]
33    pub images: Vec<ImagePart>,
34    /// Whether this message can be dropped during context compaction.
35    /// false (default) = critical — never remove (inbox, instruction, system).
36    /// true = compactable — can be summarized or dropped when context overflows.
37    #[serde(default, skip_serializing_if = "std::ops::Not::not")]
38    pub compactable: bool,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum Role {
44    System,
45    User,
46    Assistant,
47    Tool,
48}
49
50impl Message {
51    pub fn system(content: impl Into<String>) -> Self {
52        Self {
53            role: Role::System,
54            content: content.into(),
55            tool_call_id: None,
56            tool_calls: vec![],
57            images: vec![],
58            compactable: false,
59        }
60    }
61    pub fn user(content: impl Into<String>) -> Self {
62        Self {
63            role: Role::User,
64            content: content.into(),
65            tool_call_id: None,
66            tool_calls: vec![],
67            images: vec![],
68            compactable: false,
69        }
70    }
71    pub fn assistant(content: impl Into<String>) -> Self {
72        Self {
73            role: Role::Assistant,
74            content: content.into(),
75            tool_call_id: None,
76            tool_calls: vec![],
77            images: vec![],
78            compactable: false,
79        }
80    }
81    /// Create an assistant message that includes function calls (for Gemini FC protocol).
82    pub fn assistant_with_tool_calls(
83        content: impl Into<String>,
84        tool_calls: Vec<ToolCall>,
85    ) -> Self {
86        Self {
87            role: Role::Assistant,
88            content: content.into(),
89            tool_call_id: None,
90            tool_calls,
91            images: vec![],
92            compactable: false,
93        }
94    }
95    pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
96        Self {
97            role: Role::Tool,
98            content: content.into(),
99            tool_call_id: Some(call_id.into()),
100            tool_calls: vec![],
101            images: vec![],
102            compactable: false,
103        }
104    }
105    /// Tool result with inline images (for VLM — Gemini sees the images).
106    pub fn tool_with_images(
107        call_id: impl Into<String>,
108        content: impl Into<String>,
109        images: Vec<ImagePart>,
110    ) -> Self {
111        Self {
112            role: Role::Tool,
113            content: content.into(),
114            tool_call_id: Some(call_id.into()),
115            tool_calls: vec![],
116            images,
117            compactable: false,
118        }
119    }
120    /// User message with inline images.
121    pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
122        Self {
123            role: Role::User,
124            content: content.into(),
125            tool_call_id: None,
126            tool_calls: vec![],
127            images,
128            compactable: false,
129        }
130    }
131    /// Mark this message as compactable (safe to drop during context overflow).
132    pub fn compactable(mut self) -> Self {
133        self.compactable = true;
134        self
135    }
136}
137
138/// A tool call returned by the model.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolCall {
141    /// Unique ID for matching with tool results.
142    pub id: String,
143    /// Tool/function name.
144    pub name: String,
145    /// JSON-encoded arguments.
146    pub arguments: serde_json::Value,
147}
148
149/// Response from an SGR call — structured output + optional tool calls.
150#[derive(Debug, Clone)]
151pub struct SgrResponse<T> {
152    /// Parsed structured output (SGR envelope).
153    /// `None` if the model only returned tool calls without structured content.
154    pub output: Option<T>,
155    /// Tool calls the model wants to execute.
156    pub tool_calls: Vec<ToolCall>,
157    /// Raw text (for streaming / debugging).
158    pub raw_text: String,
159    /// Token usage.
160    pub usage: Option<Usage>,
161    /// Rate limit info from response headers (if provider sends them).
162    pub rate_limit: Option<RateLimitInfo>,
163}
164
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct Usage {
167    pub prompt_tokens: u32,
168    pub completion_tokens: u32,
169    pub total_tokens: u32,
170}
171
172/// Rate limit info extracted from response headers and/or error body.
173#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct RateLimitInfo {
175    /// Requests remaining in current window.
176    pub requests_remaining: Option<u32>,
177    /// Tokens remaining in current window.
178    pub tokens_remaining: Option<u32>,
179    /// Seconds until limit resets.
180    pub retry_after_secs: Option<u64>,
181    /// Unix timestamp when limit resets.
182    pub resets_at: Option<u64>,
183    /// Provider error type (e.g. "usage_limit_reached", "rate_limit_exceeded").
184    pub error_type: Option<String>,
185    /// Human-readable message from provider.
186    pub message: Option<String>,
187}
188
189impl RateLimitInfo {
190    /// Parse from HTTP response headers (OpenAI/Gemini/OpenRouter standard).
191    pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
192        let get_u32 =
193            |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
194        let get_u64 =
195            |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
196
197        let requests_remaining = get_u32("x-ratelimit-remaining-requests");
198        let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
199        let retry_after_secs =
200            get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
201        let resets_at = get_u64("x-ratelimit-reset-tokens");
202
203        if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
204        {
205            Some(Self {
206                requests_remaining,
207                tokens_remaining,
208                retry_after_secs,
209                resets_at,
210                error_type: None,
211                message: None,
212            })
213        } else {
214            None
215        }
216    }
217
218    /// Parse from JSON error body (OpenAI, Codex, Gemini error responses).
219    pub fn from_error_body(body: &str) -> Option<Self> {
220        let json: serde_json::Value = serde_json::from_str(body).ok()?;
221        let err = json.get("error")?;
222
223        let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
224        let message = err
225            .get("message")
226            .and_then(|v| v.as_str())
227            .map(String::from);
228        let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
229        let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
230
231        Some(Self {
232            requests_remaining: None,
233            tokens_remaining: None,
234            retry_after_secs,
235            resets_at,
236            error_type,
237            message,
238        })
239    }
240
241    /// Human-readable description of when limit resets.
242    pub fn reset_display(&self) -> String {
243        if let Some(secs) = self.retry_after_secs {
244            let hours = secs / 3600;
245            let mins = (secs % 3600) / 60;
246            if hours >= 24 {
247                format!("{}d {}h", hours / 24, hours % 24)
248            } else if hours > 0 {
249                format!("{}h {}m", hours, mins)
250            } else {
251                format!("{}m", mins)
252            }
253        } else {
254            "unknown".into()
255        }
256    }
257
258    /// One-line status for status bar.
259    pub fn status_line(&self) -> String {
260        let mut parts = Vec::new();
261        if let Some(r) = self.requests_remaining {
262            parts.push(format!("req:{}", r));
263        }
264        if let Some(t) = self.tokens_remaining {
265            parts.push(format!("tok:{}", t));
266        }
267        if self.retry_after_secs.is_some() {
268            parts.push(format!("reset:{}", self.reset_display()));
269        }
270        if parts.is_empty() {
271            self.message
272                .clone()
273                .unwrap_or_else(|| "rate limited".into())
274        } else {
275            parts.join(" | ")
276        }
277    }
278}
279
280/// LLM provider configuration — single config for any provider.
281///
282/// Two optional fields control routing:
283/// - `api_key`: None → auto from env vars (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
284/// - `base_url`: None → auto-detect provider from model name; Some → custom endpoint
285///
286/// ```no_run
287/// use sgr_agent::LlmConfig;
288///
289/// let c = LlmConfig::auto("gpt-4o");                                          // env vars
290/// let c = LlmConfig::with_key("sk-...", "claude-3-haiku");                    // explicit key
291/// let c = LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"); // custom
292/// let c = LlmConfig::auto("gpt-4o").temperature(0.9).max_tokens(2048);        // builder
293/// ```
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct LlmConfig {
296    pub model: String,
297    #[serde(default, skip_serializing_if = "Option::is_none")]
298    pub api_key: Option<String>,
299    #[serde(default, skip_serializing_if = "Option::is_none")]
300    pub base_url: Option<String>,
301    #[serde(default = "default_temperature")]
302    pub temp: f64,
303    #[serde(default, skip_serializing_if = "Option::is_none")]
304    pub max_tokens: Option<u32>,
305    /// OpenAI prompt cache key — caches system prompt prefix server-side.
306    #[serde(default, skip_serializing_if = "Option::is_none")]
307    pub prompt_cache_key: Option<String>,
308    /// Vertex AI project ID (enables Vertex routing when set).
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub project_id: Option<String>,
311    /// Vertex AI location (default: "global").
312    #[serde(default, skip_serializing_if = "Option::is_none")]
313    pub location: Option<String>,
314    /// Force Chat Completions API instead of Responses API.
315    /// Needed for OpenAI-compatible endpoints that don't support /responses
316    /// (e.g. Cloudflare AI Gateway compat, OpenRouter, local models).
317    #[serde(default)]
318    pub use_chat_api: bool,
319    /// Extra HTTP headers to include in LLM API requests.
320    /// E.g. `cf-aig-request-timeout: 300000` for Cloudflare AI Gateway.
321    #[serde(default, skip_serializing_if = "Vec::is_empty")]
322    pub extra_headers: Vec<(String, String)>,
323    /// Reasoning effort for reasoning models. "none" disables reasoning for FC.
324    /// E.g. DeepInfra Nemotron Super needs "none" for function calling.
325    #[serde(default, skip_serializing_if = "Option::is_none")]
326    pub reasoning_effort: Option<String>,
327    /// Force genai backend (for providers with native API: Anthropic, Gemini).
328    /// When false, oxide (OpenAI Responses API) is used by default.
329    #[serde(default)]
330    pub use_genai: bool,
331    /// Use CLI subprocess backend (claude/gemini/codex -p).
332    /// Tool calls emulated via text prompt + flexible parsing.
333    /// Uses CLI's own auth (subscription credits, no API key).
334    #[serde(default)]
335    pub use_cli: bool,
336    /// Session ID for request grouping (sticky routing, trace correlation).
337    /// Set per-trial to group all LLM calls in the same session.
338    #[serde(default, skip_serializing_if = "Option::is_none")]
339    pub session_id: Option<String>,
340
341    // ── Provider capabilities (auto-detected from model name, overridable in TOML) ──
342    /// Reject assistant message as last in conversation. Auto: true for Anthropic Opus/Sonnet.
343    #[serde(default, skip_serializing_if = "Option::is_none")]
344    pub no_assistant_prefill: Option<bool>,
345    /// Prompt cache TTL (e.g. "5m", "1h"). Auto: "1h" for Anthropic models.
346    #[serde(default, skip_serializing_if = "Option::is_none")]
347    pub cache_ttl: Option<String>,
348    /// Pin to specific provider on OpenRouter (e.g. "Anthropic"). Auto-detected.
349    #[serde(default, skip_serializing_if = "Option::is_none")]
350    pub pin_provider: Option<String>,
351    /// Enable WebSocket for Responses API (lower latency, persistent connection).
352    /// Default: true for Responses API (Oxide backend), ignored for Chat/genai/CLI.
353    #[serde(default = "default_websocket")]
354    pub websocket: bool,
355}
356
357fn default_websocket() -> bool {
358    true
359}
360
361fn default_temperature() -> f64 {
362    0.7
363}
364
365impl Default for LlmConfig {
366    fn default() -> Self {
367        Self {
368            model: String::new(),
369            api_key: None,
370            base_url: None,
371            temp: default_temperature(),
372            max_tokens: None,
373            prompt_cache_key: None,
374            project_id: None,
375            location: None,
376            use_chat_api: false,
377            extra_headers: Vec::new(),
378            reasoning_effort: None,
379            use_genai: false,
380            use_cli: false,
381            session_id: None,
382            no_assistant_prefill: None,
383            cache_ttl: None,
384            pin_provider: None,
385            websocket: default_websocket(),
386        }
387    }
388}
389
390impl LlmConfig {
391    /// Auto-detect provider from model name, use env vars for auth.
392    pub fn auto(model: impl Into<String>) -> Self {
393        Self {
394            model: model.into(),
395            ..Default::default()
396        }
397    }
398
399    /// Explicit API key, auto-detect provider from model name.
400    pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
401        Self {
402            model: model.into(),
403            api_key: Some(api_key.into()),
404            ..Default::default()
405        }
406    }
407
408    /// Custom OpenAI-compatible endpoint (OpenRouter, Ollama, LiteLLM, etc.).
409    pub fn endpoint(
410        api_key: impl Into<String>,
411        base_url: impl Into<String>,
412        model: impl Into<String>,
413    ) -> Self {
414        Self {
415            model: model.into(),
416            api_key: Some(api_key.into()),
417            base_url: Some(base_url.into()),
418            ..Default::default()
419        }
420    }
421
422    /// Vertex AI — uses gcloud ADC for auth (no API key needed).
423    pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
424        Self {
425            model: model.into(),
426            project_id: Some(project_id.into()),
427            location: Some("global".into()),
428            ..Default::default()
429        }
430    }
431
432    /// Set Vertex AI location.
433    pub fn location(mut self, loc: impl Into<String>) -> Self {
434        self.location = Some(loc.into());
435        self
436    }
437
438    /// Set temperature.
439    pub fn temperature(mut self, t: f64) -> Self {
440        self.temp = t;
441        self
442    }
443
444    /// Set max output tokens.
445    pub fn max_tokens(mut self, m: u32) -> Self {
446        self.max_tokens = Some(m);
447        self
448    }
449
450    /// Set OpenAI prompt cache key for server-side system prompt caching.
451    pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
452        self.prompt_cache_key = Some(key.into());
453        self
454    }
455
456    /// True if model targets Anthropic (auto-detect from model name).
457    pub fn is_anthropic(&self) -> bool {
458        self.model.starts_with("anthropic/") || self.model.starts_with("claude")
459    }
460
461    /// Whether assistant prefill is rejected. TOML override > auto-detect.
462    pub fn rejects_prefill(&self) -> bool {
463        self.no_assistant_prefill.unwrap_or_else(|| {
464            // Anthropic Opus/Sonnet reject prefill; Haiku via Bedrock allows it
465            self.is_anthropic() && !self.model.contains("haiku")
466        })
467    }
468
469    /// Prompt cache TTL, if any. TOML override > auto-detect.
470    pub fn resolved_cache_ttl(&self) -> Option<&str> {
471        if self.cache_ttl.is_some() {
472            return self.cache_ttl.as_deref();
473        }
474        if self.is_anthropic() {
475            Some("1h")
476        } else {
477            None
478        }
479    }
480
481    /// Provider to pin on OpenRouter, if any. TOML override > auto-detect.
482    pub fn resolved_pin_provider(&self) -> Option<&str> {
483        if self.pin_provider.is_some() {
484            return self.pin_provider.as_deref();
485        }
486        if self.is_anthropic() {
487            Some("Anthropic")
488        } else {
489            None
490        }
491    }
492
493    /// Apply extra_headers to an openai-oxide ClientConfig.
494    /// Used by both OxideClient and OxideChatClient.
495    pub fn apply_headers(&self, config: &mut openai_oxide::config::ClientConfig) {
496        if !self.extra_headers.is_empty() {
497            let mut hm = reqwest::header::HeaderMap::new();
498            for (k, v) in &self.extra_headers {
499                if let (Ok(name), Ok(val)) = (
500                    reqwest::header::HeaderName::from_bytes(k.as_bytes()),
501                    reqwest::header::HeaderValue::from_str(v),
502                ) {
503                    hm.insert(name, val);
504                }
505            }
506            config.default_headers = Some(hm);
507        }
508    }
509
510    /// CLI subprocess backend — uses `claude -p` / `gemini -p` / `codex exec`.
511    /// No API key needed, uses CLI's own auth (subscription credits).
512    /// Optional `model` overrides the CLI's default model via `--model` flag.
513    pub fn cli(cli_model: impl Into<String>) -> Self {
514        Self {
515            model: cli_model.into(),
516            use_cli: true,
517            ..Default::default()
518        }
519    }
520
521    /// Human-readable label for display.
522    pub fn label(&self) -> String {
523        if self.use_cli {
524            format!("CLI ({})", self.model)
525        } else if self.project_id.is_some() {
526            format!("Vertex ({})", self.model)
527        } else if self.base_url.is_some() {
528            format!("Custom ({})", self.model)
529        } else {
530            self.model.clone()
531        }
532    }
533
534    /// Infer a cheap/fast model for compaction based on the primary model.
535    pub fn compaction_model(&self) -> String {
536        if self.model.starts_with("gemini") {
537            "gemini-2.0-flash-lite".into()
538        } else if self.model.starts_with("gpt") {
539            "gpt-4o-mini".into()
540        } else if self.model.starts_with("claude") {
541            "claude-3-haiku-20240307".into()
542        } else {
543            // Unknown provider — use the same model
544            self.model.clone()
545        }
546    }
547
548    /// Create a compaction config — cheap model, low max_tokens.
549    pub fn for_compaction(&self) -> Self {
550        let mut cfg = self.clone();
551        cfg.model = self.compaction_model();
552        cfg.max_tokens = Some(2048);
553        cfg
554    }
555}
556
557/// Legacy provider configuration (used by OpenAIClient/GeminiClient).
558#[derive(Debug, Clone)]
559pub struct ProviderConfig {
560    pub api_key: String,
561    pub model: String,
562    pub base_url: Option<String>,
563    pub project_id: Option<String>,
564    pub location: Option<String>,
565    pub temperature: f32,
566    pub max_tokens: Option<u32>,
567}
568
569impl ProviderConfig {
570    pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
571        Self {
572            api_key: api_key.into(),
573            model: model.into(),
574            base_url: None,
575            project_id: None,
576            location: None,
577            temperature: 0.3,
578            max_tokens: None,
579        }
580    }
581
582    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
583        Self {
584            api_key: api_key.into(),
585            model: model.into(),
586            base_url: None,
587            project_id: None,
588            location: None,
589            temperature: 0.3,
590            max_tokens: None,
591        }
592    }
593
594    pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
595        Self {
596            api_key: api_key.into(),
597            model: model.into(),
598            base_url: Some("https://openrouter.ai/api/v1".into()),
599            project_id: None,
600            location: None,
601            temperature: 0.3,
602            max_tokens: None,
603        }
604    }
605
606    pub fn vertex(
607        access_token: impl Into<String>,
608        project_id: impl Into<String>,
609        model: impl Into<String>,
610    ) -> Self {
611        Self {
612            api_key: access_token.into(),
613            model: model.into(),
614            base_url: None,
615            project_id: Some(project_id.into()),
616            location: Some("global".to_string()),
617            temperature: 0.3,
618            max_tokens: None,
619        }
620    }
621
622    pub fn ollama(model: impl Into<String>) -> Self {
623        Self {
624            api_key: String::new(),
625            model: model.into(),
626            base_url: Some("http://localhost:11434/v1".into()),
627            project_id: None,
628            location: None,
629            temperature: 0.3,
630            max_tokens: None,
631        }
632    }
633}
634
635/// Errors from SGR calls.
636#[derive(Debug, thiserror::Error)]
637pub enum SgrError {
638    #[error("HTTP error: {0}")]
639    Http(#[from] reqwest::Error),
640    #[error("API error {status}: {body}")]
641    Api { status: u16, body: String },
642    #[error("Rate limit: {}", info.status_line())]
643    RateLimit { status: u16, info: RateLimitInfo },
644    #[error("JSON parse error: {0}")]
645    Json(#[from] serde_json::Error),
646    #[error("Schema error: {0}")]
647    Schema(String),
648    #[error("No content in response")]
649    EmptyResponse,
650    /// Model response was truncated due to max_output_tokens limit.
651    /// Contains the partial content that was generated before truncation.
652    #[error("Response truncated (max_output_tokens): {partial_content}")]
653    MaxOutputTokens { partial_content: String },
654    /// Prompt too long — context exceeds model's input limit.
655    #[error("Prompt too long: {0}")]
656    PromptTooLong(String),
657}
658
659impl SgrError {
660    /// Build error from HTTP status + body, auto-detecting rate limits.
661    pub fn from_api_response(status: u16, body: String) -> Self {
662        if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
663            && let Some(mut info) = RateLimitInfo::from_error_body(&body)
664        {
665            if info.message.is_none() {
666                info.message = Some(body.chars().take(200).collect());
667            }
668            return SgrError::RateLimit { status, info };
669        }
670        SgrError::Api { status, body }
671    }
672
673    /// Build error from HTTP status + body + headers, auto-detecting rate limits.
674    pub fn from_response_parts(
675        status: u16,
676        body: String,
677        headers: &reqwest::header::HeaderMap,
678    ) -> Self {
679        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
680            let mut info = RateLimitInfo::from_error_body(&body)
681                .or_else(|| RateLimitInfo::from_headers(headers))
682                .unwrap_or(RateLimitInfo {
683                    requests_remaining: None,
684                    tokens_remaining: None,
685                    retry_after_secs: None,
686                    resets_at: None,
687                    error_type: Some("rate_limit".into()),
688                    message: Some(body.chars().take(200).collect()),
689                });
690            // Merge header info into body info
691            if let Some(header_info) = RateLimitInfo::from_headers(headers) {
692                if info.requests_remaining.is_none() {
693                    info.requests_remaining = header_info.requests_remaining;
694                }
695                if info.tokens_remaining.is_none() {
696                    info.tokens_remaining = header_info.tokens_remaining;
697                }
698            }
699            return SgrError::RateLimit { status, info };
700        }
701        SgrError::Api { status, body }
702    }
703
704    /// Is this a rate limit error?
705    pub fn is_rate_limit(&self) -> bool {
706        matches!(self, SgrError::RateLimit { .. })
707    }
708
709    /// Get rate limit info if this is a rate limit error.
710    pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
711        match self {
712            SgrError::RateLimit { info, .. } => Some(info),
713            _ => None,
714        }
715    }
716}
717
718#[cfg(test)]
719mod tests {
720    use super::*;
721
722    #[test]
723    fn parse_codex_rate_limit_error() {
724        let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
725        let err = SgrError::from_api_response(429, body.to_string());
726        assert!(err.is_rate_limit());
727        let info = err.rate_limit_info().unwrap();
728        assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
729        assert_eq!(info.retry_after_secs, Some(442787));
730        assert_eq!(info.resets_at, Some(1773534007));
731        assert_eq!(info.reset_display(), "5d 2h");
732    }
733
734    #[test]
735    fn parse_openai_rate_limit_error() {
736        let body =
737            r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
738        let err = SgrError::from_api_response(429, body.to_string());
739        assert!(err.is_rate_limit());
740        let info = err.rate_limit_info().unwrap();
741        assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
742    }
743
744    #[test]
745    fn non_rate_limit_stays_api_error() {
746        let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
747        let err = SgrError::from_api_response(400, body.to_string());
748        assert!(!err.is_rate_limit());
749        assert!(matches!(err, SgrError::Api { status: 400, .. }));
750    }
751
752    #[test]
753    fn status_line_with_all_fields() {
754        let info = RateLimitInfo {
755            requests_remaining: Some(5),
756            tokens_remaining: Some(10000),
757            retry_after_secs: Some(3600),
758            resets_at: None,
759            error_type: None,
760            message: None,
761        };
762        assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
763    }
764
765    #[test]
766    fn status_line_fallback_to_message() {
767        let info = RateLimitInfo {
768            requests_remaining: None,
769            tokens_remaining: None,
770            retry_after_secs: None,
771            resets_at: None,
772            error_type: None,
773            message: Some("custom message".into()),
774        };
775        assert_eq!(info.status_line(), "custom message");
776    }
777
778    #[test]
779    fn reset_display_formats() {
780        let make = |secs| RateLimitInfo {
781            requests_remaining: None,
782            tokens_remaining: None,
783            retry_after_secs: Some(secs),
784            resets_at: None,
785            error_type: None,
786            message: None,
787        };
788        assert_eq!(make(90).reset_display(), "1m");
789        assert_eq!(make(3661).reset_display(), "1h 1m");
790        assert_eq!(make(90000).reset_display(), "1d 1h");
791    }
792}