scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62pub struct LLMClient {
63    config: Config,
64    api_key: String,
65    client: reqwest::Client,
66}
67
68impl LLMClient {
69    pub fn new() -> Result<Self> {
70        let storage = Storage::new(None);
71        let config = storage.load_config()?;
72
73        let api_key = if config.requires_api_key() {
74            env::var(config.api_key_env_var()).with_context(|| {
75                format!("{} environment variable not set", config.api_key_env_var())
76            })?
77        } else {
78            String::new() // Claude CLI doesn't need API key
79        };
80
81        Ok(LLMClient {
82            config,
83            api_key,
84            client: reqwest::Client::new(),
85        })
86    }
87
88    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
89        let storage = Storage::new(Some(project_root));
90        let config = storage.load_config()?;
91
92        let api_key = if config.requires_api_key() {
93            env::var(config.api_key_env_var()).with_context(|| {
94                format!("{} environment variable not set", config.api_key_env_var())
95            })?
96        } else {
97            String::new() // Claude CLI doesn't need API key
98        };
99
100        Ok(LLMClient {
101            config,
102            api_key,
103            client: reqwest::Client::new(),
104        })
105    }
106
107    pub async fn complete(&self, prompt: &str) -> Result<String> {
108        self.complete_with_model(prompt, None).await
109    }
110
111    /// Complete using the smart model (for validation/analysis tasks with large context)
112    /// Use user override if provided, otherwise fall back to configured smart_model
113    pub async fn complete_smart(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
114        let model = model_override.unwrap_or(self.config.smart_model());
115        self.complete_with_model(prompt, Some(model)).await
116    }
117
118    /// Complete using the fast model (for generation tasks)
119    /// Use user override if provided, otherwise fall back to configured fast_model
120    pub async fn complete_fast(&self, prompt: &str, model_override: Option<&str>) -> Result<String> {
121        let model = model_override.unwrap_or(self.config.fast_model());
122        self.complete_with_model(prompt, Some(model)).await
123    }
124
125    pub async fn complete_with_model(
126        &self,
127        prompt: &str,
128        model_override: Option<&str>,
129    ) -> Result<String> {
130        match self.config.llm.provider.as_str() {
131            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
132            "codex" => self.complete_codex_cli(prompt, model_override).await,
133            "anthropic" => {
134                self.complete_anthropic_with_model(prompt, model_override)
135                    .await
136            }
137            "xai" | "openai" | "openrouter" => {
138                self.complete_openai_compatible_with_model(prompt, model_override)
139                    .await
140            }
141            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
142        }
143    }
144
145    async fn complete_anthropic_with_model(
146        &self,
147        prompt: &str,
148        model_override: Option<&str>,
149    ) -> Result<String> {
150        let model = model_override.unwrap_or(&self.config.llm.model);
151        let request = AnthropicRequest {
152            model: model.to_string(),
153            max_tokens: self.config.llm.max_tokens,
154            messages: vec![AnthropicMessage {
155                role: "user".to_string(),
156                content: prompt.to_string(),
157            }],
158        };
159
160        let response = self
161            .client
162            .post(self.config.api_endpoint())
163            .header("x-api-key", &self.api_key)
164            .header("anthropic-version", "2023-06-01")
165            .header("content-type", "application/json")
166            .json(&request)
167            .send()
168            .await
169            .context("Failed to send request to Anthropic API")?;
170
171        if !response.status().is_success() {
172            let status = response.status();
173            let error_text = response.text().await.unwrap_or_default();
174            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
175        }
176
177        let api_response: AnthropicResponse = response
178            .json()
179            .await
180            .context("Failed to parse Anthropic API response")?;
181
182        Ok(api_response
183            .content
184            .first()
185            .map(|c| c.text.clone())
186            .unwrap_or_default())
187    }
188
189    async fn complete_openai_compatible_with_model(
190        &self,
191        prompt: &str,
192        model_override: Option<&str>,
193    ) -> Result<String> {
194        let model = model_override.unwrap_or(&self.config.llm.model);
195        let request = OpenAIRequest {
196            model: model.to_string(),
197            max_tokens: self.config.llm.max_tokens,
198            messages: vec![OpenAIMessage {
199                role: "user".to_string(),
200                content: prompt.to_string(),
201            }],
202        };
203
204        let mut request_builder = self
205            .client
206            .post(self.config.api_endpoint())
207            .header("authorization", format!("Bearer {}", self.api_key))
208            .header("content-type", "application/json");
209
210        // OpenRouter requires additional headers
211        if self.config.llm.provider == "openrouter" {
212            request_builder = request_builder
213                .header("HTTP-Referer", "https://github.com/scud-cli")
214                .header("X-Title", "SCUD Task Master");
215        }
216
217        let response = request_builder
218            .json(&request)
219            .send()
220            .await
221            .with_context(|| {
222                format!("Failed to send request to {} API", self.config.llm.provider)
223            })?;
224
225        if !response.status().is_success() {
226            let status = response.status();
227            let error_text = response.text().await.unwrap_or_default();
228            anyhow::bail!(
229                "{} API error ({}): {}",
230                self.config.llm.provider,
231                status,
232                error_text
233            );
234        }
235
236        let api_response: OpenAIResponse = response.json().await.with_context(|| {
237            format!("Failed to parse {} API response", self.config.llm.provider)
238        })?;
239
240        Ok(api_response
241            .choices
242            .first()
243            .map(|c| c.message.content.clone())
244            .unwrap_or_default())
245    }
246
247    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
248    where
249        T: serde::de::DeserializeOwned,
250    {
251        self.complete_json_with_model(prompt, None).await
252    }
253
254    /// Complete JSON using the smart model (for validation/analysis tasks)
255    pub async fn complete_json_smart<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
256    where
257        T: serde::de::DeserializeOwned,
258    {
259        let response_text = self.complete_smart(prompt, model_override).await?;
260        Self::parse_json_response(&response_text)
261    }
262
263    /// Complete JSON using the fast model (for generation tasks)
264    pub async fn complete_json_fast<T>(&self, prompt: &str, model_override: Option<&str>) -> Result<T>
265    where
266        T: serde::de::DeserializeOwned,
267    {
268        let response_text = self.complete_fast(prompt, model_override).await?;
269        Self::parse_json_response(&response_text)
270    }
271
272    pub async fn complete_json_with_model<T>(
273        &self,
274        prompt: &str,
275        model_override: Option<&str>,
276    ) -> Result<T>
277    where
278        T: serde::de::DeserializeOwned,
279    {
280        let response_text = self.complete_with_model(prompt, model_override).await?;
281        Self::parse_json_response(&response_text)
282    }
283
284    fn parse_json_response<T>(response_text: &str) -> Result<T>
285    where
286        T: serde::de::DeserializeOwned,
287    {
288        // Try to find JSON in the response (LLM might include markdown or explanations)
289        let json_str = Self::extract_json(response_text);
290
291        serde_json::from_str(json_str).with_context(|| {
292            // Provide helpful error context
293            let preview = if json_str.len() > 500 {
294                format!("{}...", &json_str[..500])
295            } else {
296                json_str.to_string()
297            };
298            format!(
299                "Failed to parse JSON from LLM response. Response preview:\n{}",
300                preview
301            )
302        })
303    }
304
305    /// Extract JSON from LLM response, handling markdown code blocks and extra text
306    fn extract_json(response: &str) -> &str {
307        // First, try to extract from markdown code blocks
308        if let Some(start) = response.find("```json") {
309            let content_start = start + 7; // Skip "```json"
310            if let Some(end) = response[content_start..].find("```") {
311                return response[content_start..content_start + end].trim();
312            }
313        }
314
315        // Try plain code blocks
316        if let Some(start) = response.find("```") {
317            let content_start = start + 3;
318            // Skip language identifier if present (e.g., "```\n")
319            let content_start = response[content_start..]
320                .find('\n')
321                .map(|i| content_start + i + 1)
322                .unwrap_or(content_start);
323            if let Some(end) = response[content_start..].find("```") {
324                return response[content_start..content_start + end].trim();
325            }
326        }
327
328        // Try to find array JSON
329        if let Some(start) = response.find('[') {
330            if let Some(end) = response.rfind(']') {
331                if end > start {
332                    return &response[start..=end];
333                }
334            }
335        }
336
337        // Try to find object JSON
338        if let Some(start) = response.find('{') {
339            if let Some(end) = response.rfind('}') {
340                if end > start {
341                    return &response[start..=end];
342                }
343            }
344        }
345
346        response.trim()
347    }
348
349    async fn complete_claude_cli(
350        &self,
351        prompt: &str,
352        model_override: Option<&str>,
353    ) -> Result<String> {
354        use std::process::Stdio;
355        use tokio::io::AsyncWriteExt;
356        use tokio::process::Command;
357
358        let model = model_override.unwrap_or(&self.config.llm.model);
359
360        // Build the claude command
361        let mut cmd = Command::new("claude");
362        cmd.arg("-p") // Print mode (headless)
363            .arg("--output-format")
364            .arg("json")
365            .arg("--model")
366            .arg(model)
367            .stdin(Stdio::piped())
368            .stdout(Stdio::piped())
369            .stderr(Stdio::piped());
370
371        // Spawn the process
372        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
373
374        // Write prompt to stdin
375        if let Some(mut stdin) = child.stdin.take() {
376            stdin
377                .write_all(prompt.as_bytes())
378                .await
379                .context("Failed to write prompt to claude stdin")?;
380            drop(stdin); // Close stdin
381        }
382
383        // Wait for completion
384        let output = child
385            .wait_with_output()
386            .await
387            .context("Failed to wait for claude command")?;
388
389        if !output.status.success() {
390            let stderr = String::from_utf8_lossy(&output.stderr);
391            anyhow::bail!("Claude CLI error: {}", stderr);
392        }
393
394        // Parse JSON output
395        let stdout =
396            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
397
398        #[derive(Deserialize)]
399        struct ClaudeCliResponse {
400            result: String,
401        }
402
403        let response: ClaudeCliResponse =
404            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
405
406        Ok(response.result)
407    }
408
409    async fn complete_codex_cli(
410        &self,
411        prompt: &str,
412        model_override: Option<&str>,
413    ) -> Result<String> {
414        use std::process::Stdio;
415        use tokio::io::AsyncWriteExt;
416        use tokio::process::Command;
417
418        let model = model_override.unwrap_or(&self.config.llm.model);
419
420        // Build the codex command
421        // Codex CLI uses similar headless mode to Claude Code
422        let mut cmd = Command::new("codex");
423        cmd.arg("-p") // Prompt mode (headless/non-interactive)
424            .arg("--model")
425            .arg(model)
426            .arg("--output-format")
427            .arg("json")
428            .stdin(Stdio::piped())
429            .stdout(Stdio::piped())
430            .stderr(Stdio::piped());
431
432        // Spawn the process
433        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
434
435        // Write prompt to stdin
436        if let Some(mut stdin) = child.stdin.take() {
437            stdin
438                .write_all(prompt.as_bytes())
439                .await
440                .context("Failed to write prompt to codex stdin")?;
441            drop(stdin); // Close stdin
442        }
443
444        // Wait for completion
445        let output = child
446            .wait_with_output()
447            .await
448            .context("Failed to wait for codex command")?;
449
450        if !output.status.success() {
451            let stderr = String::from_utf8_lossy(&output.stderr);
452            anyhow::bail!("Codex CLI error: {}", stderr);
453        }
454
455        // Parse JSON output
456        let stdout =
457            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
458
459        // Codex outputs JSON with a result field similar to Claude CLI
460        #[derive(Deserialize)]
461        struct CodexCliResponse {
462            result: String,
463        }
464
465        let response: CodexCliResponse =
466            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
467
468        Ok(response.result)
469    }
470}