scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62pub struct LLMClient {
63    config: Config,
64    api_key: String,
65    client: reqwest::Client,
66}
67
68impl LLMClient {
69    pub fn new() -> Result<Self> {
70        let storage = Storage::new(None);
71        let config = storage.load_config()?;
72
73        let api_key = if config.requires_api_key() {
74            env::var(config.api_key_env_var()).with_context(|| {
75                format!("{} environment variable not set", config.api_key_env_var())
76            })?
77        } else {
78            String::new() // Claude CLI doesn't need API key
79        };
80
81        Ok(LLMClient {
82            config,
83            api_key,
84            client: reqwest::Client::new(),
85        })
86    }
87
88    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
89        let storage = Storage::new(Some(project_root));
90        let config = storage.load_config()?;
91
92        let api_key = if config.requires_api_key() {
93            env::var(config.api_key_env_var()).with_context(|| {
94                format!("{} environment variable not set", config.api_key_env_var())
95            })?
96        } else {
97            String::new() // Claude CLI doesn't need API key
98        };
99
100        Ok(LLMClient {
101            config,
102            api_key,
103            client: reqwest::Client::new(),
104        })
105    }
106
107    pub async fn complete(&self, prompt: &str) -> Result<String> {
108        self.complete_with_model(prompt, None).await
109    }
110
111    pub async fn complete_with_model(
112        &self,
113        prompt: &str,
114        model_override: Option<&str>,
115    ) -> Result<String> {
116        match self.config.llm.provider.as_str() {
117            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
118            "anthropic" => {
119                self.complete_anthropic_with_model(prompt, model_override)
120                    .await
121            }
122            "xai" | "openai" | "openrouter" => {
123                self.complete_openai_compatible_with_model(prompt, model_override)
124                    .await
125            }
126            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
127        }
128    }
129
130    async fn complete_anthropic_with_model(
131        &self,
132        prompt: &str,
133        model_override: Option<&str>,
134    ) -> Result<String> {
135        let model = model_override.unwrap_or(&self.config.llm.model);
136        let request = AnthropicRequest {
137            model: model.to_string(),
138            max_tokens: self.config.llm.max_tokens,
139            messages: vec![AnthropicMessage {
140                role: "user".to_string(),
141                content: prompt.to_string(),
142            }],
143        };
144
145        let response = self
146            .client
147            .post(self.config.api_endpoint())
148            .header("x-api-key", &self.api_key)
149            .header("anthropic-version", "2023-06-01")
150            .header("content-type", "application/json")
151            .json(&request)
152            .send()
153            .await
154            .context("Failed to send request to Anthropic API")?;
155
156        if !response.status().is_success() {
157            let status = response.status();
158            let error_text = response.text().await.unwrap_or_default();
159            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
160        }
161
162        let api_response: AnthropicResponse = response
163            .json()
164            .await
165            .context("Failed to parse Anthropic API response")?;
166
167        Ok(api_response
168            .content
169            .first()
170            .map(|c| c.text.clone())
171            .unwrap_or_default())
172    }
173
174    async fn complete_openai_compatible_with_model(
175        &self,
176        prompt: &str,
177        model_override: Option<&str>,
178    ) -> Result<String> {
179        let model = model_override.unwrap_or(&self.config.llm.model);
180        let request = OpenAIRequest {
181            model: model.to_string(),
182            max_tokens: self.config.llm.max_tokens,
183            messages: vec![OpenAIMessage {
184                role: "user".to_string(),
185                content: prompt.to_string(),
186            }],
187        };
188
189        let mut request_builder = self
190            .client
191            .post(self.config.api_endpoint())
192            .header("authorization", format!("Bearer {}", self.api_key))
193            .header("content-type", "application/json");
194
195        // OpenRouter requires additional headers
196        if self.config.llm.provider == "openrouter" {
197            request_builder = request_builder
198                .header("HTTP-Referer", "https://github.com/scud-cli")
199                .header("X-Title", "SCUD Task Master");
200        }
201
202        let response = request_builder
203            .json(&request)
204            .send()
205            .await
206            .with_context(|| {
207                format!("Failed to send request to {} API", self.config.llm.provider)
208            })?;
209
210        if !response.status().is_success() {
211            let status = response.status();
212            let error_text = response.text().await.unwrap_or_default();
213            anyhow::bail!(
214                "{} API error ({}): {}",
215                self.config.llm.provider,
216                status,
217                error_text
218            );
219        }
220
221        let api_response: OpenAIResponse = response.json().await.with_context(|| {
222            format!("Failed to parse {} API response", self.config.llm.provider)
223        })?;
224
225        Ok(api_response
226            .choices
227            .first()
228            .map(|c| c.message.content.clone())
229            .unwrap_or_default())
230    }
231
232    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
233    where
234        T: serde::de::DeserializeOwned,
235    {
236        let response_text = self.complete(prompt).await?;
237
238        // Try to find JSON in the response (LLM might include markdown or explanations)
239        let json_str = Self::extract_json(&response_text);
240
241        serde_json::from_str(json_str).with_context(|| {
242            // Provide helpful error context
243            let preview = if json_str.len() > 500 {
244                format!("{}...", &json_str[..500])
245            } else {
246                json_str.to_string()
247            };
248            format!(
249                "Failed to parse JSON from LLM response. Response preview:\n{}",
250                preview
251            )
252        })
253    }
254
255    /// Extract JSON from LLM response, handling markdown code blocks and extra text
256    fn extract_json(response: &str) -> &str {
257        // First, try to extract from markdown code blocks
258        if let Some(start) = response.find("```json") {
259            let content_start = start + 7; // Skip "```json"
260            if let Some(end) = response[content_start..].find("```") {
261                return response[content_start..content_start + end].trim();
262            }
263        }
264
265        // Try plain code blocks
266        if let Some(start) = response.find("```") {
267            let content_start = start + 3;
268            // Skip language identifier if present (e.g., "```\n")
269            let content_start = response[content_start..]
270                .find('\n')
271                .map(|i| content_start + i + 1)
272                .unwrap_or(content_start);
273            if let Some(end) = response[content_start..].find("```") {
274                return response[content_start..content_start + end].trim();
275            }
276        }
277
278        // Try to find array JSON
279        if let Some(start) = response.find('[') {
280            if let Some(end) = response.rfind(']') {
281                if end > start {
282                    return &response[start..=end];
283                }
284            }
285        }
286
287        // Try to find object JSON
288        if let Some(start) = response.find('{') {
289            if let Some(end) = response.rfind('}') {
290                if end > start {
291                    return &response[start..=end];
292                }
293            }
294        }
295
296        response.trim()
297    }
298
299    async fn complete_claude_cli(
300        &self,
301        prompt: &str,
302        model_override: Option<&str>,
303    ) -> Result<String> {
304        use std::process::Stdio;
305        use tokio::io::AsyncWriteExt;
306        use tokio::process::Command;
307
308        let model = model_override.unwrap_or(&self.config.llm.model);
309
310        // Build the claude command
311        let mut cmd = Command::new("claude");
312        cmd.arg("-p") // Print mode (headless)
313            .arg("--output-format")
314            .arg("json")
315            .arg("--model")
316            .arg(model)
317            .stdin(Stdio::piped())
318            .stdout(Stdio::piped())
319            .stderr(Stdio::piped());
320
321        // Spawn the process
322        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
323
324        // Write prompt to stdin
325        if let Some(mut stdin) = child.stdin.take() {
326            stdin
327                .write_all(prompt.as_bytes())
328                .await
329                .context("Failed to write prompt to claude stdin")?;
330            drop(stdin); // Close stdin
331        }
332
333        // Wait for completion
334        let output = child
335            .wait_with_output()
336            .await
337            .context("Failed to wait for claude command")?;
338
339        if !output.status.success() {
340            let stderr = String::from_utf8_lossy(&output.stderr);
341            anyhow::bail!("Claude CLI error: {}", stderr);
342        }
343
344        // Parse JSON output
345        let stdout =
346            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
347
348        #[derive(Deserialize)]
349        struct ClaudeCliResponse {
350            result: String,
351        }
352
353        let response: ClaudeCliResponse =
354            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
355
356        Ok(response.result)
357    }
358}