scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62pub struct LLMClient {
63    config: Config,
64    api_key: String,
65    client: reqwest::Client,
66}
67
68/// Info about which model is being used for display purposes
69#[derive(Debug, Clone)]
70pub struct ModelInfo {
71    pub tier: &'static str, // "fast" or "smart"
72    pub provider: String,
73    pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        // Avoid double-prefixing if model already contains provider prefix
79        let prefix = format!("{}/", self.provider);
80        if self.model.starts_with(&prefix) {
81            write!(f, "{} model: {}", self.tier, self.model)
82        } else {
83            write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84        }
85    }
86}
87
88impl LLMClient {
89    pub fn new() -> Result<Self> {
90        let storage = Storage::new(None);
91        let config = storage.load_config()?;
92
93        let api_key = if config.requires_api_key() {
94            env::var(config.api_key_env_var()).with_context(|| {
95                format!("{} environment variable not set", config.api_key_env_var())
96            })?
97        } else {
98            String::new() // Claude CLI doesn't need API key
99        };
100
101        Ok(LLMClient {
102            config,
103            api_key,
104            client: reqwest::Client::new(),
105        })
106    }
107
108    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
109        let storage = Storage::new(Some(project_root));
110        let config = storage.load_config()?;
111
112        let api_key = if config.requires_api_key() {
113            env::var(config.api_key_env_var()).with_context(|| {
114                format!("{} environment variable not set", config.api_key_env_var())
115            })?
116        } else {
117            String::new() // Claude CLI doesn't need API key
118        };
119
120        Ok(LLMClient {
121            config,
122            api_key,
123            client: reqwest::Client::new(),
124        })
125    }
126
127    /// Get info about the smart model that will be used
128    pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
129        ModelInfo {
130            tier: "smart",
131            provider: self.config.smart_provider().to_string(),
132            model: model_override
133                .unwrap_or(self.config.smart_model())
134                .to_string(),
135        }
136    }
137
138    /// Get info about the fast model that will be used
139    pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
140        ModelInfo {
141            tier: "fast",
142            provider: self.config.fast_provider().to_string(),
143            model: model_override
144                .unwrap_or(self.config.fast_model())
145                .to_string(),
146        }
147    }
148
149    pub async fn complete(&self, prompt: &str) -> Result<String> {
150        self.complete_with_model(prompt, None, None).await
151    }
152
153    /// Complete using the smart model (for validation/analysis tasks with large context)
154    /// Use user override if provided, otherwise fall back to configured smart_model
155    pub async fn complete_smart(
156        &self,
157        prompt: &str,
158        model_override: Option<&str>,
159    ) -> Result<String> {
160        let model = model_override.unwrap_or(self.config.smart_model());
161        let provider = self.config.smart_provider();
162        self.complete_with_model(prompt, Some(model), Some(provider))
163            .await
164    }
165
166    /// Complete using the fast model (for generation tasks)
167    /// Use user override if provided, otherwise fall back to configured fast_model
168    pub async fn complete_fast(
169        &self,
170        prompt: &str,
171        model_override: Option<&str>,
172    ) -> Result<String> {
173        let model = model_override.unwrap_or(self.config.fast_model());
174        let provider = self.config.fast_provider();
175        self.complete_with_model(prompt, Some(model), Some(provider))
176            .await
177    }
178
179    pub async fn complete_with_model(
180        &self,
181        prompt: &str,
182        model_override: Option<&str>,
183        provider_override: Option<&str>,
184    ) -> Result<String> {
185        let provider = provider_override.unwrap_or(&self.config.llm.provider);
186        match provider {
187            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
188            "codex" => self.complete_codex_cli(prompt, model_override).await,
189            "anthropic" => {
190                self.complete_anthropic_with_model(prompt, model_override)
191                    .await
192            }
193            "xai" | "openai" | "openrouter" => {
194                self.complete_openai_compatible_with_model(prompt, model_override, provider)
195                    .await
196            }
197            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
198        }
199    }
200
201    async fn complete_anthropic_with_model(
202        &self,
203        prompt: &str,
204        model_override: Option<&str>,
205    ) -> Result<String> {
206        let model = model_override.unwrap_or(&self.config.llm.model);
207        let request = AnthropicRequest {
208            model: model.to_string(),
209            max_tokens: self.config.llm.max_tokens,
210            messages: vec![AnthropicMessage {
211                role: "user".to_string(),
212                content: prompt.to_string(),
213            }],
214        };
215
216        let response = self
217            .client
218            .post(self.config.api_endpoint())
219            .header("x-api-key", &self.api_key)
220            .header("anthropic-version", "2023-06-01")
221            .header("content-type", "application/json")
222            .json(&request)
223            .send()
224            .await
225            .context("Failed to send request to Anthropic API")?;
226
227        if !response.status().is_success() {
228            let status = response.status();
229            let error_text = response.text().await.unwrap_or_default();
230            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
231        }
232
233        let api_response: AnthropicResponse = response
234            .json()
235            .await
236            .context("Failed to parse Anthropic API response")?;
237
238        Ok(api_response
239            .content
240            .first()
241            .map(|c| c.text.clone())
242            .unwrap_or_default())
243    }
244
245    async fn complete_openai_compatible_with_model(
246        &self,
247        prompt: &str,
248        model_override: Option<&str>,
249        provider: &str,
250    ) -> Result<String> {
251        let model = model_override.unwrap_or(&self.config.llm.model);
252        // Strip provider prefix for native APIs (xai/, openai/)
253        // OpenRouter needs the prefix, native APIs don't
254        let model_for_api = if provider != "openrouter" {
255            let prefix = format!("{}/", provider);
256            model.strip_prefix(&prefix).unwrap_or(model)
257        } else {
258            model
259        };
260
261        // Get the correct endpoint for this provider
262        let endpoint = match provider {
263            "xai" => "https://api.x.ai/v1/chat/completions",
264            "openai" => "https://api.openai.com/v1/chat/completions",
265            "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
266            _ => "https://api.x.ai/v1/chat/completions",
267        };
268
269        let request = OpenAIRequest {
270            model: model_for_api.to_string(),
271            max_tokens: self.config.llm.max_tokens,
272            messages: vec![OpenAIMessage {
273                role: "user".to_string(),
274                content: prompt.to_string(),
275            }],
276        };
277
278        let mut request_builder = self
279            .client
280            .post(endpoint)
281            .header("authorization", format!("Bearer {}", self.api_key))
282            .header("content-type", "application/json");
283
284        // OpenRouter requires additional headers
285        if provider == "openrouter" {
286            request_builder = request_builder
287                .header("HTTP-Referer", "https://github.com/scud-cli")
288                .header("X-Title", "SCUD Task Master");
289        }
290
291        let response = request_builder
292            .json(&request)
293            .send()
294            .await
295            .with_context(|| format!("Failed to send request to {} API", provider))?;
296
297        if !response.status().is_success() {
298            let status = response.status();
299            let error_text = response.text().await.unwrap_or_default();
300            anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
301        }
302
303        let api_response: OpenAIResponse = response
304            .json()
305            .await
306            .with_context(|| format!("Failed to parse {} API response", provider))?;
307
308        Ok(api_response
309            .choices
310            .first()
311            .map(|c| c.message.content.clone())
312            .unwrap_or_default())
313    }
314
315    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
316    where
317        T: serde::de::DeserializeOwned,
318    {
319        self.complete_json_with_model(prompt, None).await
320    }
321
322    /// Complete JSON using the smart model (for validation/analysis tasks)
323    pub async fn complete_json_smart<T>(
324        &self,
325        prompt: &str,
326        model_override: Option<&str>,
327    ) -> Result<T>
328    where
329        T: serde::de::DeserializeOwned,
330    {
331        let response_text = self.complete_smart(prompt, model_override).await?;
332        Self::parse_json_response(&response_text)
333    }
334
335    /// Complete JSON using the fast model (for generation tasks)
336    pub async fn complete_json_fast<T>(
337        &self,
338        prompt: &str,
339        model_override: Option<&str>,
340    ) -> Result<T>
341    where
342        T: serde::de::DeserializeOwned,
343    {
344        let response_text = self.complete_fast(prompt, model_override).await?;
345        Self::parse_json_response(&response_text)
346    }
347
348    pub async fn complete_json_with_model<T>(
349        &self,
350        prompt: &str,
351        model_override: Option<&str>,
352    ) -> Result<T>
353    where
354        T: serde::de::DeserializeOwned,
355    {
356        let response_text = self
357            .complete_with_model(prompt, model_override, None)
358            .await?;
359        Self::parse_json_response(&response_text)
360    }
361
362    fn parse_json_response<T>(response_text: &str) -> Result<T>
363    where
364        T: serde::de::DeserializeOwned,
365    {
366        // Try to find JSON in the response (LLM might include markdown or explanations)
367        let json_str = Self::extract_json(response_text);
368
369        serde_json::from_str(json_str).with_context(|| {
370            // Provide helpful error context
371            let preview = if json_str.len() > 500 {
372                format!("{}...", &json_str[..500])
373            } else {
374                json_str.to_string()
375            };
376            format!(
377                "Failed to parse JSON from LLM response. Response preview:\n{}",
378                preview
379            )
380        })
381    }
382
383    /// Extract JSON from LLM response, handling markdown code blocks and extra text
384    fn extract_json(response: &str) -> &str {
385        // First, try to extract from markdown code blocks
386        if let Some(start) = response.find("```json") {
387            let content_start = start + 7; // Skip "```json"
388            if let Some(end) = response[content_start..].find("```") {
389                return response[content_start..content_start + end].trim();
390            }
391        }
392
393        // Try plain code blocks
394        if let Some(start) = response.find("```") {
395            let content_start = start + 3;
396            // Skip language identifier if present (e.g., "```\n")
397            let content_start = response[content_start..]
398                .find('\n')
399                .map(|i| content_start + i + 1)
400                .unwrap_or(content_start);
401            if let Some(end) = response[content_start..].find("```") {
402                return response[content_start..content_start + end].trim();
403            }
404        }
405
406        // Try to find array JSON
407        if let Some(start) = response.find('[') {
408            if let Some(end) = response.rfind(']') {
409                if end > start {
410                    return &response[start..=end];
411                }
412            }
413        }
414
415        // Try to find object JSON
416        if let Some(start) = response.find('{') {
417            if let Some(end) = response.rfind('}') {
418                if end > start {
419                    return &response[start..=end];
420                }
421            }
422        }
423
424        response.trim()
425    }
426
427    async fn complete_claude_cli(
428        &self,
429        prompt: &str,
430        model_override: Option<&str>,
431    ) -> Result<String> {
432        use std::process::Stdio;
433        use tokio::io::AsyncWriteExt;
434        use tokio::process::Command;
435
436        let model = model_override.unwrap_or(&self.config.llm.model);
437
438        // Build the claude command
439        let mut cmd = Command::new("claude");
440        cmd.arg("-p") // Print mode (headless)
441            .arg("--output-format")
442            .arg("json")
443            .arg("--model")
444            .arg(model)
445            .stdin(Stdio::piped())
446            .stdout(Stdio::piped())
447            .stderr(Stdio::piped());
448
449        // Spawn the process
450        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
451
452        // Write prompt to stdin
453        if let Some(mut stdin) = child.stdin.take() {
454            stdin
455                .write_all(prompt.as_bytes())
456                .await
457                .context("Failed to write prompt to claude stdin")?;
458            drop(stdin); // Close stdin
459        }
460
461        // Wait for completion
462        let output = child
463            .wait_with_output()
464            .await
465            .context("Failed to wait for claude command")?;
466
467        if !output.status.success() {
468            let stderr = String::from_utf8_lossy(&output.stderr);
469            anyhow::bail!("Claude CLI error: {}", stderr);
470        }
471
472        // Parse JSON output
473        let stdout =
474            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
475
476        #[derive(Deserialize)]
477        struct ClaudeCliResponse {
478            result: String,
479        }
480
481        let response: ClaudeCliResponse =
482            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
483
484        Ok(response.result)
485    }
486
487    async fn complete_codex_cli(
488        &self,
489        prompt: &str,
490        model_override: Option<&str>,
491    ) -> Result<String> {
492        use std::process::Stdio;
493        use tokio::io::AsyncWriteExt;
494        use tokio::process::Command;
495
496        let model = model_override.unwrap_or(&self.config.llm.model);
497
498        // Build the codex command
499        // Codex CLI uses similar headless mode to Claude Code
500        let mut cmd = Command::new("codex");
501        cmd.arg("-p") // Prompt mode (headless/non-interactive)
502            .arg("--model")
503            .arg(model)
504            .arg("--output-format")
505            .arg("json")
506            .stdin(Stdio::piped())
507            .stdout(Stdio::piped())
508            .stderr(Stdio::piped());
509
510        // Spawn the process
511        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
512
513        // Write prompt to stdin
514        if let Some(mut stdin) = child.stdin.take() {
515            stdin
516                .write_all(prompt.as_bytes())
517                .await
518                .context("Failed to write prompt to codex stdin")?;
519            drop(stdin); // Close stdin
520        }
521
522        // Wait for completion
523        let output = child
524            .wait_with_output()
525            .await
526            .context("Failed to wait for codex command")?;
527
528        if !output.status.success() {
529            let stderr = String::from_utf8_lossy(&output.stderr);
530            anyhow::bail!("Codex CLI error: {}", stderr);
531        }
532
533        // Parse JSON output
534        let stdout =
535            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
536
537        // Codex outputs JSON with a result field similar to Claude CLI
538        #[derive(Deserialize)]
539        struct CodexCliResponse {
540            result: String,
541        }
542
543        let response: CodexCliResponse =
544            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
545
546        Ok(response.result)
547    }
548}