Skip to main content

scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::llm::oauth;
8use crate::storage::Storage;
9
10// Anthropic API structures
11#[derive(Debug, Serialize)]
12struct AnthropicRequest {
13    model: String,
14    max_tokens: u32,
15    messages: Vec<AnthropicMessage>,
16}
17
18#[derive(Debug, Serialize)]
19struct AnthropicMessage {
20    role: String,
21    content: String,
22}
23
24#[derive(Debug, Deserialize)]
25struct AnthropicResponse {
26    content: Vec<AnthropicContent>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AnthropicContent {
31    text: String,
32}
33
34// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
35#[derive(Debug, Serialize)]
36struct OpenAIRequest {
37    model: String,
38    max_tokens: u32,
39    messages: Vec<OpenAIMessage>,
40}
41
42#[derive(Debug, Serialize)]
43struct OpenAIMessage {
44    role: String,
45    content: String,
46}
47
48#[derive(Debug, Deserialize)]
49struct OpenAIResponse {
50    choices: Vec<OpenAIChoice>,
51}
52
53#[derive(Debug, Deserialize)]
54struct OpenAIChoice {
55    message: OpenAIMessageResponse,
56}
57
58#[derive(Debug, Deserialize)]
59struct OpenAIMessageResponse {
60    content: String,
61}
62
63pub struct LLMClient {
64    config: Config,
65    client: reqwest::Client,
66}
67
68/// Info about which model is being used for display purposes
69#[derive(Debug, Clone)]
70pub struct ModelInfo {
71    pub tier: &'static str, // "fast" or "smart"
72    pub provider: String,
73    pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        // Avoid double-prefixing if model already contains provider prefix
79        let prefix = format!("{}/", self.provider);
80        if self.model.starts_with(&prefix) {
81            write!(f, "{} model: {}", self.tier, self.model)
82        } else {
83            write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84        }
85    }
86}
87
88impl LLMClient {
89    pub fn new() -> Result<Self> {
90        let storage = Storage::new(None);
91        let config = storage.load_config()?;
92        Ok(LLMClient {
93            config,
94            client: reqwest::Client::new(),
95        })
96    }
97
98    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
99        let storage = Storage::new(Some(project_root));
100        let config = storage.load_config()?;
101        Ok(LLMClient {
102            config,
103            client: reqwest::Client::new(),
104        })
105    }
106
107    /// Get info about the smart model that will be used
108    pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
109        ModelInfo {
110            tier: "smart",
111            provider: self.config.smart_provider().to_string(),
112            model: model_override
113                .unwrap_or(self.config.smart_model())
114                .to_string(),
115        }
116    }
117
118    /// Get info about the fast model that will be used
119    pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
120        ModelInfo {
121            tier: "fast",
122            provider: self.config.fast_provider().to_string(),
123            model: model_override
124                .unwrap_or(self.config.fast_model())
125                .to_string(),
126        }
127    }
128
129    pub async fn complete(&self, prompt: &str) -> Result<String> {
130        self.complete_with_model(prompt, None, None).await
131    }
132
133    /// Complete using the smart model (for validation/analysis tasks with large context)
134    /// Use user override if provided, otherwise fall back to configured smart_model
135    pub async fn complete_smart(
136        &self,
137        prompt: &str,
138        model_override: Option<&str>,
139    ) -> Result<String> {
140        let model = model_override.unwrap_or(self.config.smart_model());
141        let provider = self.config.smart_provider();
142        self.complete_with_model(prompt, Some(model), Some(provider))
143            .await
144    }
145
146    /// Complete using the fast model (for generation tasks)
147    /// Use user override if provided, otherwise fall back to configured fast_model
148    pub async fn complete_fast(
149        &self,
150        prompt: &str,
151        model_override: Option<&str>,
152    ) -> Result<String> {
153        let model = model_override.unwrap_or(self.config.fast_model());
154        let provider = self.config.fast_provider();
155        self.complete_with_model(prompt, Some(model), Some(provider))
156            .await
157    }
158
159    pub async fn complete_with_model(
160        &self,
161        prompt: &str,
162        model_override: Option<&str>,
163        provider_override: Option<&str>,
164    ) -> Result<String> {
165        let provider = provider_override.unwrap_or(&self.config.llm.provider);
166        match provider {
167            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
168            "codex" => self.complete_codex_cli(prompt, model_override).await,
169            "cursor" => self.complete_cursor_cli(prompt, model_override).await,
170            "anthropic" => {
171                self.complete_anthropic_api_key(prompt, model_override)
172                    .await
173            }
174            "anthropic-oauth" => self.complete_anthropic_oauth(prompt, model_override).await,
175            "xai" | "openai" | "openrouter" => {
176                self.complete_openai_compatible_with_model(prompt, model_override, provider)
177                    .await
178            }
179            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
180        }
181    }
182
183    /// Anthropic API with standard API key (ANTHROPIC_API_KEY)
184    async fn complete_anthropic_api_key(
185        &self,
186        prompt: &str,
187        model_override: Option<&str>,
188    ) -> Result<String> {
189        let model = model_override.unwrap_or(&self.config.llm.model);
190        let api_key = env::var("ANTHROPIC_API_KEY")
191            .context("ANTHROPIC_API_KEY environment variable not set")?;
192
193        let request = AnthropicRequest {
194            model: model.to_string(),
195            max_tokens: self.config.llm.max_tokens,
196            messages: vec![AnthropicMessage {
197                role: "user".to_string(),
198                content: prompt.to_string(),
199            }],
200        };
201
202        let response = self
203            .client
204            .post("https://api.anthropic.com/v1/messages")
205            .header("x-api-key", &api_key)
206            .header("anthropic-version", "2023-06-01")
207            .header("content-type", "application/json")
208            .json(&request)
209            .send()
210            .await
211            .context("Failed to send request to Anthropic API")?;
212
213        if !response.status().is_success() {
214            let status = response.status();
215            let error_text = response.text().await.unwrap_or_default();
216            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
217        }
218
219        let api_response: AnthropicResponse = response
220            .json()
221            .await
222            .context("Failed to parse Anthropic API response")?;
223
224        Ok(api_response
225            .content
226            .first()
227            .map(|c| c.text.clone())
228            .unwrap_or_default())
229    }
230
231    /// Anthropic API with Claude Code OAuth token from macOS Keychain
232    async fn complete_anthropic_oauth(
233        &self,
234        prompt: &str,
235        model_override: Option<&str>,
236    ) -> Result<String> {
237        let model = model_override.unwrap_or(&self.config.llm.model);
238        let creds = oauth::read_claude_oauth()?.context(
239            "No Claude Code OAuth credentials found in Keychain. Log in with `claude` CLI first.",
240        )?;
241
242        if !oauth::is_token_valid(&creds) {
243            anyhow::bail!("Claude Code OAuth token expired. Re-login with `claude` CLI.");
244        }
245
246        let request = AnthropicRequest {
247            model: model.to_string(),
248            max_tokens: self.config.llm.max_tokens,
249            messages: vec![AnthropicMessage {
250                role: "user".to_string(),
251                content: prompt.to_string(),
252            }],
253        };
254
255        let response = self
256            .client
257            .post("https://api.anthropic.com/v1/messages")
258            .header("authorization", format!("Bearer {}", creds.access_token))
259            .header("anthropic-version", "2023-06-01")
260            .header("anthropic-beta", "oauth-2025-04-20")
261            .header("content-type", "application/json")
262            .header("user-agent", "SCUD-CLI/1.0")
263            .json(&request)
264            .send()
265            .await
266            .context("Failed to send request to Anthropic API")?;
267
268        if !response.status().is_success() {
269            let status = response.status();
270            let error_text = response.text().await.unwrap_or_default();
271            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
272        }
273
274        let api_response: AnthropicResponse = response
275            .json()
276            .await
277            .context("Failed to parse Anthropic API response")?;
278
279        Ok(api_response
280            .content
281            .first()
282            .map(|c| c.text.clone())
283            .unwrap_or_default())
284    }
285
286    async fn complete_openai_compatible_with_model(
287        &self,
288        prompt: &str,
289        model_override: Option<&str>,
290        provider: &str,
291    ) -> Result<String> {
292        let model = model_override.unwrap_or(&self.config.llm.model);
293        // Strip provider prefix for native APIs (xai/, openai/)
294        // OpenRouter needs the prefix, native APIs don't
295        let model_for_api = if provider != "openrouter" {
296            let prefix = format!("{}/", provider);
297            model.strip_prefix(&prefix).unwrap_or(model)
298        } else {
299            model
300        };
301
302        // Get the correct endpoint for this provider
303        let endpoint = match provider {
304            "xai" => "https://api.x.ai/v1/chat/completions",
305            "openai" => "https://api.openai.com/v1/chat/completions",
306            "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
307            _ => "https://api.x.ai/v1/chat/completions",
308        };
309
310        // Resolve API key for this specific provider
311        let env_var = Config::api_key_env_var_for_provider(provider);
312        let api_key = env::var(env_var)
313            .with_context(|| format!("{} environment variable not set", env_var))?;
314
315        let request = OpenAIRequest {
316            model: model_for_api.to_string(),
317            max_tokens: self.config.llm.max_tokens,
318            messages: vec![OpenAIMessage {
319                role: "user".to_string(),
320                content: prompt.to_string(),
321            }],
322        };
323
324        let mut request_builder = self
325            .client
326            .post(endpoint)
327            .header("authorization", format!("Bearer {}", api_key))
328            .header("content-type", "application/json");
329
330        // OpenRouter requires additional headers
331        if provider == "openrouter" {
332            request_builder = request_builder
333                .header("HTTP-Referer", "https://github.com/scud-cli")
334                .header("X-Title", "SCUD Task Master");
335        }
336
337        let response = request_builder
338            .json(&request)
339            .send()
340            .await
341            .with_context(|| format!("Failed to send request to {} API", provider))?;
342
343        if !response.status().is_success() {
344            let status = response.status();
345            let error_text = response.text().await.unwrap_or_default();
346            anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
347        }
348
349        let api_response: OpenAIResponse = response
350            .json()
351            .await
352            .with_context(|| format!("Failed to parse {} API response", provider))?;
353
354        Ok(api_response
355            .choices
356            .first()
357            .map(|c| c.message.content.clone())
358            .unwrap_or_default())
359    }
360
361    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
362    where
363        T: serde::de::DeserializeOwned,
364    {
365        self.complete_json_with_model(prompt, None).await
366    }
367
368    /// Complete JSON using the smart model (for validation/analysis tasks)
369    pub async fn complete_json_smart<T>(
370        &self,
371        prompt: &str,
372        model_override: Option<&str>,
373    ) -> Result<T>
374    where
375        T: serde::de::DeserializeOwned,
376    {
377        let response_text = self.complete_smart(prompt, model_override).await?;
378        Self::parse_json_response(&response_text)
379    }
380
381    /// Complete JSON using the fast model (for generation tasks)
382    pub async fn complete_json_fast<T>(
383        &self,
384        prompt: &str,
385        model_override: Option<&str>,
386    ) -> Result<T>
387    where
388        T: serde::de::DeserializeOwned,
389    {
390        let response_text = self.complete_fast(prompt, model_override).await?;
391        Self::parse_json_response(&response_text)
392    }
393
394    pub async fn complete_json_with_model<T>(
395        &self,
396        prompt: &str,
397        model_override: Option<&str>,
398    ) -> Result<T>
399    where
400        T: serde::de::DeserializeOwned,
401    {
402        let response_text = self
403            .complete_with_model(prompt, model_override, None)
404            .await?;
405        Self::parse_json_response(&response_text)
406    }
407
408    fn parse_json_response<T>(response_text: &str) -> Result<T>
409    where
410        T: serde::de::DeserializeOwned,
411    {
412        // Try to find JSON in the response (LLM might include markdown or explanations)
413        let json_str = Self::extract_json(response_text);
414
415        serde_json::from_str(json_str).with_context(|| {
416            // Provide helpful error context
417            let preview = if json_str.len() > 500 {
418                format!("{}...", &json_str[..500])
419            } else {
420                json_str.to_string()
421            };
422            format!(
423                "Failed to parse JSON from LLM response. Response preview:\n{}",
424                preview
425            )
426        })
427    }
428
429    /// Extract JSON from LLM response, handling markdown code blocks and extra text
430    fn extract_json(response: &str) -> &str {
431        // First, try to extract from markdown code blocks
432        if let Some(start) = response.find("```json") {
433            let content_start = start + 7; // Skip "```json"
434            if let Some(end) = response[content_start..].find("```") {
435                return response[content_start..content_start + end].trim();
436            }
437        }
438
439        // Try plain code blocks
440        if let Some(start) = response.find("```") {
441            let content_start = start + 3;
442            // Skip language identifier if present (e.g., "```\n")
443            let content_start = response[content_start..]
444                .find('\n')
445                .map(|i| content_start + i + 1)
446                .unwrap_or(content_start);
447            if let Some(end) = response[content_start..].find("```") {
448                return response[content_start..content_start + end].trim();
449            }
450        }
451
452        // Try to find array JSON
453        if let Some(start) = response.find('[') {
454            if let Some(end) = response.rfind(']') {
455                if end > start {
456                    return &response[start..=end];
457                }
458            }
459        }
460
461        // Try to find object JSON
462        if let Some(start) = response.find('{') {
463            if let Some(end) = response.rfind('}') {
464                if end > start {
465                    return &response[start..=end];
466                }
467            }
468        }
469
470        response.trim()
471    }
472
473    async fn complete_claude_cli(
474        &self,
475        prompt: &str,
476        model_override: Option<&str>,
477    ) -> Result<String> {
478        use std::process::Stdio;
479        use tokio::io::AsyncWriteExt;
480        use tokio::process::Command;
481
482        let model = model_override.unwrap_or(&self.config.llm.model);
483
484        // Build the claude command
485        let mut cmd = Command::new("claude");
486        cmd.arg("-p") // Print mode (headless)
487            .arg("--output-format")
488            .arg("json")
489            .arg("--model")
490            .arg(model)
491            .stdin(Stdio::piped())
492            .stdout(Stdio::piped())
493            .stderr(Stdio::piped());
494
495        // Spawn the process
496        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
497
498        // Write prompt to stdin
499        if let Some(mut stdin) = child.stdin.take() {
500            stdin
501                .write_all(prompt.as_bytes())
502                .await
503                .context("Failed to write prompt to claude stdin")?;
504            drop(stdin); // Close stdin
505        }
506
507        // Wait for completion
508        let output = child
509            .wait_with_output()
510            .await
511            .context("Failed to wait for claude command")?;
512
513        if !output.status.success() {
514            let stderr = String::from_utf8_lossy(&output.stderr);
515            anyhow::bail!("Claude CLI error: {}", stderr);
516        }
517
518        // Parse JSON output
519        let stdout =
520            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
521
522        #[derive(Deserialize)]
523        struct ClaudeCliResponse {
524            result: String,
525        }
526
527        let response: ClaudeCliResponse =
528            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
529
530        Ok(response.result)
531    }
532
533    async fn complete_codex_cli(
534        &self,
535        prompt: &str,
536        model_override: Option<&str>,
537    ) -> Result<String> {
538        use std::process::Stdio;
539        use tokio::io::AsyncWriteExt;
540        use tokio::process::Command;
541
542        let model = model_override.unwrap_or(&self.config.llm.model);
543
544        // Build the codex command
545        // Codex CLI uses similar headless mode to Claude Code
546        let mut cmd = Command::new("codex");
547        cmd.arg("-p") // Prompt mode (headless/non-interactive)
548            .arg("--model")
549            .arg(model)
550            .arg("--output-format")
551            .arg("json")
552            .stdin(Stdio::piped())
553            .stdout(Stdio::piped())
554            .stderr(Stdio::piped());
555
556        // Spawn the process
557        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
558
559        // Write prompt to stdin
560        if let Some(mut stdin) = child.stdin.take() {
561            stdin
562                .write_all(prompt.as_bytes())
563                .await
564                .context("Failed to write prompt to codex stdin")?;
565            drop(stdin); // Close stdin
566        }
567
568        // Wait for completion
569        let output = child
570            .wait_with_output()
571            .await
572            .context("Failed to wait for codex command")?;
573
574        if !output.status.success() {
575            let stderr = String::from_utf8_lossy(&output.stderr);
576            anyhow::bail!("Codex CLI error: {}", stderr);
577        }
578
579        // Parse JSON output
580        let stdout =
581            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
582
583        // Codex outputs JSON with a result field similar to Claude CLI
584        #[derive(Deserialize)]
585        struct CodexCliResponse {
586            result: String,
587        }
588
589        let response: CodexCliResponse =
590            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
591
592        Ok(response.result)
593    }
594
595    async fn complete_cursor_cli(
596        &self,
597        prompt: &str,
598        model_override: Option<&str>,
599    ) -> Result<String> {
600        use std::process::Stdio;
601        use tokio::io::AsyncWriteExt;
602        use tokio::process::Command;
603
604        let model = model_override.unwrap_or(&self.config.llm.model);
605
606        // Build the cursor agent command
607        let mut cmd = Command::new("agent");
608        cmd.arg("-p") // Print mode (headless/non-interactive)
609            .arg("--model")
610            .arg(model)
611            .arg("--output-format")
612            .arg("json")
613            .stdin(Stdio::piped())
614            .stdout(Stdio::piped())
615            .stderr(Stdio::piped());
616
617        // Spawn the process
618        let mut child = cmd.spawn().context("Failed to spawn 'agent' command. Make sure Cursor Agent CLI is installed (curl https://cursor.com/install -fsSL | bash)")?;
619
620        // Write prompt to stdin
621        if let Some(mut stdin) = child.stdin.take() {
622            stdin
623                .write_all(prompt.as_bytes())
624                .await
625                .context("Failed to write prompt to cursor agent stdin")?;
626            drop(stdin); // Close stdin
627        }
628
629        // Wait for completion
630        let output = child
631            .wait_with_output()
632            .await
633            .context("Failed to wait for cursor agent command")?;
634
635        if !output.status.success() {
636            let stderr = String::from_utf8_lossy(&output.stderr);
637            anyhow::bail!("Cursor Agent CLI error: {}", stderr);
638        }
639
640        // Parse output - try JSON first, fall back to plain text
641        let stdout = String::from_utf8(output.stdout)
642            .context("Cursor Agent CLI output is not valid UTF-8")?;
643
644        #[derive(Deserialize)]
645        struct CursorCliResponse {
646            result: String,
647        }
648
649        // Try JSON parse first
650        if let Ok(response) = serde_json::from_str::<CursorCliResponse>(&stdout) {
651            return Ok(response.result);
652        }
653
654        // Fall back to raw text output
655        Ok(stdout.trim().to_string())
656    }
657}