Skip to main content

sgr_agent/
types.rs

1use serde::{Deserialize, Serialize};
2
3/// Inline image data for multimodal messages.
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6    /// Base64-encoded image data.
7    pub data: String,
8    /// MIME type (e.g. "image/jpeg", "image/png").
9    pub mime_type: String,
10}
11
12/// A chat message in the conversation history.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15    pub role: Role,
16    pub content: String,
17    /// Tool call results (only for Role::Tool).
18    #[serde(skip_serializing_if = "Option::is_none")]
19    pub tool_call_id: Option<String>,
20    /// Tool calls made by the assistant (only for Role::Assistant with function calling).
21    /// Gemini API requires model turns to include functionCall parts.
22    #[serde(default, skip_serializing_if = "Vec::is_empty")]
23    pub tool_calls: Vec<ToolCall>,
24    /// Inline images (for multimodal VLM input).
25    #[serde(default, skip_serializing_if = "Vec::is_empty")]
26    pub images: Vec<ImagePart>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32    System,
33    User,
34    Assistant,
35    Tool,
36}
37
38impl Message {
39    pub fn system(content: impl Into<String>) -> Self {
40        Self {
41            role: Role::System,
42            content: content.into(),
43            tool_call_id: None,
44            tool_calls: vec![],
45            images: vec![],
46        }
47    }
48    pub fn user(content: impl Into<String>) -> Self {
49        Self {
50            role: Role::User,
51            content: content.into(),
52            tool_call_id: None,
53            tool_calls: vec![],
54            images: vec![],
55        }
56    }
57    pub fn assistant(content: impl Into<String>) -> Self {
58        Self {
59            role: Role::Assistant,
60            content: content.into(),
61            tool_call_id: None,
62            tool_calls: vec![],
63            images: vec![],
64        }
65    }
66    /// Create an assistant message that includes function calls (for Gemini FC protocol).
67    pub fn assistant_with_tool_calls(
68        content: impl Into<String>,
69        tool_calls: Vec<ToolCall>,
70    ) -> Self {
71        Self {
72            role: Role::Assistant,
73            content: content.into(),
74            tool_call_id: None,
75            tool_calls,
76            images: vec![],
77        }
78    }
79    pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
80        Self {
81            role: Role::Tool,
82            content: content.into(),
83            tool_call_id: Some(call_id.into()),
84            tool_calls: vec![],
85            images: vec![],
86        }
87    }
88    /// Tool result with inline images (for VLM — Gemini sees the images).
89    pub fn tool_with_images(
90        call_id: impl Into<String>,
91        content: impl Into<String>,
92        images: Vec<ImagePart>,
93    ) -> Self {
94        Self {
95            role: Role::Tool,
96            content: content.into(),
97            tool_call_id: Some(call_id.into()),
98            tool_calls: vec![],
99            images,
100        }
101    }
102    /// User message with inline images.
103    pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
104        Self {
105            role: Role::User,
106            content: content.into(),
107            tool_call_id: None,
108            tool_calls: vec![],
109            images,
110        }
111    }
112}
113
114/// A tool call returned by the model.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCall {
117    /// Unique ID for matching with tool results.
118    pub id: String,
119    /// Tool/function name.
120    pub name: String,
121    /// JSON-encoded arguments.
122    pub arguments: serde_json::Value,
123}
124
125/// Response from an SGR call — structured output + optional tool calls.
126#[derive(Debug, Clone)]
127pub struct SgrResponse<T> {
128    /// Parsed structured output (SGR envelope).
129    /// `None` if the model only returned tool calls without structured content.
130    pub output: Option<T>,
131    /// Tool calls the model wants to execute.
132    pub tool_calls: Vec<ToolCall>,
133    /// Raw text (for streaming / debugging).
134    pub raw_text: String,
135    /// Token usage.
136    pub usage: Option<Usage>,
137    /// Rate limit info from response headers (if provider sends them).
138    pub rate_limit: Option<RateLimitInfo>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
142pub struct Usage {
143    pub prompt_tokens: u32,
144    pub completion_tokens: u32,
145    pub total_tokens: u32,
146}
147
148/// Rate limit info extracted from response headers and/or error body.
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RateLimitInfo {
151    /// Requests remaining in current window.
152    pub requests_remaining: Option<u32>,
153    /// Tokens remaining in current window.
154    pub tokens_remaining: Option<u32>,
155    /// Seconds until limit resets.
156    pub retry_after_secs: Option<u64>,
157    /// Unix timestamp when limit resets.
158    pub resets_at: Option<u64>,
159    /// Provider error type (e.g. "usage_limit_reached", "rate_limit_exceeded").
160    pub error_type: Option<String>,
161    /// Human-readable message from provider.
162    pub message: Option<String>,
163}
164
165impl RateLimitInfo {
166    /// Parse from HTTP response headers (OpenAI/Gemini/OpenRouter standard).
167    pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
168        let get_u32 =
169            |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
170        let get_u64 =
171            |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
172
173        let requests_remaining = get_u32("x-ratelimit-remaining-requests");
174        let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
175        let retry_after_secs =
176            get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
177        let resets_at = get_u64("x-ratelimit-reset-tokens");
178
179        if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
180        {
181            Some(Self {
182                requests_remaining,
183                tokens_remaining,
184                retry_after_secs,
185                resets_at,
186                error_type: None,
187                message: None,
188            })
189        } else {
190            None
191        }
192    }
193
194    /// Parse from JSON error body (OpenAI, Codex, Gemini error responses).
195    pub fn from_error_body(body: &str) -> Option<Self> {
196        let json: serde_json::Value = serde_json::from_str(body).ok()?;
197        let err = json.get("error")?;
198
199        let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
200        let message = err
201            .get("message")
202            .and_then(|v| v.as_str())
203            .map(String::from);
204        let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
205        let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
206
207        Some(Self {
208            requests_remaining: None,
209            tokens_remaining: None,
210            retry_after_secs,
211            resets_at,
212            error_type,
213            message,
214        })
215    }
216
217    /// Human-readable description of when limit resets.
218    pub fn reset_display(&self) -> String {
219        if let Some(secs) = self.retry_after_secs {
220            let hours = secs / 3600;
221            let mins = (secs % 3600) / 60;
222            if hours >= 24 {
223                format!("{}d {}h", hours / 24, hours % 24)
224            } else if hours > 0 {
225                format!("{}h {}m", hours, mins)
226            } else {
227                format!("{}m", mins)
228            }
229        } else {
230            "unknown".into()
231        }
232    }
233
234    /// One-line status for status bar.
235    pub fn status_line(&self) -> String {
236        let mut parts = Vec::new();
237        if let Some(r) = self.requests_remaining {
238            parts.push(format!("req:{}", r));
239        }
240        if let Some(t) = self.tokens_remaining {
241            parts.push(format!("tok:{}", t));
242        }
243        if self.retry_after_secs.is_some() {
244            parts.push(format!("reset:{}", self.reset_display()));
245        }
246        if parts.is_empty() {
247            self.message
248                .clone()
249                .unwrap_or_else(|| "rate limited".into())
250        } else {
251            parts.join(" | ")
252        }
253    }
254}
255
256/// LLM provider configuration — single config for any provider.
257///
258/// Two optional fields control routing:
259/// - `api_key`: None → auto from env vars (OPENAI_API_KEY, ANTHROPIC_API_KEY, etc.)
260/// - `base_url`: None → auto-detect provider from model name; Some → custom endpoint
261///
262/// ```no_run
263/// use sgr_agent::LlmConfig;
264///
265/// let c = LlmConfig::auto("gpt-4o");                                          // env vars
266/// let c = LlmConfig::with_key("sk-...", "claude-3-haiku");                    // explicit key
267/// let c = LlmConfig::endpoint("sk-or-...", "https://openrouter.ai/api/v1", "gpt-4o"); // custom
268/// let c = LlmConfig::auto("gpt-4o").temperature(0.9).max_tokens(2048);        // builder
269/// ```
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LlmConfig {
272    pub model: String,
273    #[serde(default, skip_serializing_if = "Option::is_none")]
274    pub api_key: Option<String>,
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub base_url: Option<String>,
277    #[serde(default = "default_temperature")]
278    pub temp: f64,
279    #[serde(default, skip_serializing_if = "Option::is_none")]
280    pub max_tokens: Option<u32>,
281    /// OpenAI prompt cache key — caches system prompt prefix server-side.
282    #[serde(default, skip_serializing_if = "Option::is_none")]
283    pub prompt_cache_key: Option<String>,
284    /// Vertex AI project ID (enables Vertex routing when set).
285    #[serde(default, skip_serializing_if = "Option::is_none")]
286    pub project_id: Option<String>,
287    /// Vertex AI location (default: "global").
288    #[serde(default, skip_serializing_if = "Option::is_none")]
289    pub location: Option<String>,
290}
291
292fn default_temperature() -> f64 {
293    0.7
294}
295
296impl Default for LlmConfig {
297    fn default() -> Self {
298        Self {
299            model: String::new(),
300            api_key: None,
301            base_url: None,
302            temp: default_temperature(),
303            max_tokens: None,
304            prompt_cache_key: None,
305            project_id: None,
306            location: None,
307        }
308    }
309}
310
311impl LlmConfig {
312    /// Auto-detect provider from model name, use env vars for auth.
313    pub fn auto(model: impl Into<String>) -> Self {
314        Self {
315            model: model.into(),
316            ..Default::default()
317        }
318    }
319
320    /// Explicit API key, auto-detect provider from model name.
321    pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
322        Self {
323            model: model.into(),
324            api_key: Some(api_key.into()),
325            ..Default::default()
326        }
327    }
328
329    /// Custom OpenAI-compatible endpoint (OpenRouter, Ollama, LiteLLM, etc.).
330    pub fn endpoint(
331        api_key: impl Into<String>,
332        base_url: impl Into<String>,
333        model: impl Into<String>,
334    ) -> Self {
335        Self {
336            model: model.into(),
337            api_key: Some(api_key.into()),
338            base_url: Some(base_url.into()),
339            ..Default::default()
340        }
341    }
342
343    /// Vertex AI — uses gcloud ADC for auth (no API key needed).
344    pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
345        Self {
346            model: model.into(),
347            project_id: Some(project_id.into()),
348            location: Some("global".into()),
349            ..Default::default()
350        }
351    }
352
353    /// Set Vertex AI location.
354    pub fn location(mut self, loc: impl Into<String>) -> Self {
355        self.location = Some(loc.into());
356        self
357    }
358
359    /// Set temperature.
360    pub fn temperature(mut self, t: f64) -> Self {
361        self.temp = t;
362        self
363    }
364
365    /// Set max output tokens.
366    pub fn max_tokens(mut self, m: u32) -> Self {
367        self.max_tokens = Some(m);
368        self
369    }
370
371    /// Set OpenAI prompt cache key for server-side system prompt caching.
372    pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
373        self.prompt_cache_key = Some(key.into());
374        self
375    }
376
377    /// Human-readable label for display.
378    pub fn label(&self) -> String {
379        if self.project_id.is_some() {
380            format!("Vertex ({})", self.model)
381        } else if self.base_url.is_some() {
382            format!("Custom ({})", self.model)
383        } else {
384            self.model.clone()
385        }
386    }
387
388    /// Infer a cheap/fast model for compaction based on the primary model.
389    pub fn compaction_model(&self) -> String {
390        if self.model.starts_with("gemini") {
391            "gemini-2.0-flash-lite".into()
392        } else if self.model.starts_with("gpt") {
393            "gpt-4o-mini".into()
394        } else if self.model.starts_with("claude") {
395            "claude-3-haiku-20240307".into()
396        } else {
397            // Unknown provider — use the same model
398            self.model.clone()
399        }
400    }
401
402    /// Create a compaction config — cheap model, low max_tokens.
403    pub fn for_compaction(&self) -> Self {
404        let mut cfg = self.clone();
405        cfg.model = self.compaction_model();
406        cfg.max_tokens = Some(2048);
407        cfg
408    }
409}
410
411/// Legacy provider configuration (used by OpenAIClient/GeminiClient).
412#[derive(Debug, Clone)]
413pub struct ProviderConfig {
414    pub api_key: String,
415    pub model: String,
416    pub base_url: Option<String>,
417    pub project_id: Option<String>,
418    pub location: Option<String>,
419    pub temperature: f32,
420    pub max_tokens: Option<u32>,
421}
422
423impl ProviderConfig {
424    pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
425        Self {
426            api_key: api_key.into(),
427            model: model.into(),
428            base_url: None,
429            project_id: None,
430            location: None,
431            temperature: 0.3,
432            max_tokens: None,
433        }
434    }
435
436    pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
437        Self {
438            api_key: api_key.into(),
439            model: model.into(),
440            base_url: None,
441            project_id: None,
442            location: None,
443            temperature: 0.3,
444            max_tokens: None,
445        }
446    }
447
448    pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
449        Self {
450            api_key: api_key.into(),
451            model: model.into(),
452            base_url: Some("https://openrouter.ai/api/v1".into()),
453            project_id: None,
454            location: None,
455            temperature: 0.3,
456            max_tokens: None,
457        }
458    }
459
460    pub fn vertex(
461        access_token: impl Into<String>,
462        project_id: impl Into<String>,
463        model: impl Into<String>,
464    ) -> Self {
465        Self {
466            api_key: access_token.into(),
467            model: model.into(),
468            base_url: None,
469            project_id: Some(project_id.into()),
470            location: Some("global".to_string()),
471            temperature: 0.3,
472            max_tokens: None,
473        }
474    }
475
476    pub fn ollama(model: impl Into<String>) -> Self {
477        Self {
478            api_key: String::new(),
479            model: model.into(),
480            base_url: Some("http://localhost:11434/v1".into()),
481            project_id: None,
482            location: None,
483            temperature: 0.3,
484            max_tokens: None,
485        }
486    }
487}
488
489/// Errors from SGR calls.
490#[derive(Debug, thiserror::Error)]
491pub enum SgrError {
492    #[error("HTTP error: {0}")]
493    Http(#[from] reqwest::Error),
494    #[error("API error {status}: {body}")]
495    Api { status: u16, body: String },
496    #[error("Rate limit: {}", info.status_line())]
497    RateLimit { status: u16, info: RateLimitInfo },
498    #[error("JSON parse error: {0}")]
499    Json(#[from] serde_json::Error),
500    #[error("Schema error: {0}")]
501    Schema(String),
502    #[error("No content in response")]
503    EmptyResponse,
504}
505
506impl SgrError {
507    /// Build error from HTTP status + body, auto-detecting rate limits.
508    pub fn from_api_response(status: u16, body: String) -> Self {
509        if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
510            && let Some(mut info) = RateLimitInfo::from_error_body(&body)
511        {
512            if info.message.is_none() {
513                info.message = Some(body.chars().take(200).collect());
514            }
515            return SgrError::RateLimit { status, info };
516        }
517        SgrError::Api { status, body }
518    }
519
520    /// Build error from HTTP status + body + headers, auto-detecting rate limits.
521    pub fn from_response_parts(
522        status: u16,
523        body: String,
524        headers: &reqwest::header::HeaderMap,
525    ) -> Self {
526        if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
527            let mut info = RateLimitInfo::from_error_body(&body)
528                .or_else(|| RateLimitInfo::from_headers(headers))
529                .unwrap_or(RateLimitInfo {
530                    requests_remaining: None,
531                    tokens_remaining: None,
532                    retry_after_secs: None,
533                    resets_at: None,
534                    error_type: Some("rate_limit".into()),
535                    message: Some(body.chars().take(200).collect()),
536                });
537            // Merge header info into body info
538            if let Some(header_info) = RateLimitInfo::from_headers(headers) {
539                if info.requests_remaining.is_none() {
540                    info.requests_remaining = header_info.requests_remaining;
541                }
542                if info.tokens_remaining.is_none() {
543                    info.tokens_remaining = header_info.tokens_remaining;
544                }
545            }
546            return SgrError::RateLimit { status, info };
547        }
548        SgrError::Api { status, body }
549    }
550
551    /// Is this a rate limit error?
552    pub fn is_rate_limit(&self) -> bool {
553        matches!(self, SgrError::RateLimit { .. })
554    }
555
556    /// Get rate limit info if this is a rate limit error.
557    pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
558        match self {
559            SgrError::RateLimit { info, .. } => Some(info),
560            _ => None,
561        }
562    }
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    #[test]
570    fn parse_codex_rate_limit_error() {
571        let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
572        let err = SgrError::from_api_response(429, body.to_string());
573        assert!(err.is_rate_limit());
574        let info = err.rate_limit_info().unwrap();
575        assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
576        assert_eq!(info.retry_after_secs, Some(442787));
577        assert_eq!(info.resets_at, Some(1773534007));
578        assert_eq!(info.reset_display(), "5d 2h");
579    }
580
581    #[test]
582    fn parse_openai_rate_limit_error() {
583        let body =
584            r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
585        let err = SgrError::from_api_response(429, body.to_string());
586        assert!(err.is_rate_limit());
587        let info = err.rate_limit_info().unwrap();
588        assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
589    }
590
591    #[test]
592    fn non_rate_limit_stays_api_error() {
593        let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
594        let err = SgrError::from_api_response(400, body.to_string());
595        assert!(!err.is_rate_limit());
596        assert!(matches!(err, SgrError::Api { status: 400, .. }));
597    }
598
599    #[test]
600    fn status_line_with_all_fields() {
601        let info = RateLimitInfo {
602            requests_remaining: Some(5),
603            tokens_remaining: Some(10000),
604            retry_after_secs: Some(3600),
605            resets_at: None,
606            error_type: None,
607            message: None,
608        };
609        assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
610    }
611
612    #[test]
613    fn status_line_fallback_to_message() {
614        let info = RateLimitInfo {
615            requests_remaining: None,
616            tokens_remaining: None,
617            retry_after_secs: None,
618            resets_at: None,
619            error_type: None,
620            message: Some("custom message".into()),
621        };
622        assert_eq!(info.status_line(), "custom message");
623    }
624
625    #[test]
626    fn reset_display_formats() {
627        let make = |secs| RateLimitInfo {
628            requests_remaining: None,
629            tokens_remaining: None,
630            retry_after_secs: Some(secs),
631            resets_at: None,
632            error_type: None,
633            message: None,
634        };
635        assert_eq!(make(90).reset_display(), "1m");
636        assert_eq!(make(3661).reset_display(), "1h 1m");
637        assert_eq!(make(90000).reset_display(), "1d 1h");
638    }
639}