Skip to main content

scud/llm/
client.rs

1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9// Anthropic API structures
10#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12    model: String,
13    max_tokens: u32,
14    messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25    content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30    text: String,
31}
32
33// OpenAI-compatible API structures (used by xAI, OpenAI, OpenRouter)
34#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36    model: String,
37    max_tokens: u32,
38    messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43    role: String,
44    content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49    choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54    message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59    content: String,
60}
61
62pub struct LLMClient {
63    config: Config,
64    api_key: String,
65    client: reqwest::Client,
66}
67
68/// Info about which model is being used for display purposes
69#[derive(Debug, Clone)]
70pub struct ModelInfo {
71    pub tier: &'static str, // "fast" or "smart"
72    pub provider: String,
73    pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        // Avoid double-prefixing if model already contains provider prefix
79        let prefix = format!("{}/", self.provider);
80        if self.model.starts_with(&prefix) {
81            write!(f, "{} model: {}", self.tier, self.model)
82        } else {
83            write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84        }
85    }
86}
87
88impl LLMClient {
89    pub fn new() -> Result<Self> {
90        let storage = Storage::new(None);
91        let config = storage.load_config()?;
92
93        let api_key = if config.requires_api_key() {
94            env::var(config.api_key_env_var()).with_context(|| {
95                format!("{} environment variable not set", config.api_key_env_var())
96            })?
97        } else {
98            String::new() // Claude CLI doesn't need API key
99        };
100
101        Ok(LLMClient {
102            config,
103            api_key,
104            client: reqwest::Client::new(),
105        })
106    }
107
108    pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
109        let storage = Storage::new(Some(project_root));
110        let config = storage.load_config()?;
111
112        let api_key = if config.requires_api_key() {
113            env::var(config.api_key_env_var()).with_context(|| {
114                format!("{} environment variable not set", config.api_key_env_var())
115            })?
116        } else {
117            String::new() // Claude CLI doesn't need API key
118        };
119
120        Ok(LLMClient {
121            config,
122            api_key,
123            client: reqwest::Client::new(),
124        })
125    }
126
127    /// Get info about the smart model that will be used
128    pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
129        ModelInfo {
130            tier: "smart",
131            provider: self.config.smart_provider().to_string(),
132            model: model_override
133                .unwrap_or(self.config.smart_model())
134                .to_string(),
135        }
136    }
137
138    /// Get info about the fast model that will be used
139    pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
140        ModelInfo {
141            tier: "fast",
142            provider: self.config.fast_provider().to_string(),
143            model: model_override
144                .unwrap_or(self.config.fast_model())
145                .to_string(),
146        }
147    }
148
149    pub async fn complete(&self, prompt: &str) -> Result<String> {
150        self.complete_with_model(prompt, None, None).await
151    }
152
153    /// Complete using the smart model (for validation/analysis tasks with large context)
154    /// Use user override if provided, otherwise fall back to configured smart_model
155    pub async fn complete_smart(
156        &self,
157        prompt: &str,
158        model_override: Option<&str>,
159    ) -> Result<String> {
160        let model = model_override.unwrap_or(self.config.smart_model());
161        let provider = self.config.smart_provider();
162        self.complete_with_model(prompt, Some(model), Some(provider))
163            .await
164    }
165
166    /// Complete using the fast model (for generation tasks)
167    /// Use user override if provided, otherwise fall back to configured fast_model
168    pub async fn complete_fast(
169        &self,
170        prompt: &str,
171        model_override: Option<&str>,
172    ) -> Result<String> {
173        let model = model_override.unwrap_or(self.config.fast_model());
174        let provider = self.config.fast_provider();
175        self.complete_with_model(prompt, Some(model), Some(provider))
176            .await
177    }
178
179    pub async fn complete_with_model(
180        &self,
181        prompt: &str,
182        model_override: Option<&str>,
183        provider_override: Option<&str>,
184    ) -> Result<String> {
185        let provider = provider_override.unwrap_or(&self.config.llm.provider);
186        match provider {
187            "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
188            "codex" => self.complete_codex_cli(prompt, model_override).await,
189            "cursor" => self.complete_cursor_cli(prompt, model_override).await,
190            "anthropic" => {
191                self.complete_anthropic_with_model(prompt, model_override)
192                    .await
193            }
194            "xai" | "openai" | "openrouter" => {
195                self.complete_openai_compatible_with_model(prompt, model_override, provider)
196                    .await
197            }
198            _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
199        }
200    }
201
202    async fn complete_anthropic_with_model(
203        &self,
204        prompt: &str,
205        model_override: Option<&str>,
206    ) -> Result<String> {
207        let model = model_override.unwrap_or(&self.config.llm.model);
208        let request = AnthropicRequest {
209            model: model.to_string(),
210            max_tokens: self.config.llm.max_tokens,
211            messages: vec![AnthropicMessage {
212                role: "user".to_string(),
213                content: prompt.to_string(),
214            }],
215        };
216
217        let response = self
218            .client
219            .post(self.config.api_endpoint())
220            .header("x-api-key", &self.api_key)
221            .header("anthropic-version", "2023-06-01")
222            .header("content-type", "application/json")
223            .json(&request)
224            .send()
225            .await
226            .context("Failed to send request to Anthropic API")?;
227
228        if !response.status().is_success() {
229            let status = response.status();
230            let error_text = response.text().await.unwrap_or_default();
231            anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
232        }
233
234        let api_response: AnthropicResponse = response
235            .json()
236            .await
237            .context("Failed to parse Anthropic API response")?;
238
239        Ok(api_response
240            .content
241            .first()
242            .map(|c| c.text.clone())
243            .unwrap_or_default())
244    }
245
246    async fn complete_openai_compatible_with_model(
247        &self,
248        prompt: &str,
249        model_override: Option<&str>,
250        provider: &str,
251    ) -> Result<String> {
252        let model = model_override.unwrap_or(&self.config.llm.model);
253        // Strip provider prefix for native APIs (xai/, openai/)
254        // OpenRouter needs the prefix, native APIs don't
255        let model_for_api = if provider != "openrouter" {
256            let prefix = format!("{}/", provider);
257            model.strip_prefix(&prefix).unwrap_or(model)
258        } else {
259            model
260        };
261
262        // Get the correct endpoint for this provider
263        let endpoint = match provider {
264            "xai" => "https://api.x.ai/v1/chat/completions",
265            "openai" => "https://api.openai.com/v1/chat/completions",
266            "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
267            _ => "https://api.x.ai/v1/chat/completions",
268        };
269
270        let request = OpenAIRequest {
271            model: model_for_api.to_string(),
272            max_tokens: self.config.llm.max_tokens,
273            messages: vec![OpenAIMessage {
274                role: "user".to_string(),
275                content: prompt.to_string(),
276            }],
277        };
278
279        let mut request_builder = self
280            .client
281            .post(endpoint)
282            .header("authorization", format!("Bearer {}", self.api_key))
283            .header("content-type", "application/json");
284
285        // OpenRouter requires additional headers
286        if provider == "openrouter" {
287            request_builder = request_builder
288                .header("HTTP-Referer", "https://github.com/scud-cli")
289                .header("X-Title", "SCUD Task Master");
290        }
291
292        let response = request_builder
293            .json(&request)
294            .send()
295            .await
296            .with_context(|| format!("Failed to send request to {} API", provider))?;
297
298        if !response.status().is_success() {
299            let status = response.status();
300            let error_text = response.text().await.unwrap_or_default();
301            anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
302        }
303
304        let api_response: OpenAIResponse = response
305            .json()
306            .await
307            .with_context(|| format!("Failed to parse {} API response", provider))?;
308
309        Ok(api_response
310            .choices
311            .first()
312            .map(|c| c.message.content.clone())
313            .unwrap_or_default())
314    }
315
316    pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
317    where
318        T: serde::de::DeserializeOwned,
319    {
320        self.complete_json_with_model(prompt, None).await
321    }
322
323    /// Complete JSON using the smart model (for validation/analysis tasks)
324    pub async fn complete_json_smart<T>(
325        &self,
326        prompt: &str,
327        model_override: Option<&str>,
328    ) -> Result<T>
329    where
330        T: serde::de::DeserializeOwned,
331    {
332        let response_text = self.complete_smart(prompt, model_override).await?;
333        Self::parse_json_response(&response_text)
334    }
335
336    /// Complete JSON using the fast model (for generation tasks)
337    pub async fn complete_json_fast<T>(
338        &self,
339        prompt: &str,
340        model_override: Option<&str>,
341    ) -> Result<T>
342    where
343        T: serde::de::DeserializeOwned,
344    {
345        let response_text = self.complete_fast(prompt, model_override).await?;
346        Self::parse_json_response(&response_text)
347    }
348
349    pub async fn complete_json_with_model<T>(
350        &self,
351        prompt: &str,
352        model_override: Option<&str>,
353    ) -> Result<T>
354    where
355        T: serde::de::DeserializeOwned,
356    {
357        let response_text = self
358            .complete_with_model(prompt, model_override, None)
359            .await?;
360        Self::parse_json_response(&response_text)
361    }
362
363    fn parse_json_response<T>(response_text: &str) -> Result<T>
364    where
365        T: serde::de::DeserializeOwned,
366    {
367        // Try to find JSON in the response (LLM might include markdown or explanations)
368        let json_str = Self::extract_json(response_text);
369
370        serde_json::from_str(json_str).with_context(|| {
371            // Provide helpful error context
372            let preview = if json_str.len() > 500 {
373                format!("{}...", &json_str[..500])
374            } else {
375                json_str.to_string()
376            };
377            format!(
378                "Failed to parse JSON from LLM response. Response preview:\n{}",
379                preview
380            )
381        })
382    }
383
384    /// Extract JSON from LLM response, handling markdown code blocks and extra text
385    fn extract_json(response: &str) -> &str {
386        // First, try to extract from markdown code blocks
387        if let Some(start) = response.find("```json") {
388            let content_start = start + 7; // Skip "```json"
389            if let Some(end) = response[content_start..].find("```") {
390                return response[content_start..content_start + end].trim();
391            }
392        }
393
394        // Try plain code blocks
395        if let Some(start) = response.find("```") {
396            let content_start = start + 3;
397            // Skip language identifier if present (e.g., "```\n")
398            let content_start = response[content_start..]
399                .find('\n')
400                .map(|i| content_start + i + 1)
401                .unwrap_or(content_start);
402            if let Some(end) = response[content_start..].find("```") {
403                return response[content_start..content_start + end].trim();
404            }
405        }
406
407        // Try to find array JSON
408        if let Some(start) = response.find('[') {
409            if let Some(end) = response.rfind(']') {
410                if end > start {
411                    return &response[start..=end];
412                }
413            }
414        }
415
416        // Try to find object JSON
417        if let Some(start) = response.find('{') {
418            if let Some(end) = response.rfind('}') {
419                if end > start {
420                    return &response[start..=end];
421                }
422            }
423        }
424
425        response.trim()
426    }
427
428    async fn complete_claude_cli(
429        &self,
430        prompt: &str,
431        model_override: Option<&str>,
432    ) -> Result<String> {
433        use std::process::Stdio;
434        use tokio::io::AsyncWriteExt;
435        use tokio::process::Command;
436
437        let model = model_override.unwrap_or(&self.config.llm.model);
438
439        // Build the claude command
440        let mut cmd = Command::new("claude");
441        cmd.arg("-p") // Print mode (headless)
442            .arg("--output-format")
443            .arg("json")
444            .arg("--model")
445            .arg(model)
446            .stdin(Stdio::piped())
447            .stdout(Stdio::piped())
448            .stderr(Stdio::piped());
449
450        // Spawn the process
451        let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
452
453        // Write prompt to stdin
454        if let Some(mut stdin) = child.stdin.take() {
455            stdin
456                .write_all(prompt.as_bytes())
457                .await
458                .context("Failed to write prompt to claude stdin")?;
459            drop(stdin); // Close stdin
460        }
461
462        // Wait for completion
463        let output = child
464            .wait_with_output()
465            .await
466            .context("Failed to wait for claude command")?;
467
468        if !output.status.success() {
469            let stderr = String::from_utf8_lossy(&output.stderr);
470            anyhow::bail!("Claude CLI error: {}", stderr);
471        }
472
473        // Parse JSON output
474        let stdout =
475            String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
476
477        #[derive(Deserialize)]
478        struct ClaudeCliResponse {
479            result: String,
480        }
481
482        let response: ClaudeCliResponse =
483            serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
484
485        Ok(response.result)
486    }
487
488    async fn complete_codex_cli(
489        &self,
490        prompt: &str,
491        model_override: Option<&str>,
492    ) -> Result<String> {
493        use std::process::Stdio;
494        use tokio::io::AsyncWriteExt;
495        use tokio::process::Command;
496
497        let model = model_override.unwrap_or(&self.config.llm.model);
498
499        // Build the codex command
500        // Codex CLI uses similar headless mode to Claude Code
501        let mut cmd = Command::new("codex");
502        cmd.arg("-p") // Prompt mode (headless/non-interactive)
503            .arg("--model")
504            .arg(model)
505            .arg("--output-format")
506            .arg("json")
507            .stdin(Stdio::piped())
508            .stdout(Stdio::piped())
509            .stderr(Stdio::piped());
510
511        // Spawn the process
512        let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
513
514        // Write prompt to stdin
515        if let Some(mut stdin) = child.stdin.take() {
516            stdin
517                .write_all(prompt.as_bytes())
518                .await
519                .context("Failed to write prompt to codex stdin")?;
520            drop(stdin); // Close stdin
521        }
522
523        // Wait for completion
524        let output = child
525            .wait_with_output()
526            .await
527            .context("Failed to wait for codex command")?;
528
529        if !output.status.success() {
530            let stderr = String::from_utf8_lossy(&output.stderr);
531            anyhow::bail!("Codex CLI error: {}", stderr);
532        }
533
534        // Parse JSON output
535        let stdout =
536            String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
537
538        // Codex outputs JSON with a result field similar to Claude CLI
539        #[derive(Deserialize)]
540        struct CodexCliResponse {
541            result: String,
542        }
543
544        let response: CodexCliResponse =
545            serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
546
547        Ok(response.result)
548    }
549
550    async fn complete_cursor_cli(
551        &self,
552        prompt: &str,
553        model_override: Option<&str>,
554    ) -> Result<String> {
555        use std::process::Stdio;
556        use tokio::io::AsyncWriteExt;
557        use tokio::process::Command;
558
559        let model = model_override.unwrap_or(&self.config.llm.model);
560
561        // Build the cursor agent command
562        let mut cmd = Command::new("agent");
563        cmd.arg("-p") // Print mode (headless/non-interactive)
564            .arg("--model")
565            .arg(model)
566            .arg("--output-format")
567            .arg("json")
568            .stdin(Stdio::piped())
569            .stdout(Stdio::piped())
570            .stderr(Stdio::piped());
571
572        // Spawn the process
573        let mut child = cmd.spawn().context("Failed to spawn 'agent' command. Make sure Cursor Agent CLI is installed (curl https://cursor.com/install -fsSL | bash)")?;
574
575        // Write prompt to stdin
576        if let Some(mut stdin) = child.stdin.take() {
577            stdin
578                .write_all(prompt.as_bytes())
579                .await
580                .context("Failed to write prompt to cursor agent stdin")?;
581            drop(stdin); // Close stdin
582        }
583
584        // Wait for completion
585        let output = child
586            .wait_with_output()
587            .await
588            .context("Failed to wait for cursor agent command")?;
589
590        if !output.status.success() {
591            let stderr = String::from_utf8_lossy(&output.stderr);
592            anyhow::bail!("Cursor Agent CLI error: {}", stderr);
593        }
594
595        // Parse output - try JSON first, fall back to plain text
596        let stdout = String::from_utf8(output.stdout)
597            .context("Cursor Agent CLI output is not valid UTF-8")?;
598
599        #[derive(Deserialize)]
600        struct CursorCliResponse {
601            result: String,
602        }
603
604        // Try JSON parse first
605        if let Ok(response) = serde_json::from_str::<CursorCliResponse>(&stdout) {
606            return Ok(response.result);
607        }
608
609        // Fall back to raw text output
610        Ok(stdout.trim().to_string())
611    }
612}