1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62pub struct LLMClient {
63 config: Config,
64 api_key: String,
65 client: reqwest::Client,
66}
67
68impl LLMClient {
69 pub fn new() -> Result<Self> {
70 let storage = Storage::new(None);
71 let config = storage.load_config()?;
72
73 let api_key = if config.requires_api_key() {
74 env::var(config.api_key_env_var()).with_context(|| {
75 format!("{} environment variable not set", config.api_key_env_var())
76 })?
77 } else {
78 String::new() };
80
81 Ok(LLMClient {
82 config,
83 api_key,
84 client: reqwest::Client::new(),
85 })
86 }
87
88 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
89 let storage = Storage::new(Some(project_root));
90 let config = storage.load_config()?;
91
92 let api_key = if config.requires_api_key() {
93 env::var(config.api_key_env_var()).with_context(|| {
94 format!("{} environment variable not set", config.api_key_env_var())
95 })?
96 } else {
97 String::new() };
99
100 Ok(LLMClient {
101 config,
102 api_key,
103 client: reqwest::Client::new(),
104 })
105 }
106
107 pub async fn complete(&self, prompt: &str) -> Result<String> {
108 self.complete_with_model(prompt, None).await
109 }
110
111 pub async fn complete_with_model(
112 &self,
113 prompt: &str,
114 model_override: Option<&str>,
115 ) -> Result<String> {
116 match self.config.llm.provider.as_str() {
117 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
118 "anthropic" => {
119 self.complete_anthropic_with_model(prompt, model_override)
120 .await
121 }
122 "xai" | "openai" | "openrouter" => {
123 self.complete_openai_compatible_with_model(prompt, model_override)
124 .await
125 }
126 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
127 }
128 }
129
130 async fn complete_anthropic_with_model(
131 &self,
132 prompt: &str,
133 model_override: Option<&str>,
134 ) -> Result<String> {
135 let model = model_override.unwrap_or(&self.config.llm.model);
136 let request = AnthropicRequest {
137 model: model.to_string(),
138 max_tokens: self.config.llm.max_tokens,
139 messages: vec![AnthropicMessage {
140 role: "user".to_string(),
141 content: prompt.to_string(),
142 }],
143 };
144
145 let response = self
146 .client
147 .post(self.config.api_endpoint())
148 .header("x-api-key", &self.api_key)
149 .header("anthropic-version", "2023-06-01")
150 .header("content-type", "application/json")
151 .json(&request)
152 .send()
153 .await
154 .context("Failed to send request to Anthropic API")?;
155
156 if !response.status().is_success() {
157 let status = response.status();
158 let error_text = response.text().await.unwrap_or_default();
159 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
160 }
161
162 let api_response: AnthropicResponse = response
163 .json()
164 .await
165 .context("Failed to parse Anthropic API response")?;
166
167 Ok(api_response
168 .content
169 .first()
170 .map(|c| c.text.clone())
171 .unwrap_or_default())
172 }
173
174 async fn complete_openai_compatible_with_model(
175 &self,
176 prompt: &str,
177 model_override: Option<&str>,
178 ) -> Result<String> {
179 let model = model_override.unwrap_or(&self.config.llm.model);
180 let request = OpenAIRequest {
181 model: model.to_string(),
182 max_tokens: self.config.llm.max_tokens,
183 messages: vec![OpenAIMessage {
184 role: "user".to_string(),
185 content: prompt.to_string(),
186 }],
187 };
188
189 let mut request_builder = self
190 .client
191 .post(self.config.api_endpoint())
192 .header("authorization", format!("Bearer {}", self.api_key))
193 .header("content-type", "application/json");
194
195 if self.config.llm.provider == "openrouter" {
197 request_builder = request_builder
198 .header("HTTP-Referer", "https://github.com/scud-cli")
199 .header("X-Title", "SCUD Task Master");
200 }
201
202 let response = request_builder
203 .json(&request)
204 .send()
205 .await
206 .with_context(|| {
207 format!("Failed to send request to {} API", self.config.llm.provider)
208 })?;
209
210 if !response.status().is_success() {
211 let status = response.status();
212 let error_text = response.text().await.unwrap_or_default();
213 anyhow::bail!(
214 "{} API error ({}): {}",
215 self.config.llm.provider,
216 status,
217 error_text
218 );
219 }
220
221 let api_response: OpenAIResponse = response.json().await.with_context(|| {
222 format!("Failed to parse {} API response", self.config.llm.provider)
223 })?;
224
225 Ok(api_response
226 .choices
227 .first()
228 .map(|c| c.message.content.clone())
229 .unwrap_or_default())
230 }
231
232 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
233 where
234 T: serde::de::DeserializeOwned,
235 {
236 let response_text = self.complete(prompt).await?;
237
238 let json_str = Self::extract_json(&response_text);
240
241 serde_json::from_str(json_str).with_context(|| {
242 let preview = if json_str.len() > 500 {
244 format!("{}...", &json_str[..500])
245 } else {
246 json_str.to_string()
247 };
248 format!(
249 "Failed to parse JSON from LLM response. Response preview:\n{}",
250 preview
251 )
252 })
253 }
254
255 fn extract_json(response: &str) -> &str {
257 if let Some(start) = response.find("```json") {
259 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
261 return response[content_start..content_start + end].trim();
262 }
263 }
264
265 if let Some(start) = response.find("```") {
267 let content_start = start + 3;
268 let content_start = response[content_start..]
270 .find('\n')
271 .map(|i| content_start + i + 1)
272 .unwrap_or(content_start);
273 if let Some(end) = response[content_start..].find("```") {
274 return response[content_start..content_start + end].trim();
275 }
276 }
277
278 if let Some(start) = response.find('[') {
280 if let Some(end) = response.rfind(']') {
281 if end > start {
282 return &response[start..=end];
283 }
284 }
285 }
286
287 if let Some(start) = response.find('{') {
289 if let Some(end) = response.rfind('}') {
290 if end > start {
291 return &response[start..=end];
292 }
293 }
294 }
295
296 response.trim()
297 }
298
299 async fn complete_claude_cli(
300 &self,
301 prompt: &str,
302 model_override: Option<&str>,
303 ) -> Result<String> {
304 use std::process::Stdio;
305 use tokio::io::AsyncWriteExt;
306 use tokio::process::Command;
307
308 let model = model_override.unwrap_or(&self.config.llm.model);
309
310 let mut cmd = Command::new("claude");
312 cmd.arg("-p") .arg("--output-format")
314 .arg("json")
315 .arg("--model")
316 .arg(model)
317 .stdin(Stdio::piped())
318 .stdout(Stdio::piped())
319 .stderr(Stdio::piped());
320
321 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
323
324 if let Some(mut stdin) = child.stdin.take() {
326 stdin
327 .write_all(prompt.as_bytes())
328 .await
329 .context("Failed to write prompt to claude stdin")?;
330 drop(stdin); }
332
333 let output = child
335 .wait_with_output()
336 .await
337 .context("Failed to wait for claude command")?;
338
339 if !output.status.success() {
340 let stderr = String::from_utf8_lossy(&output.stderr);
341 anyhow::bail!("Claude CLI error: {}", stderr);
342 }
343
344 let stdout =
346 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
347
348 #[derive(Deserialize)]
349 struct ClaudeCliResponse {
350 result: String,
351 }
352
353 let response: ClaudeCliResponse =
354 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
355
356 Ok(response.result)
357 }
358}