1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct ImagePart {
6 pub data: String,
8 pub mime_type: String,
10}
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Message {
15 pub role: Role,
16 pub content: String,
17 #[serde(skip_serializing_if = "Option::is_none")]
19 pub tool_call_id: Option<String>,
20 #[serde(default, skip_serializing_if = "Vec::is_empty")]
23 pub tool_calls: Vec<ToolCall>,
24 #[serde(default, skip_serializing_if = "Vec::is_empty")]
26 pub images: Vec<ImagePart>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30#[serde(rename_all = "lowercase")]
31pub enum Role {
32 System,
33 User,
34 Assistant,
35 Tool,
36}
37
38impl Message {
39 pub fn system(content: impl Into<String>) -> Self {
40 Self {
41 role: Role::System,
42 content: content.into(),
43 tool_call_id: None,
44 tool_calls: vec![],
45 images: vec![],
46 }
47 }
48 pub fn user(content: impl Into<String>) -> Self {
49 Self {
50 role: Role::User,
51 content: content.into(),
52 tool_call_id: None,
53 tool_calls: vec![],
54 images: vec![],
55 }
56 }
57 pub fn assistant(content: impl Into<String>) -> Self {
58 Self {
59 role: Role::Assistant,
60 content: content.into(),
61 tool_call_id: None,
62 tool_calls: vec![],
63 images: vec![],
64 }
65 }
66 pub fn assistant_with_tool_calls(
68 content: impl Into<String>,
69 tool_calls: Vec<ToolCall>,
70 ) -> Self {
71 Self {
72 role: Role::Assistant,
73 content: content.into(),
74 tool_call_id: None,
75 tool_calls,
76 images: vec![],
77 }
78 }
79 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
80 Self {
81 role: Role::Tool,
82 content: content.into(),
83 tool_call_id: Some(call_id.into()),
84 tool_calls: vec![],
85 images: vec![],
86 }
87 }
88 pub fn tool_with_images(
90 call_id: impl Into<String>,
91 content: impl Into<String>,
92 images: Vec<ImagePart>,
93 ) -> Self {
94 Self {
95 role: Role::Tool,
96 content: content.into(),
97 tool_call_id: Some(call_id.into()),
98 tool_calls: vec![],
99 images,
100 }
101 }
102 pub fn user_with_images(content: impl Into<String>, images: Vec<ImagePart>) -> Self {
104 Self {
105 role: Role::User,
106 content: content.into(),
107 tool_call_id: None,
108 tool_calls: vec![],
109 images,
110 }
111 }
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ToolCall {
117 pub id: String,
119 pub name: String,
121 pub arguments: serde_json::Value,
123}
124
125#[derive(Debug, Clone)]
127pub struct SgrResponse<T> {
128 pub output: Option<T>,
131 pub tool_calls: Vec<ToolCall>,
133 pub raw_text: String,
135 pub usage: Option<Usage>,
137 pub rate_limit: Option<RateLimitInfo>,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
142pub struct Usage {
143 pub prompt_tokens: u32,
144 pub completion_tokens: u32,
145 pub total_tokens: u32,
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RateLimitInfo {
151 pub requests_remaining: Option<u32>,
153 pub tokens_remaining: Option<u32>,
155 pub retry_after_secs: Option<u64>,
157 pub resets_at: Option<u64>,
159 pub error_type: Option<String>,
161 pub message: Option<String>,
163}
164
165impl RateLimitInfo {
166 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
168 let get_u32 =
169 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
170 let get_u64 =
171 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
172
173 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
174 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
175 let retry_after_secs =
176 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
177 let resets_at = get_u64("x-ratelimit-reset-tokens");
178
179 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
180 {
181 Some(Self {
182 requests_remaining,
183 tokens_remaining,
184 retry_after_secs,
185 resets_at,
186 error_type: None,
187 message: None,
188 })
189 } else {
190 None
191 }
192 }
193
194 pub fn from_error_body(body: &str) -> Option<Self> {
196 let json: serde_json::Value = serde_json::from_str(body).ok()?;
197 let err = json.get("error")?;
198
199 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
200 let message = err
201 .get("message")
202 .and_then(|v| v.as_str())
203 .map(String::from);
204 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
205 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
206
207 Some(Self {
208 requests_remaining: None,
209 tokens_remaining: None,
210 retry_after_secs,
211 resets_at,
212 error_type,
213 message,
214 })
215 }
216
217 pub fn reset_display(&self) -> String {
219 if let Some(secs) = self.retry_after_secs {
220 let hours = secs / 3600;
221 let mins = (secs % 3600) / 60;
222 if hours >= 24 {
223 format!("{}d {}h", hours / 24, hours % 24)
224 } else if hours > 0 {
225 format!("{}h {}m", hours, mins)
226 } else {
227 format!("{}m", mins)
228 }
229 } else {
230 "unknown".into()
231 }
232 }
233
234 pub fn status_line(&self) -> String {
236 let mut parts = Vec::new();
237 if let Some(r) = self.requests_remaining {
238 parts.push(format!("req:{}", r));
239 }
240 if let Some(t) = self.tokens_remaining {
241 parts.push(format!("tok:{}", t));
242 }
243 if self.retry_after_secs.is_some() {
244 parts.push(format!("reset:{}", self.reset_display()));
245 }
246 if parts.is_empty() {
247 self.message
248 .clone()
249 .unwrap_or_else(|| "rate limited".into())
250 } else {
251 parts.join(" | ")
252 }
253 }
254}
255
256#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct LlmConfig {
272 pub model: String,
273 #[serde(default, skip_serializing_if = "Option::is_none")]
274 pub api_key: Option<String>,
275 #[serde(default, skip_serializing_if = "Option::is_none")]
276 pub base_url: Option<String>,
277 #[serde(default = "default_temperature")]
278 pub temp: f64,
279 #[serde(default, skip_serializing_if = "Option::is_none")]
280 pub max_tokens: Option<u32>,
281}
282
283fn default_temperature() -> f64 {
284 0.7
285}
286
287impl LlmConfig {
288 pub fn auto(model: impl Into<String>) -> Self {
290 Self {
291 model: model.into(),
292 api_key: None,
293 base_url: None,
294 temp: default_temperature(),
295 max_tokens: None,
296 }
297 }
298
299 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
301 Self {
302 model: model.into(),
303 api_key: Some(api_key.into()),
304 base_url: None,
305 temp: default_temperature(),
306 max_tokens: None,
307 }
308 }
309
310 pub fn endpoint(
312 api_key: impl Into<String>,
313 base_url: impl Into<String>,
314 model: impl Into<String>,
315 ) -> Self {
316 Self {
317 model: model.into(),
318 api_key: Some(api_key.into()),
319 base_url: Some(base_url.into()),
320 temp: default_temperature(),
321 max_tokens: None,
322 }
323 }
324
325 pub fn temperature(mut self, t: f64) -> Self {
327 self.temp = t;
328 self
329 }
330
331 pub fn max_tokens(mut self, m: u32) -> Self {
333 self.max_tokens = Some(m);
334 self
335 }
336}
337
338#[derive(Debug, Clone)]
340pub struct ProviderConfig {
341 pub api_key: String,
342 pub model: String,
343 pub base_url: Option<String>,
344 pub project_id: Option<String>,
345 pub location: Option<String>,
346 pub temperature: f32,
347 pub max_tokens: Option<u32>,
348}
349
350impl ProviderConfig {
351 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
352 Self {
353 api_key: api_key.into(),
354 model: model.into(),
355 base_url: None,
356 project_id: None,
357 location: None,
358 temperature: 0.3,
359 max_tokens: None,
360 }
361 }
362
363 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
364 Self {
365 api_key: api_key.into(),
366 model: model.into(),
367 base_url: None,
368 project_id: None,
369 location: None,
370 temperature: 0.3,
371 max_tokens: None,
372 }
373 }
374
375 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
376 Self {
377 api_key: api_key.into(),
378 model: model.into(),
379 base_url: Some("https://openrouter.ai/api/v1".into()),
380 project_id: None,
381 location: None,
382 temperature: 0.3,
383 max_tokens: None,
384 }
385 }
386
387 pub fn vertex(
388 access_token: impl Into<String>,
389 project_id: impl Into<String>,
390 model: impl Into<String>,
391 ) -> Self {
392 Self {
393 api_key: access_token.into(),
394 model: model.into(),
395 base_url: None,
396 project_id: Some(project_id.into()),
397 location: Some("global".to_string()),
398 temperature: 0.3,
399 max_tokens: None,
400 }
401 }
402
403 pub fn ollama(model: impl Into<String>) -> Self {
404 Self {
405 api_key: String::new(),
406 model: model.into(),
407 base_url: Some("http://localhost:11434/v1".into()),
408 project_id: None,
409 location: None,
410 temperature: 0.3,
411 max_tokens: None,
412 }
413 }
414}
415
416#[derive(Debug, thiserror::Error)]
418pub enum SgrError {
419 #[error("HTTP error: {0}")]
420 Http(#[from] reqwest::Error),
421 #[error("API error {status}: {body}")]
422 Api { status: u16, body: String },
423 #[error("Rate limit: {}", info.status_line())]
424 RateLimit { status: u16, info: RateLimitInfo },
425 #[error("JSON parse error: {0}")]
426 Json(#[from] serde_json::Error),
427 #[error("Schema error: {0}")]
428 Schema(String),
429 #[error("No content in response")]
430 EmptyResponse,
431}
432
433impl SgrError {
434 pub fn from_api_response(status: u16, body: String) -> Self {
436 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
437 if let Some(mut info) = RateLimitInfo::from_error_body(&body) {
438 if info.message.is_none() {
439 info.message = Some(body.chars().take(200).collect());
440 }
441 return SgrError::RateLimit { status, info };
442 }
443 }
444 SgrError::Api { status, body }
445 }
446
447 pub fn from_response_parts(
449 status: u16,
450 body: String,
451 headers: &reqwest::header::HeaderMap,
452 ) -> Self {
453 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
454 let mut info = RateLimitInfo::from_error_body(&body)
455 .or_else(|| RateLimitInfo::from_headers(headers))
456 .unwrap_or(RateLimitInfo {
457 requests_remaining: None,
458 tokens_remaining: None,
459 retry_after_secs: None,
460 resets_at: None,
461 error_type: Some("rate_limit".into()),
462 message: Some(body.chars().take(200).collect()),
463 });
464 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
466 if info.requests_remaining.is_none() {
467 info.requests_remaining = header_info.requests_remaining;
468 }
469 if info.tokens_remaining.is_none() {
470 info.tokens_remaining = header_info.tokens_remaining;
471 }
472 }
473 return SgrError::RateLimit { status, info };
474 }
475 SgrError::Api { status, body }
476 }
477
478 pub fn is_rate_limit(&self) -> bool {
480 matches!(self, SgrError::RateLimit { .. })
481 }
482
483 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
485 match self {
486 SgrError::RateLimit { info, .. } => Some(info),
487 _ => None,
488 }
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn parse_codex_rate_limit_error() {
498 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
499 let err = SgrError::from_api_response(429, body.to_string());
500 assert!(err.is_rate_limit());
501 let info = err.rate_limit_info().unwrap();
502 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
503 assert_eq!(info.retry_after_secs, Some(442787));
504 assert_eq!(info.resets_at, Some(1773534007));
505 assert_eq!(info.reset_display(), "5d 2h");
506 }
507
508 #[test]
509 fn parse_openai_rate_limit_error() {
510 let body =
511 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
512 let err = SgrError::from_api_response(429, body.to_string());
513 assert!(err.is_rate_limit());
514 let info = err.rate_limit_info().unwrap();
515 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
516 }
517
518 #[test]
519 fn non_rate_limit_stays_api_error() {
520 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
521 let err = SgrError::from_api_response(400, body.to_string());
522 assert!(!err.is_rate_limit());
523 assert!(matches!(err, SgrError::Api { status: 400, .. }));
524 }
525
526 #[test]
527 fn status_line_with_all_fields() {
528 let info = RateLimitInfo {
529 requests_remaining: Some(5),
530 tokens_remaining: Some(10000),
531 retry_after_secs: Some(3600),
532 resets_at: None,
533 error_type: None,
534 message: None,
535 };
536 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
537 }
538
539 #[test]
540 fn status_line_fallback_to_message() {
541 let info = RateLimitInfo {
542 requests_remaining: None,
543 tokens_remaining: None,
544 retry_after_secs: None,
545 resets_at: None,
546 error_type: None,
547 message: Some("custom message".into()),
548 };
549 assert_eq!(info.status_line(), "custom message");
550 }
551
552 #[test]
553 fn reset_display_formats() {
554 let make = |secs| RateLimitInfo {
555 requests_remaining: None,
556 tokens_remaining: None,
557 retry_after_secs: Some(secs),
558 resets_at: None,
559 error_type: None,
560 message: None,
561 };
562 assert_eq!(make(90).reset_display(), "1m");
563 assert_eq!(make(3661).reset_display(), "1h 1m");
564 assert_eq!(make(90000).reset_display(), "1d 1h");
565 }
566}