Skip to main content

sgr_agent/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// A chat message in the conversation history.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct Message {
6    pub role: Role,
7    pub content: String,
8    /// Tool call results (only for Role::Tool).
9    #[serde(skip_serializing_if = "Option::is_none")]
10    pub tool_call_id: Option<String>,
11    /// Tool calls made by the assistant (only for Role::Assistant with function calling).
12    /// Gemini API requires model turns to include functionCall parts.
13    #[serde(default, skip_serializing_if = "Vec::is_empty")]
14    pub tool_calls: Vec<ToolCall>,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20    System,
21    User,
22    Assistant,
23    Tool,
24}
25
26impl Message {
27    pub fn system(content: impl Into<String>) -> Self {
28        Self {
29            role: Role::System,
30            content: content.into(),
31            tool_call_id: None,
32            tool_calls: vec![],
33        }
34    }
35    pub fn user(content: impl Into<String>) -> Self {
36        Self {
37            role: Role::User,
38            content: content.into(),
39            tool_call_id: None,
40            tool_calls: vec![],
41        }
42    }
43    pub fn assistant(content: impl Into<String>) -> Self {
44        Self {
45            role: Role::Assistant,
46            content: content.into(),
47            tool_call_id: None,
48            tool_calls: vec![],
49        }
50    }
51    /// Create an assistant message that includes function calls (for Gemini FC protocol).
52    pub fn assistant_with_tool_calls(
53        content: impl Into<String>,
54        tool_calls: Vec<ToolCall>,
55    ) -> Self {
56        Self {
57            role: Role::Assistant,
58            content: content.into(),
59            tool_call_id: None,
60            tool_calls,
61        }
62    }
63    pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
64        Self {
65            role: Role::Tool,
66            content: content.into(),
67            tool_call_id: Some(call_id.into()),
68            tool_calls: vec![],
69        }
70    }
71}
72
73/// A tool call returned by the model.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolCall {
76    /// Unique ID for matching with tool results.
77    pub id: String,
78    /// Tool/function name.
79    pub name: String,
80    /// JSON-encoded arguments.
81    pub arguments: serde_json::Value,
82}
83
84/// Response from an SGR call — structured output + optional tool calls.
85#[derive(Debug, Clone)]
86pub struct SgrResponse<T> {
87    /// Parsed structured output (SGR envelope).
88    /// `None` if the model only returned tool calls without structured content.
89    pub output: Option<T>,
90    /// Tool calls the model wants to execute.
91    pub tool_calls: Vec<ToolCall>,
92    /// Raw text (for streaming / debugging).
93    pub raw_text: String,
94    /// Token usage.
95    pub usage: Option<Usage>,
96    /// Rate limit info from response headers (if provider sends them).
97    pub rate_limit: Option<RateLimitInfo>,
98}
99
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct Usage {
102    pub prompt_tokens: u32,
103    pub completion_tokens: u32,
104    pub total_tokens: u32,
105}
106
107/// Rate limit info extracted from response headers and/or error body.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct RateLimitInfo {
110    /// Requests remaining in current window.
111    pub requests_remaining: Option<u32>,
112    /// Tokens remaining in current window.
113    pub tokens_remaining: Option<u32>,
114    /// Seconds until limit resets.
115    pub retry_after_secs: Option<u64>,
116    /// Unix timestamp when limit resets.
117    pub resets_at: Option<u64>,
118    /// Provider error type (e.g. "usage_limit_reached", "rate_limit_exceeded").
119    pub error_type: Option<String>,
120    /// Human-readable message from provider.
121    pub message: Option<String>,
122}
123
124impl RateLimitInfo {
125    /// Parse from HTTP response headers (OpenAI/Gemini/OpenRouter standard).
126    pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
127        let get_u32 =
128            |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
129        let get_u64 =
130            |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
131
132        let requests_remaining = get_u32("x-ratelimit-remaining-requests");
133        let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
134        let retry_after_secs =
135            get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
136        let resets_at = get_u64("x-ratelimit-reset-tokens");
137
138        if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
139        {
140            Some(Self {
141                requests_remaining,
142                tokens_remaining,
143                retry_after_secs,
144                resets_at,
145                error_type: None,
146                message: None,
147            })
148        } else {
149            None
150        }
151    }
152
153    /// Parse from JSON error body (OpenAI, Codex, Gemini error responses).
154    pub fn from_error_body(body: &str) -> Option<Self> {
155        let json: serde_json::Value = serde_json::from_str(body).ok()?;
156        let err = json.get("error")?;
157
158        let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
159        let message = err
160            .get("message")
161            .and_then(|v| v.as_str())
162            .map(String::from);
163        let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
164        let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
165
166        Some(Self {
167            requests_remaining: None,
168            tokens_remaining: None,
169            retry_after_secs,
170            resets_at,
171            error_type,
172            message,
173        })
174    }
175
176    /// Human-readable description of when limit resets.
177    pub fn reset_display(&self) -> String {
178        if let Some(secs) = self.retry_after_secs {
179            let hours = secs / 3600;
180            let mins = (secs % 3600) / 60;
181            if hours >= 24 {
182                format!("{}d {}h", hours / 24, hours % 24)
183            } else if hours > 0 {
184                format!("{}h {}m", hours, mins)
185            } else {
186                format!("{}m", mins)
187            }
188        } else {
189            "unknown".into()
190        }
191    }
192
193    /// One-line status for status bar.
194    pub fn status_line(&self) -> String {
195        let mut parts = Vec::new();
196        if let Some(r) = self.requests_remaining {
197            parts.push(format!("req:{}", r));
198        }
199        if let Some(t) = self.tokens_remaining {
200            parts.push(format!("tok:{}", t));
201        }
202        if self.retry_after_secs.is_some() {
203            parts.push(format!("reset:{}", self.reset_display()));
204        }
205        if parts.is_empty() {
206            self.message
207                .clone()
208                .unwrap_or_else(|| "rate limited".into())
209        } else {
210            parts.join(" | ")
211        }
212    }
213}
214
215/// LLM provider configuration — single config for any provider.
216///
217/// Two optional fields control routing:
218/// - `api_key`: None → auto from env vars (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
219/// - `base_url`: None → auto-detect provider from model name; Some → custom endpoint
220///
221/// ```no_run
222/// use sgr_agent::LlmConfig;
223///
224/// let c = LlmConfig::auto("gpt-4o");                                          // env vars
225/// let c = LlmConfig::with_key("sk-...", "claude-3-haiku");                    // explicit key
226/// let c = LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"); // custom
227/// let c = LlmConfig::auto("gpt-4o").temperature(0.9).max_tokens(2048);        // builder
228/// ```
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct LlmConfig {
231    pub model: String,
232    #[serde(default, skip_serializing_if = "Option::is_none")]
233    pub api_key: Option<String>,
234    #[serde(default, skip_serializing_if = "Option::is_none")]
235    pub base_url: Option<String>,
236    #[serde(default = "default_temperature")]
237    pub temp: f64,
238    #[serde(default, skip_serializing_if = "Option::is_none")]
239    pub max_tokens: Option<u32>,
240}
241
242fn default_temperature() -> f64 {
243    0.7
244}
245
246impl LlmConfig {
247    /// Auto-detect provider from model name, use env vars for auth.
248    pub fn auto(model: impl Into<String>) -> Self {
249        Self {
250            model: model.into(),
251            api_key: None,
252            base_url: None,
253            temp: default_temperature(),
254            max_tokens: None,
255        }
256    }
257
258    /// Explicit API key, auto-detect provider from model name.
259    pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
260        Self {
261            model: model.into(),
262            api_key: Some(api_key.into()),
263            base_url: None,
264            temp: default_temperature(),
265            max_tokens: None,
266        }
267    }
268
269    /// Custom OpenAI-compatible endpoint (OpenRouter, Ollama, LiteLLM, etc.).
270    pub fn endpoint(
271        api_key: impl Into<String>,
272        base_url: impl Into<String>,
273        model: impl Into<String>,
274    ) -> Self {
275        Self {
276            model: model.into(),
277            api_key: Some(api_key.into()),
278            base_url: Some(base_url.into()),
279            temp: default_temperature(),
280            max_tokens: None,
281        }
282    }
283
284    /// Set temperature.
285    pub fn temperature(mut self, t: f64) -> Self {
286        self.temp = t;
287        self
288    }
289
290    /// Set max output tokens.
291    pub fn max_tokens(mut self, m: u32) -> Self {
292        self.max_tokens = Some(m);
293        self
294    }
295}
296
297/// Legacy provider configuration (used by OpenAIClient/GeminiClient).
298#[derive(Debug, Clone)]
299pub struct ProviderConfig {
300    pub api_key: String,
301    pub model: String,
302    pub base_url: Option<String>,
303    pub project_id: Option<String>,
304    pub location: Option<String>,
305    pub temperature: f32,
306    pub max_tokens: Option<u32>,
307}
308
309impl ProviderConfig {
310    pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
311        Self {
312            api_key: api_key.into(),
313            model: model.into(),
314            base_url: None,
315            project_id: None,
316            location: None,
317            temperature: 0.3,
318            max_tokens: None,
319        }
320    }
321
322    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
323        Self {
324            api_key: api_key.into(),
325            model: model.into(),
326            base_url: None,
327            project_id: None,
328            location: None,
329            temperature: 0.3,
330            max_tokens: None,
331        }
332    }
333
334    pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
335        Self {
336            api_key: api_key.into(),
337            model: model.into(),
338            base_url: Some("https://openrouter.ai/api/v1".into()),
339            project_id: None,
340            location: None,
341            temperature: 0.3,
342            max_tokens: None,
343        }
344    }
345
346    pub fn vertex(
347        access_token: impl Into<String>,
348        project_id: impl Into<String>,
349        model: impl Into<String>,
350    ) -> Self {
351        Self {
352            api_key: access_token.into(),
353            model: model.into(),
354            base_url: None,
355            project_id: Some(project_id.into()),
356            location: Some("global".to_string()),
357            temperature: 0.3,
358            max_tokens: None,
359        }
360    }
361
362    pub fn ollama(model: impl Into<String>) -> Self {
363        Self {
364            api_key: String::new(),
365            model: model.into(),
366            base_url: Some("http://localhost:11434/v1".into()),
367            project_id: None,
368            location: None,
369            temperature: 0.3,
370            max_tokens: None,
371        }
372    }
373}
374
375/// Errors from SGR calls.
376#[derive(Debug, thiserror::Error)]
377pub enum SgrError {
378    #[error("HTTP error: {0}")]
379    Http(#[from] reqwest::Error),
380    #[error("API error {status}: {body}")]
381    Api { status: u16, body: String },
382    #[error("Rate limit: {}", info.status_line())]
383    RateLimit { status: u16, info: RateLimitInfo },
384    #[error("JSON parse error: {0}")]
385    Json(#[from] serde_json::Error),
386    #[error("Schema error: {0}")]
387    Schema(String),
388    #[error("No content in response")]
389    EmptyResponse,
390}
391
392impl SgrError {
393    /// Build error from HTTP status + body, auto-detecting rate limits.
394    pub fn from_api_response(status: u16, body: String) -> Self {
395        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
396            if let Some(mut info) = RateLimitInfo::from_error_body(&body) {
397                if info.message.is_none() {
398                    info.message = Some(body.chars().take(200).collect());
399                }
400                return SgrError::RateLimit { status, info };
401            }
402        }
403        SgrError::Api { status, body }
404    }
405
406    /// Build error from HTTP status + body + headers, auto-detecting rate limits.
407    pub fn from_response_parts(
408        status: u16,
409        body: String,
410        headers: &reqwest::header::HeaderMap,
411    ) -> Self {
412        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
413            let mut info = RateLimitInfo::from_error_body(&body)
414                .or_else(|| RateLimitInfo::from_headers(headers))
415                .unwrap_or(RateLimitInfo {
416                    requests_remaining: None,
417                    tokens_remaining: None,
418                    retry_after_secs: None,
419                    resets_at: None,
420                    error_type: Some("rate_limit".into()),
421                    message: Some(body.chars().take(200).collect()),
422                });
423            // Merge header info into body info
424            if let Some(header_info) = RateLimitInfo::from_headers(headers) {
425                if info.requests_remaining.is_none() {
426                    info.requests_remaining = header_info.requests_remaining;
427                }
428                if info.tokens_remaining.is_none() {
429                    info.tokens_remaining = header_info.tokens_remaining;
430                }
431            }
432            return SgrError::RateLimit { status, info };
433        }
434        SgrError::Api { status, body }
435    }
436
437    /// Is this a rate limit error?
438    pub fn is_rate_limit(&self) -> bool {
439        matches!(self, SgrError::RateLimit { .. })
440    }
441
442    /// Get rate limit info if this is a rate limit error.
443    pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
444        match self {
445            SgrError::RateLimit { info, .. } => Some(info),
446            _ => None,
447        }
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn parse_codex_rate_limit_error() {
457        let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
458        let err = SgrError::from_api_response(429, body.to_string());
459        assert!(err.is_rate_limit());
460        let info = err.rate_limit_info().unwrap();
461        assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
462        assert_eq!(info.retry_after_secs, Some(442787));
463        assert_eq!(info.resets_at, Some(1773534007));
464        assert_eq!(info.reset_display(), "5d 2h");
465    }
466
467    #[test]
468    fn parse_openai_rate_limit_error() {
469        let body =
470            r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
471        let err = SgrError::from_api_response(429, body.to_string());
472        assert!(err.is_rate_limit());
473        let info = err.rate_limit_info().unwrap();
474        assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
475    }
476
477    #[test]
478    fn non_rate_limit_stays_api_error() {
479        let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
480        let err = SgrError::from_api_response(400, body.to_string());
481        assert!(!err.is_rate_limit());
482        assert!(matches!(err, SgrError::Api { status: 400, .. }));
483    }
484
485    #[test]
486    fn status_line_with_all_fields() {
487        let info = RateLimitInfo {
488            requests_remaining: Some(5),
489            tokens_remaining: Some(10000),
490            retry_after_secs: Some(3600),
491            resets_at: None,
492            error_type: None,
493            message: None,
494        };
495        assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
496    }
497
498    #[test]
499    fn status_line_fallback_to_message() {
500        let info = RateLimitInfo {
501            requests_remaining: None,
502            tokens_remaining: None,
503            retry_after_secs: None,
504            resets_at: None,
505            error_type: None,
506            message: Some("custom message".into()),
507        };
508        assert_eq!(info.status_line(), "custom message");
509    }
510
511    #[test]
512    fn reset_display_formats() {
513        let make = |secs| RateLimitInfo {
514            requests_remaining: None,
515            tokens_remaining: None,
516            retry_after_secs: Some(secs),
517            resets_at: None,
518            error_type: None,
519            message: None,
520        };
521        assert_eq!(make(90).reset_display(), "1m");
522        assert_eq!(make(3661).reset_display(), "1h 1m");
523        assert_eq!(make(90000).reset_display(), "1d 1h");
524    }
525}