1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62pub struct LLMClient {
63 config: Config,
64 api_key: String,
65 client: reqwest::Client,
66}
67
68#[derive(Debug, Clone)]
70pub struct ModelInfo {
71 pub tier: &'static str, pub provider: String,
73 pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 let prefix = format!("{}/", self.provider);
80 if self.model.starts_with(&prefix) {
81 write!(f, "{} model: {}", self.tier, self.model)
82 } else {
83 write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84 }
85 }
86}
87
88impl LLMClient {
89 pub fn new() -> Result<Self> {
90 let storage = Storage::new(None);
91 let config = storage.load_config()?;
92
93 let api_key = if config.requires_api_key() {
94 env::var(config.api_key_env_var()).with_context(|| {
95 format!("{} environment variable not set", config.api_key_env_var())
96 })?
97 } else {
98 String::new() };
100
101 Ok(LLMClient {
102 config,
103 api_key,
104 client: reqwest::Client::new(),
105 })
106 }
107
108 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
109 let storage = Storage::new(Some(project_root));
110 let config = storage.load_config()?;
111
112 let api_key = if config.requires_api_key() {
113 env::var(config.api_key_env_var()).with_context(|| {
114 format!("{} environment variable not set", config.api_key_env_var())
115 })?
116 } else {
117 String::new() };
119
120 Ok(LLMClient {
121 config,
122 api_key,
123 client: reqwest::Client::new(),
124 })
125 }
126
127 pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
129 ModelInfo {
130 tier: "smart",
131 provider: self.config.smart_provider().to_string(),
132 model: model_override
133 .unwrap_or(self.config.smart_model())
134 .to_string(),
135 }
136 }
137
138 pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
140 ModelInfo {
141 tier: "fast",
142 provider: self.config.fast_provider().to_string(),
143 model: model_override
144 .unwrap_or(self.config.fast_model())
145 .to_string(),
146 }
147 }
148
149 pub async fn complete(&self, prompt: &str) -> Result<String> {
150 self.complete_with_model(prompt, None, None).await
151 }
152
153 pub async fn complete_smart(
156 &self,
157 prompt: &str,
158 model_override: Option<&str>,
159 ) -> Result<String> {
160 let model = model_override.unwrap_or(self.config.smart_model());
161 let provider = self.config.smart_provider();
162 self.complete_with_model(prompt, Some(model), Some(provider))
163 .await
164 }
165
166 pub async fn complete_fast(
169 &self,
170 prompt: &str,
171 model_override: Option<&str>,
172 ) -> Result<String> {
173 let model = model_override.unwrap_or(self.config.fast_model());
174 let provider = self.config.fast_provider();
175 self.complete_with_model(prompt, Some(model), Some(provider))
176 .await
177 }
178
179 pub async fn complete_with_model(
180 &self,
181 prompt: &str,
182 model_override: Option<&str>,
183 provider_override: Option<&str>,
184 ) -> Result<String> {
185 let provider = provider_override.unwrap_or(&self.config.llm.provider);
186 match provider {
187 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
188 "codex" => self.complete_codex_cli(prompt, model_override).await,
189 "cursor" => self.complete_cursor_cli(prompt, model_override).await,
190 "anthropic" => {
191 self.complete_anthropic_with_model(prompt, model_override)
192 .await
193 }
194 "xai" | "openai" | "openrouter" => {
195 self.complete_openai_compatible_with_model(prompt, model_override, provider)
196 .await
197 }
198 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
199 }
200 }
201
202 async fn complete_anthropic_with_model(
203 &self,
204 prompt: &str,
205 model_override: Option<&str>,
206 ) -> Result<String> {
207 let model = model_override.unwrap_or(&self.config.llm.model);
208 let request = AnthropicRequest {
209 model: model.to_string(),
210 max_tokens: self.config.llm.max_tokens,
211 messages: vec![AnthropicMessage {
212 role: "user".to_string(),
213 content: prompt.to_string(),
214 }],
215 };
216
217 let response = self
218 .client
219 .post(self.config.api_endpoint())
220 .header("x-api-key", &self.api_key)
221 .header("anthropic-version", "2023-06-01")
222 .header("content-type", "application/json")
223 .json(&request)
224 .send()
225 .await
226 .context("Failed to send request to Anthropic API")?;
227
228 if !response.status().is_success() {
229 let status = response.status();
230 let error_text = response.text().await.unwrap_or_default();
231 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
232 }
233
234 let api_response: AnthropicResponse = response
235 .json()
236 .await
237 .context("Failed to parse Anthropic API response")?;
238
239 Ok(api_response
240 .content
241 .first()
242 .map(|c| c.text.clone())
243 .unwrap_or_default())
244 }
245
246 async fn complete_openai_compatible_with_model(
247 &self,
248 prompt: &str,
249 model_override: Option<&str>,
250 provider: &str,
251 ) -> Result<String> {
252 let model = model_override.unwrap_or(&self.config.llm.model);
253 let model_for_api = if provider != "openrouter" {
256 let prefix = format!("{}/", provider);
257 model.strip_prefix(&prefix).unwrap_or(model)
258 } else {
259 model
260 };
261
262 let endpoint = match provider {
264 "xai" => "https://api.x.ai/v1/chat/completions",
265 "openai" => "https://api.openai.com/v1/chat/completions",
266 "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
267 _ => "https://api.x.ai/v1/chat/completions",
268 };
269
270 let request = OpenAIRequest {
271 model: model_for_api.to_string(),
272 max_tokens: self.config.llm.max_tokens,
273 messages: vec![OpenAIMessage {
274 role: "user".to_string(),
275 content: prompt.to_string(),
276 }],
277 };
278
279 let mut request_builder = self
280 .client
281 .post(endpoint)
282 .header("authorization", format!("Bearer {}", self.api_key))
283 .header("content-type", "application/json");
284
285 if provider == "openrouter" {
287 request_builder = request_builder
288 .header("HTTP-Referer", "https://github.com/scud-cli")
289 .header("X-Title", "SCUD Task Master");
290 }
291
292 let response = request_builder
293 .json(&request)
294 .send()
295 .await
296 .with_context(|| format!("Failed to send request to {} API", provider))?;
297
298 if !response.status().is_success() {
299 let status = response.status();
300 let error_text = response.text().await.unwrap_or_default();
301 anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
302 }
303
304 let api_response: OpenAIResponse = response
305 .json()
306 .await
307 .with_context(|| format!("Failed to parse {} API response", provider))?;
308
309 Ok(api_response
310 .choices
311 .first()
312 .map(|c| c.message.content.clone())
313 .unwrap_or_default())
314 }
315
316 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
317 where
318 T: serde::de::DeserializeOwned,
319 {
320 self.complete_json_with_model(prompt, None).await
321 }
322
323 pub async fn complete_json_smart<T>(
325 &self,
326 prompt: &str,
327 model_override: Option<&str>,
328 ) -> Result<T>
329 where
330 T: serde::de::DeserializeOwned,
331 {
332 let response_text = self.complete_smart(prompt, model_override).await?;
333 Self::parse_json_response(&response_text)
334 }
335
336 pub async fn complete_json_fast<T>(
338 &self,
339 prompt: &str,
340 model_override: Option<&str>,
341 ) -> Result<T>
342 where
343 T: serde::de::DeserializeOwned,
344 {
345 let response_text = self.complete_fast(prompt, model_override).await?;
346 Self::parse_json_response(&response_text)
347 }
348
349 pub async fn complete_json_with_model<T>(
350 &self,
351 prompt: &str,
352 model_override: Option<&str>,
353 ) -> Result<T>
354 where
355 T: serde::de::DeserializeOwned,
356 {
357 let response_text = self
358 .complete_with_model(prompt, model_override, None)
359 .await?;
360 Self::parse_json_response(&response_text)
361 }
362
363 fn parse_json_response<T>(response_text: &str) -> Result<T>
364 where
365 T: serde::de::DeserializeOwned,
366 {
367 let json_str = Self::extract_json(response_text);
369
370 serde_json::from_str(json_str).with_context(|| {
371 let preview = if json_str.len() > 500 {
373 format!("{}...", &json_str[..500])
374 } else {
375 json_str.to_string()
376 };
377 format!(
378 "Failed to parse JSON from LLM response. Response preview:\n{}",
379 preview
380 )
381 })
382 }
383
384 fn extract_json(response: &str) -> &str {
386 if let Some(start) = response.find("```json") {
388 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
390 return response[content_start..content_start + end].trim();
391 }
392 }
393
394 if let Some(start) = response.find("```") {
396 let content_start = start + 3;
397 let content_start = response[content_start..]
399 .find('\n')
400 .map(|i| content_start + i + 1)
401 .unwrap_or(content_start);
402 if let Some(end) = response[content_start..].find("```") {
403 return response[content_start..content_start + end].trim();
404 }
405 }
406
407 if let Some(start) = response.find('[') {
409 if let Some(end) = response.rfind(']') {
410 if end > start {
411 return &response[start..=end];
412 }
413 }
414 }
415
416 if let Some(start) = response.find('{') {
418 if let Some(end) = response.rfind('}') {
419 if end > start {
420 return &response[start..=end];
421 }
422 }
423 }
424
425 response.trim()
426 }
427
428 async fn complete_claude_cli(
429 &self,
430 prompt: &str,
431 model_override: Option<&str>,
432 ) -> Result<String> {
433 use std::process::Stdio;
434 use tokio::io::AsyncWriteExt;
435 use tokio::process::Command;
436
437 let model = model_override.unwrap_or(&self.config.llm.model);
438
439 let mut cmd = Command::new("claude");
441 cmd.arg("-p") .arg("--output-format")
443 .arg("json")
444 .arg("--model")
445 .arg(model)
446 .stdin(Stdio::piped())
447 .stdout(Stdio::piped())
448 .stderr(Stdio::piped());
449
450 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
452
453 if let Some(mut stdin) = child.stdin.take() {
455 stdin
456 .write_all(prompt.as_bytes())
457 .await
458 .context("Failed to write prompt to claude stdin")?;
459 drop(stdin); }
461
462 let output = child
464 .wait_with_output()
465 .await
466 .context("Failed to wait for claude command")?;
467
468 if !output.status.success() {
469 let stderr = String::from_utf8_lossy(&output.stderr);
470 anyhow::bail!("Claude CLI error: {}", stderr);
471 }
472
473 let stdout =
475 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
476
477 #[derive(Deserialize)]
478 struct ClaudeCliResponse {
479 result: String,
480 }
481
482 let response: ClaudeCliResponse =
483 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
484
485 Ok(response.result)
486 }
487
488 async fn complete_codex_cli(
489 &self,
490 prompt: &str,
491 model_override: Option<&str>,
492 ) -> Result<String> {
493 use std::process::Stdio;
494 use tokio::io::AsyncWriteExt;
495 use tokio::process::Command;
496
497 let model = model_override.unwrap_or(&self.config.llm.model);
498
499 let mut cmd = Command::new("codex");
502 cmd.arg("-p") .arg("--model")
504 .arg(model)
505 .arg("--output-format")
506 .arg("json")
507 .stdin(Stdio::piped())
508 .stdout(Stdio::piped())
509 .stderr(Stdio::piped());
510
511 let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
513
514 if let Some(mut stdin) = child.stdin.take() {
516 stdin
517 .write_all(prompt.as_bytes())
518 .await
519 .context("Failed to write prompt to codex stdin")?;
520 drop(stdin); }
522
523 let output = child
525 .wait_with_output()
526 .await
527 .context("Failed to wait for codex command")?;
528
529 if !output.status.success() {
530 let stderr = String::from_utf8_lossy(&output.stderr);
531 anyhow::bail!("Codex CLI error: {}", stderr);
532 }
533
534 let stdout =
536 String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
537
538 #[derive(Deserialize)]
540 struct CodexCliResponse {
541 result: String,
542 }
543
544 let response: CodexCliResponse =
545 serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
546
547 Ok(response.result)
548 }
549
550 async fn complete_cursor_cli(
551 &self,
552 prompt: &str,
553 model_override: Option<&str>,
554 ) -> Result<String> {
555 use std::process::Stdio;
556 use tokio::io::AsyncWriteExt;
557 use tokio::process::Command;
558
559 let model = model_override.unwrap_or(&self.config.llm.model);
560
561 let mut cmd = Command::new("agent");
563 cmd.arg("-p") .arg("--model")
565 .arg(model)
566 .arg("--output-format")
567 .arg("json")
568 .stdin(Stdio::piped())
569 .stdout(Stdio::piped())
570 .stderr(Stdio::piped());
571
572 let mut child = cmd.spawn().context("Failed to spawn 'agent' command. Make sure Cursor Agent CLI is installed (curl https://cursor.com/install -fsSL | bash)")?;
574
575 if let Some(mut stdin) = child.stdin.take() {
577 stdin
578 .write_all(prompt.as_bytes())
579 .await
580 .context("Failed to write prompt to cursor agent stdin")?;
581 drop(stdin); }
583
584 let output = child
586 .wait_with_output()
587 .await
588 .context("Failed to wait for cursor agent command")?;
589
590 if !output.status.success() {
591 let stderr = String::from_utf8_lossy(&output.stderr);
592 anyhow::bail!("Cursor Agent CLI error: {}", stderr);
593 }
594
595 let stdout = String::from_utf8(output.stdout)
597 .context("Cursor Agent CLI output is not valid UTF-8")?;
598
599 #[derive(Deserialize)]
600 struct CursorCliResponse {
601 result: String,
602 }
603
604 if let Ok(response) = serde_json::from_str::<CursorCliResponse>(&stdout) {
606 return Ok(response.result);
607 }
608
609 Ok(stdout.trim().to_string())
611 }
612}