1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6 pub data: String,
8 pub mime_type: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15 pub role: Role,
16 pub content: String,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub tool_call_id: Option<String>,
20 #[serde(default, skip_serializing_if = "Vec::is_empty")]
23 pub tool_calls: Vec<ToolCall>,
24 #[serde(default, skip_serializing_if = "Vec::is_empty")]
26 pub images: Vec<ImagePart>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32 System,
33 User,
34 Assistant,
35 Tool,
36}
37
38impl Message {
39 pub fn system(content: impl Into<String>) -> Self {
40 Self {
41 role: Role::System,
42 content: content.into(),
43 tool_call_id: None,
44 tool_calls: vec![],
45 images: vec![],
46 }
47 }
48 pub fn user(content: impl Into<String>) -> Self {
49 Self {
50 role: Role::User,
51 content: content.into(),
52 tool_call_id: None,
53 tool_calls: vec![],
54 images: vec![],
55 }
56 }
57 pub fn assistant(content: impl Into<String>) -> Self {
58 Self {
59 role: Role::Assistant,
60 content: content.into(),
61 tool_call_id: None,
62 tool_calls: vec![],
63 images: vec![],
64 }
65 }
66 pub fn assistant_with_tool_calls(
68 content: impl Into<String>,
69 tool_calls: Vec<ToolCall>,
70 ) -> Self {
71 Self {
72 role: Role::Assistant,
73 content: content.into(),
74 tool_call_id: None,
75 tool_calls,
76 images: vec![],
77 }
78 }
79 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
80 Self {
81 role: Role::Tool,
82 content: content.into(),
83 tool_call_id: Some(call_id.into()),
84 tool_calls: vec![],
85 images: vec![],
86 }
87 }
88 pub fn tool_with_images(
90 call_id: impl Into<String>,
91 content: impl Into<String>,
92 images: Vec<ImagePart>,
93 ) -> Self {
94 Self {
95 role: Role::Tool,
96 content: content.into(),
97 tool_call_id: Some(call_id.into()),
98 tool_calls: vec![],
99 images,
100 }
101 }
102 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
104 Self {
105 role: Role::User,
106 content: content.into(),
107 tool_call_id: None,
108 tool_calls: vec![],
109 images,
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCall {
117 pub id: String,
119 pub name: String,
121 pub arguments: serde_json::Value,
123}
124
125#[derive(Debug, Clone)]
127pub struct SgrResponse<T> {
128 pub output: Option<T>,
131 pub tool_calls: Vec<ToolCall>,
133 pub raw_text: String,
135 pub usage: Option<Usage>,
137 pub rate_limit: Option<RateLimitInfo>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
142pub struct Usage {
143 pub prompt_tokens: u32,
144 pub completion_tokens: u32,
145 pub total_tokens: u32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RateLimitInfo {
151 pub requests_remaining: Option<u32>,
153 pub tokens_remaining: Option<u32>,
155 pub retry_after_secs: Option<u64>,
157 pub resets_at: Option<u64>,
159 pub error_type: Option<String>,
161 pub message: Option<String>,
163}
164
165impl RateLimitInfo {
166 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
168 let get_u32 =
169 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
170 let get_u64 =
171 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
172
173 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
174 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
175 let retry_after_secs =
176 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
177 let resets_at = get_u64("x-ratelimit-reset-tokens");
178
179 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
180 {
181 Some(Self {
182 requests_remaining,
183 tokens_remaining,
184 retry_after_secs,
185 resets_at,
186 error_type: None,
187 message: None,
188 })
189 } else {
190 None
191 }
192 }
193
194 pub fn from_error_body(body: &str) -> Option<Self> {
196 let json: serde_json::Value = serde_json::from_str(body).ok()?;
197 let err = json.get("error")?;
198
199 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
200 let message = err
201 .get("message")
202 .and_then(|v| v.as_str())
203 .map(String::from);
204 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
205 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
206
207 Some(Self {
208 requests_remaining: None,
209 tokens_remaining: None,
210 retry_after_secs,
211 resets_at,
212 error_type,
213 message,
214 })
215 }
216
217 pub fn reset_display(&self) -> String {
219 if let Some(secs) = self.retry_after_secs {
220 let hours = secs / 3600;
221 let mins = (secs % 3600) / 60;
222 if hours >= 24 {
223 format!("{}d {}h", hours / 24, hours % 24)
224 } else if hours > 0 {
225 format!("{}h {}m", hours, mins)
226 } else {
227 format!("{}m", mins)
228 }
229 } else {
230 "unknown".into()
231 }
232 }
233
234 pub fn status_line(&self) -> String {
236 let mut parts = Vec::new();
237 if let Some(r) = self.requests_remaining {
238 parts.push(format!("req:{}", r));
239 }
240 if let Some(t) = self.tokens_remaining {
241 parts.push(format!("tok:{}", t));
242 }
243 if self.retry_after_secs.is_some() {
244 parts.push(format!("reset:{}", self.reset_display()));
245 }
246 if parts.is_empty() {
247 self.message
248 .clone()
249 .unwrap_or_else(|| "rate limited".into())
250 } else {
251 parts.join(" | ")
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LlmConfig {
272 pub model: String,
273 #[serde(default, skip_serializing_if = "Option::is_none")]
274 pub api_key: Option<String>,
275 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub base_url: Option<String>,
277 #[serde(default = "default_temperature")]
278 pub temp: f64,
279 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub max_tokens: Option<u32>,
281 #[serde(default, skip_serializing_if = "Option::is_none")]
283 pub prompt_cache_key: Option<String>,
284 #[serde(default, skip_serializing_if = "Option::is_none")]
286 pub project_id: Option<String>,
287 #[serde(default, skip_serializing_if = "Option::is_none")]
289 pub location: Option<String>,
290}
291
292fn default_temperature() -> f64 {
293 0.7
294}
295
296impl Default for LlmConfig {
297 fn default() -> Self {
298 Self {
299 model: String::new(),
300 api_key: None,
301 base_url: None,
302 temp: default_temperature(),
303 max_tokens: None,
304 prompt_cache_key: None,
305 project_id: None,
306 location: None,
307 }
308 }
309}
310
311impl LlmConfig {
312 pub fn auto(model: impl Into<String>) -> Self {
314 Self {
315 model: model.into(),
316 ..Default::default()
317 }
318 }
319
320 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
322 Self {
323 model: model.into(),
324 api_key: Some(api_key.into()),
325 ..Default::default()
326 }
327 }
328
329 pub fn endpoint(
331 api_key: impl Into<String>,
332 base_url: impl Into<String>,
333 model: impl Into<String>,
334 ) -> Self {
335 Self {
336 model: model.into(),
337 api_key: Some(api_key.into()),
338 base_url: Some(base_url.into()),
339 ..Default::default()
340 }
341 }
342
343 pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
345 Self {
346 model: model.into(),
347 project_id: Some(project_id.into()),
348 location: Some("global".into()),
349 ..Default::default()
350 }
351 }
352
353 pub fn location(mut self, loc: impl Into<String>) -> Self {
355 self.location = Some(loc.into());
356 self
357 }
358
359 pub fn temperature(mut self, t: f64) -> Self {
361 self.temp = t;
362 self
363 }
364
365 pub fn max_tokens(mut self, m: u32) -> Self {
367 self.max_tokens = Some(m);
368 self
369 }
370
371 pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
373 self.prompt_cache_key = Some(key.into());
374 self
375 }
376
377 pub fn label(&self) -> String {
379 if self.project_id.is_some() {
380 format!("Vertex ({})", self.model)
381 } else if self.base_url.is_some() {
382 format!("Custom ({})", self.model)
383 } else {
384 self.model.clone()
385 }
386 }
387
388 pub fn compaction_model(&self) -> String {
390 if self.model.starts_with("gemini") {
391 "gemini-2.0-flash-lite".into()
392 } else if self.model.starts_with("gpt") {
393 "gpt-4o-mini".into()
394 } else if self.model.starts_with("claude") {
395 "claude-3-haiku-20240307".into()
396 } else {
397 self.model.clone()
399 }
400 }
401
402 pub fn for_compaction(&self) -> Self {
404 let mut cfg = self.clone();
405 cfg.model = self.compaction_model();
406 cfg.max_tokens = Some(2048);
407 cfg
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct ProviderConfig {
414 pub api_key: String,
415 pub model: String,
416 pub base_url: Option<String>,
417 pub project_id: Option<String>,
418 pub location: Option<String>,
419 pub temperature: f32,
420 pub max_tokens: Option<u32>,
421}
422
423impl ProviderConfig {
424 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
425 Self {
426 api_key: api_key.into(),
427 model: model.into(),
428 base_url: None,
429 project_id: None,
430 location: None,
431 temperature: 0.3,
432 max_tokens: None,
433 }
434 }
435
436 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
437 Self {
438 api_key: api_key.into(),
439 model: model.into(),
440 base_url: None,
441 project_id: None,
442 location: None,
443 temperature: 0.3,
444 max_tokens: None,
445 }
446 }
447
448 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
449 Self {
450 api_key: api_key.into(),
451 model: model.into(),
452 base_url: Some("https://openrouter.ai/api/v1".into()),
453 project_id: None,
454 location: None,
455 temperature: 0.3,
456 max_tokens: None,
457 }
458 }
459
460 pub fn vertex(
461 access_token: impl Into<String>,
462 project_id: impl Into<String>,
463 model: impl Into<String>,
464 ) -> Self {
465 Self {
466 api_key: access_token.into(),
467 model: model.into(),
468 base_url: None,
469 project_id: Some(project_id.into()),
470 location: Some("global".to_string()),
471 temperature: 0.3,
472 max_tokens: None,
473 }
474 }
475
476 pub fn ollama(model: impl Into<String>) -> Self {
477 Self {
478 api_key: String::new(),
479 model: model.into(),
480 base_url: Some("http://localhost:11434/v1".into()),
481 project_id: None,
482 location: None,
483 temperature: 0.3,
484 max_tokens: None,
485 }
486 }
487}
488
489#[derive(Debug, thiserror::Error)]
491pub enum SgrError {
492 #[error("HTTP error: {0}")]
493 Http(#[from] reqwest::Error),
494 #[error("API error {status}: {body}")]
495 Api { status: u16, body: String },
496 #[error("Rate limit: {}", info.status_line())]
497 RateLimit { status: u16, info: RateLimitInfo },
498 #[error("JSON parse error: {0}")]
499 Json(#[from] serde_json::Error),
500 #[error("Schema error: {0}")]
501 Schema(String),
502 #[error("No content in response")]
503 EmptyResponse,
504}
505
506impl SgrError {
507 pub fn from_api_response(status: u16, body: String) -> Self {
509 if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
510 && let Some(mut info) = RateLimitInfo::from_error_body(&body)
511 {
512 if info.message.is_none() {
513 info.message = Some(body.chars().take(200).collect());
514 }
515 return SgrError::RateLimit { status, info };
516 }
517 SgrError::Api { status, body }
518 }
519
520 pub fn from_response_parts(
522 status: u16,
523 body: String,
524 headers: &reqwest::header::HeaderMap,
525 ) -> Self {
526 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
527 let mut info = RateLimitInfo::from_error_body(&body)
528 .or_else(|| RateLimitInfo::from_headers(headers))
529 .unwrap_or(RateLimitInfo {
530 requests_remaining: None,
531 tokens_remaining: None,
532 retry_after_secs: None,
533 resets_at: None,
534 error_type: Some("rate_limit".into()),
535 message: Some(body.chars().take(200).collect()),
536 });
537 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
539 if info.requests_remaining.is_none() {
540 info.requests_remaining = header_info.requests_remaining;
541 }
542 if info.tokens_remaining.is_none() {
543 info.tokens_remaining = header_info.tokens_remaining;
544 }
545 }
546 return SgrError::RateLimit { status, info };
547 }
548 SgrError::Api { status, body }
549 }
550
551 pub fn is_rate_limit(&self) -> bool {
553 matches!(self, SgrError::RateLimit { .. })
554 }
555
556 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
558 match self {
559 SgrError::RateLimit { info, .. } => Some(info),
560 _ => None,
561 }
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
570 fn parse_codex_rate_limit_error() {
571 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
572 let err = SgrError::from_api_response(429, body.to_string());
573 assert!(err.is_rate_limit());
574 let info = err.rate_limit_info().unwrap();
575 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
576 assert_eq!(info.retry_after_secs, Some(442787));
577 assert_eq!(info.resets_at, Some(1773534007));
578 assert_eq!(info.reset_display(), "5d 2h");
579 }
580
581 #[test]
582 fn parse_openai_rate_limit_error() {
583 let body =
584 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
585 let err = SgrError::from_api_response(429, body.to_string());
586 assert!(err.is_rate_limit());
587 let info = err.rate_limit_info().unwrap();
588 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
589 }
590
591 #[test]
592 fn non_rate_limit_stays_api_error() {
593 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
594 let err = SgrError::from_api_response(400, body.to_string());
595 assert!(!err.is_rate_limit());
596 assert!(matches!(err, SgrError::Api { status: 400, .. }));
597 }
598
599 #[test]
600 fn status_line_with_all_fields() {
601 let info = RateLimitInfo {
602 requests_remaining: Some(5),
603 tokens_remaining: Some(10000),
604 retry_after_secs: Some(3600),
605 resets_at: None,
606 error_type: None,
607 message: None,
608 };
609 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
610 }
611
612 #[test]
613 fn status_line_fallback_to_message() {
614 let info = RateLimitInfo {
615 requests_remaining: None,
616 tokens_remaining: None,
617 retry_after_secs: None,
618 resets_at: None,
619 error_type: None,
620 message: Some("custom message".into()),
621 };
622 assert_eq!(info.status_line(), "custom message");
623 }
624
625 #[test]
626 fn reset_display_formats() {
627 let make = |secs| RateLimitInfo {
628 requests_remaining: None,
629 tokens_remaining: None,
630 retry_after_secs: Some(secs),
631 resets_at: None,
632 error_type: None,
633 message: None,
634 };
635 assert_eq!(make(90).reset_display(), "1m");
636 assert_eq!(make(3661).reset_display(), "1h 1m");
637 assert_eq!(make(90000).reset_display(), "1d 1h");
638 }
639}