1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6 pub data: String,
8 pub mime_type: String,
10}
11
12impl ImagePart {
13 pub fn data_url(&self) -> String {
15 format!("data:{};base64,{}", self.mime_type, self.data)
16 }
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22 pub role: Role,
23 pub content: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tool_call_id: Option<String>,
27 #[serde(default, skip_serializing_if = "Vec::is_empty")]
30 pub tool_calls: Vec<ToolCall>,
31 #[serde(default, skip_serializing_if = "Vec::is_empty")]
33 pub images: Vec<ImagePart>,
34 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
38 pub compactable: bool,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum Role {
44 System,
45 User,
46 Assistant,
47 Tool,
48}
49
50impl Message {
51 pub fn system(content: impl Into<String>) -> Self {
52 Self {
53 role: Role::System,
54 content: content.into(),
55 tool_call_id: None,
56 tool_calls: vec![],
57 images: vec![],
58 compactable: false,
59 }
60 }
61 pub fn user(content: impl Into<String>) -> Self {
62 Self {
63 role: Role::User,
64 content: content.into(),
65 tool_call_id: None,
66 tool_calls: vec![],
67 images: vec![],
68 compactable: false,
69 }
70 }
71 pub fn assistant(content: impl Into<String>) -> Self {
72 Self {
73 role: Role::Assistant,
74 content: content.into(),
75 tool_call_id: None,
76 tool_calls: vec![],
77 images: vec![],
78 compactable: false,
79 }
80 }
81 pub fn assistant_with_tool_calls(
83 content: impl Into<String>,
84 tool_calls: Vec<ToolCall>,
85 ) -> Self {
86 Self {
87 role: Role::Assistant,
88 content: content.into(),
89 tool_call_id: None,
90 tool_calls,
91 images: vec![],
92 compactable: false,
93 }
94 }
95 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
96 Self {
97 role: Role::Tool,
98 content: content.into(),
99 tool_call_id: Some(call_id.into()),
100 tool_calls: vec![],
101 images: vec![],
102 compactable: false,
103 }
104 }
105 pub fn tool_with_images(
107 call_id: impl Into<String>,
108 content: impl Into<String>,
109 images: Vec<ImagePart>,
110 ) -> Self {
111 Self {
112 role: Role::Tool,
113 content: content.into(),
114 tool_call_id: Some(call_id.into()),
115 tool_calls: vec![],
116 images,
117 compactable: false,
118 }
119 }
120 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
122 Self {
123 role: Role::User,
124 content: content.into(),
125 tool_call_id: None,
126 tool_calls: vec![],
127 images,
128 compactable: false,
129 }
130 }
131 pub fn compactable(mut self) -> Self {
133 self.compactable = true;
134 self
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolCall {
141 pub id: String,
143 pub name: String,
145 pub arguments: serde_json::Value,
147}
148
149#[derive(Debug, Clone)]
151pub struct SgrResponse<T> {
152 pub output: Option<T>,
155 pub tool_calls: Vec<ToolCall>,
157 pub raw_text: String,
159 pub usage: Option<Usage>,
161 pub rate_limit: Option<RateLimitInfo>,
163}
164
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct Usage {
167 pub prompt_tokens: u32,
168 pub completion_tokens: u32,
169 pub total_tokens: u32,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct RateLimitInfo {
175 pub requests_remaining: Option<u32>,
177 pub tokens_remaining: Option<u32>,
179 pub retry_after_secs: Option<u64>,
181 pub resets_at: Option<u64>,
183 pub error_type: Option<String>,
185 pub message: Option<String>,
187}
188
189impl RateLimitInfo {
190 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
192 let get_u32 =
193 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
194 let get_u64 =
195 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
196
197 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
198 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
199 let retry_after_secs =
200 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
201 let resets_at = get_u64("x-ratelimit-reset-tokens");
202
203 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
204 {
205 Some(Self {
206 requests_remaining,
207 tokens_remaining,
208 retry_after_secs,
209 resets_at,
210 error_type: None,
211 message: None,
212 })
213 } else {
214 None
215 }
216 }
217
218 pub fn from_error_body(body: &str) -> Option<Self> {
220 let json: serde_json::Value = serde_json::from_str(body).ok()?;
221 let err = json.get("error")?;
222
223 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
224 let message = err
225 .get("message")
226 .and_then(|v| v.as_str())
227 .map(String::from);
228 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
229 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
230
231 Some(Self {
232 requests_remaining: None,
233 tokens_remaining: None,
234 retry_after_secs,
235 resets_at,
236 error_type,
237 message,
238 })
239 }
240
241 pub fn reset_display(&self) -> String {
243 if let Some(secs) = self.retry_after_secs {
244 let hours = secs / 3600;
245 let mins = (secs % 3600) / 60;
246 if hours >= 24 {
247 format!("{}d {}h", hours / 24, hours % 24)
248 } else if hours > 0 {
249 format!("{}h {}m", hours, mins)
250 } else {
251 format!("{}m", mins)
252 }
253 } else {
254 "unknown".into()
255 }
256 }
257
258 pub fn status_line(&self) -> String {
260 let mut parts = Vec::new();
261 if let Some(r) = self.requests_remaining {
262 parts.push(format!("req:{}", r));
263 }
264 if let Some(t) = self.tokens_remaining {
265 parts.push(format!("tok:{}", t));
266 }
267 if self.retry_after_secs.is_some() {
268 parts.push(format!("reset:{}", self.reset_display()));
269 }
270 if parts.is_empty() {
271 self.message
272 .clone()
273 .unwrap_or_else(|| "rate limited".into())
274 } else {
275 parts.join(" | ")
276 }
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct LlmConfig {
296 pub model: String,
297 #[serde(default, skip_serializing_if = "Option::is_none")]
298 pub api_key: Option<String>,
299 #[serde(default, skip_serializing_if = "Option::is_none")]
300 pub base_url: Option<String>,
301 #[serde(default = "default_temperature")]
302 pub temp: f64,
303 #[serde(default, skip_serializing_if = "Option::is_none")]
304 pub max_tokens: Option<u32>,
305 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub prompt_cache_key: Option<String>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub project_id: Option<String>,
311 #[serde(default, skip_serializing_if = "Option::is_none")]
313 pub location: Option<String>,
314 #[serde(default)]
318 pub use_chat_api: bool,
319 #[serde(default, skip_serializing_if = "Vec::is_empty")]
322 pub extra_headers: Vec<(String, String)>,
323 #[serde(default, skip_serializing_if = "Option::is_none")]
326 pub reasoning_effort: Option<String>,
327 #[serde(default)]
330 pub use_genai: bool,
331 #[serde(default)]
335 pub use_cli: bool,
336 #[serde(default, skip_serializing_if = "Option::is_none")]
339 pub session_id: Option<String>,
340
341 #[serde(default, skip_serializing_if = "Option::is_none")]
344 pub no_assistant_prefill: Option<bool>,
345 #[serde(default, skip_serializing_if = "Option::is_none")]
347 pub cache_ttl: Option<String>,
348 #[serde(default, skip_serializing_if = "Option::is_none")]
350 pub pin_provider: Option<String>,
351 #[serde(default = "default_websocket")]
354 pub websocket: bool,
355}
356
357fn default_websocket() -> bool {
358 true
359}
360
361fn default_temperature() -> f64 {
362 0.7
363}
364
365impl Default for LlmConfig {
366 fn default() -> Self {
367 Self {
368 model: String::new(),
369 api_key: None,
370 base_url: None,
371 temp: default_temperature(),
372 max_tokens: None,
373 prompt_cache_key: None,
374 project_id: None,
375 location: None,
376 use_chat_api: false,
377 extra_headers: Vec::new(),
378 reasoning_effort: None,
379 use_genai: false,
380 use_cli: false,
381 session_id: None,
382 no_assistant_prefill: None,
383 cache_ttl: None,
384 pin_provider: None,
385 websocket: default_websocket(),
386 }
387 }
388}
389
390impl LlmConfig {
391 pub fn auto(model: impl Into<String>) -> Self {
393 Self {
394 model: model.into(),
395 ..Default::default()
396 }
397 }
398
399 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
401 Self {
402 model: model.into(),
403 api_key: Some(api_key.into()),
404 ..Default::default()
405 }
406 }
407
408 pub fn endpoint(
410 api_key: impl Into<String>,
411 base_url: impl Into<String>,
412 model: impl Into<String>,
413 ) -> Self {
414 Self {
415 model: model.into(),
416 api_key: Some(api_key.into()),
417 base_url: Some(base_url.into()),
418 ..Default::default()
419 }
420 }
421
422 pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
424 Self {
425 model: model.into(),
426 project_id: Some(project_id.into()),
427 location: Some("global".into()),
428 ..Default::default()
429 }
430 }
431
432 pub fn location(mut self, loc: impl Into<String>) -> Self {
434 self.location = Some(loc.into());
435 self
436 }
437
438 pub fn temperature(mut self, t: f64) -> Self {
440 self.temp = t;
441 self
442 }
443
444 pub fn max_tokens(mut self, m: u32) -> Self {
446 self.max_tokens = Some(m);
447 self
448 }
449
450 pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
452 self.prompt_cache_key = Some(key.into());
453 self
454 }
455
456 pub fn is_anthropic(&self) -> bool {
458 self.model.starts_with("anthropic/") || self.model.starts_with("claude")
459 }
460
461 pub fn rejects_prefill(&self) -> bool {
463 self.no_assistant_prefill.unwrap_or_else(|| {
464 self.is_anthropic() && !self.model.contains("haiku")
466 })
467 }
468
469 pub fn resolved_cache_ttl(&self) -> Option<&str> {
471 if self.cache_ttl.is_some() {
472 return self.cache_ttl.as_deref();
473 }
474 if self.is_anthropic() {
475 Some("1h")
476 } else {
477 None
478 }
479 }
480
481 pub fn resolved_pin_provider(&self) -> Option<&str> {
483 if self.pin_provider.is_some() {
484 return self.pin_provider.as_deref();
485 }
486 if self.is_anthropic() {
487 Some("Anthropic")
488 } else {
489 None
490 }
491 }
492
493 pub fn apply_headers(&self, config: &mut openai_oxide::config::ClientConfig) {
496 if !self.extra_headers.is_empty() {
497 let mut hm = reqwest::header::HeaderMap::new();
498 for (k, v) in &self.extra_headers {
499 if let (Ok(name), Ok(val)) = (
500 reqwest::header::HeaderName::from_bytes(k.as_bytes()),
501 reqwest::header::HeaderValue::from_str(v),
502 ) {
503 hm.insert(name, val);
504 }
505 }
506 config.default_headers = Some(hm);
507 }
508 }
509
510 pub fn cli(cli_model: impl Into<String>) -> Self {
514 Self {
515 model: cli_model.into(),
516 use_cli: true,
517 ..Default::default()
518 }
519 }
520
521 pub fn label(&self) -> String {
523 if self.use_cli {
524 format!("CLI ({})", self.model)
525 } else if self.project_id.is_some() {
526 format!("Vertex ({})", self.model)
527 } else if self.base_url.is_some() {
528 format!("Custom ({})", self.model)
529 } else {
530 self.model.clone()
531 }
532 }
533
534 pub fn compaction_model(&self) -> String {
536 if self.model.starts_with("gemini") {
537 "gemini-2.0-flash-lite".into()
538 } else if self.model.starts_with("gpt") {
539 "gpt-4o-mini".into()
540 } else if self.model.starts_with("claude") {
541 "claude-3-haiku-20240307".into()
542 } else {
543 self.model.clone()
545 }
546 }
547
548 pub fn for_compaction(&self) -> Self {
550 let mut cfg = self.clone();
551 cfg.model = self.compaction_model();
552 cfg.max_tokens = Some(2048);
553 cfg
554 }
555}
556
557#[derive(Debug, Clone)]
559pub struct ProviderConfig {
560 pub api_key: String,
561 pub model: String,
562 pub base_url: Option<String>,
563 pub project_id: Option<String>,
564 pub location: Option<String>,
565 pub temperature: f32,
566 pub max_tokens: Option<u32>,
567}
568
569impl ProviderConfig {
570 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
571 Self {
572 api_key: api_key.into(),
573 model: model.into(),
574 base_url: None,
575 project_id: None,
576 location: None,
577 temperature: 0.3,
578 max_tokens: None,
579 }
580 }
581
582 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
583 Self {
584 api_key: api_key.into(),
585 model: model.into(),
586 base_url: None,
587 project_id: None,
588 location: None,
589 temperature: 0.3,
590 max_tokens: None,
591 }
592 }
593
594 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
595 Self {
596 api_key: api_key.into(),
597 model: model.into(),
598 base_url: Some("https://openrouter.ai/api/v1".into()),
599 project_id: None,
600 location: None,
601 temperature: 0.3,
602 max_tokens: None,
603 }
604 }
605
606 pub fn vertex(
607 access_token: impl Into<String>,
608 project_id: impl Into<String>,
609 model: impl Into<String>,
610 ) -> Self {
611 Self {
612 api_key: access_token.into(),
613 model: model.into(),
614 base_url: None,
615 project_id: Some(project_id.into()),
616 location: Some("global".to_string()),
617 temperature: 0.3,
618 max_tokens: None,
619 }
620 }
621
622 pub fn ollama(model: impl Into<String>) -> Self {
623 Self {
624 api_key: String::new(),
625 model: model.into(),
626 base_url: Some("http://localhost:11434/v1".into()),
627 project_id: None,
628 location: None,
629 temperature: 0.3,
630 max_tokens: None,
631 }
632 }
633}
634
635#[derive(Debug, thiserror::Error)]
637pub enum SgrError {
638 #[error("HTTP error: {0}")]
639 Http(#[from] reqwest::Error),
640 #[error("API error {status}: {body}")]
641 Api { status: u16, body: String },
642 #[error("Rate limit: {}", info.status_line())]
643 RateLimit { status: u16, info: RateLimitInfo },
644 #[error("JSON parse error: {0}")]
645 Json(#[from] serde_json::Error),
646 #[error("Schema error: {0}")]
647 Schema(String),
648 #[error("No content in response")]
649 EmptyResponse,
650 #[error("Response truncated (max_output_tokens): {partial_content}")]
653 MaxOutputTokens { partial_content: String },
654 #[error("Prompt too long: {0}")]
656 PromptTooLong(String),
657}
658
659impl SgrError {
660 pub fn from_api_response(status: u16, body: String) -> Self {
662 if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
663 && let Some(mut info) = RateLimitInfo::from_error_body(&body)
664 {
665 if info.message.is_none() {
666 info.message = Some(body.chars().take(200).collect());
667 }
668 return SgrError::RateLimit { status, info };
669 }
670 SgrError::Api { status, body }
671 }
672
673 pub fn from_response_parts(
675 status: u16,
676 body: String,
677 headers: &reqwest::header::HeaderMap,
678 ) -> Self {
679 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
680 let mut info = RateLimitInfo::from_error_body(&body)
681 .or_else(|| RateLimitInfo::from_headers(headers))
682 .unwrap_or(RateLimitInfo {
683 requests_remaining: None,
684 tokens_remaining: None,
685 retry_after_secs: None,
686 resets_at: None,
687 error_type: Some("rate_limit".into()),
688 message: Some(body.chars().take(200).collect()),
689 });
690 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
692 if info.requests_remaining.is_none() {
693 info.requests_remaining = header_info.requests_remaining;
694 }
695 if info.tokens_remaining.is_none() {
696 info.tokens_remaining = header_info.tokens_remaining;
697 }
698 }
699 return SgrError::RateLimit { status, info };
700 }
701 SgrError::Api { status, body }
702 }
703
704 pub fn is_rate_limit(&self) -> bool {
706 matches!(self, SgrError::RateLimit { .. })
707 }
708
709 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
711 match self {
712 SgrError::RateLimit { info, .. } => Some(info),
713 _ => None,
714 }
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[test]
723 fn parse_codex_rate_limit_error() {
724 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
725 let err = SgrError::from_api_response(429, body.to_string());
726 assert!(err.is_rate_limit());
727 let info = err.rate_limit_info().unwrap();
728 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
729 assert_eq!(info.retry_after_secs, Some(442787));
730 assert_eq!(info.resets_at, Some(1773534007));
731 assert_eq!(info.reset_display(), "5d 2h");
732 }
733
734 #[test]
735 fn parse_openai_rate_limit_error() {
736 let body =
737 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
738 let err = SgrError::from_api_response(429, body.to_string());
739 assert!(err.is_rate_limit());
740 let info = err.rate_limit_info().unwrap();
741 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
742 }
743
744 #[test]
745 fn non_rate_limit_stays_api_error() {
746 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
747 let err = SgrError::from_api_response(400, body.to_string());
748 assert!(!err.is_rate_limit());
749 assert!(matches!(err, SgrError::Api { status: 400, .. }));
750 }
751
752 #[test]
753 fn status_line_with_all_fields() {
754 let info = RateLimitInfo {
755 requests_remaining: Some(5),
756 tokens_remaining: Some(10000),
757 retry_after_secs: Some(3600),
758 resets_at: None,
759 error_type: None,
760 message: None,
761 };
762 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
763 }
764
765 #[test]
766 fn status_line_fallback_to_message() {
767 let info = RateLimitInfo {
768 requests_remaining: None,
769 tokens_remaining: None,
770 retry_after_secs: None,
771 resets_at: None,
772 error_type: None,
773 message: Some("custom message".into()),
774 };
775 assert_eq!(info.status_line(), "custom message");
776 }
777
778 #[test]
779 fn reset_display_formats() {
780 let make = |secs| RateLimitInfo {
781 requests_remaining: None,
782 tokens_remaining: None,
783 retry_after_secs: Some(secs),
784 resets_at: None,
785 error_type: None,
786 message: None,
787 };
788 assert_eq!(make(90).reset_display(), "1m");
789 assert_eq!(make(3661).reset_display(), "1h 1m");
790 assert_eq!(make(90000).reset_display(), "1d 1h");
791 }
792}