scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62pub struct LLMClient {
63    config: Config,
64    api_key: String,
65    client: reqwest::Client,
66}
67
68/// Info about which model is being used for display purposes
69#[derive(Debug, Clone)]
70pub struct ModelInfo {
71    pub tier: &'static str, // "fast" or "smart"
72    pub provider: String,
73    pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
79    }
80}
81
82impl LLMClient {
83    pub fn new() -> Result<Self> {
84        let storage = Storage::new(None);
85        let config = storage.load_config()?;
86
87        let api_key = if config.requires_api_key() {
88            env::var(config.api_key_env_var()).with_context(|| {
89                format!("{} environment variable not set", config.api_key_env_var())
90            })?
91        } else {
92            String::new() // Claude CLI doesn't need API key
93        };
94
95        Ok(LLMClient {
96            config,
97            api_key,
98            client: reqwest::Client::new(),
99        })
100    }
101
102    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
103        let storage = Storage::new(Some(project_root));
104        let config = storage.load_config()?;
105
106        let api_key = if config.requires_api_key() {
107            env::var(config.api_key_env_var()).with_context(|| {
108                format!("{} environment variable not set", config.api_key_env_var())
109            })?
110        } else {
111            String::new() // Claude CLI doesn't need API key
112        };
113
114        Ok(LLMClient {
115            config,
116            api_key,
117            client: reqwest::Client::new(),
118        })
119    }
120
121    /// Get info about the smart model that will be used
122    pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
123        ModelInfo {
124            tier: "smart",
125            provider: self.config.smart_provider().to_string(),
126            model: model_override
127                .unwrap_or(self.config.smart_model())
128                .to_string(),
129        }
130    }
131
132    /// Get info about the fast model that will be used
133    pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
134        ModelInfo {
135            tier: "fast",
136            provider: self.config.fast_provider().to_string(),
137            model: model_override
138                .unwrap_or(self.config.fast_model())
139                .to_string(),
140        }
141    }
142
143    pub async fn complete(&self, prompt: &str) -> Result<String> {
144        self.complete_with_model(prompt, None, None).await
145    }
146
147    /// Complete using the smart model (for validation/analysis tasks with large context)
148    /// Use user override if provided, otherwise fall back to configured smart_model
149    pub async fn complete_smart(
150        &self,
151        prompt: &str,
152        model_override: Option<&str>,
153    ) -> Result<String> {
154        let model = model_override.unwrap_or(self.config.smart_model());
155        let provider = self.config.smart_provider();
156        self.complete_with_model(prompt, Some(model), Some(provider))
157            .await
158    }
159
160    /// Complete using the fast model (for generation tasks)
161    /// Use user override if provided, otherwise fall back to configured fast_model
162    pub async fn complete_fast(
163        &self,
164        prompt: &str,
165        model_override: Option<&str>,
166    ) -> Result<String> {
167        let model = model_override.unwrap_or(self.config.fast_model());
168        let provider = self.config.fast_provider();
169        self.complete_with_model(prompt, Some(model), Some(provider))
170            .await
171    }
172
173    pub async fn complete_with_model(
174        &self,
175        prompt: &str,
176        model_override: Option<&str>,
177        provider_override: Option<&str>,
178    ) -> Result<String> {
179        let provider = provider_override.unwrap_or(&self.config.llm.provider);
180        match provider.as_ref() {
181            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
182            "codex" => self.complete_codex_cli(prompt, model_override).await,
183            "anthropic" => {
184                self.complete_anthropic_with_model(prompt, model_override)
185                    .await
186            }
187            "xai" | "openai" | "openrouter" => {
188                self.complete_openai_compatible_with_model(prompt, model_override)
189                    .await
190            }
191            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
192        }
193    }
194
195    async fn complete_anthropic_with_model(
196        &self,
197        prompt: &str,
198        model_override: Option<&str>,
199    ) -> Result<String> {
200        let model = model_override.unwrap_or(&self.config.llm.model);
201        let request = AnthropicRequest {
202            model: model.to_string(),
203            max_tokens: self.config.llm.max_tokens,
204            messages: vec![AnthropicMessage {
205                role: "user".to_string(),
206                content: prompt.to_string(),
207            }],
208        };
209
210        let response = self
211            .client
212            .post(self.config.api_endpoint())
213            .header("x-api-key", &self.api_key)
214            .header("anthropic-version", "2023-06-01")
215            .header("content-type", "application/json")
216            .json(&request)
217            .send()
218            .await
219            .context("Failed to send request to Anthropic API")?;
220
221        if !response.status().is_success() {
222            let status = response.status();
223            let error_text = response.text().await.unwrap_or_default();
224            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
225        }
226
227        let api_response: AnthropicResponse = response
228            .json()
229            .await
230            .context("Failed to parse Anthropic API response")?;
231
232        Ok(api_response
233            .content
234            .first()
235            .map(|c| c.text.clone())
236            .unwrap_or_default())
237    }
238
239    async fn complete_openai_compatible_with_model(
240        &self,
241        prompt: &str,
242        model_override: Option<&str>,
243    ) -> Result<String> {
244        let model = model_override.unwrap_or(&self.config.llm.model);
245        let request = OpenAIRequest {
246            model: model.to_string(),
247            max_tokens: self.config.llm.max_tokens,
248            messages: vec![OpenAIMessage {
249                role: "user".to_string(),
250                content: prompt.to_string(),
251            }],
252        };
253
254        let mut request_builder = self
255            .client
256            .post(self.config.api_endpoint())
257            .header("authorization", format!("Bearer {}", self.api_key))
258            .header("content-type", "application/json");
259
260        // OpenRouter requires additional headers
261        if self.config.llm.provider == "openrouter" {
262            request_builder = request_builder
263                .header("HTTP-Referer", "https://github.com/scud-cli")
264                .header("X-Title", "SCUD Task Master");
265        }
266
267        let response = request_builder
268            .json(&request)
269            .send()
270            .await
271            .with_context(|| {
272                format!("Failed to send request to {} API", self.config.llm.provider)
273            })?;
274
275        if !response.status().is_success() {
276            let status = response.status();
277            let error_text = response.text().await.unwrap_or_default();
278            anyhow::bail!(
279                "{} API error ({}): {}",
280                self.config.llm.provider,
281                status,
282                error_text
283            );
284        }
285
286        let api_response: OpenAIResponse = response.json().await.with_context(|| {
287            format!("Failed to parse {} API response", self.config.llm.provider)
288        })?;
289
290        Ok(api_response
291            .choices
292            .first()
293            .map(|c| c.message.content.clone())
294            .unwrap_or_default())
295    }
296
297    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
298    where
299        T: serde::de::DeserializeOwned,
300    {
301        self.complete_json_with_model(prompt, None).await
302    }
303
304    /// Complete JSON using the smart model (for validation/analysis tasks)
305    pub async fn complete_json_smart<T>(
306        &self,
307        prompt: &str,
308        model_override: Option<&str>,
309    ) -> Result<T>
310    where
311        T: serde::de::DeserializeOwned,
312    {
313        let response_text = self.complete_smart(prompt, model_override).await?;
314        Self::parse_json_response(&response_text)
315    }
316
317    /// Complete JSON using the fast model (for generation tasks)
318    pub async fn complete_json_fast<T>(
319        &self,
320        prompt: &str,
321        model_override: Option<&str>,
322    ) -> Result<T>
323    where
324        T: serde::de::DeserializeOwned,
325    {
326        let response_text = self.complete_fast(prompt, model_override).await?;
327        Self::parse_json_response(&response_text)
328    }
329
330    pub async fn complete_json_with_model<T>(
331        &self,
332        prompt: &str,
333        model_override: Option<&str>,
334    ) -> Result<T>
335    where
336        T: serde::de::DeserializeOwned,
337    {
338        let response_text = self
339            .complete_with_model(prompt, model_override, None)
340            .await?;
341        Self::parse_json_response(&response_text)
342    }
343
344    fn parse_json_response<T>(response_text: &str) -> Result<T>
345    where
346        T: serde::de::DeserializeOwned,
347    {
348        // Try to find JSON in the response (LLM might include markdown or explanations)
349        let json_str = Self::extract_json(response_text);
350
351        serde_json::from_str(json_str).with_context(|| {
352            // Provide helpful error context
353            let preview = if json_str.len() > 500 {
354                format!("{}...", &json_str[..500])
355            } else {
356                json_str.to_string()
357            };
358            format!(
359                "Failed to parse JSON from LLM response. Response preview:\n{}",
360                preview
361            )
362        })
363    }
364
365    /// Extract JSON from LLM response, handling markdown code blocks and extra text
366    fn extract_json(response: &str) -> &str {
367        // First, try to extract from markdown code blocks
368        if let Some(start) = response.find("```json") {
369            let content_start = start + 7; // Skip "```json"
370            if let Some(end) = response[content_start..].find("```") {
371                return response[content_start..content_start + end].trim();
372            }
373        }
374
375        // Try plain code blocks
376        if let Some(start) = response.find("```") {
377            let content_start = start + 3;
378            // Skip language identifier if present (e.g., "```\n")
379            let content_start = response[content_start..]
380                .find('\n')
381                .map(|i| content_start + i + 1)
382                .unwrap_or(content_start);
383            if let Some(end) = response[content_start..].find("```") {
384                return response[content_start..content_start + end].trim();
385            }
386        }
387
388        // Try to find array JSON
389        if let Some(start) = response.find('[') {
390            if let Some(end) = response.rfind(']') {
391                if end > start {
392                    return &response[start..=end];
393                }
394            }
395        }
396
397        // Try to find object JSON
398        if let Some(start) = response.find('{') {
399            if let Some(end) = response.rfind('}') {
400                if end > start {
401                    return &response[start..=end];
402                }
403            }
404        }
405
406        response.trim()
407    }
408
409    async fn complete_claude_cli(
410        &self,
411        prompt: &str,
412        model_override: Option<&str>,
413    ) -> Result<String> {
414        use std::process::Stdio;
415        use tokio::io::AsyncWriteExt;
416        use tokio::process::Command;
417
418        let model = model_override.unwrap_or(&self.config.llm.model);
419
420        // Build the claude command
421        let mut cmd = Command::new("claude");
422        cmd.arg("-p") // Print mode (headless)
423            .arg("--output-format")
424            .arg("json")
425            .arg("--model")
426            .arg(model)
427            .stdin(Stdio::piped())
428            .stdout(Stdio::piped())
429            .stderr(Stdio::piped());
430
431        // Spawn the process
432        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
433
434        // Write prompt to stdin
435        if let Some(mut stdin) = child.stdin.take() {
436            stdin
437                .write_all(prompt.as_bytes())
438                .await
439                .context("Failed to write prompt to claude stdin")?;
440            drop(stdin); // Close stdin
441        }
442
443        // Wait for completion
444        let output = child
445            .wait_with_output()
446            .await
447            .context("Failed to wait for claude command")?;
448
449        if !output.status.success() {
450            let stderr = String::from_utf8_lossy(&output.stderr);
451            anyhow::bail!("Claude CLI error: {}", stderr);
452        }
453
454        // Parse JSON output
455        let stdout =
456            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
457
458        #[derive(Deserialize)]
459        struct ClaudeCliResponse {
460            result: String,
461        }
462
463        let response: ClaudeCliResponse =
464            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
465
466        Ok(response.result)
467    }
468
469    async fn complete_codex_cli(
470        &self,
471        prompt: &str,
472        model_override: Option<&str>,
473    ) -> Result<String> {
474        use std::process::Stdio;
475        use tokio::io::AsyncWriteExt;
476        use tokio::process::Command;
477
478        let model = model_override.unwrap_or(&self.config.llm.model);
479
480        // Build the codex command
481        // Codex CLI uses similar headless mode to Claude Code
482        let mut cmd = Command::new("codex");
483        cmd.arg("-p") // Prompt mode (headless/non-interactive)
484            .arg("--model")
485            .arg(model)
486            .arg("--output-format")
487            .arg("json")
488            .stdin(Stdio::piped())
489            .stdout(Stdio::piped())
490            .stderr(Stdio::piped());
491
492        // Spawn the process
493        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
494
495        // Write prompt to stdin
496        if let Some(mut stdin) = child.stdin.take() {
497            stdin
498                .write_all(prompt.as_bytes())
499                .await
500                .context("Failed to write prompt to codex stdin")?;
501            drop(stdin); // Close stdin
502        }
503
504        // Wait for completion
505        let output = child
506            .wait_with_output()
507            .await
508            .context("Failed to wait for codex command")?;
509
510        if !output.status.success() {
511            let stderr = String::from_utf8_lossy(&output.stderr);
512            anyhow::bail!("Codex CLI error: {}", stderr);
513        }
514
515        // Parse JSON output
516        let stdout =
517            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
518
519        // Codex outputs JSON with a result field similar to Claude CLI
520        #[derive(Deserialize)]
521        struct CodexCliResponse {
522            result: String,
523        }
524
525        let response: CodexCliResponse =
526            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
527
528        Ok(response.result)
529    }
530}