1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62pub struct LLMClient {
63 config: Config,
64 api_key: String,
65 client: reqwest::Client,
66}
67
68#[derive(Debug, Clone)]
70pub struct ModelInfo {
71 pub tier: &'static str, pub provider: String,
73 pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
79 }
80}
81
82impl LLMClient {
83 pub fn new() -> Result<Self> {
84 let storage = Storage::new(None);
85 let config = storage.load_config()?;
86
87 let api_key = if config.requires_api_key() {
88 env::var(config.api_key_env_var()).with_context(|| {
89 format!("{} environment variable not set", config.api_key_env_var())
90 })?
91 } else {
92 String::new() };
94
95 Ok(LLMClient {
96 config,
97 api_key,
98 client: reqwest::Client::new(),
99 })
100 }
101
102 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
103 let storage = Storage::new(Some(project_root));
104 let config = storage.load_config()?;
105
106 let api_key = if config.requires_api_key() {
107 env::var(config.api_key_env_var()).with_context(|| {
108 format!("{} environment variable not set", config.api_key_env_var())
109 })?
110 } else {
111 String::new() };
113
114 Ok(LLMClient {
115 config,
116 api_key,
117 client: reqwest::Client::new(),
118 })
119 }
120
121 pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
123 ModelInfo {
124 tier: "smart",
125 provider: self.config.smart_provider().to_string(),
126 model: model_override
127 .unwrap_or(self.config.smart_model())
128 .to_string(),
129 }
130 }
131
132 pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
134 ModelInfo {
135 tier: "fast",
136 provider: self.config.fast_provider().to_string(),
137 model: model_override
138 .unwrap_or(self.config.fast_model())
139 .to_string(),
140 }
141 }
142
143 pub async fn complete(&self, prompt: &str) -> Result<String> {
144 self.complete_with_model(prompt, None, None).await
145 }
146
147 pub async fn complete_smart(
150 &self,
151 prompt: &str,
152 model_override: Option<&str>,
153 ) -> Result<String> {
154 let model = model_override.unwrap_or(self.config.smart_model());
155 let provider = self.config.smart_provider();
156 self.complete_with_model(prompt, Some(model), Some(provider))
157 .await
158 }
159
160 pub async fn complete_fast(
163 &self,
164 prompt: &str,
165 model_override: Option<&str>,
166 ) -> Result<String> {
167 let model = model_override.unwrap_or(self.config.fast_model());
168 let provider = self.config.fast_provider();
169 self.complete_with_model(prompt, Some(model), Some(provider))
170 .await
171 }
172
173 pub async fn complete_with_model(
174 &self,
175 prompt: &str,
176 model_override: Option<&str>,
177 provider_override: Option<&str>,
178 ) -> Result<String> {
179 let provider = provider_override.unwrap_or(&self.config.llm.provider);
180 match provider.as_ref() {
181 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
182 "codex" => self.complete_codex_cli(prompt, model_override).await,
183 "anthropic" => {
184 self.complete_anthropic_with_model(prompt, model_override)
185 .await
186 }
187 "xai" | "openai" | "openrouter" => {
188 self.complete_openai_compatible_with_model(prompt, model_override)
189 .await
190 }
191 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
192 }
193 }
194
195 async fn complete_anthropic_with_model(
196 &self,
197 prompt: &str,
198 model_override: Option<&str>,
199 ) -> Result<String> {
200 let model = model_override.unwrap_or(&self.config.llm.model);
201 let request = AnthropicRequest {
202 model: model.to_string(),
203 max_tokens: self.config.llm.max_tokens,
204 messages: vec![AnthropicMessage {
205 role: "user".to_string(),
206 content: prompt.to_string(),
207 }],
208 };
209
210 let response = self
211 .client
212 .post(self.config.api_endpoint())
213 .header("x-api-key", &self.api_key)
214 .header("anthropic-version", "2023-06-01")
215 .header("content-type", "application/json")
216 .json(&request)
217 .send()
218 .await
219 .context("Failed to send request to Anthropic API")?;
220
221 if !response.status().is_success() {
222 let status = response.status();
223 let error_text = response.text().await.unwrap_or_default();
224 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
225 }
226
227 let api_response: AnthropicResponse = response
228 .json()
229 .await
230 .context("Failed to parse Anthropic API response")?;
231
232 Ok(api_response
233 .content
234 .first()
235 .map(|c| c.text.clone())
236 .unwrap_or_default())
237 }
238
239 async fn complete_openai_compatible_with_model(
240 &self,
241 prompt: &str,
242 model_override: Option<&str>,
243 ) -> Result<String> {
244 let model = model_override.unwrap_or(&self.config.llm.model);
245 let request = OpenAIRequest {
246 model: model.to_string(),
247 max_tokens: self.config.llm.max_tokens,
248 messages: vec![OpenAIMessage {
249 role: "user".to_string(),
250 content: prompt.to_string(),
251 }],
252 };
253
254 let mut request_builder = self
255 .client
256 .post(self.config.api_endpoint())
257 .header("authorization", format!("Bearer {}", self.api_key))
258 .header("content-type", "application/json");
259
260 if self.config.llm.provider == "openrouter" {
262 request_builder = request_builder
263 .header("HTTP-Referer", "https://github.com/scud-cli")
264 .header("X-Title", "SCUD Task Master");
265 }
266
267 let response = request_builder
268 .json(&request)
269 .send()
270 .await
271 .with_context(|| {
272 format!("Failed to send request to {} API", self.config.llm.provider)
273 })?;
274
275 if !response.status().is_success() {
276 let status = response.status();
277 let error_text = response.text().await.unwrap_or_default();
278 anyhow::bail!(
279 "{} API error ({}): {}",
280 self.config.llm.provider,
281 status,
282 error_text
283 );
284 }
285
286 let api_response: OpenAIResponse = response.json().await.with_context(|| {
287 format!("Failed to parse {} API response", self.config.llm.provider)
288 })?;
289
290 Ok(api_response
291 .choices
292 .first()
293 .map(|c| c.message.content.clone())
294 .unwrap_or_default())
295 }
296
297 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
298 where
299 T: serde::de::DeserializeOwned,
300 {
301 self.complete_json_with_model(prompt, None).await
302 }
303
304 pub async fn complete_json_smart<T>(
306 &self,
307 prompt: &str,
308 model_override: Option<&str>,
309 ) -> Result<T>
310 where
311 T: serde::de::DeserializeOwned,
312 {
313 let response_text = self.complete_smart(prompt, model_override).await?;
314 Self::parse_json_response(&response_text)
315 }
316
317 pub async fn complete_json_fast<T>(
319 &self,
320 prompt: &str,
321 model_override: Option<&str>,
322 ) -> Result<T>
323 where
324 T: serde::de::DeserializeOwned,
325 {
326 let response_text = self.complete_fast(prompt, model_override).await?;
327 Self::parse_json_response(&response_text)
328 }
329
330 pub async fn complete_json_with_model<T>(
331 &self,
332 prompt: &str,
333 model_override: Option<&str>,
334 ) -> Result<T>
335 where
336 T: serde::de::DeserializeOwned,
337 {
338 let response_text = self
339 .complete_with_model(prompt, model_override, None)
340 .await?;
341 Self::parse_json_response(&response_text)
342 }
343
344 fn parse_json_response<T>(response_text: &str) -> Result<T>
345 where
346 T: serde::de::DeserializeOwned,
347 {
348 let json_str = Self::extract_json(response_text);
350
351 serde_json::from_str(json_str).with_context(|| {
352 let preview = if json_str.len() > 500 {
354 format!("{}...", &json_str[..500])
355 } else {
356 json_str.to_string()
357 };
358 format!(
359 "Failed to parse JSON from LLM response. Response preview:\n{}",
360 preview
361 )
362 })
363 }
364
365 fn extract_json(response: &str) -> &str {
367 if let Some(start) = response.find("```json") {
369 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
371 return response[content_start..content_start + end].trim();
372 }
373 }
374
375 if let Some(start) = response.find("```") {
377 let content_start = start + 3;
378 let content_start = response[content_start..]
380 .find('\n')
381 .map(|i| content_start + i + 1)
382 .unwrap_or(content_start);
383 if let Some(end) = response[content_start..].find("```") {
384 return response[content_start..content_start + end].trim();
385 }
386 }
387
388 if let Some(start) = response.find('[') {
390 if let Some(end) = response.rfind(']') {
391 if end > start {
392 return &response[start..=end];
393 }
394 }
395 }
396
397 if let Some(start) = response.find('{') {
399 if let Some(end) = response.rfind('}') {
400 if end > start {
401 return &response[start..=end];
402 }
403 }
404 }
405
406 response.trim()
407 }
408
409 async fn complete_claude_cli(
410 &self,
411 prompt: &str,
412 model_override: Option<&str>,
413 ) -> Result<String> {
414 use std::process::Stdio;
415 use tokio::io::AsyncWriteExt;
416 use tokio::process::Command;
417
418 let model = model_override.unwrap_or(&self.config.llm.model);
419
420 let mut cmd = Command::new("claude");
422 cmd.arg("-p") .arg("--output-format")
424 .arg("json")
425 .arg("--model")
426 .arg(model)
427 .stdin(Stdio::piped())
428 .stdout(Stdio::piped())
429 .stderr(Stdio::piped());
430
431 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
433
434 if let Some(mut stdin) = child.stdin.take() {
436 stdin
437 .write_all(prompt.as_bytes())
438 .await
439 .context("Failed to write prompt to claude stdin")?;
440 drop(stdin); }
442
443 let output = child
445 .wait_with_output()
446 .await
447 .context("Failed to wait for claude command")?;
448
449 if !output.status.success() {
450 let stderr = String::from_utf8_lossy(&output.stderr);
451 anyhow::bail!("Claude CLI error: {}", stderr);
452 }
453
454 let stdout =
456 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
457
458 #[derive(Deserialize)]
459 struct ClaudeCliResponse {
460 result: String,
461 }
462
463 let response: ClaudeCliResponse =
464 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
465
466 Ok(response.result)
467 }
468
469 async fn complete_codex_cli(
470 &self,
471 prompt: &str,
472 model_override: Option<&str>,
473 ) -> Result<String> {
474 use std::process::Stdio;
475 use tokio::io::AsyncWriteExt;
476 use tokio::process::Command;
477
478 let model = model_override.unwrap_or(&self.config.llm.model);
479
480 let mut cmd = Command::new("codex");
483 cmd.arg("-p") .arg("--model")
485 .arg(model)
486 .arg("--output-format")
487 .arg("json")
488 .stdin(Stdio::piped())
489 .stdout(Stdio::piped())
490 .stderr(Stdio::piped());
491
492 let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
494
495 if let Some(mut stdin) = child.stdin.take() {
497 stdin
498 .write_all(prompt.as_bytes())
499 .await
500 .context("Failed to write prompt to codex stdin")?;
501 drop(stdin); }
503
504 let output = child
506 .wait_with_output()
507 .await
508 .context("Failed to wait for codex command")?;
509
510 if !output.status.success() {
511 let stderr = String::from_utf8_lossy(&output.stderr);
512 anyhow::bail!("Codex CLI error: {}", stderr);
513 }
514
515 let stdout =
517 String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
518
519 #[derive(Deserialize)]
521 struct CodexCliResponse {
522 result: String,
523 }
524
525 let response: CodexCliResponse =
526 serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
527
528 Ok(response.result)
529 }
530}