Skip to main content

sgr_agent/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Inline image data for multimodal messages.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6    /// Base64-encoded image data.
7    pub data: String,
8    /// MIME type (e.g. "image/jpeg", "image/png").
9    pub mime_type: String,
10}
11
12/// A chat message in the conversation history.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15    pub role: Role,
16    pub content: String,
17    /// Tool call results (only for Role::Tool).
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub tool_call_id: Option<String>,
20    /// Tool calls made by the assistant (only for Role::Assistant with function calling).
21    /// Gemini API requires model turns to include functionCall parts.
22    #[serde(default, skip_serializing_if = "Vec::is_empty")]
23    pub tool_calls: Vec<ToolCall>,
24    /// Inline images (for multimodal VLM input).
25    #[serde(default, skip_serializing_if = "Vec::is_empty")]
26    pub images: Vec<ImagePart>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32    System,
33    User,
34    Assistant,
35    Tool,
36}
37
38impl Message {
39    pub fn system(content: impl Into<String>) -> Self {
40        Self {
41            role: Role::System,
42            content: content.into(),
43            tool_call_id: None,
44            tool_calls: vec![],
45            images: vec![],
46        }
47    }
48    pub fn user(content: impl Into<String>) -> Self {
49        Self {
50            role: Role::User,
51            content: content.into(),
52            tool_call_id: None,
53            tool_calls: vec![],
54            images: vec![],
55        }
56    }
57    pub fn assistant(content: impl Into<String>) -> Self {
58        Self {
59            role: Role::Assistant,
60            content: content.into(),
61            tool_call_id: None,
62            tool_calls: vec![],
63            images: vec![],
64        }
65    }
66    /// Create an assistant message that includes function calls (for Gemini FC protocol).
67    pub fn assistant_with_tool_calls(
68        content: impl Into<String>,
69        tool_calls: Vec<ToolCall>,
70    ) -> Self {
71        Self {
72            role: Role::Assistant,
73            content: content.into(),
74            tool_call_id: None,
75            tool_calls,
76            images: vec![],
77        }
78    }
79    pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
80        Self {
81            role: Role::Tool,
82            content: content.into(),
83            tool_call_id: Some(call_id.into()),
84            tool_calls: vec![],
85            images: vec![],
86        }
87    }
88    /// Tool result with inline images (for VLM — Gemini sees the images).
89    pub fn tool_with_images(
90        call_id: impl Into<String>,
91        content: impl Into<String>,
92        images: Vec<ImagePart>,
93    ) -> Self {
94        Self {
95            role: Role::Tool,
96            content: content.into(),
97            tool_call_id: Some(call_id.into()),
98            tool_calls: vec![],
99            images,
100        }
101    }
102    /// User message with inline images.
103    pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
104        Self {
105            role: Role::User,
106            content: content.into(),
107            tool_call_id: None,
108            tool_calls: vec![],
109            images,
110        }
111    }
112}
113
114/// A tool call returned by the model.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCall {
117    /// Unique ID for matching with tool results.
118    pub id: String,
119    /// Tool/function name.
120    pub name: String,
121    /// JSON-encoded arguments.
122    pub arguments: serde_json::Value,
123}
124
125/// Response from an SGR call — structured output + optional tool calls.
126#[derive(Debug, Clone)]
127pub struct SgrResponse<T> {
128    /// Parsed structured output (SGR envelope).
129    /// `None` if the model only returned tool calls without structured content.
130    pub output: Option<T>,
131    /// Tool calls the model wants to execute.
132    pub tool_calls: Vec<ToolCall>,
133    /// Raw text (for streaming / debugging).
134    pub raw_text: String,
135    /// Token usage.
136    pub usage: Option<Usage>,
137    /// Rate limit info from response headers (if provider sends them).
138    pub rate_limit: Option<RateLimitInfo>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
142pub struct Usage {
143    pub prompt_tokens: u32,
144    pub completion_tokens: u32,
145    pub total_tokens: u32,
146}
147
148/// Rate limit info extracted from response headers and/or error body.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RateLimitInfo {
151    /// Requests remaining in current window.
152    pub requests_remaining: Option<u32>,
153    /// Tokens remaining in current window.
154    pub tokens_remaining: Option<u32>,
155    /// Seconds until limit resets.
156    pub retry_after_secs: Option<u64>,
157    /// Unix timestamp when limit resets.
158    pub resets_at: Option<u64>,
159    /// Provider error type (e.g. "usage_limit_reached", "rate_limit_exceeded").
160    pub error_type: Option<String>,
161    /// Human-readable message from provider.
162    pub message: Option<String>,
163}
164
165impl RateLimitInfo {
166    /// Parse from HTTP response headers (OpenAI/Gemini/OpenRouter standard).
167    pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
168        let get_u32 =
169            |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
170        let get_u64 =
171            |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
172
173        let requests_remaining = get_u32("x-ratelimit-remaining-requests");
174        let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
175        let retry_after_secs =
176            get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
177        let resets_at = get_u64("x-ratelimit-reset-tokens");
178
179        if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
180        {
181            Some(Self {
182                requests_remaining,
183                tokens_remaining,
184                retry_after_secs,
185                resets_at,
186                error_type: None,
187                message: None,
188            })
189        } else {
190            None
191        }
192    }
193
194    /// Parse from JSON error body (OpenAI, Codex, Gemini error responses).
195    pub fn from_error_body(body: &str) -> Option<Self> {
196        let json: serde_json::Value = serde_json::from_str(body).ok()?;
197        let err = json.get("error")?;
198
199        let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
200        let message = err
201            .get("message")
202            .and_then(|v| v.as_str())
203            .map(String::from);
204        let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
205        let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
206
207        Some(Self {
208            requests_remaining: None,
209            tokens_remaining: None,
210            retry_after_secs,
211            resets_at,
212            error_type,
213            message,
214        })
215    }
216
217    /// Human-readable description of when limit resets.
218    pub fn reset_display(&self) -> String {
219        if let Some(secs) = self.retry_after_secs {
220            let hours = secs / 3600;
221            let mins = (secs % 3600) / 60;
222            if hours >= 24 {
223                format!("{}d {}h", hours / 24, hours % 24)
224            } else if hours > 0 {
225                format!("{}h {}m", hours, mins)
226            } else {
227                format!("{}m", mins)
228            }
229        } else {
230            "unknown".into()
231        }
232    }
233
234    /// One-line status for status bar.
235    pub fn status_line(&self) -> String {
236        let mut parts = Vec::new();
237        if let Some(r) = self.requests_remaining {
238            parts.push(format!("req:{}", r));
239        }
240        if let Some(t) = self.tokens_remaining {
241            parts.push(format!("tok:{}", t));
242        }
243        if self.retry_after_secs.is_some() {
244            parts.push(format!("reset:{}", self.reset_display()));
245        }
246        if parts.is_empty() {
247            self.message
248                .clone()
249                .unwrap_or_else(|| "rate limited".into())
250        } else {
251            parts.join(" | ")
252        }
253    }
254}
255
256/// LLM provider configuration — single config for any provider.
257///
258/// Two optional fields control routing:
259/// - `api_key`: None → auto from env vars (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
260/// - `base_url`: None → auto-detect provider from model name; Some → custom endpoint
261///
262/// ```no_run
263/// use sgr_agent::LlmConfig;
264///
265/// let c = LlmConfig::auto("gpt-4o");                                          // env vars
266/// let c = LlmConfig::with_key("sk-...", "claude-3-haiku");                    // explicit key
267/// let c = LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"); // custom
268/// let c = LlmConfig::auto("gpt-4o").temperature(0.9).max_tokens(2048);        // builder
269/// ```
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LlmConfig {
272    pub model: String,
273    #[serde(default, skip_serializing_if = "Option::is_none")]
274    pub api_key: Option<String>,
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub base_url: Option<String>,
277    #[serde(default = "default_temperature")]
278    pub temp: f64,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub max_tokens: Option<u32>,
281    /// OpenAI prompt cache key — caches system prompt prefix server-side.
282    #[serde(default, skip_serializing_if = "Option::is_none")]
283    pub prompt_cache_key: Option<String>,
284    /// Vertex AI project ID (enables Vertex routing when set).
285    #[serde(default, skip_serializing_if = "Option::is_none")]
286    pub project_id: Option<String>,
287    /// Vertex AI location (default: "global").
288    #[serde(default, skip_serializing_if = "Option::is_none")]
289    pub location: Option<String>,
290    /// Force Chat Completions API instead of Responses API.
291    /// Needed for OpenAI-compatible endpoints that don't support /responses
292    /// (e.g. Cloudflare AI Gateway compat, OpenRouter, local models).
293    #[serde(default)]
294    pub use_chat_api: bool,
295}
296
297fn default_temperature() -> f64 {
298    0.7
299}
300
301impl Default for LlmConfig {
302    fn default() -> Self {
303        Self {
304            model: String::new(),
305            api_key: None,
306            base_url: None,
307            temp: default_temperature(),
308            max_tokens: None,
309            prompt_cache_key: None,
310            project_id: None,
311            location: None,
312            use_chat_api: false,
313        }
314    }
315}
316
317impl LlmConfig {
318    /// Auto-detect provider from model name, use env vars for auth.
319    pub fn auto(model: impl Into<String>) -> Self {
320        Self {
321            model: model.into(),
322            ..Default::default()
323        }
324    }
325
326    /// Explicit API key, auto-detect provider from model name.
327    pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
328        Self {
329            model: model.into(),
330            api_key: Some(api_key.into()),
331            ..Default::default()
332        }
333    }
334
335    /// Custom OpenAI-compatible endpoint (OpenRouter, Ollama, LiteLLM, etc.).
336    pub fn endpoint(
337        api_key: impl Into<String>,
338        base_url: impl Into<String>,
339        model: impl Into<String>,
340    ) -> Self {
341        Self {
342            model: model.into(),
343            api_key: Some(api_key.into()),
344            base_url: Some(base_url.into()),
345            ..Default::default()
346        }
347    }
348
349    /// Vertex AI — uses gcloud ADC for auth (no API key needed).
350    pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
351        Self {
352            model: model.into(),
353            project_id: Some(project_id.into()),
354            location: Some("global".into()),
355            ..Default::default()
356        }
357    }
358
359    /// Set Vertex AI location.
360    pub fn location(mut self, loc: impl Into<String>) -> Self {
361        self.location = Some(loc.into());
362        self
363    }
364
365    /// Set temperature.
366    pub fn temperature(mut self, t: f64) -> Self {
367        self.temp = t;
368        self
369    }
370
371    /// Set max output tokens.
372    pub fn max_tokens(mut self, m: u32) -> Self {
373        self.max_tokens = Some(m);
374        self
375    }
376
377    /// Set OpenAI prompt cache key for server-side system prompt caching.
378    pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
379        self.prompt_cache_key = Some(key.into());
380        self
381    }
382
383    /// Human-readable label for display.
384    pub fn label(&self) -> String {
385        if self.project_id.is_some() {
386            format!("Vertex ({})", self.model)
387        } else if self.base_url.is_some() {
388            format!("Custom ({})", self.model)
389        } else {
390            self.model.clone()
391        }
392    }
393
394    /// Infer a cheap/fast model for compaction based on the primary model.
395    pub fn compaction_model(&self) -> String {
396        if self.model.starts_with("gemini") {
397            "gemini-2.0-flash-lite".into()
398        } else if self.model.starts_with("gpt") {
399            "gpt-4o-mini".into()
400        } else if self.model.starts_with("claude") {
401            "claude-3-haiku-20240307".into()
402        } else {
403            // Unknown provider — use the same model
404            self.model.clone()
405        }
406    }
407
408    /// Create a compaction config — cheap model, low max_tokens.
409    pub fn for_compaction(&self) -> Self {
410        let mut cfg = self.clone();
411        cfg.model = self.compaction_model();
412        cfg.max_tokens = Some(2048);
413        cfg
414    }
415}
416
417/// Legacy provider configuration (used by OpenAIClient/GeminiClient).
418#[derive(Debug, Clone)]
419pub struct ProviderConfig {
420    pub api_key: String,
421    pub model: String,
422    pub base_url: Option<String>,
423    pub project_id: Option<String>,
424    pub location: Option<String>,
425    pub temperature: f32,
426    pub max_tokens: Option<u32>,
427}
428
429impl ProviderConfig {
430    pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
431        Self {
432            api_key: api_key.into(),
433            model: model.into(),
434            base_url: None,
435            project_id: None,
436            location: None,
437            temperature: 0.3,
438            max_tokens: None,
439        }
440    }
441
442    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
443        Self {
444            api_key: api_key.into(),
445            model: model.into(),
446            base_url: None,
447            project_id: None,
448            location: None,
449            temperature: 0.3,
450            max_tokens: None,
451        }
452    }
453
454    pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
455        Self {
456            api_key: api_key.into(),
457            model: model.into(),
458            base_url: Some("https://openrouter.ai/api/v1".into()),
459            project_id: None,
460            location: None,
461            temperature: 0.3,
462            max_tokens: None,
463        }
464    }
465
466    pub fn vertex(
467        access_token: impl Into<String>,
468        project_id: impl Into<String>,
469        model: impl Into<String>,
470    ) -> Self {
471        Self {
472            api_key: access_token.into(),
473            model: model.into(),
474            base_url: None,
475            project_id: Some(project_id.into()),
476            location: Some("global".to_string()),
477            temperature: 0.3,
478            max_tokens: None,
479        }
480    }
481
482    pub fn ollama(model: impl Into<String>) -> Self {
483        Self {
484            api_key: String::new(),
485            model: model.into(),
486            base_url: Some("http://localhost:11434/v1".into()),
487            project_id: None,
488            location: None,
489            temperature: 0.3,
490            max_tokens: None,
491        }
492    }
493}
494
495/// Errors from SGR calls.
496#[derive(Debug, thiserror::Error)]
497pub enum SgrError {
498    #[error("HTTP error: {0}")]
499    Http(#[from] reqwest::Error),
500    #[error("API error {status}: {body}")]
501    Api { status: u16, body: String },
502    #[error("Rate limit: {}", info.status_line())]
503    RateLimit { status: u16, info: RateLimitInfo },
504    #[error("JSON parse error: {0}")]
505    Json(#[from] serde_json::Error),
506    #[error("Schema error: {0}")]
507    Schema(String),
508    #[error("No content in response")]
509    EmptyResponse,
510}
511
512impl SgrError {
513    /// Build error from HTTP status + body, auto-detecting rate limits.
514    pub fn from_api_response(status: u16, body: String) -> Self {
515        if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
516            && let Some(mut info) = RateLimitInfo::from_error_body(&body)
517        {
518            if info.message.is_none() {
519                info.message = Some(body.chars().take(200).collect());
520            }
521            return SgrError::RateLimit { status, info };
522        }
523        SgrError::Api { status, body }
524    }
525
526    /// Build error from HTTP status + body + headers, auto-detecting rate limits.
527    pub fn from_response_parts(
528        status: u16,
529        body: String,
530        headers: &reqwest::header::HeaderMap,
531    ) -> Self {
532        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
533            let mut info = RateLimitInfo::from_error_body(&body)
534                .or_else(|| RateLimitInfo::from_headers(headers))
535                .unwrap_or(RateLimitInfo {
536                    requests_remaining: None,
537                    tokens_remaining: None,
538                    retry_after_secs: None,
539                    resets_at: None,
540                    error_type: Some("rate_limit".into()),
541                    message: Some(body.chars().take(200).collect()),
542                });
543            // Merge header info into body info
544            if let Some(header_info) = RateLimitInfo::from_headers(headers) {
545                if info.requests_remaining.is_none() {
546                    info.requests_remaining = header_info.requests_remaining;
547                }
548                if info.tokens_remaining.is_none() {
549                    info.tokens_remaining = header_info.tokens_remaining;
550                }
551            }
552            return SgrError::RateLimit { status, info };
553        }
554        SgrError::Api { status, body }
555    }
556
557    /// Is this a rate limit error?
558    pub fn is_rate_limit(&self) -> bool {
559        matches!(self, SgrError::RateLimit { .. })
560    }
561
562    /// Get rate limit info if this is a rate limit error.
563    pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
564        match self {
565            SgrError::RateLimit { info, .. } => Some(info),
566            _ => None,
567        }
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn parse_codex_rate_limit_error() {
577        let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
578        let err = SgrError::from_api_response(429, body.to_string());
579        assert!(err.is_rate_limit());
580        let info = err.rate_limit_info().unwrap();
581        assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
582        assert_eq!(info.retry_after_secs, Some(442787));
583        assert_eq!(info.resets_at, Some(1773534007));
584        assert_eq!(info.reset_display(), "5d 2h");
585    }
586
587    #[test]
588    fn parse_openai_rate_limit_error() {
589        let body =
590            r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
591        let err = SgrError::from_api_response(429, body.to_string());
592        assert!(err.is_rate_limit());
593        let info = err.rate_limit_info().unwrap();
594        assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
595    }
596
597    #[test]
598    fn non_rate_limit_stays_api_error() {
599        let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
600        let err = SgrError::from_api_response(400, body.to_string());
601        assert!(!err.is_rate_limit());
602        assert!(matches!(err, SgrError::Api { status: 400, .. }));
603    }
604
605    #[test]
606    fn status_line_with_all_fields() {
607        let info = RateLimitInfo {
608            requests_remaining: Some(5),
609            tokens_remaining: Some(10000),
610            retry_after_secs: Some(3600),
611            resets_at: None,
612            error_type: None,
613            message: None,
614        };
615        assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
616    }
617
618    #[test]
619    fn status_line_fallback_to_message() {
620        let info = RateLimitInfo {
621            requests_remaining: None,
622            tokens_remaining: None,
623            retry_after_secs: None,
624            resets_at: None,
625            error_type: None,
626            message: Some("custom message".into()),
627        };
628        assert_eq!(info.status_line(), "custom message");
629    }
630
631    #[test]
632    fn reset_display_formats() {
633        let make = |secs| RateLimitInfo {
634            requests_remaining: None,
635            tokens_remaining: None,
636            retry_after_secs: Some(secs),
637            resets_at: None,
638            error_type: None,
639            message: None,
640        };
641        assert_eq!(make(90).reset_display(), "1m");
642        assert_eq!(make(3661).reset_display(), "1h 1m");
643        assert_eq!(make(90000).reset_display(), "1d 1h");
644    }
645}