1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62pub struct LLMClient {
63 config: Config,
64 api_key: String,
65 client: reqwest::Client,
66}
67
68impl LLMClient {
69 pub fn new() -> Result<Self> {
70 let storage = Storage::new(None);
71 let config = storage.load_config()?;
72
73 let api_key = if config.requires_api_key() {
74 env::var(config.api_key_env_var()).with_context(|| {
75 format!("{} environment variable not set", config.api_key_env_var())
76 })?
77 } else {
78 String::new() };
80
81 Ok(LLMClient {
82 config,
83 api_key,
84 client: reqwest::Client::new(),
85 })
86 }
87
88 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
89 let storage = Storage::new(Some(project_root));
90 let config = storage.load_config()?;
91
92 let api_key = if config.requires_api_key() {
93 env::var(config.api_key_env_var()).with_context(|| {
94 format!("{} environment variable not set", config.api_key_env_var())
95 })?
96 } else {
97 String::new() };
99
100 Ok(LLMClient {
101 config,
102 api_key,
103 client: reqwest::Client::new(),
104 })
105 }
106
107 pub async fn complete(&self, prompt: &str) -> Result<String> {
108 self.complete_with_model(prompt, None).await
109 }
110
111 pub async fn complete_smart(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
114 let model = model_override.unwrap_or(self.config.smart_model());
115 self.complete_with_model(prompt, Some(model)).await
116 }
117
118 pub async fn complete_fast(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
121 let model = model_override.unwrap_or(self.config.fast_model());
122 self.complete_with_model(prompt, Some(model)).await
123 }
124
125 pub async fn complete_with_model(
126 &self,
127 prompt: &str,
128 model_override: Option<&str>,
129 ) -> Result<String> {
130 match self.config.llm.provider.as_str() {
131 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
132 "codex" => self.complete_codex_cli(prompt, model_override).await,
133 "anthropic" => {
134 self.complete_anthropic_with_model(prompt, model_override)
135 .await
136 }
137 "xai" | "openai" | "openrouter" => {
138 self.complete_openai_compatible_with_model(prompt, model_override)
139 .await
140 }
141 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
142 }
143 }
144
145 async fn complete_anthropic_with_model(
146 &self,
147 prompt: &str,
148 model_override: Option<&str>,
149 ) -> Result<String> {
150 let model = model_override.unwrap_or(&self.config.llm.model);
151 let request = AnthropicRequest {
152 model: model.to_string(),
153 max_tokens: self.config.llm.max_tokens,
154 messages: vec![AnthropicMessage {
155 role: "user".to_string(),
156 content: prompt.to_string(),
157 }],
158 };
159
160 let response = self
161 .client
162 .post(self.config.api_endpoint())
163 .header("x-api-key", &self.api_key)
164 .header("anthropic-version", "2023-06-01")
165 .header("content-type", "application/json")
166 .json(&request)
167 .send()
168 .await
169 .context("Failed to send request to Anthropic API")?;
170
171 if !response.status().is_success() {
172 let status = response.status();
173 let error_text = response.text().await.unwrap_or_default();
174 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
175 }
176
177 let api_response: AnthropicResponse = response
178 .json()
179 .await
180 .context("Failed to parse Anthropic API response")?;
181
182 Ok(api_response
183 .content
184 .first()
185 .map(|c| c.text.clone())
186 .unwrap_or_default())
187 }
188
189 async fn complete_openai_compatible_with_model(
190 &self,
191 prompt: &str,
192 model_override: Option<&str>,
193 ) -> Result<String> {
194 let model = model_override.unwrap_or(&self.config.llm.model);
195 let request = OpenAIRequest {
196 model: model.to_string(),
197 max_tokens: self.config.llm.max_tokens,
198 messages: vec![OpenAIMessage {
199 role: "user".to_string(),
200 content: prompt.to_string(),
201 }],
202 };
203
204 let mut request_builder = self
205 .client
206 .post(self.config.api_endpoint())
207 .header("authorization", format!("Bearer {}", self.api_key))
208 .header("content-type", "application/json");
209
210 if self.config.llm.provider == "openrouter" {
212 request_builder = request_builder
213 .header("HTTP-Referer", "https://github.com/scud-cli")
214 .header("X-Title", "SCUD Task Master");
215 }
216
217 let response = request_builder
218 .json(&request)
219 .send()
220 .await
221 .with_context(|| {
222 format!("Failed to send request to {} API", self.config.llm.provider)
223 })?;
224
225 if !response.status().is_success() {
226 let status = response.status();
227 let error_text = response.text().await.unwrap_or_default();
228 anyhow::bail!(
229 "{} API error ({}): {}",
230 self.config.llm.provider,
231 status,
232 error_text
233 );
234 }
235
236 let api_response: OpenAIResponse = response.json().await.with_context(|| {
237 format!("Failed to parse {} API response", self.config.llm.provider)
238 })?;
239
240 Ok(api_response
241 .choices
242 .first()
243 .map(|c| c.message.content.clone())
244 .unwrap_or_default())
245 }
246
247 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
248 where
249 T: serde::de::DeserializeOwned,
250 {
251 self.complete_json_with_model(prompt, None).await
252 }
253
254 pub async fn complete_json_smart<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
256 where
257 T: serde::de::DeserializeOwned,
258 {
259 let response_text = self.complete_smart(prompt, model_override).await?;
260 Self::parse_json_response(&response_text)
261 }
262
263 pub async fn complete_json_fast<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
265 where
266 T: serde::de::DeserializeOwned,
267 {
268 let response_text = self.complete_fast(prompt, model_override).await?;
269 Self::parse_json_response(&response_text)
270 }
271
272 pub async fn complete_json_with_model<T>(
273 &self,
274 prompt: &str,
275 model_override: Option<&str>,
276 ) -> Result<T>
277 where
278 T: serde::de::DeserializeOwned,
279 {
280 let response_text = self.complete_with_model(prompt, model_override).await?;
281 Self::parse_json_response(&response_text)
282 }
283
284 fn parse_json_response<T>(response_text: &str) -> Result<T>
285 where
286 T: serde::de::DeserializeOwned,
287 {
288 let json_str = Self::extract_json(response_text);
290
291 serde_json::from_str(json_str).with_context(|| {
292 let preview = if json_str.len() > 500 {
294 format!("{}...", &json_str[..500])
295 } else {
296 json_str.to_string()
297 };
298 format!(
299 "Failed to parse JSON from LLM response. Response preview:\n{}",
300 preview
301 )
302 })
303 }
304
305 fn extract_json(response: &str) -> &str {
307 if let Some(start) = response.find("```json") {
309 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
311 return response[content_start..content_start + end].trim();
312 }
313 }
314
315 if let Some(start) = response.find("```") {
317 let content_start = start + 3;
318 let content_start = response[content_start..]
320 .find('\n')
321 .map(|i| content_start + i + 1)
322 .unwrap_or(content_start);
323 if let Some(end) = response[content_start..].find("```") {
324 return response[content_start..content_start + end].trim();
325 }
326 }
327
328 if let Some(start) = response.find('[') {
330 if let Some(end) = response.rfind(']') {
331 if end > start {
332 return &response[start..=end];
333 }
334 }
335 }
336
337 if let Some(start) = response.find('{') {
339 if let Some(end) = response.rfind('}') {
340 if end > start {
341 return &response[start..=end];
342 }
343 }
344 }
345
346 response.trim()
347 }
348
349 async fn complete_claude_cli(
350 &self,
351 prompt: &str,
352 model_override: Option<&str>,
353 ) -> Result<String> {
354 use std::process::Stdio;
355 use tokio::io::AsyncWriteExt;
356 use tokio::process::Command;
357
358 let model = model_override.unwrap_or(&self.config.llm.model);
359
360 let mut cmd = Command::new("claude");
362 cmd.arg("-p") .arg("--output-format")
364 .arg("json")
365 .arg("--model")
366 .arg(model)
367 .stdin(Stdio::piped())
368 .stdout(Stdio::piped())
369 .stderr(Stdio::piped());
370
371 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
373
374 if let Some(mut stdin) = child.stdin.take() {
376 stdin
377 .write_all(prompt.as_bytes())
378 .await
379 .context("Failed to write prompt to claude stdin")?;
380 drop(stdin); }
382
383 let output = child
385 .wait_with_output()
386 .await
387 .context("Failed to wait for claude command")?;
388
389 if !output.status.success() {
390 let stderr = String::from_utf8_lossy(&output.stderr);
391 anyhow::bail!("Claude CLI error: {}", stderr);
392 }
393
394 let stdout =
396 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
397
398 #[derive(Deserialize)]
399 struct ClaudeCliResponse {
400 result: String,
401 }
402
403 let response: ClaudeCliResponse =
404 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
405
406 Ok(response.result)
407 }
408
409 async fn complete_codex_cli(
410 &self,
411 prompt: &str,
412 model_override: Option<&str>,
413 ) -> Result<String> {
414 use std::process::Stdio;
415 use tokio::io::AsyncWriteExt;
416 use tokio::process::Command;
417
418 let model = model_override.unwrap_or(&self.config.llm.model);
419
420 let mut cmd = Command::new("codex");
423 cmd.arg("-p") .arg("--model")
425 .arg(model)
426 .arg("--output-format")
427 .arg("json")
428 .stdin(Stdio::piped())
429 .stdout(Stdio::piped())
430 .stderr(Stdio::piped());
431
432 let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
434
435 if let Some(mut stdin) = child.stdin.take() {
437 stdin
438 .write_all(prompt.as_bytes())
439 .await
440 .context("Failed to write prompt to codex stdin")?;
441 drop(stdin); }
443
444 let output = child
446 .wait_with_output()
447 .await
448 .context("Failed to wait for codex command")?;
449
450 if !output.status.success() {
451 let stderr = String::from_utf8_lossy(&output.stderr);
452 anyhow::bail!("Codex CLI error: {}", stderr);
453 }
454
455 let stdout =
457 String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
458
459 #[derive(Deserialize)]
461 struct CodexCliResponse {
462 result: String,
463 }
464
465 let response: CodexCliResponse =
466 serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
467
468 Ok(response.result)
469 }
470}