Skip to main content

scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::llm::oauth;
8use crate::storage::Storage;
9
10// Anthropic API structures
11#[derive(Debug, Serialize)]
12struct AnthropicRequest {
13    model: String,
14    max_tokens: u32,
15    messages: Vec<AnthropicMessage>,
16}
17
18#[derive(Debug, Serialize)]
19struct AnthropicMessage {
20    role: String,
21    content: String,
22}
23
24#[derive(Debug, Deserialize)]
25struct AnthropicResponse {
26    content: Vec<AnthropicContent>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AnthropicContent {
31    text: String,
32}
33
34// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
35#[derive(Debug, Serialize)]
36struct OpenAIRequest {
37    model: String,
38    max_tokens: u32,
39    messages: Vec<OpenAIMessage>,
40}
41
42#[derive(Debug, Serialize)]
43struct OpenAIMessage {
44    role: String,
45    content: String,
46}
47
48#[derive(Debug, Deserialize)]
49struct OpenAIResponse {
50    choices: Vec<OpenAIChoice>,
51}
52
53#[derive(Debug, Deserialize)]
54struct OpenAIChoice {
55    message: OpenAIMessageResponse,
56}
57
58#[derive(Debug, Deserialize)]
59struct OpenAIMessageResponse {
60    content: String,
61}
62
63pub struct LLMClient {
64    config: Config,
65    client: reqwest::Client,
66}
67
68/// Info about which model is being used for display purposes
69#[derive(Debug, Clone)]
70pub struct ModelInfo {
71    pub tier: &'static str, // "fast" or "smart"
72    pub provider: String,
73    pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        // Avoid double-prefixing if model already contains provider prefix
79        let prefix = format!("{}/", self.provider);
80        if self.model.starts_with(&prefix) {
81            write!(f, "{} model: {}", self.tier, self.model)
82        } else {
83            write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84        }
85    }
86}
87
88impl LLMClient {
89    pub fn new() -> Result<Self> {
90        let storage = Storage::new(None);
91        let config = storage.load_config()?;
92        Ok(LLMClient {
93            config,
94            client: reqwest::Client::new(),
95        })
96    }
97
98    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
99        let storage = Storage::new(Some(project_root));
100        let config = storage.load_config()?;
101        Ok(LLMClient {
102            config,
103            client: reqwest::Client::new(),
104        })
105    }
106
107    /// Get info about the smart model that will be used
108    pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
109        ModelInfo {
110            tier: "smart",
111            provider: self.config.smart_provider().to_string(),
112            model: model_override
113                .unwrap_or(self.config.smart_model())
114                .to_string(),
115        }
116    }
117
118    /// Get info about the fast model that will be used
119    pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
120        ModelInfo {
121            tier: "fast",
122            provider: self.config.fast_provider().to_string(),
123            model: model_override
124                .unwrap_or(self.config.fast_model())
125                .to_string(),
126        }
127    }
128
129    pub async fn complete(&self, prompt: &str) -> Result<String> {
130        self.complete_with_model(prompt, None, None).await
131    }
132
133    /// Complete using the smart model (for validation/analysis tasks with large context)
134    /// Use user override if provided, otherwise fall back to configured smart_model
135    pub async fn complete_smart(
136        &self,
137        prompt: &str,
138        model_override: Option<&str>,
139    ) -> Result<String> {
140        let model = model_override.unwrap_or(self.config.smart_model());
141        let provider = self.config.smart_provider();
142        self.complete_with_model(prompt, Some(model), Some(provider))
143            .await
144    }
145
146    /// Complete using the fast model (for generation tasks)
147    /// Use user override if provided, otherwise fall back to configured fast_model
148    pub async fn complete_fast(
149        &self,
150        prompt: &str,
151        model_override: Option<&str>,
152    ) -> Result<String> {
153        let model = model_override.unwrap_or(self.config.fast_model());
154        let provider = self.config.fast_provider();
155        self.complete_with_model(prompt, Some(model), Some(provider))
156            .await
157    }
158
159    pub async fn complete_with_model(
160        &self,
161        prompt: &str,
162        model_override: Option<&str>,
163        provider_override: Option<&str>,
164    ) -> Result<String> {
165        let provider = provider_override.unwrap_or(&self.config.llm.provider);
166        match provider {
167            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
168            "codex" => self.complete_codex_cli(prompt, model_override).await,
169            "cursor" => self.complete_cursor_cli(prompt, model_override).await,
170            "anthropic" => {
171                self.complete_anthropic_api_key(prompt, model_override)
172                    .await
173            }
174            "anthropic-oauth" => {
175                self.complete_anthropic_oauth(prompt, model_override)
176                    .await
177            }
178            "xai" | "openai" | "openrouter" => {
179                self.complete_openai_compatible_with_model(prompt, model_override, provider)
180                    .await
181            }
182            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
183        }
184    }
185
186    /// Anthropic API with standard API key (ANTHROPIC_API_KEY)
187    async fn complete_anthropic_api_key(
188        &self,
189        prompt: &str,
190        model_override: Option<&str>,
191    ) -> Result<String> {
192        let model = model_override.unwrap_or(&self.config.llm.model);
193        let api_key = env::var("ANTHROPIC_API_KEY")
194            .context("ANTHROPIC_API_KEY environment variable not set")?;
195
196        let request = AnthropicRequest {
197            model: model.to_string(),
198            max_tokens: self.config.llm.max_tokens,
199            messages: vec![AnthropicMessage {
200                role: "user".to_string(),
201                content: prompt.to_string(),
202            }],
203        };
204
205        let response = self
206            .client
207            .post("https://api.anthropic.com/v1/messages")
208            .header("x-api-key", &api_key)
209            .header("anthropic-version", "2023-06-01")
210            .header("content-type", "application/json")
211            .json(&request)
212            .send()
213            .await
214            .context("Failed to send request to Anthropic API")?;
215
216        if !response.status().is_success() {
217            let status = response.status();
218            let error_text = response.text().await.unwrap_or_default();
219            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
220        }
221
222        let api_response: AnthropicResponse = response
223            .json()
224            .await
225            .context("Failed to parse Anthropic API response")?;
226
227        Ok(api_response
228            .content
229            .first()
230            .map(|c| c.text.clone())
231            .unwrap_or_default())
232    }
233
234    /// Anthropic API with Claude Code OAuth token from macOS Keychain
235    async fn complete_anthropic_oauth(
236        &self,
237        prompt: &str,
238        model_override: Option<&str>,
239    ) -> Result<String> {
240        let model = model_override.unwrap_or(&self.config.llm.model);
241        let creds = oauth::read_claude_oauth()?
242            .context("No Claude Code OAuth credentials found in Keychain. Log in with `claude` CLI first.")?;
243
244        if !oauth::is_token_valid(&creds) {
245            anyhow::bail!("Claude Code OAuth token expired. Re-login with `claude` CLI.");
246        }
247
248        let request = AnthropicRequest {
249            model: model.to_string(),
250            max_tokens: self.config.llm.max_tokens,
251            messages: vec![AnthropicMessage {
252                role: "user".to_string(),
253                content: prompt.to_string(),
254            }],
255        };
256
257        let response = self
258            .client
259            .post("https://api.anthropic.com/v1/messages")
260            .header("authorization", format!("Bearer {}", creds.access_token))
261            .header("anthropic-version", "2023-06-01")
262            .header("anthropic-beta", "oauth-2025-04-20")
263            .header("content-type", "application/json")
264            .header("user-agent", "SCUD-CLI/1.0")
265            .json(&request)
266            .send()
267            .await
268            .context("Failed to send request to Anthropic API")?;
269
270        if !response.status().is_success() {
271            let status = response.status();
272            let error_text = response.text().await.unwrap_or_default();
273            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
274        }
275
276        let api_response: AnthropicResponse = response
277            .json()
278            .await
279            .context("Failed to parse Anthropic API response")?;
280
281        Ok(api_response
282            .content
283            .first()
284            .map(|c| c.text.clone())
285            .unwrap_or_default())
286    }
287
288    async fn complete_openai_compatible_with_model(
289        &self,
290        prompt: &str,
291        model_override: Option<&str>,
292        provider: &str,
293    ) -> Result<String> {
294        let model = model_override.unwrap_or(&self.config.llm.model);
295        // Strip provider prefix for native APIs (xai/, openai/)
296        // OpenRouter needs the prefix, native APIs don't
297        let model_for_api = if provider != "openrouter" {
298            let prefix = format!("{}/", provider);
299            model.strip_prefix(&prefix).unwrap_or(model)
300        } else {
301            model
302        };
303
304        // Get the correct endpoint for this provider
305        let endpoint = match provider {
306            "xai" => "https://api.x.ai/v1/chat/completions",
307            "openai" => "https://api.openai.com/v1/chat/completions",
308            "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
309            _ => "https://api.x.ai/v1/chat/completions",
310        };
311
312        // Resolve API key for this specific provider
313        let env_var = Config::api_key_env_var_for_provider(provider);
314        let api_key = env::var(env_var)
315            .with_context(|| format!("{} environment variable not set", env_var))?;
316
317        let request = OpenAIRequest {
318            model: model_for_api.to_string(),
319            max_tokens: self.config.llm.max_tokens,
320            messages: vec![OpenAIMessage {
321                role: "user".to_string(),
322                content: prompt.to_string(),
323            }],
324        };
325
326        let mut request_builder = self
327            .client
328            .post(endpoint)
329            .header("authorization", format!("Bearer {}", api_key))
330            .header("content-type", "application/json");
331
332        // OpenRouter requires additional headers
333        if provider == "openrouter" {
334            request_builder = request_builder
335                .header("HTTP-Referer", "https://github.com/scud-cli")
336                .header("X-Title", "SCUD Task Master");
337        }
338
339        let response = request_builder
340            .json(&request)
341            .send()
342            .await
343            .with_context(|| format!("Failed to send request to {} API", provider))?;
344
345        if !response.status().is_success() {
346            let status = response.status();
347            let error_text = response.text().await.unwrap_or_default();
348            anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
349        }
350
351        let api_response: OpenAIResponse = response
352            .json()
353            .await
354            .with_context(|| format!("Failed to parse {} API response", provider))?;
355
356        Ok(api_response
357            .choices
358            .first()
359            .map(|c| c.message.content.clone())
360            .unwrap_or_default())
361    }
362
363    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
364    where
365        T: serde::de::DeserializeOwned,
366    {
367        self.complete_json_with_model(prompt, None).await
368    }
369
370    /// Complete JSON using the smart model (for validation/analysis tasks)
371    pub async fn complete_json_smart<T>(
372        &self,
373        prompt: &str,
374        model_override: Option<&str>,
375    ) -> Result<T>
376    where
377        T: serde::de::DeserializeOwned,
378    {
379        let response_text = self.complete_smart(prompt, model_override).await?;
380        Self::parse_json_response(&response_text)
381    }
382
383    /// Complete JSON using the fast model (for generation tasks)
384    pub async fn complete_json_fast<T>(
385        &self,
386        prompt: &str,
387        model_override: Option<&str>,
388    ) -> Result<T>
389    where
390        T: serde::de::DeserializeOwned,
391    {
392        let response_text = self.complete_fast(prompt, model_override).await?;
393        Self::parse_json_response(&response_text)
394    }
395
396    pub async fn complete_json_with_model<T>(
397        &self,
398        prompt: &str,
399        model_override: Option<&str>,
400    ) -> Result<T>
401    where
402        T: serde::de::DeserializeOwned,
403    {
404        let response_text = self
405            .complete_with_model(prompt, model_override, None)
406            .await?;
407        Self::parse_json_response(&response_text)
408    }
409
410    fn parse_json_response<T>(response_text: &str) -> Result<T>
411    where
412        T: serde::de::DeserializeOwned,
413    {
414        // Try to find JSON in the response (LLM might include markdown or explanations)
415        let json_str = Self::extract_json(response_text);
416
417        serde_json::from_str(json_str).with_context(|| {
418            // Provide helpful error context
419            let preview = if json_str.len() > 500 {
420                format!("{}...", &json_str[..500])
421            } else {
422                json_str.to_string()
423            };
424            format!(
425                "Failed to parse JSON from LLM response. Response preview:\n{}",
426                preview
427            )
428        })
429    }
430
431    /// Extract JSON from LLM response, handling markdown code blocks and extra text
432    fn extract_json(response: &str) -> &str {
433        // First, try to extract from markdown code blocks
434        if let Some(start) = response.find("```json") {
435            let content_start = start + 7; // Skip "```json"
436            if let Some(end) = response[content_start..].find("```") {
437                return response[content_start..content_start + end].trim();
438            }
439        }
440
441        // Try plain code blocks
442        if let Some(start) = response.find("```") {
443            let content_start = start + 3;
444            // Skip language identifier if present (e.g., "```\n")
445            let content_start = response[content_start..]
446                .find('\n')
447                .map(|i| content_start + i + 1)
448                .unwrap_or(content_start);
449            if let Some(end) = response[content_start..].find("```") {
450                return response[content_start..content_start + end].trim();
451            }
452        }
453
454        // Try to find array JSON
455        if let Some(start) = response.find('[') {
456            if let Some(end) = response.rfind(']') {
457                if end > start {
458                    return &response[start..=end];
459                }
460            }
461        }
462
463        // Try to find object JSON
464        if let Some(start) = response.find('{') {
465            if let Some(end) = response.rfind('}') {
466                if end > start {
467                    return &response[start..=end];
468                }
469            }
470        }
471
472        response.trim()
473    }
474
475    async fn complete_claude_cli(
476        &self,
477        prompt: &str,
478        model_override: Option<&str>,
479    ) -> Result<String> {
480        use std::process::Stdio;
481        use tokio::io::AsyncWriteExt;
482        use tokio::process::Command;
483
484        let model = model_override.unwrap_or(&self.config.llm.model);
485
486        // Build the claude command
487        let mut cmd = Command::new("claude");
488        cmd.arg("-p") // Print mode (headless)
489            .arg("--output-format")
490            .arg("json")
491            .arg("--model")
492            .arg(model)
493            .stdin(Stdio::piped())
494            .stdout(Stdio::piped())
495            .stderr(Stdio::piped());
496
497        // Spawn the process
498        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
499
500        // Write prompt to stdin
501        if let Some(mut stdin) = child.stdin.take() {
502            stdin
503                .write_all(prompt.as_bytes())
504                .await
505                .context("Failed to write prompt to claude stdin")?;
506            drop(stdin); // Close stdin
507        }
508
509        // Wait for completion
510        let output = child
511            .wait_with_output()
512            .await
513            .context("Failed to wait for claude command")?;
514
515        if !output.status.success() {
516            let stderr = String::from_utf8_lossy(&output.stderr);
517            anyhow::bail!("Claude CLI error: {}", stderr);
518        }
519
520        // Parse JSON output
521        let stdout =
522            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
523
524        #[derive(Deserialize)]
525        struct ClaudeCliResponse {
526            result: String,
527        }
528
529        let response: ClaudeCliResponse =
530            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
531
532        Ok(response.result)
533    }
534
535    async fn complete_codex_cli(
536        &self,
537        prompt: &str,
538        model_override: Option<&str>,
539    ) -> Result<String> {
540        use std::process::Stdio;
541        use tokio::io::AsyncWriteExt;
542        use tokio::process::Command;
543
544        let model = model_override.unwrap_or(&self.config.llm.model);
545
546        // Build the codex command
547        // Codex CLI uses similar headless mode to Claude Code
548        let mut cmd = Command::new("codex");
549        cmd.arg("-p") // Prompt mode (headless/non-interactive)
550            .arg("--model")
551            .arg(model)
552            .arg("--output-format")
553            .arg("json")
554            .stdin(Stdio::piped())
555            .stdout(Stdio::piped())
556            .stderr(Stdio::piped());
557
558        // Spawn the process
559        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
560
561        // Write prompt to stdin
562        if let Some(mut stdin) = child.stdin.take() {
563            stdin
564                .write_all(prompt.as_bytes())
565                .await
566                .context("Failed to write prompt to codex stdin")?;
567            drop(stdin); // Close stdin
568        }
569
570        // Wait for completion
571        let output = child
572            .wait_with_output()
573            .await
574            .context("Failed to wait for codex command")?;
575
576        if !output.status.success() {
577            let stderr = String::from_utf8_lossy(&output.stderr);
578            anyhow::bail!("Codex CLI error: {}", stderr);
579        }
580
581        // Parse JSON output
582        let stdout =
583            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
584
585        // Codex outputs JSON with a result field similar to Claude CLI
586        #[derive(Deserialize)]
587        struct CodexCliResponse {
588            result: String,
589        }
590
591        let response: CodexCliResponse =
592            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
593
594        Ok(response.result)
595    }
596
597    async fn complete_cursor_cli(
598        &self,
599        prompt: &str,
600        model_override: Option<&str>,
601    ) -> Result<String> {
602        use std::process::Stdio;
603        use tokio::io::AsyncWriteExt;
604        use tokio::process::Command;
605
606        let model = model_override.unwrap_or(&self.config.llm.model);
607
608        // Build the cursor agent command
609        let mut cmd = Command::new("agent");
610        cmd.arg("-p") // Print mode (headless/non-interactive)
611            .arg("--model")
612            .arg(model)
613            .arg("--output-format")
614            .arg("json")
615            .stdin(Stdio::piped())
616            .stdout(Stdio::piped())
617            .stderr(Stdio::piped());
618
619        // Spawn the process
620        let mut child = cmd.spawn().context("Failed to spawn 'agent' command. Make sure Cursor Agent CLI is installed (curl https://cursor.com/install -fsSL | bash)")?;
621
622        // Write prompt to stdin
623        if let Some(mut stdin) = child.stdin.take() {
624            stdin
625                .write_all(prompt.as_bytes())
626                .await
627                .context("Failed to write prompt to cursor agent stdin")?;
628            drop(stdin); // Close stdin
629        }
630
631        // Wait for completion
632        let output = child
633            .wait_with_output()
634            .await
635            .context("Failed to wait for cursor agent command")?;
636
637        if !output.status.success() {
638            let stderr = String::from_utf8_lossy(&output.stderr);
639            anyhow::bail!("Cursor Agent CLI error: {}", stderr);
640        }
641
642        // Parse output - try JSON first, fall back to plain text
643        let stdout = String::from_utf8(output.stdout)
644            .context("Cursor Agent CLI output is not valid UTF-8")?;
645
646        #[derive(Deserialize)]
647        struct CursorCliResponse {
648            result: String,
649        }
650
651        // Try JSON parse first
652        if let Ok(response) = serde_json::from_str::<CursorCliResponse>(&stdout) {
653            return Ok(response.result);
654        }
655
656        // Fall back to raw text output
657        Ok(stdout.trim().to_string())
658    }
659}