1use reqwest::{Client, header};
2use serde::{Deserialize, Serialize};
3use std::env;
4use thiserror::Error;
5
6#[derive(Error, Debug)]
8pub enum GroqError {
9 #[error("Reqwest error: {0}")]
10 Reqwest(#[from] reqwest::Error),
11
12 #[error("API error: {code} - {message}")]
13 Api { code: String, message: String },
14
15 #[error("Environment error: {0}")]
16 Env(String),
17
18 #[error("No response choices returned")]
19 NoChoices,
20}
21
22#[derive(Serialize, Deserialize, Clone, Debug)]
24pub struct Message {
25 pub role: String,
26 pub content: String,
27}
28
29impl Message {
30 pub fn new(role: &str, content: &str) -> Self {
31 Self {
32 role: role.to_string(),
33 content: content.to_string(),
34 }
35 }
36
37 pub fn system(content: &str) -> Self {
38 Self::new("system", content)
39 }
40
41 pub fn user(content: &str) -> Self {
42 Self::new("user", content)
43 }
44
45 pub fn assistant(content: &str) -> Self {
46 Self::new("assistant", content)
47 }
48}
49
50#[derive(Serialize)]
52struct ChatRequest {
53 model: String,
54 messages: Vec<Message>,
55 temperature: Option<f32>,
56 max_tokens: Option<u32>,
57 stream: Option<bool>,
58}
59
60#[derive(Deserialize)]
62pub struct Choice {
63 pub message: Message,
64 pub finish_reason: Option<String>,
65}
66
67#[derive(Deserialize)]
69pub struct ChatResponse {
70 pub choices: Vec<Choice>,
71 }
73
74#[derive(Clone)]
76pub struct GroqClient {
77 client: Client,
78 api_key: String,
79 base_url: String,
80}
81
82impl GroqClient {
83 pub fn new(api_key: impl Into<String>) -> Self {
85 Self {
86 client: Client::new(),
87 api_key: api_key.into(),
88 base_url: "https://api.groq.com/openai/v1".to_string(),
89 }
90 }
91
92 pub fn from_env() -> Result<Self, GroqError> {
94 let api_key = env::var("GROQ_API_KEY")
95 .map_err(|_| GroqError::Env("GROQ_API_KEY not set".to_string()))?;
96 Ok(Self::new(api_key))
97 }
98
99 pub async fn chat_completion(
101 &self,
102 model: &str,
103 messages: Vec<Message>,
104 temperature: Option<f32>,
105 max_tokens: Option<u32>,
106 ) -> Result<String, GroqError> {
107 let request = ChatRequest {
108 model: model.to_string(),
109 messages,
110 temperature,
111 max_tokens,
112 stream: Some(false),
113 };
114
115 let response = self
116 .client
117 .post(format!("{}/chat/completions", self.base_url))
118 .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
119 .header(header::CONTENT_TYPE, "application/json")
120 .json(&request)
121 .send()
122 .await?;
123
124 if !response.status().is_success() {
126 let status = response.status();
127 let text = response.text().await.unwrap_or_default();
128 return Err(GroqError::Api {
129 code: status.to_string(),
130 message: text,
131 });
132 }
133
134 let chat_response: ChatResponse = response.json().await?;
135
136 let content = chat_response
137 .choices
138 .into_iter()
139 .next()
140 .ok_or(GroqError::NoChoices)?
141 .message
142 .content;
143
144 Ok(content)
145 }
146}