scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62/// Configuration for Claude Code tool access in headless mode
63#[derive(Debug, Clone, Default)]
64pub struct ToolConfig {
65    /// Tools to auto-approve (e.g., "Read", "Write", "Grep", "Glob")
66    pub allowed_tools: Vec<String>,
67    /// Maximum agentic turns before stopping
68    pub max_turns: Option<u32>,
69}
70
71impl ToolConfig {
72    pub fn new(allowed_tools: Vec<String>, max_turns: Option<u32>) -> Self {
73        Self {
74            allowed_tools,
75            max_turns,
76        }
77    }
78}
79
80pub struct LLMClient {
81    config: Config,
82    api_key: String,
83    client: reqwest::Client,
84}
85
86impl LLMClient {
87    pub fn new() -> Result<Self> {
88        let storage = Storage::new(None);
89        let config = storage.load_config()?;
90
91        let api_key = if config.requires_api_key() {
92            env::var(config.api_key_env_var()).with_context(|| {
93                format!("{} environment variable not set", config.api_key_env_var())
94            })?
95        } else {
96            String::new() // Claude CLI doesn't need API key
97        };
98
99        Ok(LLMClient {
100            config,
101            api_key,
102            client: reqwest::Client::new(),
103        })
104    }
105
106    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
107        let storage = Storage::new(Some(project_root));
108        let config = storage.load_config()?;
109
110        let api_key = if config.requires_api_key() {
111            env::var(config.api_key_env_var()).with_context(|| {
112                format!("{} environment variable not set", config.api_key_env_var())
113            })?
114        } else {
115            String::new() // Claude CLI doesn't need API key
116        };
117
118        Ok(LLMClient {
119            config,
120            api_key,
121            client: reqwest::Client::new(),
122        })
123    }
124
125    pub async fn complete(&self, prompt: &str) -> Result<String> {
126        self.complete_with_model(prompt, None).await
127    }
128
129    pub async fn complete_with_model(
130        &self,
131        prompt: &str,
132        model_override: Option<&str>,
133    ) -> Result<String> {
134        match self.config.llm.provider.as_str() {
135            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
136            "anthropic" => {
137                self.complete_anthropic_with_model(prompt, model_override)
138                    .await
139            }
140            "xai" | "openai" | "openrouter" => {
141                self.complete_openai_compatible_with_model(prompt, model_override)
142                    .await
143            }
144            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
145        }
146    }
147
148    async fn complete_anthropic_with_model(
149        &self,
150        prompt: &str,
151        model_override: Option<&str>,
152    ) -> Result<String> {
153        let model = model_override.unwrap_or(&self.config.llm.model);
154        let request = AnthropicRequest {
155            model: model.to_string(),
156            max_tokens: self.config.llm.max_tokens,
157            messages: vec![AnthropicMessage {
158                role: "user".to_string(),
159                content: prompt.to_string(),
160            }],
161        };
162
163        let response = self
164            .client
165            .post(self.config.api_endpoint())
166            .header("x-api-key", &self.api_key)
167            .header("anthropic-version", "2023-06-01")
168            .header("content-type", "application/json")
169            .json(&request)
170            .send()
171            .await
172            .context("Failed to send request to Anthropic API")?;
173
174        if !response.status().is_success() {
175            let status = response.status();
176            let error_text = response.text().await.unwrap_or_default();
177            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
178        }
179
180        let api_response: AnthropicResponse = response
181            .json()
182            .await
183            .context("Failed to parse Anthropic API response")?;
184
185        Ok(api_response
186            .content
187            .first()
188            .map(|c| c.text.clone())
189            .unwrap_or_default())
190    }
191
192    async fn complete_openai_compatible_with_model(
193        &self,
194        prompt: &str,
195        model_override: Option<&str>,
196    ) -> Result<String> {
197        let model = model_override.unwrap_or(&self.config.llm.model);
198        let request = OpenAIRequest {
199            model: model.to_string(),
200            max_tokens: self.config.llm.max_tokens,
201            messages: vec![OpenAIMessage {
202                role: "user".to_string(),
203                content: prompt.to_string(),
204            }],
205        };
206
207        let mut request_builder = self
208            .client
209            .post(self.config.api_endpoint())
210            .header("authorization", format!("Bearer {}", self.api_key))
211            .header("content-type", "application/json");
212
213        // OpenRouter requires additional headers
214        if self.config.llm.provider == "openrouter" {
215            request_builder = request_builder
216                .header("HTTP-Referer", "https://github.com/scud-cli")
217                .header("X-Title", "SCUD Task Master");
218        }
219
220        let response = request_builder
221            .json(&request)
222            .send()
223            .await
224            .with_context(|| {
225                format!("Failed to send request to {} API", self.config.llm.provider)
226            })?;
227
228        if !response.status().is_success() {
229            let status = response.status();
230            let error_text = response.text().await.unwrap_or_default();
231            anyhow::bail!(
232                "{} API error ({}): {}",
233                self.config.llm.provider,
234                status,
235                error_text
236            );
237        }
238
239        let api_response: OpenAIResponse = response.json().await.with_context(|| {
240            format!("Failed to parse {} API response", self.config.llm.provider)
241        })?;
242
243        Ok(api_response
244            .choices
245            .first()
246            .map(|c| c.message.content.clone())
247            .unwrap_or_default())
248    }
249
250    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
251    where
252        T: serde::de::DeserializeOwned,
253    {
254        let response_text = self.complete(prompt).await?;
255
256        // Try to find JSON in the response (LLM might include markdown or explanations)
257        let json_str = Self::extract_json(&response_text);
258
259        serde_json::from_str(json_str).with_context(|| {
260            // Provide helpful error context
261            let preview = if json_str.len() > 500 {
262                format!("{}...", &json_str[..500])
263            } else {
264                json_str.to_string()
265            };
266            format!(
267                "Failed to parse JSON from LLM response. Response preview:\n{}",
268                preview
269            )
270        })
271    }
272
273    /// Check if the current provider is claude-cli
274    pub fn is_claude_cli(&self) -> bool {
275        self.config.llm.provider == "claude-cli"
276    }
277
278    /// Complete with Claude Code tool access (only works with claude-cli provider)
279    pub async fn complete_with_tools(&self, prompt: &str, tools: &ToolConfig) -> Result<String> {
280        if !self.is_claude_cli() {
281            // Fall back to regular completion for non-claude-cli providers
282            return self.complete(prompt).await;
283        }
284        self.complete_claude_cli_with_tools(prompt, None, Some(tools))
285            .await
286    }
287
288    /// Complete and parse JSON with Claude Code tool access
289    pub async fn complete_json_with_tools<T>(&self, prompt: &str, tools: &ToolConfig) -> Result<T>
290    where
291        T: serde::de::DeserializeOwned,
292    {
293        let response_text = self.complete_with_tools(prompt, tools).await?;
294
295        // Try to find JSON in the response
296        let json_str = Self::extract_json(&response_text);
297
298        serde_json::from_str(json_str).with_context(|| {
299            let preview = if json_str.len() > 500 {
300                format!("{}...", &json_str[..500])
301            } else {
302                json_str.to_string()
303            };
304            format!(
305                "Failed to parse JSON from LLM response. Response preview:\n{}",
306                preview
307            )
308        })
309    }
310
311    /// Extract JSON from LLM response, handling markdown code blocks and extra text
312    fn extract_json(response: &str) -> &str {
313        // First, try to extract from markdown code blocks
314        if let Some(start) = response.find("```json") {
315            let content_start = start + 7; // Skip "```json"
316            if let Some(end) = response[content_start..].find("```") {
317                return response[content_start..content_start + end].trim();
318            }
319        }
320
321        // Try plain code blocks
322        if let Some(start) = response.find("```") {
323            let content_start = start + 3;
324            // Skip language identifier if present (e.g., "```\n")
325            let content_start = response[content_start..]
326                .find('\n')
327                .map(|i| content_start + i + 1)
328                .unwrap_or(content_start);
329            if let Some(end) = response[content_start..].find("```") {
330                return response[content_start..content_start + end].trim();
331            }
332        }
333
334        // Try to find array JSON
335        if let Some(start) = response.find('[') {
336            if let Some(end) = response.rfind(']') {
337                if end > start {
338                    return &response[start..=end];
339                }
340            }
341        }
342
343        // Try to find object JSON
344        if let Some(start) = response.find('{') {
345            if let Some(end) = response.rfind('}') {
346                if end > start {
347                    return &response[start..=end];
348                }
349            }
350        }
351
352        response.trim()
353    }
354
355    async fn complete_claude_cli(
356        &self,
357        prompt: &str,
358        model_override: Option<&str>,
359    ) -> Result<String> {
360        self.complete_claude_cli_with_tools(prompt, model_override, None)
361            .await
362    }
363
364    async fn complete_claude_cli_with_tools(
365        &self,
366        prompt: &str,
367        model_override: Option<&str>,
368        tool_config: Option<&ToolConfig>,
369    ) -> Result<String> {
370        use std::process::Stdio;
371        use tokio::io::AsyncWriteExt;
372        use tokio::process::Command;
373
374        let model = model_override.unwrap_or(&self.config.llm.model);
375
376        // Build the claude command
377        let mut cmd = Command::new("claude");
378        cmd.arg("-p") // Print mode (headless)
379            .arg("--output-format")
380            .arg("json")
381            .arg("--model")
382            .arg(model);
383
384        // Add tool configuration if provided
385        if let Some(tools) = tool_config {
386            if !tools.allowed_tools.is_empty() {
387                cmd.arg("--allowedTools")
388                    .arg(tools.allowed_tools.join(","));
389            }
390            if let Some(max_turns) = tools.max_turns {
391                cmd.arg("--max-turns").arg(max_turns.to_string());
392            }
393        }
394
395        cmd.stdin(Stdio::piped())
396            .stdout(Stdio::piped())
397            .stderr(Stdio::piped());
398
399        // Spawn the process
400        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
401
402        // Write prompt to stdin
403        if let Some(mut stdin) = child.stdin.take() {
404            stdin
405                .write_all(prompt.as_bytes())
406                .await
407                .context("Failed to write prompt to claude stdin")?;
408            drop(stdin); // Close stdin
409        }
410
411        // Wait for completion
412        let output = child
413            .wait_with_output()
414            .await
415            .context("Failed to wait for claude command")?;
416
417        if !output.status.success() {
418            let stderr = String::from_utf8_lossy(&output.stderr);
419            anyhow::bail!("Claude CLI error: {}", stderr);
420        }
421
422        // Parse JSON output
423        let stdout =
424            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
425
426        #[derive(Deserialize)]
427        struct ClaudeCliResponse {
428            result: String,
429        }
430
431        let response: ClaudeCliResponse =
432            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
433
434        Ok(response.result)
435    }
436}