1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct Message {
6 pub role: Role,
7 pub content: String,
8 #[serde(skip_serializing_if = "Option::is_none")]
10 pub tool_call_id: Option<String>,
11 #[serde(default, skip_serializing_if = "Vec::is_empty")]
14 pub tool_calls: Vec<ToolCall>,
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum Role {
20 System,
21 User,
22 Assistant,
23 Tool,
24}
25
26impl Message {
27 pub fn system(content: impl Into<String>) -> Self {
28 Self {
29 role: Role::System,
30 content: content.into(),
31 tool_call_id: None,
32 tool_calls: vec![],
33 }
34 }
35 pub fn user(content: impl Into<String>) -> Self {
36 Self {
37 role: Role::User,
38 content: content.into(),
39 tool_call_id: None,
40 tool_calls: vec![],
41 }
42 }
43 pub fn assistant(content: impl Into<String>) -> Self {
44 Self {
45 role: Role::Assistant,
46 content: content.into(),
47 tool_call_id: None,
48 tool_calls: vec![],
49 }
50 }
51 pub fn assistant_with_tool_calls(
53 content: impl Into<String>,
54 tool_calls: Vec<ToolCall>,
55 ) -> Self {
56 Self {
57 role: Role::Assistant,
58 content: content.into(),
59 tool_call_id: None,
60 tool_calls,
61 }
62 }
63 pub fn tool(call_id: impl Into<String>, content: impl Into<String>) -> Self {
64 Self {
65 role: Role::Tool,
66 content: content.into(),
67 tool_call_id: Some(call_id.into()),
68 tool_calls: vec![],
69 }
70 }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolCall {
76 pub id: String,
78 pub name: String,
80 pub arguments: serde_json::Value,
82}
83
84#[derive(Debug, Clone)]
86pub struct SgrResponse<T> {
87 pub output: Option<T>,
90 pub tool_calls: Vec<ToolCall>,
92 pub raw_text: String,
94 pub usage: Option<Usage>,
96 pub rate_limit: Option<RateLimitInfo>,
98}
99
100#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub struct Usage {
102 pub prompt_tokens: u32,
103 pub completion_tokens: u32,
104 pub total_tokens: u32,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct RateLimitInfo {
110 pub requests_remaining: Option<u32>,
112 pub tokens_remaining: Option<u32>,
114 pub retry_after_secs: Option<u64>,
116 pub resets_at: Option<u64>,
118 pub error_type: Option<String>,
120 pub message: Option<String>,
122}
123
124impl RateLimitInfo {
125 pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Option<Self> {
127 let get_u32 =
128 |name: &str| -> Option<u32> { headers.get(name)?.to_str().ok()?.parse().ok() };
129 let get_u64 =
130 |name: &str| -> Option<u64> { headers.get(name)?.to_str().ok()?.parse().ok() };
131
132 let requests_remaining = get_u32("x-ratelimit-remaining-requests");
133 let tokens_remaining = get_u32("x-ratelimit-remaining-tokens");
134 let retry_after_secs =
135 get_u64("retry-after").or_else(|| get_u64("x-ratelimit-reset-requests"));
136 let resets_at = get_u64("x-ratelimit-reset-tokens");
137
138 if requests_remaining.is_some() || tokens_remaining.is_some() || retry_after_secs.is_some()
139 {
140 Some(Self {
141 requests_remaining,
142 tokens_remaining,
143 retry_after_secs,
144 resets_at,
145 error_type: None,
146 message: None,
147 })
148 } else {
149 None
150 }
151 }
152
153 pub fn from_error_body(body: &str) -> Option<Self> {
155 let json: serde_json::Value = serde_json::from_str(body).ok()?;
156 let err = json.get("error")?;
157
158 let error_type = err.get("type").and_then(|v| v.as_str()).map(String::from);
159 let message = err
160 .get("message")
161 .and_then(|v| v.as_str())
162 .map(String::from);
163 let resets_at = err.get("resets_at").and_then(|v| v.as_u64());
164 let retry_after_secs = err.get("resets_in_seconds").and_then(|v| v.as_u64());
165
166 Some(Self {
167 requests_remaining: None,
168 tokens_remaining: None,
169 retry_after_secs,
170 resets_at,
171 error_type,
172 message,
173 })
174 }
175
176 pub fn reset_display(&self) -> String {
178 if let Some(secs) = self.retry_after_secs {
179 let hours = secs / 3600;
180 let mins = (secs % 3600) / 60;
181 if hours >= 24 {
182 format!("{}d {}h", hours / 24, hours % 24)
183 } else if hours > 0 {
184 format!("{}h {}m", hours, mins)
185 } else {
186 format!("{}m", mins)
187 }
188 } else {
189 "unknown".into()
190 }
191 }
192
193 pub fn status_line(&self) -> String {
195 let mut parts = Vec::new();
196 if let Some(r) = self.requests_remaining {
197 parts.push(format!("req:{}", r));
198 }
199 if let Some(t) = self.tokens_remaining {
200 parts.push(format!("tok:{}", t));
201 }
202 if self.retry_after_secs.is_some() {
203 parts.push(format!("reset:{}", self.reset_display()));
204 }
205 if parts.is_empty() {
206 self.message
207 .clone()
208 .unwrap_or_else(|| "rate limited".into())
209 } else {
210 parts.join(" | ")
211 }
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct LlmConfig {
231 pub model: String,
232 #[serde(default, skip_serializing_if = "Option::is_none")]
233 pub api_key: Option<String>,
234 #[serde(default, skip_serializing_if = "Option::is_none")]
235 pub base_url: Option<String>,
236 #[serde(default = "default_temperature")]
237 pub temp: f64,
238 #[serde(default, skip_serializing_if = "Option::is_none")]
239 pub max_tokens: Option<u32>,
240}
241
242fn default_temperature() -> f64 {
243 0.7
244}
245
246impl LlmConfig {
247 pub fn auto(model: impl Into<String>) -> Self {
249 Self {
250 model: model.into(),
251 api_key: None,
252 base_url: None,
253 temp: default_temperature(),
254 max_tokens: None,
255 }
256 }
257
258 pub fn with_key(api_key: impl Into<String>, model: impl Into<String>) -> Self {
260 Self {
261 model: model.into(),
262 api_key: Some(api_key.into()),
263 base_url: None,
264 temp: default_temperature(),
265 max_tokens: None,
266 }
267 }
268
269 pub fn endpoint(
271 api_key: impl Into<String>,
272 base_url: impl Into<String>,
273 model: impl Into<String>,
274 ) -> Self {
275 Self {
276 model: model.into(),
277 api_key: Some(api_key.into()),
278 base_url: Some(base_url.into()),
279 temp: default_temperature(),
280 max_tokens: None,
281 }
282 }
283
284 pub fn temperature(mut self, t: f64) -> Self {
286 self.temp = t;
287 self
288 }
289
290 pub fn max_tokens(mut self, m: u32) -> Self {
292 self.max_tokens = Some(m);
293 self
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct ProviderConfig {
300 pub api_key: String,
301 pub model: String,
302 pub base_url: Option<String>,
303 pub project_id: Option<String>,
304 pub location: Option<String>,
305 pub temperature: f32,
306 pub max_tokens: Option<u32>,
307}
308
309impl ProviderConfig {
310 pub fn gemini(api_key: impl Into<String>, model: impl Into<String>) -> Self {
311 Self {
312 api_key: api_key.into(),
313 model: model.into(),
314 base_url: None,
315 project_id: None,
316 location: None,
317 temperature: 0.3,
318 max_tokens: None,
319 }
320 }
321
322 pub fn openai(api_key: impl Into<String>, model: impl Into<String>) -> Self {
323 Self {
324 api_key: api_key.into(),
325 model: model.into(),
326 base_url: None,
327 project_id: None,
328 location: None,
329 temperature: 0.3,
330 max_tokens: None,
331 }
332 }
333
334 pub fn openrouter(api_key: impl Into<String>, model: impl Into<String>) -> Self {
335 Self {
336 api_key: api_key.into(),
337 model: model.into(),
338 base_url: Some("https://openrouter.ai/api/v1".into()),
339 project_id: None,
340 location: None,
341 temperature: 0.3,
342 max_tokens: None,
343 }
344 }
345
346 pub fn vertex(
347 access_token: impl Into<String>,
348 project_id: impl Into<String>,
349 model: impl Into<String>,
350 ) -> Self {
351 Self {
352 api_key: access_token.into(),
353 model: model.into(),
354 base_url: None,
355 project_id: Some(project_id.into()),
356 location: Some("global".to_string()),
357 temperature: 0.3,
358 max_tokens: None,
359 }
360 }
361
362 pub fn ollama(model: impl Into<String>) -> Self {
363 Self {
364 api_key: String::new(),
365 model: model.into(),
366 base_url: Some("http://localhost:11434/v1".into()),
367 project_id: None,
368 location: None,
369 temperature: 0.3,
370 max_tokens: None,
371 }
372 }
373}
374
375#[derive(Debug, thiserror::Error)]
377pub enum SgrError {
378 #[error("HTTP error: {0}")]
379 Http(#[from] reqwest::Error),
380 #[error("API error {status}: {body}")]
381 Api { status: u16, body: String },
382 #[error("Rate limit: {}", info.status_line())]
383 RateLimit { status: u16, info: RateLimitInfo },
384 #[error("JSON parse error: {0}")]
385 Json(#[from] serde_json::Error),
386 #[error("Schema error: {0}")]
387 Schema(String),
388 #[error("No content in response")]
389 EmptyResponse,
390}
391
392impl SgrError {
393 pub fn from_api_response(status: u16, body: String) -> Self {
395 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
396 if let Some(mut info) = RateLimitInfo::from_error_body(&body) {
397 if info.message.is_none() {
398 info.message = Some(body.chars().take(200).collect());
399 }
400 return SgrError::RateLimit { status, info };
401 }
402 }
403 SgrError::Api { status, body }
404 }
405
406 pub fn from_response_parts(
408 status: u16,
409 body: String,
410 headers: &reqwest::header::HeaderMap,
411 ) -> Self {
412 if status == 429 || body.contains("usage_limit_reached") || body.contains("rate_limit") {
413 let mut info = RateLimitInfo::from_error_body(&body)
414 .or_else(|| RateLimitInfo::from_headers(headers))
415 .unwrap_or(RateLimitInfo {
416 requests_remaining: None,
417 tokens_remaining: None,
418 retry_after_secs: None,
419 resets_at: None,
420 error_type: Some("rate_limit".into()),
421 message: Some(body.chars().take(200).collect()),
422 });
423 if let Some(header_info) = RateLimitInfo::from_headers(headers) {
425 if info.requests_remaining.is_none() {
426 info.requests_remaining = header_info.requests_remaining;
427 }
428 if info.tokens_remaining.is_none() {
429 info.tokens_remaining = header_info.tokens_remaining;
430 }
431 }
432 return SgrError::RateLimit { status, info };
433 }
434 SgrError::Api { status, body }
435 }
436
437 pub fn is_rate_limit(&self) -> bool {
439 matches!(self, SgrError::RateLimit { .. })
440 }
441
442 pub fn rate_limit_info(&self) -> Option<&RateLimitInfo> {
444 match self {
445 SgrError::RateLimit { info, .. } => Some(info),
446 _ => None,
447 }
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn parse_codex_rate_limit_error() {
457 let body = r#"{"error":{"type":"usage_limit_reached","message":"The usage limit has been reached","plan_type":"plus","resets_at":1773534007,"resets_in_seconds":442787}}"#;
458 let err = SgrError::from_api_response(429, body.to_string());
459 assert!(err.is_rate_limit());
460 let info = err.rate_limit_info().unwrap();
461 assert_eq!(info.error_type.as_deref(), Some("usage_limit_reached"));
462 assert_eq!(info.retry_after_secs, Some(442787));
463 assert_eq!(info.resets_at, Some(1773534007));
464 assert_eq!(info.reset_display(), "5d 2h");
465 }
466
467 #[test]
468 fn parse_openai_rate_limit_error() {
469 let body =
470 r#"{"error":{"type":"rate_limit_exceeded","message":"Rate limit reached for gpt-4"}}"#;
471 let err = SgrError::from_api_response(429, body.to_string());
472 assert!(err.is_rate_limit());
473 let info = err.rate_limit_info().unwrap();
474 assert_eq!(info.error_type.as_deref(), Some("rate_limit_exceeded"));
475 }
476
477 #[test]
478 fn non_rate_limit_stays_api_error() {
479 let body = r#"{"error":{"type":"invalid_request","message":"Bad request"}}"#;
480 let err = SgrError::from_api_response(400, body.to_string());
481 assert!(!err.is_rate_limit());
482 assert!(matches!(err, SgrError::Api { status: 400, .. }));
483 }
484
485 #[test]
486 fn status_line_with_all_fields() {
487 let info = RateLimitInfo {
488 requests_remaining: Some(5),
489 tokens_remaining: Some(10000),
490 retry_after_secs: Some(3600),
491 resets_at: None,
492 error_type: None,
493 message: None,
494 };
495 assert_eq!(info.status_line(), "req:5 | tok:10000 | reset:1h 0m");
496 }
497
498 #[test]
499 fn status_line_fallback_to_message() {
500 let info = RateLimitInfo {
501 requests_remaining: None,
502 tokens_remaining: None,
503 retry_after_secs: None,
504 resets_at: None,
505 error_type: None,
506 message: Some("custom message".into()),
507 };
508 assert_eq!(info.status_line(), "custom message");
509 }
510
511 #[test]
512 fn reset_display_formats() {
513 let make = |secs| RateLimitInfo {
514 requests_remaining: None,
515 tokens_remaining: None,
516 retry_after_secs: Some(secs),
517 resets_at: None,
518 error_type: None,
519 message: None,
520 };
521 assert_eq!(make(90).reset_display(), "1m");
522 assert_eq!(make(3661).reset_display(), "1h 1m");
523 assert_eq!(make(90000).reset_display(), "1d 1h");
524 }
525}