1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62pub struct LLMClient {
63 config: Config,
64 api_key: String,
65 client: reqwest::Client,
66}
67
68#[derive(Debug, Clone)]
70pub struct ModelInfo {
71 pub tier: &'static str, pub provider: String,
73 pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 let prefix = format!("{}/", self.provider);
80 if self.model.starts_with(&prefix) {
81 write!(f, "{} model: {}", self.tier, self.model)
82 } else {
83 write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84 }
85 }
86}
87
88impl LLMClient {
89 pub fn new() -> Result<Self> {
90 let storage = Storage::new(None);
91 let config = storage.load_config()?;
92
93 let api_key = if config.requires_api_key() {
94 env::var(config.api_key_env_var()).with_context(|| {
95 format!("{} environment variable not set", config.api_key_env_var())
96 })?
97 } else {
98 String::new() };
100
101 Ok(LLMClient {
102 config,
103 api_key,
104 client: reqwest::Client::new(),
105 })
106 }
107
108 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
109 let storage = Storage::new(Some(project_root));
110 let config = storage.load_config()?;
111
112 let api_key = if config.requires_api_key() {
113 env::var(config.api_key_env_var()).with_context(|| {
114 format!("{} environment variable not set", config.api_key_env_var())
115 })?
116 } else {
117 String::new() };
119
120 Ok(LLMClient {
121 config,
122 api_key,
123 client: reqwest::Client::new(),
124 })
125 }
126
127 pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
129 ModelInfo {
130 tier: "smart",
131 provider: self.config.smart_provider().to_string(),
132 model: model_override
133 .unwrap_or(self.config.smart_model())
134 .to_string(),
135 }
136 }
137
138 pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
140 ModelInfo {
141 tier: "fast",
142 provider: self.config.fast_provider().to_string(),
143 model: model_override
144 .unwrap_or(self.config.fast_model())
145 .to_string(),
146 }
147 }
148
149 pub async fn complete(&self, prompt: &str) -> Result<String> {
150 self.complete_with_model(prompt, None, None).await
151 }
152
153 pub async fn complete_smart(
156 &self,
157 prompt: &str,
158 model_override: Option<&str>,
159 ) -> Result<String> {
160 let model = model_override.unwrap_or(self.config.smart_model());
161 let provider = self.config.smart_provider();
162 self.complete_with_model(prompt, Some(model), Some(provider))
163 .await
164 }
165
166 pub async fn complete_fast(
169 &self,
170 prompt: &str,
171 model_override: Option<&str>,
172 ) -> Result<String> {
173 let model = model_override.unwrap_or(self.config.fast_model());
174 let provider = self.config.fast_provider();
175 self.complete_with_model(prompt, Some(model), Some(provider))
176 .await
177 }
178
179 pub async fn complete_with_model(
180 &self,
181 prompt: &str,
182 model_override: Option<&str>,
183 provider_override: Option<&str>,
184 ) -> Result<String> {
185 let provider = provider_override.unwrap_or(&self.config.llm.provider);
186 match provider {
187 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
188 "codex" => self.complete_codex_cli(prompt, model_override).await,
189 "anthropic" => {
190 self.complete_anthropic_with_model(prompt, model_override)
191 .await
192 }
193 "xai" | "openai" | "openrouter" => {
194 self.complete_openai_compatible_with_model(prompt, model_override, provider)
195 .await
196 }
197 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
198 }
199 }
200
201 async fn complete_anthropic_with_model(
202 &self,
203 prompt: &str,
204 model_override: Option<&str>,
205 ) -> Result<String> {
206 let model = model_override.unwrap_or(&self.config.llm.model);
207 let request = AnthropicRequest {
208 model: model.to_string(),
209 max_tokens: self.config.llm.max_tokens,
210 messages: vec![AnthropicMessage {
211 role: "user".to_string(),
212 content: prompt.to_string(),
213 }],
214 };
215
216 let response = self
217 .client
218 .post(self.config.api_endpoint())
219 .header("x-api-key", &self.api_key)
220 .header("anthropic-version", "2023-06-01")
221 .header("content-type", "application/json")
222 .json(&request)
223 .send()
224 .await
225 .context("Failed to send request to Anthropic API")?;
226
227 if !response.status().is_success() {
228 let status = response.status();
229 let error_text = response.text().await.unwrap_or_default();
230 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
231 }
232
233 let api_response: AnthropicResponse = response
234 .json()
235 .await
236 .context("Failed to parse Anthropic API response")?;
237
238 Ok(api_response
239 .content
240 .first()
241 .map(|c| c.text.clone())
242 .unwrap_or_default())
243 }
244
245 async fn complete_openai_compatible_with_model(
246 &self,
247 prompt: &str,
248 model_override: Option<&str>,
249 provider: &str,
250 ) -> Result<String> {
251 let model = model_override.unwrap_or(&self.config.llm.model);
252 let model_for_api = if provider != "openrouter" {
255 let prefix = format!("{}/", provider);
256 model.strip_prefix(&prefix).unwrap_or(model)
257 } else {
258 model
259 };
260
261 let endpoint = match provider {
263 "xai" => "https://api.x.ai/v1/chat/completions",
264 "openai" => "https://api.openai.com/v1/chat/completions",
265 "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
266 _ => "https://api.x.ai/v1/chat/completions",
267 };
268
269 let request = OpenAIRequest {
270 model: model_for_api.to_string(),
271 max_tokens: self.config.llm.max_tokens,
272 messages: vec![OpenAIMessage {
273 role: "user".to_string(),
274 content: prompt.to_string(),
275 }],
276 };
277
278 let mut request_builder = self
279 .client
280 .post(endpoint)
281 .header("authorization", format!("Bearer {}", self.api_key))
282 .header("content-type", "application/json");
283
284 if provider == "openrouter" {
286 request_builder = request_builder
287 .header("HTTP-Referer", "https://github.com/scud-cli")
288 .header("X-Title", "SCUD Task Master");
289 }
290
291 let response = request_builder
292 .json(&request)
293 .send()
294 .await
295 .with_context(|| format!("Failed to send request to {} API", provider))?;
296
297 if !response.status().is_success() {
298 let status = response.status();
299 let error_text = response.text().await.unwrap_or_default();
300 anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
301 }
302
303 let api_response: OpenAIResponse = response
304 .json()
305 .await
306 .with_context(|| format!("Failed to parse {} API response", provider))?;
307
308 Ok(api_response
309 .choices
310 .first()
311 .map(|c| c.message.content.clone())
312 .unwrap_or_default())
313 }
314
315 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
316 where
317 T: serde::de::DeserializeOwned,
318 {
319 self.complete_json_with_model(prompt, None).await
320 }
321
322 pub async fn complete_json_smart<T>(
324 &self,
325 prompt: &str,
326 model_override: Option<&str>,
327 ) -> Result<T>
328 where
329 T: serde::de::DeserializeOwned,
330 {
331 let response_text = self.complete_smart(prompt, model_override).await?;
332 Self::parse_json_response(&response_text)
333 }
334
335 pub async fn complete_json_fast<T>(
337 &self,
338 prompt: &str,
339 model_override: Option<&str>,
340 ) -> Result<T>
341 where
342 T: serde::de::DeserializeOwned,
343 {
344 let response_text = self.complete_fast(prompt, model_override).await?;
345 Self::parse_json_response(&response_text)
346 }
347
348 pub async fn complete_json_with_model<T>(
349 &self,
350 prompt: &str,
351 model_override: Option<&str>,
352 ) -> Result<T>
353 where
354 T: serde::de::DeserializeOwned,
355 {
356 let response_text = self
357 .complete_with_model(prompt, model_override, None)
358 .await?;
359 Self::parse_json_response(&response_text)
360 }
361
362 fn parse_json_response<T>(response_text: &str) -> Result<T>
363 where
364 T: serde::de::DeserializeOwned,
365 {
366 let json_str = Self::extract_json(response_text);
368
369 serde_json::from_str(json_str).with_context(|| {
370 let preview = if json_str.len() > 500 {
372 format!("{}...", &json_str[..500])
373 } else {
374 json_str.to_string()
375 };
376 format!(
377 "Failed to parse JSON from LLM response. Response preview:\n{}",
378 preview
379 )
380 })
381 }
382
383 fn extract_json(response: &str) -> &str {
385 if let Some(start) = response.find("```json") {
387 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
389 return response[content_start..content_start + end].trim();
390 }
391 }
392
393 if let Some(start) = response.find("```") {
395 let content_start = start + 3;
396 let content_start = response[content_start..]
398 .find('\n')
399 .map(|i| content_start + i + 1)
400 .unwrap_or(content_start);
401 if let Some(end) = response[content_start..].find("```") {
402 return response[content_start..content_start + end].trim();
403 }
404 }
405
406 if let Some(start) = response.find('[') {
408 if let Some(end) = response.rfind(']') {
409 if end > start {
410 return &response[start..=end];
411 }
412 }
413 }
414
415 if let Some(start) = response.find('{') {
417 if let Some(end) = response.rfind('}') {
418 if end > start {
419 return &response[start..=end];
420 }
421 }
422 }
423
424 response.trim()
425 }
426
427 async fn complete_claude_cli(
428 &self,
429 prompt: &str,
430 model_override: Option<&str>,
431 ) -> Result<String> {
432 use std::process::Stdio;
433 use tokio::io::AsyncWriteExt;
434 use tokio::process::Command;
435
436 let model = model_override.unwrap_or(&self.config.llm.model);
437
438 let mut cmd = Command::new("claude");
440 cmd.arg("-p") .arg("--output-format")
442 .arg("json")
443 .arg("--model")
444 .arg(model)
445 .stdin(Stdio::piped())
446 .stdout(Stdio::piped())
447 .stderr(Stdio::piped());
448
449 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
451
452 if let Some(mut stdin) = child.stdin.take() {
454 stdin
455 .write_all(prompt.as_bytes())
456 .await
457 .context("Failed to write prompt to claude stdin")?;
458 drop(stdin); }
460
461 let output = child
463 .wait_with_output()
464 .await
465 .context("Failed to wait for claude command")?;
466
467 if !output.status.success() {
468 let stderr = String::from_utf8_lossy(&output.stderr);
469 anyhow::bail!("Claude CLI error: {}", stderr);
470 }
471
472 let stdout =
474 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
475
476 #[derive(Deserialize)]
477 struct ClaudeCliResponse {
478 result: String,
479 }
480
481 let response: ClaudeCliResponse =
482 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
483
484 Ok(response.result)
485 }
486
487 async fn complete_codex_cli(
488 &self,
489 prompt: &str,
490 model_override: Option<&str>,
491 ) -> Result<String> {
492 use std::process::Stdio;
493 use tokio::io::AsyncWriteExt;
494 use tokio::process::Command;
495
496 let model = model_override.unwrap_or(&self.config.llm.model);
497
498 let mut cmd = Command::new("codex");
501 cmd.arg("-p") .arg("--model")
503 .arg(model)
504 .arg("--output-format")
505 .arg("json")
506 .stdin(Stdio::piped())
507 .stdout(Stdio::piped())
508 .stderr(Stdio::piped());
509
510 let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
512
513 if let Some(mut stdin) = child.stdin.take() {
515 stdin
516 .write_all(prompt.as_bytes())
517 .await
518 .context("Failed to write prompt to codex stdin")?;
519 drop(stdin); }
521
522 let output = child
524 .wait_with_output()
525 .await
526 .context("Failed to wait for codex command")?;
527
528 if !output.status.success() {
529 let stderr = String::from_utf8_lossy(&output.stderr);
530 anyhow::bail!("Codex CLI error: {}", stderr);
531 }
532
533 let stdout =
535 String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
536
537 #[derive(Deserialize)]
539 struct CodexCliResponse {
540 result: String,
541 }
542
543 let response: CodexCliResponse =
544 serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
545
546 Ok(response.result)
547 }
548}