1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6 pub data: String,
8 pub mime_type: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15 pub role: Role,
16 pub content: String,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub tool_call_id: Option<String>,
20 #[serde(default, skip_serializing_if = "Vec::is_empty")]
23 pub tool_calls: Vec<ToolCall>,
24 #[serde(default, skip_serializing_if = "Vec::is_empty")]
26 pub images: Vec<ImagePart>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32 System,
33 User,
34 Assistant,
35 Tool,
36}
37
38impl Message {
39 pub fn system(content: impl Into<String>) -> Self {
40 Self {
41 role: Role::System,
42 content: content.into(),
43 tool_call_id: None,
44 tool_calls: vec![],
45 images: vec![],
46 }
47 }
48 pub fn user(content: impl Into<String>) -> Self {
49 Self {
50 role: Role::User,
51 content: content.into(),
52 tool_call_id: None,
53 tool_calls: vec![],
54 images: vec![],
55 }
56 }
57 pub fn assistant(content: impl Into<String>) -> Self {
58 Self {
59 role: Role::Assistant,
60 content: content.into(),
61 tool_call_id: None,
62 tool_calls: vec![],
63 images: vec![],
64 }
65 }
66 pub fn assistant_with_tool_calls(
68 content: impl Into<String>,
69 tool_calls: Vec<ToolCall>,
70 ) -> Self {
71 Self {
72 role: Role::Assistant,
73 content: content.into(),
74 tool_call_id: None,
75 tool_calls,
76 images: vec![],
77 }
78 }
79 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
80 Self {
81 role: Role::Tool,
82 content: content.into(),
83 tool_call_id: Some(call_id.into()),
84 tool_calls: vec![],
85 images: vec![],
86 }
87 }
88 pub fn tool_with_images(
90 call_id: impl Into<String>,
91 content: impl Into<String>,
92 images: Vec<ImagePart>,
93 ) -> Self {
94 Self {
95 role: Role::Tool,
96 content: content.into(),
97 tool_call_id: Some(call_id.into()),
98 tool_calls: vec![],
99 images,
100 }
101 }
102 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
104 Self {
105 role: Role::User,
106 content: content.into(),
107 tool_call_id: None,
108 tool_calls: vec![],
109 images,
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCall {
117 pub id: String,
119 pub name: String,
121 pub arguments: serde_json::Value,
123}
124
125#[derive(Debug, Clone)]
127pub struct SgrResponse<T> {
128 pub output: Option<T>,
131 pub tool_calls: Vec<ToolCall>,
133 pub raw_text: String,
135 pub usage: Option<Usage>,
137 pub rate_limit: Option<RateLimitInfo>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
142pub struct Usage {
143 pub prompt_tokens: u32,
144 pub completion_tokens: u32,
145 pub total_tokens: u32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RateLimitInfo {
151 pub requests_remaining: Option<u32>,
153 pub tokens_remaining: Option<u32>,
155 pub retry_after_secs: Option<u64>,
157 pub resets_at: Option<u64>,
159 pub error_type: Option<String>,
161 pub message: Option<String>,
163}
164
165impl RateLimitInfo {
166 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
168 let get_u32 =
169 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
170 let get_u64 =
171 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
172
173 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
174 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
175 let retry_after_secs =
176 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
177 let resets_at = get_u64("x-ratelimit-reset-tokens");
178
179 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
180 {
181 Some(Self {
182 requests_remaining,
183 tokens_remaining,
184 retry_after_secs,
185 resets_at,
186 error_type: None,
187 message: None,
188 })
189 } else {
190 None
191 }
192 }
193
194 pub fn from_error_body(body: &str) -> Option<Self> {
196 let json: serde_json::Value = serde_json::from_str(body).ok()?;
197 let err = json.get("error")?;
198
199 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
200 let message = err
201 .get("message")
202 .and_then(|v| v.as_str())
203 .map(String::from);
204 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
205 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
206
207 Some(Self {
208 requests_remaining: None,
209 tokens_remaining: None,
210 retry_after_secs,
211 resets_at,
212 error_type,
213 message,
214 })
215 }
216
217 pub fn reset_display(&self) -> String {
219 if let Some(secs) = self.retry_after_secs {
220 let hours = secs / 3600;
221 let mins = (secs % 3600) / 60;
222 if hours >= 24 {
223 format!("{}d {}h", hours / 24, hours % 24)
224 } else if hours > 0 {
225 format!("{}h {}m", hours, mins)
226 } else {
227 format!("{}m", mins)
228 }
229 } else {
230 "unknown".into()
231 }
232 }
233
234 pub fn status_line(&self) -> String {
236 let mut parts = Vec::new();
237 if let Some(r) = self.requests_remaining {
238 parts.push(format!("req:{}", r));
239 }
240 if let Some(t) = self.tokens_remaining {
241 parts.push(format!("tok:{}", t));
242 }
243 if self.retry_after_secs.is_some() {
244 parts.push(format!("reset:{}", self.reset_display()));
245 }
246 if parts.is_empty() {
247 self.message
248 .clone()
249 .unwrap_or_else(|| "rate limited".into())
250 } else {
251 parts.join(" | ")
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LlmConfig {
272 pub model: String,
273 #[serde(default, skip_serializing_if = "Option::is_none")]
274 pub api_key: Option<String>,
275 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub base_url: Option<String>,
277 #[serde(default = "default_temperature")]
278 pub temp: f64,
279 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub max_tokens: Option<u32>,
281 #[serde(default, skip_serializing_if = "Option::is_none")]
283 pub prompt_cache_key: Option<String>,
284 #[serde(default, skip_serializing_if = "Option::is_none")]
286 pub project_id: Option<String>,
287 #[serde(default, skip_serializing_if = "Option::is_none")]
289 pub location: Option<String>,
290 #[serde(default)]
294 pub use_chat_api: bool,
295}
296
297fn default_temperature() -> f64 {
298 0.7
299}
300
301impl Default for LlmConfig {
302 fn default() -> Self {
303 Self {
304 model: String::new(),
305 api_key: None,
306 base_url: None,
307 temp: default_temperature(),
308 max_tokens: None,
309 prompt_cache_key: None,
310 project_id: None,
311 location: None,
312 use_chat_api: false,
313 }
314 }
315}
316
317impl LlmConfig {
318 pub fn auto(model: impl Into<String>) -> Self {
320 Self {
321 model: model.into(),
322 ..Default::default()
323 }
324 }
325
326 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
328 Self {
329 model: model.into(),
330 api_key: Some(api_key.into()),
331 ..Default::default()
332 }
333 }
334
335 pub fn endpoint(
337 api_key: impl Into<String>,
338 base_url: impl Into<String>,
339 model: impl Into<String>,
340 ) -> Self {
341 Self {
342 model: model.into(),
343 api_key: Some(api_key.into()),
344 base_url: Some(base_url.into()),
345 ..Default::default()
346 }
347 }
348
349 pub fn vertex(project_id: impl Into<String>, model: impl Into<String>) -> Self {
351 Self {
352 model: model.into(),
353 project_id: Some(project_id.into()),
354 location: Some("global".into()),
355 ..Default::default()
356 }
357 }
358
359 pub fn location(mut self, loc: impl Into<String>) -> Self {
361 self.location = Some(loc.into());
362 self
363 }
364
365 pub fn temperature(mut self, t: f64) -> Self {
367 self.temp = t;
368 self
369 }
370
371 pub fn max_tokens(mut self, m: u32) -> Self {
373 self.max_tokens = Some(m);
374 self
375 }
376
377 pub fn prompt_cache_key(mut self, key: impl Into<String>) -> Self {
379 self.prompt_cache_key = Some(key.into());
380 self
381 }
382
383 pub fn label(&self) -> String {
385 if self.project_id.is_some() {
386 format!("Vertex ({})", self.model)
387 } else if self.base_url.is_some() {
388 format!("Custom ({})", self.model)
389 } else {
390 self.model.clone()
391 }
392 }
393
394 pub fn compaction_model(&self) -> String {
396 if self.model.starts_with("gemini") {
397 "gemini-2.0-flash-lite".into()
398 } else if self.model.starts_with("gpt") {
399 "gpt-4o-mini".into()
400 } else if self.model.starts_with("claude") {
401 "claude-3-haiku-20240307".into()
402 } else {
403 self.model.clone()
405 }
406 }
407
408 pub fn for_compaction(&self) -> Self {
410 let mut cfg = self.clone();
411 cfg.model = self.compaction_model();
412 cfg.max_tokens = Some(2048);
413 cfg
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct ProviderConfig {
420 pub api_key: String,
421 pub model: String,
422 pub base_url: Option<String>,
423 pub project_id: Option<String>,
424 pub location: Option<String>,
425 pub temperature: f32,
426 pub max_tokens: Option<u32>,
427}
428
429impl ProviderConfig {
430 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
431 Self {
432 api_key: api_key.into(),
433 model: model.into(),
434 base_url: None,
435 project_id: None,
436 location: None,
437 temperature: 0.3,
438 max_tokens: None,
439 }
440 }
441
442 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
443 Self {
444 api_key: api_key.into(),
445 model: model.into(),
446 base_url: None,
447 project_id: None,
448 location: None,
449 temperature: 0.3,
450 max_tokens: None,
451 }
452 }
453
454 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
455 Self {
456 api_key: api_key.into(),
457 model: model.into(),
458 base_url: Some("https://openrouter.ai/api/v1".into()),
459 project_id: None,
460 location: None,
461 temperature: 0.3,
462 max_tokens: None,
463 }
464 }
465
466 pub fn vertex(
467 access_token: impl Into<String>,
468 project_id: impl Into<String>,
469 model: impl Into<String>,
470 ) -> Self {
471 Self {
472 api_key: access_token.into(),
473 model: model.into(),
474 base_url: None,
475 project_id: Some(project_id.into()),
476 location: Some("global".to_string()),
477 temperature: 0.3,
478 max_tokens: None,
479 }
480 }
481
482 pub fn ollama(model: impl Into<String>) -> Self {
483 Self {
484 api_key: String::new(),
485 model: model.into(),
486 base_url: Some("http://localhost:11434/v1".into()),
487 project_id: None,
488 location: None,
489 temperature: 0.3,
490 max_tokens: None,
491 }
492 }
493}
494
495#[derive(Debug, thiserror::Error)]
497pub enum SgrError {
498 #[error("HTTP error: {0}")]
499 Http(#[from] reqwest::Error),
500 #[error("API error {status}: {body}")]
501 Api { status: u16, body: String },
502 #[error("Rate limit: {}", info.status_line())]
503 RateLimit { status: u16, info: RateLimitInfo },
504 #[error("JSON parse error: {0}")]
505 Json(#[from] serde_json::Error),
506 #[error("Schema error: {0}")]
507 Schema(String),
508 #[error("No content in response")]
509 EmptyResponse,
510}
511
512impl SgrError {
513 pub fn from_api_response(status: u16, body: String) -> Self {
515 if (status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit"))
516 && let Some(mut info) = RateLimitInfo::from_error_body(&body)
517 {
518 if info.message.is_none() {
519 info.message = Some(body.chars().take(200).collect());
520 }
521 return SgrError::RateLimit { status, info };
522 }
523 SgrError::Api { status, body }
524 }
525
526 pub fn from_response_parts(
528 status: u16,
529 body: String,
530 headers: &reqwest::header::HeaderMap,
531 ) -> Self {
532 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
533 let mut info = RateLimitInfo::from_error_body(&body)
534 .or_else(|| RateLimitInfo::from_headers(headers))
535 .unwrap_or(RateLimitInfo {
536 requests_remaining: None,
537 tokens_remaining: None,
538 retry_after_secs: None,
539 resets_at: None,
540 error_type: Some("rate_limit".into()),
541 message: Some(body.chars().take(200).collect()),
542 });
543 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
545 if info.requests_remaining.is_none() {
546 info.requests_remaining = header_info.requests_remaining;
547 }
548 if info.tokens_remaining.is_none() {
549 info.tokens_remaining = header_info.tokens_remaining;
550 }
551 }
552 return SgrError::RateLimit { status, info };
553 }
554 SgrError::Api { status, body }
555 }
556
557 pub fn is_rate_limit(&self) -> bool {
559 matches!(self, SgrError::RateLimit { .. })
560 }
561
562 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
564 match self {
565 SgrError::RateLimit { info, .. } => Some(info),
566 _ => None,
567 }
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn parse_codex_rate_limit_error() {
577 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
578 let err = SgrError::from_api_response(429, body.to_string());
579 assert!(err.is_rate_limit());
580 let info = err.rate_limit_info().unwrap();
581 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
582 assert_eq!(info.retry_after_secs, Some(442787));
583 assert_eq!(info.resets_at, Some(1773534007));
584 assert_eq!(info.reset_display(), "5d 2h");
585 }
586
587 #[test]
588 fn parse_openai_rate_limit_error() {
589 let body =
590 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
591 let err = SgrError::from_api_response(429, body.to_string());
592 assert!(err.is_rate_limit());
593 let info = err.rate_limit_info().unwrap();
594 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
595 }
596
597 #[test]
598 fn non_rate_limit_stays_api_error() {
599 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
600 let err = SgrError::from_api_response(400, body.to_string());
601 assert!(!err.is_rate_limit());
602 assert!(matches!(err, SgrError::Api { status: 400, .. }));
603 }
604
605 #[test]
606 fn status_line_with_all_fields() {
607 let info = RateLimitInfo {
608 requests_remaining: Some(5),
609 tokens_remaining: Some(10000),
610 retry_after_secs: Some(3600),
611 resets_at: None,
612 error_type: None,
613 message: None,
614 };
615 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
616 }
617
618 #[test]
619 fn status_line_fallback_to_message() {
620 let info = RateLimitInfo {
621 requests_remaining: None,
622 tokens_remaining: None,
623 retry_after_secs: None,
624 resets_at: None,
625 error_type: None,
626 message: Some("custom message".into()),
627 };
628 assert_eq!(info.status_line(), "custom message");
629 }
630
631 #[test]
632 fn reset_display_formats() {
633 let make = |secs| RateLimitInfo {
634 requests_remaining: None,
635 tokens_remaining: None,
636 retry_after_secs: Some(secs),
637 resets_at: None,
638 error_type: None,
639 message: None,
640 };
641 assert_eq!(make(90).reset_display(), "1m");
642 assert_eq!(make(3661).reset_display(), "1h 1m");
643 assert_eq!(make(90000).reset_display(), "1d 1h");
644 }
645}