1use crate::proxy::{LlmMessage, LlmProvider, LlmRequest, LlmResponse, LlmRole, LlmUsage};
8use anyhow::{Context, Result};
9use async_trait::async_trait;
10use reqwest::Client;
11use serde::{Deserialize, Serialize};
12
13pub struct GrokProvider {
14 client: Client,
15 api_key: String,
16 base_url: String,
17}
18
19impl GrokProvider {
20 pub fn new(api_key: String) -> Self {
21 Self {
22 client: Client::new(),
23 api_key,
24 base_url: "https://api.x.ai/v1".to_string(),
25 }
26 }
27
28 pub fn with_base_url(api_key: String, base_url: String) -> Self {
30 Self {
31 client: Client::new(),
32 api_key,
33 base_url,
34 }
35 }
36}
37
38impl Default for GrokProvider {
39 fn default() -> Self {
40 let api_key = std::env::var("XAI_API_KEY")
41 .or_else(|_| std::env::var("GROK_API_KEY"))
42 .unwrap_or_default();
43 Self::new(api_key)
44 }
45}
46
47#[async_trait]
48impl LlmProvider for GrokProvider {
49 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
50 let url = format!("{}/chat/completions", self.base_url);
51
52 let model = if request.model.is_empty() || request.model == "default" {
54 "grok-beta".to_string()
55 } else {
56 request.model.clone()
57 };
58
59 let grok_request = GrokChatRequest {
60 model,
61 messages: request.messages.into_iter().map(Into::into).collect(),
62 temperature: request.temperature,
63 max_tokens: request.max_tokens,
64 stream: request.stream,
65 };
66
67 let response = self
68 .client
69 .post(&url)
70 .header("Authorization", format!("Bearer {}", self.api_key))
71 .header("Content-Type", "application/json")
72 .json(&grok_request)
73 .send()
74 .await
75 .context("Failed to send request to Grok API")?;
76
77 if !response.status().is_success() {
78 let error_text = response.text().await?;
79 return Err(anyhow::anyhow!("Grok API error: {}", error_text));
80 }
81
82 let grok_response: GrokChatResponse = response.json().await?;
83
84 let content = grok_response
85 .choices
86 .first()
87 .map(|c| c.message.content.clone())
88 .unwrap_or_default();
89
90 Ok(LlmResponse {
91 content,
92 model: grok_response.model,
93 usage: grok_response.usage.map(Into::into),
94 })
95 }
96
97 fn name(&self) -> &'static str {
98 "Grok"
99 }
100}
101
102#[derive(Debug, Serialize)]
103struct GrokChatRequest {
104 model: String,
105 messages: Vec<GrokMessage>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 temperature: Option<f32>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 max_tokens: Option<usize>,
110 stream: bool,
111}
112
113#[derive(Debug, Serialize, Deserialize)]
114struct GrokMessage {
115 role: String,
116 content: String,
117}
118
119impl From<LlmMessage> for GrokMessage {
120 fn from(msg: LlmMessage) -> Self {
121 Self {
122 role: match msg.role {
123 LlmRole::System => "system".to_string(),
124 LlmRole::User => "user".to_string(),
125 LlmRole::Assistant => "assistant".to_string(),
126 },
127 content: msg.content,
128 }
129 }
130}
131
132#[derive(Debug, Deserialize)]
133struct GrokChatResponse {
134 model: String,
135 choices: Vec<GrokChoice>,
136 usage: Option<GrokUsage>,
137}
138
139#[derive(Debug, Deserialize)]
140struct GrokChoice {
141 message: GrokMessage,
142}
143
144#[derive(Debug, Deserialize)]
145struct GrokUsage {
146 prompt_tokens: usize,
147 completion_tokens: usize,
148 total_tokens: usize,
149}
150
151impl From<GrokUsage> for LlmUsage {
152 fn from(usage: GrokUsage) -> Self {
153 Self {
154 prompt_tokens: usage.prompt_tokens,
155 completion_tokens: usage.completion_tokens,
156 total_tokens: usage.total_tokens,
157 }
158 }
159}