1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6 pub data: String,
8 pub mime_type: String,
10}
11
12impl ImagePart {
13 pub fn data_url(&self) -> String {
15 format!("data:{};base64,{}", self.mime_type, self.data)
16 }
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct Message {
22 pub role: Role,
23 pub content: String,
24 #[serde(skip_serializing_if = "Option::is_none")]
26 pub tool_call_id: Option<String>,
27 #[serde(default, skip_serializing_if = "Vec::is_empty")]
30 pub tool_calls: Vec<ToolCall>,
31 #[serde(default, skip_serializing_if = "Vec::is_empty")]
33 pub images: Vec<ImagePart>,
34 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
38 pub compactable: bool,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum Role {
44 System,
45 User,
46 Assistant,
47 Tool,
48}
49
50impl Message {
51 pub fn system(content: impl Into<String>) -> Self {
52 Self {
53 role: Role::System,
54 content: content.into(),
55 tool_call_id: None,
56 tool_calls: vec![],
57 images: vec![],
58 compactable: false,
59 }
60 }
61 pub fn user(content: impl Into<String>) -> Self {
62 Self {
63 role: Role::User,
64 content: content.into(),
65 tool_call_id: None,
66 tool_calls: vec![],
67 images: vec![],
68 compactable: false,
69 }
70 }
71 pub fn assistant(content: impl Into<String>) -> Self {
72 Self {
73 role: Role::Assistant,
74 content: content.into(),
75 tool_call_id: None,
76 tool_calls: vec![],
77 images: vec![],
78 compactable: false,
79 }
80 }
81 pub fn assistant_with_tool_calls(
83 content: impl Into<String>,
84 tool_calls: Vec<ToolCall>,
85 ) -> Self {
86 Self {
87 role: Role::Assistant,
88 content: content.into(),
89 tool_call_id: None,
90 tool_calls,
91 images: vec![],
92 compactable: false,
93 }
94 }
95 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
96 Self {
97 role: Role::Tool,
98 content: content.into(),
99 tool_call_id: Some(call_id.into()),
100 tool_calls: vec![],
101 images: vec![],
102 compactable: false,
103 }
104 }
105 pub fn tool_with_images(
107 call_id: impl Into<String>,
108 content: impl Into<String>,
109 images: Vec<ImagePart>,
110 ) -> Self {
111 Self {
112 role: Role::Tool,
113 content: content.into(),
114 tool_call_id: Some(call_id.into()),
115 tool_calls: vec![],
116 images,
117 compactable: false,
118 }
119 }
120 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
122 Self {
123 role: Role::User,
124 content: content.into(),
125 tool_call_id: None,
126 tool_calls: vec![],
127 images,
128 compactable: false,
129 }
130 }
131 pub fn compactable(mut self) -> Self {
133 self.compactable = true;
134 self
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct ToolCall {
141 pub id: String,
143 pub name: String,
145 pub arguments: serde_json::Value,
147}
148
149#[derive(Debug, Clone)]
151pub struct SgrResponse<T> {
152 pub output: Option<T>,
155 pub tool_calls: Vec<ToolCall>,
157 pub raw_text: String,
159 pub usage: Option<Usage>,
161 pub rate_limit: Option<RateLimitInfo>,
163}
164
165#[derive(Debug, Clone, Default, Serialize, Deserialize)]
166pub struct Usage {
167 pub prompt_tokens: u32,
168 pub completion_tokens: u32,
169 pub total_tokens: u32,
170}
171
172#[derive(Debug, Clone, Serialize, Deserialize)]
174pub struct RateLimitInfo {
175 pub requests_remaining: Option<u32>,
177 pub tokens_remaining: Option<u32>,
179 pub retry_after_secs: Option<u64>,
181 pub resets_at: Option<u64>,
183 pub error_type: Option<String>,
185 pub message: Option<String>,
187}
188
189impl RateLimitInfo {
190 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
192 let get_u32 =
193 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
194 let get_u64 =
195 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
196
197 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
198 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
199 let retry_after_secs =
200 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
201 let resets_at = get_u64("x-ratelimit-reset-tokens");
202
203 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
204 {
205 Some(Self {
206 requests_remaining,
207 tokens_remaining,
208 retry_after_secs,
209 resets_at,
210 error_type: None,
211 message: None,
212 })
213 } else {
214 None
215 }
216 }
217
218 pub fn from_error_body(body: &str) -> Option<Self> {
220 let json: serde_json::Value = serde_json::from_str(body).ok()?;
221 let err = json.get("error")?;
222
223 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
224 let message = err
225 .get("message")
226 .and_then(|v| v.as_str())
227 .map(String::from);
228 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
229 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
230
231 Some(Self {
232 requests_remaining: None,
233 tokens_remaining: None,
234 retry_after_secs,
235 resets_at,
236 error_type,
237 message,
238 })
239 }
240
241 pub fn reset_display(&self) -> String {
243 if let Some(secs) = self.retry_after_secs {
244 let hours = secs / 3600;
245 let mins = (secs % 3600) / 60;
246 if hours >= 24 {
247 format!("{}d {}h", hours / 24, hours % 24)
248 } else if hours > 0 {
249 format!("{}h {}m", hours, mins)
250 } else {
251 format!("{}m", mins)
252 }
253 } else {
254 "unknown".into()
255 }
256 }
257
258 pub fn status_line(&self) -> String {
260 let mut parts = Vec::new();
261 if let Some(r) = self.requests_remaining {
262 parts.push(format!("req:{}", r));
263 }
264 if let Some(t) = self.tokens_remaining {
265 parts.push(format!("tok:{}", t));
266 }
267 if self.retry_after_secs.is_some() {
268 parts.push(format!("reset:{}", self.reset_display()));
269 }
270 if parts.is_empty() {
271 self.message
272 .clone()
273 .unwrap_or_else(|| "rate limited".into())
274 } else {
275 parts.join(" | ")
276 }
277 }
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct LlmConfig {
296 pub model: String,
297 #[serde(default, skip_serializing_if = "Option::is_none")]
298 pub api_key: Option<String>,
299 #[serde(default, skip_serializing_if = "Option::is_none")]
300 pub base_url: Option<String>,
301 #[serde(default = "default_temperature")]
302 pub temp: f64,
303 #[serde(default, skip_serializing_if = "Option::is_none")]
304 pub max_tokens: Option<u32>,
305 #[serde(default, skip_serializing_if = "Option::is_none")]
307 pub prompt_cache_key: Option<String>,
308 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub project_id: Option<String>,
311 #[serde(default, skip_serializing_if = "Option::is_none")]
313 pub location: Option<String>,
314 #[serde(default)]
318 pub use_chat_api: bool,
319 #[serde(default, skip_serializing_if = "Vec::is_empty")]
322 pub extra_headers: Vec<(String, String)>,
323 #[serde(default, skip_serializing_if = "Option::is_none")]
326 pub reasoning_effort: Option<String>,
327 #[serde(default, skip_serializing_if = "Option::is_none")]
332 pub verbosity: Option<String>,
333 #[serde(default)]
336 pub use_genai: bool,
337 #[serde(default)]
341 pub use_cli: bool,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
345 pub session_id: Option<String>,
346
347 #[serde(default, skip_serializing_if = "Option::is_none")]
350 pub no_assistant_prefill: Option<bool>,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
353 pub cache_ttl: Option<String>,
354 #[serde(default, skip_serializing_if = "Option::is_none")]
356 pub pin_provider: Option<String>,
357 #[serde(default = "default_websocket")]
360 pub websocket: bool,
361}
362
363fn default_websocket() -> bool {
364 true
365}
366
367fn default_temperature() -> f64 {
368 0.7
369}
370
371impl Default for LlmConfig {
372 fn default() -> Self {
373 Self {
374 model: String::new(),
375 api_key: None,
376 base_url: None,
377 temp: default_temperature(),
378 max_tokens: None,
379 prompt_cache_key: None,
380 project_id: None,
381 location: None,
382 use_chat_api: false,
383 extra_headers: Vec::new(),
384 reasoning_effort: None,
385 verbosity: None,
386 use_genai: false,
387 use_cli: false,
388 session_id: None,
389 no_assistant_prefill: None,
390 cache_ttl: None,
391 pin_provider: None,
392 websocket: default_websocket(),
393 }
394 }
395}
396
397impl LlmConfig {
398 pub fn auto(model: impl Into<String>) -> Self {
400 Self {
401 model: model.into(),
402 ..Default::default()
403 }
404 }
405
406 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
408 Self {
409 model: model.into(),
410 api_key: Some(api_key.into()),
411 ..Default::default()
412 }
413 }
414
415 pub fn endpoint(
417 api_key: impl Into<String>,
418 base_url: impl Into<String>,
419 model: impl Into<String>,
420 ) -> Self {
421 Self {
422 model: model.into(),
423 api_key: Some(api_key.into()),
424 base_url: Some(base_url.into()),
425 ..Default::default()
426 }
427 }
428
429 pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
431 Self {
432 model: model.into(),
433 project_id: Some(project_id.into()),
434 location: Some("global".into()),
435 ..Default::default()
436 }
437 }
438
439 pub fn location(mut self, loc: impl Into<String>) -> Self {
441 self.location = Some(loc.into());
442 self
443 }
444
445 pub fn temperature(mut self, t: f64) -> Self {
447 self.temp = t;
448 self
449 }
450
451 pub fn max_tokens(mut self, m: u32) -> Self {
453 self.max_tokens = Some(m);
454 self
455 }
456
457 pub fn verbosity(mut self, v: impl Into<String>) -> Self {
461 self.verbosity = Some(v.into());
462 self
463 }
464
465 pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
467 self.prompt_cache_key = Some(key.into());
468 self
469 }
470
471 pub fn is_anthropic(&self) -> bool {
473 self.model.starts_with("anthropic/") || self.model.starts_with("claude")
474 }
475
476 pub fn rejects_prefill(&self) -> bool {
478 self.no_assistant_prefill.unwrap_or_else(|| {
479 self.is_anthropic() && !self.model.contains("haiku")
481 })
482 }
483
484 pub fn resolved_cache_ttl(&self) -> Option<&str> {
486 if self.cache_ttl.is_some() {
487 return self.cache_ttl.as_deref();
488 }
489 if self.is_anthropic() {
490 Some("1h")
491 } else {
492 None
493 }
494 }
495
496 pub fn resolved_pin_provider(&self) -> Option<&str> {
498 if self.pin_provider.is_some() {
499 return self.pin_provider.as_deref();
500 }
501 if self.is_anthropic() {
502 Some("Anthropic")
503 } else {
504 None
505 }
506 }
507
508 pub fn apply_headers(&self, config: &mut openai_oxide::config::ClientConfig) {
511 if !self.extra_headers.is_empty() {
512 let mut hm = reqwest::header::HeaderMap::new();
513 for (k, v) in &self.extra_headers {
514 if let (Ok(name), Ok(val)) = (
515 reqwest::header::HeaderName::from_bytes(k.as_bytes()),
516 reqwest::header::HeaderValue::from_str(v),
517 ) {
518 hm.insert(name, val);
519 }
520 }
521 config.default_headers = Some(hm);
522 }
523 }
524
525 pub fn cli(cli_model: impl Into<String>) -> Self {
529 Self {
530 model: cli_model.into(),
531 use_cli: true,
532 ..Default::default()
533 }
534 }
535
536 pub fn label(&self) -> String {
538 if self.use_cli {
539 format!("CLI ({})", self.model)
540 } else if self.project_id.is_some() {
541 format!("Vertex ({})", self.model)
542 } else if self.base_url.is_some() {
543 format!("Custom ({})", self.model)
544 } else {
545 self.model.clone()
546 }
547 }
548
549 pub fn compaction_model(&self) -> String {
551 if self.model.starts_with("gemini") {
552 "gemini-2.0-flash-lite".into()
553 } else if self.model.starts_with("gpt") {
554 "gpt-4o-mini".into()
555 } else if self.model.starts_with("claude") {
556 "claude-3-haiku-20240307".into()
557 } else {
558 self.model.clone()
560 }
561 }
562
563 pub fn for_compaction(&self) -> Self {
565 let mut cfg = self.clone();
566 cfg.model = self.compaction_model();
567 cfg.max_tokens = Some(2048);
568 cfg
569 }
570}
571
572#[derive(Debug, Clone)]
574pub struct ProviderConfig {
575 pub api_key: String,
576 pub model: String,
577 pub base_url: Option<String>,
578 pub project_id: Option<String>,
579 pub location: Option<String>,
580 pub temperature: f32,
581 pub max_tokens: Option<u32>,
582}
583
584impl ProviderConfig {
585 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
586 Self {
587 api_key: api_key.into(),
588 model: model.into(),
589 base_url: None,
590 project_id: None,
591 location: None,
592 temperature: 0.3,
593 max_tokens: None,
594 }
595 }
596
597 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
598 Self {
599 api_key: api_key.into(),
600 model: model.into(),
601 base_url: None,
602 project_id: None,
603 location: None,
604 temperature: 0.3,
605 max_tokens: None,
606 }
607 }
608
609 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
610 Self {
611 api_key: api_key.into(),
612 model: model.into(),
613 base_url: Some("https://openrouter.ai/api/v1".into()),
614 project_id: None,
615 location: None,
616 temperature: 0.3,
617 max_tokens: None,
618 }
619 }
620
621 pub fn vertex(
622 access_token: impl Into<String>,
623 project_id: impl Into<String>,
624 model: impl Into<String>,
625 ) -> Self {
626 Self {
627 api_key: access_token.into(),
628 model: model.into(),
629 base_url: None,
630 project_id: Some(project_id.into()),
631 location: Some("global".to_string()),
632 temperature: 0.3,
633 max_tokens: None,
634 }
635 }
636
637 pub fn ollama(model: impl Into<String>) -> Self {
638 Self {
639 api_key: String::new(),
640 model: model.into(),
641 base_url: Some("http://localhost:11434/v1".into()),
642 project_id: None,
643 location: None,
644 temperature: 0.3,
645 max_tokens: None,
646 }
647 }
648}
649
650#[derive(Debug, thiserror::Error)]
652pub enum SgrError {
653 #[error("HTTP error: {0}")]
654 Http(#[from] reqwest::Error),
655 #[error("API error {status}: {body}")]
656 Api { status: u16, body: String },
657 #[error("Rate limit: {}", info.status_line())]
658 RateLimit { status: u16, info: RateLimitInfo },
659 #[error("JSON parse error: {0}")]
660 Json(#[from] serde_json::Error),
661 #[error("Schema error: {0}")]
662 Schema(String),
663 #[error("No content in response")]
664 EmptyResponse,
665 #[error("Response truncated (max_output_tokens): {partial_content}")]
668 MaxOutputTokens { partial_content: String },
669 #[error("Prompt too long: {0}")]
671 PromptTooLong(String),
672}
673
674impl SgrError {
675 pub fn from_api_response(status: u16, body: String) -> Self {
677 if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
678 && let Some(mut info) = RateLimitInfo::from_error_body(&body)
679 {
680 if info.message.is_none() {
681 info.message = Some(body.chars().take(200).collect());
682 }
683 return SgrError::RateLimit { status, info };
684 }
685 SgrError::Api { status, body }
686 }
687
688 pub fn from_response_parts(
690 status: u16,
691 body: String,
692 headers: &reqwest::header::HeaderMap,
693 ) -> Self {
694 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
695 let mut info = RateLimitInfo::from_error_body(&body)
696 .or_else(|| RateLimitInfo::from_headers(headers))
697 .unwrap_or(RateLimitInfo {
698 requests_remaining: None,
699 tokens_remaining: None,
700 retry_after_secs: None,
701 resets_at: None,
702 error_type: Some("rate_limit".into()),
703 message: Some(body.chars().take(200).collect()),
704 });
705 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
707 if info.requests_remaining.is_none() {
708 info.requests_remaining = header_info.requests_remaining;
709 }
710 if info.tokens_remaining.is_none() {
711 info.tokens_remaining = header_info.tokens_remaining;
712 }
713 }
714 return SgrError::RateLimit { status, info };
715 }
716 SgrError::Api { status, body }
717 }
718
719 pub fn is_rate_limit(&self) -> bool {
721 matches!(self, SgrError::RateLimit { .. })
722 }
723
724 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
726 match self {
727 SgrError::RateLimit { info, .. } => Some(info),
728 _ => None,
729 }
730 }
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736
737 #[test]
738 fn parse_codex_rate_limit_error() {
739 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
740 let err = SgrError::from_api_response(429, body.to_string());
741 assert!(err.is_rate_limit());
742 let info = err.rate_limit_info().unwrap();
743 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
744 assert_eq!(info.retry_after_secs, Some(442787));
745 assert_eq!(info.resets_at, Some(1773534007));
746 assert_eq!(info.reset_display(), "5d 2h");
747 }
748
749 #[test]
750 fn parse_openai_rate_limit_error() {
751 let body =
752 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
753 let err = SgrError::from_api_response(429, body.to_string());
754 assert!(err.is_rate_limit());
755 let info = err.rate_limit_info().unwrap();
756 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
757 }
758
759 #[test]
760 fn non_rate_limit_stays_api_error() {
761 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
762 let err = SgrError::from_api_response(400, body.to_string());
763 assert!(!err.is_rate_limit());
764 assert!(matches!(err, SgrError::Api { status: 400, .. }));
765 }
766
767 #[test]
768 fn status_line_with_all_fields() {
769 let info = RateLimitInfo {
770 requests_remaining: Some(5),
771 tokens_remaining: Some(10000),
772 retry_after_secs: Some(3600),
773 resets_at: None,
774 error_type: None,
775 message: None,
776 };
777 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
778 }
779
780 #[test]
781 fn status_line_fallback_to_message() {
782 let info = RateLimitInfo {
783 requests_remaining: None,
784 tokens_remaining: None,
785 retry_after_secs: None,
786 resets_at: None,
787 error_type: None,
788 message: Some("custom message".into()),
789 };
790 assert_eq!(info.status_line(), "custom message");
791 }
792
793 #[test]
794 fn reset_display_formats() {
795 let make = |secs| RateLimitInfo {
796 requests_remaining: None,
797 tokens_remaining: None,
798 retry_after_secs: Some(secs),
799 resets_at: None,
800 error_type: None,
801 message: None,
802 };
803 assert_eq!(make(90).reset_display(), "1m");
804 assert_eq!(make(3661).reset_display(), "1h 1m");
805 assert_eq!(make(90000).reset_display(), "1d 1h");
806 }
807}