Skip to main content

sgr_agent/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Inline image data for multimodal messages.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6    /// Base64-encoded image data.
7    pub data: String,
8    /// MIME type (e.g. "image/jpeg", "image/png").
9    pub mime_type: String,
10}
11
12/// A chat message in the conversation history.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15    pub role: Role,
16    pub content: String,
17    /// Tool call results (only for Role::Tool).
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub tool_call_id: Option<String>,
20    /// Tool calls made by the assistant (only for Role::Assistant with function calling).
21    /// Gemini API requires model turns to include functionCall parts.
22    #[serde(default, skip_serializing_if = "Vec::is_empty")]
23    pub tool_calls: Vec<ToolCall>,
24    /// Inline images (for multimodal VLM input).
25    #[serde(default, skip_serializing_if = "Vec::is_empty")]
26    pub images: Vec<ImagePart>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32    System,
33    User,
34    Assistant,
35    Tool,
36}
37
38impl Message {
39    pub fn system(content: impl Into<String>) -> Self {
40        Self {
41            role: Role::System,
42            content: content.into(),
43            tool_call_id: None,
44            tool_calls: vec![],
45            images: vec![],
46        }
47    }
48    pub fn user(content: impl Into<String>) -> Self {
49        Self {
50            role: Role::User,
51            content: content.into(),
52            tool_call_id: None,
53            tool_calls: vec![],
54            images: vec![],
55        }
56    }
57    pub fn assistant(content: impl Into<String>) -> Self {
58        Self {
59            role: Role::Assistant,
60            content: content.into(),
61            tool_call_id: None,
62            tool_calls: vec![],
63            images: vec![],
64        }
65    }
66    /// Create an assistant message that includes function calls (for Gemini FC protocol).
67    pub fn assistant_with_tool_calls(
68        content: impl Into<String>,
69        tool_calls: Vec<ToolCall>,
70    ) -> Self {
71        Self {
72            role: Role::Assistant,
73            content: content.into(),
74            tool_call_id: None,
75            tool_calls,
76            images: vec![],
77        }
78    }
79    pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
80        Self {
81            role: Role::Tool,
82            content: content.into(),
83            tool_call_id: Some(call_id.into()),
84            tool_calls: vec![],
85            images: vec![],
86        }
87    }
88    /// Tool result with inline images (for VLM — Gemini sees the images).
89    pub fn tool_with_images(
90        call_id: impl Into<String>,
91        content: impl Into<String>,
92        images: Vec<ImagePart>,
93    ) -> Self {
94        Self {
95            role: Role::Tool,
96            content: content.into(),
97            tool_call_id: Some(call_id.into()),
98            tool_calls: vec![],
99            images,
100        }
101    }
102    /// User message with inline images.
103    pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
104        Self {
105            role: Role::User,
106            content: content.into(),
107            tool_call_id: None,
108            tool_calls: vec![],
109            images,
110        }
111    }
112}
113
114/// A tool call returned by the model.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCall {
117    /// Unique ID for matching with tool results.
118    pub id: String,
119    /// Tool/function name.
120    pub name: String,
121    /// JSON-encoded arguments.
122    pub arguments: serde_json::Value,
123}
124
125/// Response from an SGR call — structured output + optional tool calls.
126#[derive(Debug, Clone)]
127pub struct SgrResponse<T> {
128    /// Parsed structured output (SGR envelope).
129    /// `None` if the model only returned tool calls without structured content.
130    pub output: Option<T>,
131    /// Tool calls the model wants to execute.
132    pub tool_calls: Vec<ToolCall>,
133    /// Raw text (for streaming / debugging).
134    pub raw_text: String,
135    /// Token usage.
136    pub usage: Option<Usage>,
137    /// Rate limit info from response headers (if provider sends them).
138    pub rate_limit: Option<RateLimitInfo>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
142pub struct Usage {
143    pub prompt_tokens: u32,
144    pub completion_tokens: u32,
145    pub total_tokens: u32,
146}
147
148/// Rate limit info extracted from response headers and/or error body.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RateLimitInfo {
151    /// Requests remaining in current window.
152    pub requests_remaining: Option<u32>,
153    /// Tokens remaining in current window.
154    pub tokens_remaining: Option<u32>,
155    /// Seconds until limit resets.
156    pub retry_after_secs: Option<u64>,
157    /// Unix timestamp when limit resets.
158    pub resets_at: Option<u64>,
159    /// Provider error type (e.g. "usage_limit_reached", "rate_limit_exceeded").
160    pub error_type: Option<String>,
161    /// Human-readable message from provider.
162    pub message: Option<String>,
163}
164
165impl RateLimitInfo {
166    /// Parse from HTTP response headers (OpenAI/Gemini/OpenRouter standard).
167    pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
168        let get_u32 =
169            |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
170        let get_u64 =
171            |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
172
173        let requests_remaining = get_u32("x-ratelimit-remaining-requests");
174        let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
175        let retry_after_secs =
176            get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
177        let resets_at = get_u64("x-ratelimit-reset-tokens");
178
179        if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
180        {
181            Some(Self {
182                requests_remaining,
183                tokens_remaining,
184                retry_after_secs,
185                resets_at,
186                error_type: None,
187                message: None,
188            })
189        } else {
190            None
191        }
192    }
193
194    /// Parse from JSON error body (OpenAI, Codex, Gemini error responses).
195    pub fn from_error_body(body: &str) -> Option<Self> {
196        let json: serde_json::Value = serde_json::from_str(body).ok()?;
197        let err = json.get("error")?;
198
199        let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
200        let message = err
201            .get("message")
202            .and_then(|v| v.as_str())
203            .map(String::from);
204        let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
205        let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
206
207        Some(Self {
208            requests_remaining: None,
209            tokens_remaining: None,
210            retry_after_secs,
211            resets_at,
212            error_type,
213            message,
214        })
215    }
216
217    /// Human-readable description of when limit resets.
218    pub fn reset_display(&self) -> String {
219        if let Some(secs) = self.retry_after_secs {
220            let hours = secs / 3600;
221            let mins = (secs % 3600) / 60;
222            if hours >= 24 {
223                format!("{}d {}h", hours / 24, hours % 24)
224            } else if hours > 0 {
225                format!("{}h {}m", hours, mins)
226            } else {
227                format!("{}m", mins)
228            }
229        } else {
230            "unknown".into()
231        }
232    }
233
234    /// One-line status for status bar.
235    pub fn status_line(&self) -> String {
236        let mut parts = Vec::new();
237        if let Some(r) = self.requests_remaining {
238            parts.push(format!("req:{}", r));
239        }
240        if let Some(t) = self.tokens_remaining {
241            parts.push(format!("tok:{}", t));
242        }
243        if self.retry_after_secs.is_some() {
244            parts.push(format!("reset:{}", self.reset_display()));
245        }
246        if parts.is_empty() {
247            self.message
248                .clone()
249                .unwrap_or_else(|| "rate limited".into())
250        } else {
251            parts.join(" | ")
252        }
253    }
254}
255
256/// LLM provider configuration — single config for any provider.
257///
258/// Two optional fields control routing:
259/// - `api_key`: None → auto from env vars (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
260/// - `base_url`: None → auto-detect provider from model name; Some → custom endpoint
261///
262/// ```no_run
263/// use sgr_agent::LlmConfig;
264///
265/// let c = LlmConfig::auto("gpt-4o");                                          // env vars
266/// let c = LlmConfig::with_key("sk-...", "claude-3-haiku");                    // explicit key
267/// let c = LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"); // custom
268/// let c = LlmConfig::auto("gpt-4o").temperature(0.9).max_tokens(2048);        // builder
269/// ```
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LlmConfig {
272    pub model: String,
273    #[serde(default, skip_serializing_if = "Option::is_none")]
274    pub api_key: Option<String>,
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub base_url: Option<String>,
277    #[serde(default = "default_temperature")]
278    pub temp: f64,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub max_tokens: Option<u32>,
281}
282
283fn default_temperature() -> f64 {
284    0.7
285}
286
287impl LlmConfig {
288    /// Auto-detect provider from model name, use env vars for auth.
289    pub fn auto(model: impl Into<String>) -> Self {
290        Self {
291            model: model.into(),
292            api_key: None,
293            base_url: None,
294            temp: default_temperature(),
295            max_tokens: None,
296        }
297    }
298
299    /// Explicit API key, auto-detect provider from model name.
300    pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
301        Self {
302            model: model.into(),
303            api_key: Some(api_key.into()),
304            base_url: None,
305            temp: default_temperature(),
306            max_tokens: None,
307        }
308    }
309
310    /// Custom OpenAI-compatible endpoint (OpenRouter, Ollama, LiteLLM, etc.).
311    pub fn endpoint(
312        api_key: impl Into<String>,
313        base_url: impl Into<String>,
314        model: impl Into<String>,
315    ) -> Self {
316        Self {
317            model: model.into(),
318            api_key: Some(api_key.into()),
319            base_url: Some(base_url.into()),
320            temp: default_temperature(),
321            max_tokens: None,
322        }
323    }
324
325    /// Set temperature.
326    pub fn temperature(mut self, t: f64) -> Self {
327        self.temp = t;
328        self
329    }
330
331    /// Set max output tokens.
332    pub fn max_tokens(mut self, m: u32) -> Self {
333        self.max_tokens = Some(m);
334        self
335    }
336}
337
338/// Legacy provider configuration (used by OpenAIClient/GeminiClient).
339#[derive(Debug, Clone)]
340pub struct ProviderConfig {
341    pub api_key: String,
342    pub model: String,
343    pub base_url: Option<String>,
344    pub project_id: Option<String>,
345    pub location: Option<String>,
346    pub temperature: f32,
347    pub max_tokens: Option<u32>,
348}
349
350impl ProviderConfig {
351    pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
352        Self {
353            api_key: api_key.into(),
354            model: model.into(),
355            base_url: None,
356            project_id: None,
357            location: None,
358            temperature: 0.3,
359            max_tokens: None,
360        }
361    }
362
363    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
364        Self {
365            api_key: api_key.into(),
366            model: model.into(),
367            base_url: None,
368            project_id: None,
369            location: None,
370            temperature: 0.3,
371            max_tokens: None,
372        }
373    }
374
375    pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
376        Self {
377            api_key: api_key.into(),
378            model: model.into(),
379            base_url: Some("https://openrouter.ai/api/v1".into()),
380            project_id: None,
381            location: None,
382            temperature: 0.3,
383            max_tokens: None,
384        }
385    }
386
387    pub fn vertex(
388        access_token: impl Into<String>,
389        project_id: impl Into<String>,
390        model: impl Into<String>,
391    ) -> Self {
392        Self {
393            api_key: access_token.into(),
394            model: model.into(),
395            base_url: None,
396            project_id: Some(project_id.into()),
397            location: Some("global".to_string()),
398            temperature: 0.3,
399            max_tokens: None,
400        }
401    }
402
403    pub fn ollama(model: impl Into<String>) -> Self {
404        Self {
405            api_key: String::new(),
406            model: model.into(),
407            base_url: Some("http://localhost:11434/v1".into()),
408            project_id: None,
409            location: None,
410            temperature: 0.3,
411            max_tokens: None,
412        }
413    }
414}
415
416/// Errors from SGR calls.
417#[derive(Debug, thiserror::Error)]
418pub enum SgrError {
419    #[error("HTTP error: {0}")]
420    Http(#[from] reqwest::Error),
421    #[error("API error {status}: {body}")]
422    Api { status: u16, body: String },
423    #[error("Rate limit: {}", info.status_line())]
424    RateLimit { status: u16, info: RateLimitInfo },
425    #[error("JSON parse error: {0}")]
426    Json(#[from] serde_json::Error),
427    #[error("Schema error: {0}")]
428    Schema(String),
429    #[error("No content in response")]
430    EmptyResponse,
431}
432
433impl SgrError {
434    /// Build error from HTTP status + body, auto-detecting rate limits.
435    pub fn from_api_response(status: u16, body: String) -> Self {
436        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
437            if let Some(mut info) = RateLimitInfo::from_error_body(&body) {
438                if info.message.is_none() {
439                    info.message = Some(body.chars().take(200).collect());
440                }
441                return SgrError::RateLimit { status, info };
442            }
443        }
444        SgrError::Api { status, body }
445    }
446
447    /// Build error from HTTP status + body + headers, auto-detecting rate limits.
448    pub fn from_response_parts(
449        status: u16,
450        body: String,
451        headers: &reqwest::header::HeaderMap,
452    ) -> Self {
453        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
454            let mut info = RateLimitInfo::from_error_body(&body)
455                .or_else(|| RateLimitInfo::from_headers(headers))
456                .unwrap_or(RateLimitInfo {
457                    requests_remaining: None,
458                    tokens_remaining: None,
459                    retry_after_secs: None,
460                    resets_at: None,
461                    error_type: Some("rate_limit".into()),
462                    message: Some(body.chars().take(200).collect()),
463                });
464            // Merge header info into body info
465            if let Some(header_info) = RateLimitInfo::from_headers(headers) {
466                if info.requests_remaining.is_none() {
467                    info.requests_remaining = header_info.requests_remaining;
468                }
469                if info.tokens_remaining.is_none() {
470                    info.tokens_remaining = header_info.tokens_remaining;
471                }
472            }
473            return SgrError::RateLimit { status, info };
474        }
475        SgrError::Api { status, body }
476    }
477
478    /// Is this a rate limit error?
479    pub fn is_rate_limit(&self) -> bool {
480        matches!(self, SgrError::RateLimit { .. })
481    }
482
483    /// Get rate limit info if this is a rate limit error.
484    pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
485        match self {
486            SgrError::RateLimit { info, .. } => Some(info),
487            _ => None,
488        }
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn parse_codex_rate_limit_error() {
498        let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
499        let err = SgrError::from_api_response(429, body.to_string());
500        assert!(err.is_rate_limit());
501        let info = err.rate_limit_info().unwrap();
502        assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
503        assert_eq!(info.retry_after_secs, Some(442787));
504        assert_eq!(info.resets_at, Some(1773534007));
505        assert_eq!(info.reset_display(), "5d 2h");
506    }
507
508    #[test]
509    fn parse_openai_rate_limit_error() {
510        let body =
511            r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
512        let err = SgrError::from_api_response(429, body.to_string());
513        assert!(err.is_rate_limit());
514        let info = err.rate_limit_info().unwrap();
515        assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
516    }
517
518    #[test]
519    fn non_rate_limit_stays_api_error() {
520        let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
521        let err = SgrError::from_api_response(400, body.to_string());
522        assert!(!err.is_rate_limit());
523        assert!(matches!(err, SgrError::Api { status: 400, .. }));
524    }
525
526    #[test]
527    fn status_line_with_all_fields() {
528        let info = RateLimitInfo {
529            requests_remaining: Some(5),
530            tokens_remaining: Some(10000),
531            retry_after_secs: Some(3600),
532            resets_at: None,
533            error_type: None,
534            message: None,
535        };
536        assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
537    }
538
539    #[test]
540    fn status_line_fallback_to_message() {
541        let info = RateLimitInfo {
542            requests_remaining: None,
543            tokens_remaining: None,
544            retry_after_secs: None,
545            resets_at: None,
546            error_type: None,
547            message: Some("custom message".into()),
548        };
549        assert_eq!(info.status_line(), "custom message");
550    }
551
552    #[test]
553    fn reset_display_formats() {
554        let make = |secs| RateLimitInfo {
555            requests_remaining: None,
556            tokens_remaining: None,
557            retry_after_secs: Some(secs),
558            resets_at: None,
559            error_type: None,
560            message: None,
561        };
562        assert_eq!(make(90).reset_display(), "1m");
563        assert_eq!(make(3661).reset_display(), "1h 1m");
564        assert_eq!(make(90000).reset_display(), "1d 1h");
565    }
566}