1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6 pub data: String,
8 pub mime_type: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15 pub role: Role,
16 pub content: String,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub tool_call_id: Option<String>,
20 #[serde(default, skip_serializing_if = "Vec::is_empty")]
23 pub tool_calls: Vec<ToolCall>,
24 #[serde(default, skip_serializing_if = "Vec::is_empty")]
26 pub images: Vec<ImagePart>,
27 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
31 pub compactable: bool,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "lowercase")]
36pub enum Role {
37 System,
38 User,
39 Assistant,
40 Tool,
41}
42
43impl Message {
44 pub fn system(content: impl Into<String>) -> Self {
45 Self {
46 role: Role::System,
47 content: content.into(),
48 tool_call_id: None,
49 tool_calls: vec![],
50 images: vec![],
51 compactable: false,
52 }
53 }
54 pub fn user(content: impl Into<String>) -> Self {
55 Self {
56 role: Role::User,
57 content: content.into(),
58 tool_call_id: None,
59 tool_calls: vec![],
60 images: vec![],
61 compactable: false,
62 }
63 }
64 pub fn assistant(content: impl Into<String>) -> Self {
65 Self {
66 role: Role::Assistant,
67 content: content.into(),
68 tool_call_id: None,
69 tool_calls: vec![],
70 images: vec![],
71 compactable: false,
72 }
73 }
74 pub fn assistant_with_tool_calls(
76 content: impl Into<String>,
77 tool_calls: Vec<ToolCall>,
78 ) -> Self {
79 Self {
80 role: Role::Assistant,
81 content: content.into(),
82 tool_call_id: None,
83 tool_calls,
84 images: vec![],
85 compactable: false,
86 }
87 }
88 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
89 Self {
90 role: Role::Tool,
91 content: content.into(),
92 tool_call_id: Some(call_id.into()),
93 tool_calls: vec![],
94 images: vec![],
95 compactable: false,
96 }
97 }
98 pub fn tool_with_images(
100 call_id: impl Into<String>,
101 content: impl Into<String>,
102 images: Vec<ImagePart>,
103 ) -> Self {
104 Self {
105 role: Role::Tool,
106 content: content.into(),
107 tool_call_id: Some(call_id.into()),
108 tool_calls: vec![],
109 images,
110 compactable: false,
111 }
112 }
113 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
115 Self {
116 role: Role::User,
117 content: content.into(),
118 tool_call_id: None,
119 tool_calls: vec![],
120 images,
121 compactable: false,
122 }
123 }
124 pub fn compactable(mut self) -> Self {
126 self.compactable = true;
127 self
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ToolCall {
134 pub id: String,
136 pub name: String,
138 pub arguments: serde_json::Value,
140}
141
142#[derive(Debug, Clone)]
144pub struct SgrResponse<T> {
145 pub output: Option<T>,
148 pub tool_calls: Vec<ToolCall>,
150 pub raw_text: String,
152 pub usage: Option<Usage>,
154 pub rate_limit: Option<RateLimitInfo>,
156}
157
158#[derive(Debug, Clone, Default, Serialize, Deserialize)]
159pub struct Usage {
160 pub prompt_tokens: u32,
161 pub completion_tokens: u32,
162 pub total_tokens: u32,
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct RateLimitInfo {
168 pub requests_remaining: Option<u32>,
170 pub tokens_remaining: Option<u32>,
172 pub retry_after_secs: Option<u64>,
174 pub resets_at: Option<u64>,
176 pub error_type: Option<String>,
178 pub message: Option<String>,
180}
181
182impl RateLimitInfo {
183 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
185 let get_u32 =
186 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
187 let get_u64 =
188 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
189
190 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
191 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
192 let retry_after_secs =
193 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
194 let resets_at = get_u64("x-ratelimit-reset-tokens");
195
196 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
197 {
198 Some(Self {
199 requests_remaining,
200 tokens_remaining,
201 retry_after_secs,
202 resets_at,
203 error_type: None,
204 message: None,
205 })
206 } else {
207 None
208 }
209 }
210
211 pub fn from_error_body(body: &str) -> Option<Self> {
213 let json: serde_json::Value = serde_json::from_str(body).ok()?;
214 let err = json.get("error")?;
215
216 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
217 let message = err
218 .get("message")
219 .and_then(|v| v.as_str())
220 .map(String::from);
221 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
222 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
223
224 Some(Self {
225 requests_remaining: None,
226 tokens_remaining: None,
227 retry_after_secs,
228 resets_at,
229 error_type,
230 message,
231 })
232 }
233
234 pub fn reset_display(&self) -> String {
236 if let Some(secs) = self.retry_after_secs {
237 let hours = secs / 3600;
238 let mins = (secs % 3600) / 60;
239 if hours >= 24 {
240 format!("{}d {}h", hours / 24, hours % 24)
241 } else if hours > 0 {
242 format!("{}h {}m", hours, mins)
243 } else {
244 format!("{}m", mins)
245 }
246 } else {
247 "unknown".into()
248 }
249 }
250
251 pub fn status_line(&self) -> String {
253 let mut parts = Vec::new();
254 if let Some(r) = self.requests_remaining {
255 parts.push(format!("req:{}", r));
256 }
257 if let Some(t) = self.tokens_remaining {
258 parts.push(format!("tok:{}", t));
259 }
260 if self.retry_after_secs.is_some() {
261 parts.push(format!("reset:{}", self.reset_display()));
262 }
263 if parts.is_empty() {
264 self.message
265 .clone()
266 .unwrap_or_else(|| "rate limited".into())
267 } else {
268 parts.join(" | ")
269 }
270 }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
288pub struct LlmConfig {
289 pub model: String,
290 #[serde(default, skip_serializing_if = "Option::is_none")]
291 pub api_key: Option<String>,
292 #[serde(default, skip_serializing_if = "Option::is_none")]
293 pub base_url: Option<String>,
294 #[serde(default = "default_temperature")]
295 pub temp: f64,
296 #[serde(default, skip_serializing_if = "Option::is_none")]
297 pub max_tokens: Option<u32>,
298 #[serde(default, skip_serializing_if = "Option::is_none")]
300 pub prompt_cache_key: Option<String>,
301 #[serde(default, skip_serializing_if = "Option::is_none")]
303 pub project_id: Option<String>,
304 #[serde(default, skip_serializing_if = "Option::is_none")]
306 pub location: Option<String>,
307 #[serde(default)]
311 pub use_chat_api: bool,
312 #[serde(default, skip_serializing_if = "Vec::is_empty")]
315 pub extra_headers: Vec<(String, String)>,
316 #[serde(default, skip_serializing_if = "Option::is_none")]
319 pub reasoning_effort: Option<String>,
320 #[serde(default)]
323 pub use_genai: bool,
324 #[serde(default)]
328 pub use_cli: bool,
329}
330
331fn default_temperature() -> f64 {
332 0.7
333}
334
335impl Default for LlmConfig {
336 fn default() -> Self {
337 Self {
338 model: String::new(),
339 api_key: None,
340 base_url: None,
341 temp: default_temperature(),
342 max_tokens: None,
343 prompt_cache_key: None,
344 project_id: None,
345 location: None,
346 use_chat_api: false,
347 extra_headers: Vec::new(),
348 reasoning_effort: None,
349 use_genai: false,
350 use_cli: false,
351 }
352 }
353}
354
355impl LlmConfig {
356 pub fn auto(model: impl Into<String>) -> Self {
358 Self {
359 model: model.into(),
360 ..Default::default()
361 }
362 }
363
364 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
366 Self {
367 model: model.into(),
368 api_key: Some(api_key.into()),
369 ..Default::default()
370 }
371 }
372
373 pub fn endpoint(
375 api_key: impl Into<String>,
376 base_url: impl Into<String>,
377 model: impl Into<String>,
378 ) -> Self {
379 Self {
380 model: model.into(),
381 api_key: Some(api_key.into()),
382 base_url: Some(base_url.into()),
383 ..Default::default()
384 }
385 }
386
387 pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
389 Self {
390 model: model.into(),
391 project_id: Some(project_id.into()),
392 location: Some("global".into()),
393 ..Default::default()
394 }
395 }
396
397 pub fn location(mut self, loc: impl Into<String>) -> Self {
399 self.location = Some(loc.into());
400 self
401 }
402
403 pub fn temperature(mut self, t: f64) -> Self {
405 self.temp = t;
406 self
407 }
408
409 pub fn max_tokens(mut self, m: u32) -> Self {
411 self.max_tokens = Some(m);
412 self
413 }
414
415 pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
417 self.prompt_cache_key = Some(key.into());
418 self
419 }
420
421 pub fn apply_headers(&self, config: &mut openai_oxide::config::ClientConfig) {
424 if !self.extra_headers.is_empty() {
425 let mut hm = reqwest::header::HeaderMap::new();
426 for (k, v) in &self.extra_headers {
427 if let (Ok(name), Ok(val)) = (
428 reqwest::header::HeaderName::from_bytes(k.as_bytes()),
429 reqwest::header::HeaderValue::from_str(v),
430 ) {
431 hm.insert(name, val);
432 }
433 }
434 config.default_headers = Some(hm);
435 }
436 }
437
438 pub fn cli(cli_model: impl Into<String>) -> Self {
442 Self {
443 model: cli_model.into(),
444 use_cli: true,
445 ..Default::default()
446 }
447 }
448
449 pub fn label(&self) -> String {
451 if self.use_cli {
452 format!("CLI ({})", self.model)
453 } else if self.project_id.is_some() {
454 format!("Vertex ({})", self.model)
455 } else if self.base_url.is_some() {
456 format!("Custom ({})", self.model)
457 } else {
458 self.model.clone()
459 }
460 }
461
462 pub fn compaction_model(&self) -> String {
464 if self.model.starts_with("gemini") {
465 "gemini-2.0-flash-lite".into()
466 } else if self.model.starts_with("gpt") {
467 "gpt-4o-mini".into()
468 } else if self.model.starts_with("claude") {
469 "claude-3-haiku-20240307".into()
470 } else {
471 self.model.clone()
473 }
474 }
475
476 pub fn for_compaction(&self) -> Self {
478 let mut cfg = self.clone();
479 cfg.model = self.compaction_model();
480 cfg.max_tokens = Some(2048);
481 cfg
482 }
483}
484
485#[derive(Debug, Clone)]
487pub struct ProviderConfig {
488 pub api_key: String,
489 pub model: String,
490 pub base_url: Option<String>,
491 pub project_id: Option<String>,
492 pub location: Option<String>,
493 pub temperature: f32,
494 pub max_tokens: Option<u32>,
495}
496
497impl ProviderConfig {
498 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
499 Self {
500 api_key: api_key.into(),
501 model: model.into(),
502 base_url: None,
503 project_id: None,
504 location: None,
505 temperature: 0.3,
506 max_tokens: None,
507 }
508 }
509
510 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
511 Self {
512 api_key: api_key.into(),
513 model: model.into(),
514 base_url: None,
515 project_id: None,
516 location: None,
517 temperature: 0.3,
518 max_tokens: None,
519 }
520 }
521
522 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
523 Self {
524 api_key: api_key.into(),
525 model: model.into(),
526 base_url: Some("https://openrouter.ai/api/v1".into()),
527 project_id: None,
528 location: None,
529 temperature: 0.3,
530 max_tokens: None,
531 }
532 }
533
534 pub fn vertex(
535 access_token: impl Into<String>,
536 project_id: impl Into<String>,
537 model: impl Into<String>,
538 ) -> Self {
539 Self {
540 api_key: access_token.into(),
541 model: model.into(),
542 base_url: None,
543 project_id: Some(project_id.into()),
544 location: Some("global".to_string()),
545 temperature: 0.3,
546 max_tokens: None,
547 }
548 }
549
550 pub fn ollama(model: impl Into<String>) -> Self {
551 Self {
552 api_key: String::new(),
553 model: model.into(),
554 base_url: Some("http://localhost:11434/v1".into()),
555 project_id: None,
556 location: None,
557 temperature: 0.3,
558 max_tokens: None,
559 }
560 }
561}
562
563#[derive(Debug, thiserror::Error)]
565pub enum SgrError {
566 #[error("HTTP error: {0}")]
567 Http(#[from] reqwest::Error),
568 #[error("API error {status}: {body}")]
569 Api { status: u16, body: String },
570 #[error("Rate limit: {}", info.status_line())]
571 RateLimit { status: u16, info: RateLimitInfo },
572 #[error("JSON parse error: {0}")]
573 Json(#[from] serde_json::Error),
574 #[error("Schema error: {0}")]
575 Schema(String),
576 #[error("No content in response")]
577 EmptyResponse,
578 #[error("Response truncated (max_output_tokens): {partial_content}")]
581 MaxOutputTokens { partial_content: String },
582 #[error("Prompt too long: {0}")]
584 PromptTooLong(String),
585}
586
587impl SgrError {
588 pub fn from_api_response(status: u16, body: String) -> Self {
590 if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
591 && let Some(mut info) = RateLimitInfo::from_error_body(&body)
592 {
593 if info.message.is_none() {
594 info.message = Some(body.chars().take(200).collect());
595 }
596 return SgrError::RateLimit { status, info };
597 }
598 SgrError::Api { status, body }
599 }
600
601 pub fn from_response_parts(
603 status: u16,
604 body: String,
605 headers: &reqwest::header::HeaderMap,
606 ) -> Self {
607 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
608 let mut info = RateLimitInfo::from_error_body(&body)
609 .or_else(|| RateLimitInfo::from_headers(headers))
610 .unwrap_or(RateLimitInfo {
611 requests_remaining: None,
612 tokens_remaining: None,
613 retry_after_secs: None,
614 resets_at: None,
615 error_type: Some("rate_limit".into()),
616 message: Some(body.chars().take(200).collect()),
617 });
618 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
620 if info.requests_remaining.is_none() {
621 info.requests_remaining = header_info.requests_remaining;
622 }
623 if info.tokens_remaining.is_none() {
624 info.tokens_remaining = header_info.tokens_remaining;
625 }
626 }
627 return SgrError::RateLimit { status, info };
628 }
629 SgrError::Api { status, body }
630 }
631
632 pub fn is_rate_limit(&self) -> bool {
634 matches!(self, SgrError::RateLimit { .. })
635 }
636
637 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
639 match self {
640 SgrError::RateLimit { info, .. } => Some(info),
641 _ => None,
642 }
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn parse_codex_rate_limit_error() {
652 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
653 let err = SgrError::from_api_response(429, body.to_string());
654 assert!(err.is_rate_limit());
655 let info = err.rate_limit_info().unwrap();
656 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
657 assert_eq!(info.retry_after_secs, Some(442787));
658 assert_eq!(info.resets_at, Some(1773534007));
659 assert_eq!(info.reset_display(), "5d 2h");
660 }
661
662 #[test]
663 fn parse_openai_rate_limit_error() {
664 let body =
665 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
666 let err = SgrError::from_api_response(429, body.to_string());
667 assert!(err.is_rate_limit());
668 let info = err.rate_limit_info().unwrap();
669 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
670 }
671
672 #[test]
673 fn non_rate_limit_stays_api_error() {
674 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
675 let err = SgrError::from_api_response(400, body.to_string());
676 assert!(!err.is_rate_limit());
677 assert!(matches!(err, SgrError::Api { status: 400, .. }));
678 }
679
680 #[test]
681 fn status_line_with_all_fields() {
682 let info = RateLimitInfo {
683 requests_remaining: Some(5),
684 tokens_remaining: Some(10000),
685 retry_after_secs: Some(3600),
686 resets_at: None,
687 error_type: None,
688 message: None,
689 };
690 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
691 }
692
693 #[test]
694 fn status_line_fallback_to_message() {
695 let info = RateLimitInfo {
696 requests_remaining: None,
697 tokens_remaining: None,
698 retry_after_secs: None,
699 resets_at: None,
700 error_type: None,
701 message: Some("custom message".into()),
702 };
703 assert_eq!(info.status_line(), "custom message");
704 }
705
706 #[test]
707 fn reset_display_formats() {
708 let make = |secs| RateLimitInfo {
709 requests_remaining: None,
710 tokens_remaining: None,
711 retry_after_secs: Some(secs),
712 resets_at: None,
713 error_type: None,
714 message: None,
715 };
716 assert_eq!(make(90).reset_display(), "1m");
717 assert_eq!(make(3661).reset_display(), "1h 1m");
718 assert_eq!(make(90000).reset_display(), "1d 1h");
719 }
720}