Skip to main content

simple_groq_rs/
lib.rs

1use reqwest::{Client, header};
2use serde::{Deserialize, Serialize};
3use std::env;
4use thiserror::Error;
5
6/// Custom error type for the crate
7#[derive(Error, Debug)]
8pub enum GroqError {
9    #[error("Reqwest error: {0}")]
10    Reqwest(#[from] reqwest::Error),
11
12    #[error("API error: {code} - {message}")]
13    Api { code: String, message: String },
14
15    #[error("Environment error: {0}")]
16    Env(String),
17
18    #[error("No response choices returned")]
19    NoChoices,
20}
21
22/// A single message in the conversation
23#[derive(Serialize, Deserialize, Clone, Debug)]
24pub struct Message {
25    pub role: String,
26    pub content: String,
27}
28
29impl Message {
30    pub fn new(role: &str, content: &str) -> Self {
31        Self {
32            role: role.to_string(),
33            content: content.to_string(),
34        }
35    }
36
37    pub fn system(content: &str) -> Self {
38        Self::new("system", content)
39    }
40
41    pub fn user(content: &str) -> Self {
42        Self::new("user", content)
43    }
44
45    pub fn assistant(content: &str) -> Self {
46        Self::new("assistant", content)
47    }
48}
49
50/// Request body for chat completions
51#[derive(Serialize)]
52struct ChatRequest {
53    model: String,
54    messages: Vec<Message>,
55    temperature: Option<f32>,
56    max_tokens: Option<u32>,
57    stream: Option<bool>,
58}
59
60/// Part of the response
61#[derive(Deserialize)]
62pub struct Choice {
63    pub message: Message,
64    pub finish_reason: Option<String>,
65}
66
67/// Main response from the API
68#[derive(Deserialize)]
69pub struct ChatResponse {
70    pub choices: Vec<Choice>,
71    // You can add usage, id, created, etc. later
72}
73
74/// The main client
75#[derive(Clone)]
76pub struct GroqClient {
77    client: Client,
78    api_key: String,
79    base_url: String,
80}
81
82impl GroqClient {
83    /// Create a new client with an API key
84    pub fn new(api_key: impl Into<String>) -> Self {
85        Self {
86            client: Client::new(),
87            api_key: api_key.into(),
88            base_url: "https://api.groq.com/openai/v1".to_string(),
89        }
90    }
91
92    /// Create from Groq_API_KEY environment variable
93    pub fn from_env() -> Result<Self, GroqError> {
94        let api_key = env::var("GROQ_API_KEY")
95            .map_err(|_| GroqError::Env("GROQ_API_KEY not set".to_string()))?;
96        Ok(Self::new(api_key))
97    }
98
99    /// Send a chat completion request
100    pub async fn chat_completion(
101        &self,
102        model: &str,
103        messages: Vec<Message>,
104        temperature: Option<f32>,
105        max_tokens: Option<u32>,
106    ) -> Result<String, GroqError> {
107        let request = ChatRequest {
108            model: model.to_string(),
109            messages,
110            temperature,
111            max_tokens,
112            stream: Some(false),
113        };
114
115        let response = self
116            .client
117            .post(format!("{}/chat/completions", self.base_url))
118            .header(header::AUTHORIZATION, format!("Bearer {}", self.api_key))
119            .header(header::CONTENT_TYPE, "application/json")
120            .json(&request)
121            .send()
122            .await?;
123
124        // Better error handling for API errors
125        if !response.status().is_success() {
126            let status = response.status();
127            let text = response.text().await.unwrap_or_default();
128            return Err(GroqError::Api {
129                code: status.to_string(),
130                message: text,
131            });
132        }
133
134        let chat_response: ChatResponse = response.json().await?;
135
136        let content = chat_response
137            .choices
138            .into_iter()
139            .next()
140            .ok_or(GroqError::NoChoices)?
141            .message
142            .content;
143
144        Ok(content)
145    }
146}