1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::llm::oauth;
8use crate::storage::Storage;
9
10#[derive(Debug, Serialize)]
12struct AnthropicRequest {
13 model: String,
14 max_tokens: u32,
15 messages: Vec<AnthropicMessage>,
16}
17
18#[derive(Debug, Serialize)]
19struct AnthropicMessage {
20 role: String,
21 content: String,
22}
23
24#[derive(Debug, Deserialize)]
25struct AnthropicResponse {
26 content: Vec<AnthropicContent>,
27}
28
29#[derive(Debug, Deserialize)]
30struct AnthropicContent {
31 text: String,
32}
33
34#[derive(Debug, Serialize)]
36struct OpenAIRequest {
37 model: String,
38 max_tokens: u32,
39 messages: Vec<OpenAIMessage>,
40}
41
42#[derive(Debug, Serialize)]
43struct OpenAIMessage {
44 role: String,
45 content: String,
46}
47
48#[derive(Debug, Deserialize)]
49struct OpenAIResponse {
50 choices: Vec<OpenAIChoice>,
51}
52
53#[derive(Debug, Deserialize)]
54struct OpenAIChoice {
55 message: OpenAIMessageResponse,
56}
57
58#[derive(Debug, Deserialize)]
59struct OpenAIMessageResponse {
60 content: String,
61}
62
63pub struct LLMClient {
64 config: Config,
65 client: reqwest::Client,
66}
67
68#[derive(Debug, Clone)]
70pub struct ModelInfo {
71 pub tier: &'static str, pub provider: String,
73 pub model: String,
74}
75
76impl std::fmt::Display for ModelInfo {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 let prefix = format!("{}/", self.provider);
80 if self.model.starts_with(&prefix) {
81 write!(f, "{} model: {}", self.tier, self.model)
82 } else {
83 write!(f, "{} model: {}/{}", self.tier, self.provider, self.model)
84 }
85 }
86}
87
88impl LLMClient {
89 pub fn new() -> Result<Self> {
90 let storage = Storage::new(None);
91 let config = storage.load_config()?;
92 Ok(LLMClient {
93 config,
94 client: reqwest::Client::new(),
95 })
96 }
97
98 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
99 let storage = Storage::new(Some(project_root));
100 let config = storage.load_config()?;
101 Ok(LLMClient {
102 config,
103 client: reqwest::Client::new(),
104 })
105 }
106
107 pub fn smart_model_info(&self, model_override: Option<&str>) -> ModelInfo {
109 ModelInfo {
110 tier: "smart",
111 provider: self.config.smart_provider().to_string(),
112 model: model_override
113 .unwrap_or(self.config.smart_model())
114 .to_string(),
115 }
116 }
117
118 pub fn fast_model_info(&self, model_override: Option<&str>) -> ModelInfo {
120 ModelInfo {
121 tier: "fast",
122 provider: self.config.fast_provider().to_string(),
123 model: model_override
124 .unwrap_or(self.config.fast_model())
125 .to_string(),
126 }
127 }
128
129 pub async fn complete(&self, prompt: &str) -> Result<String> {
130 self.complete_with_model(prompt, None, None).await
131 }
132
133 pub async fn complete_smart(
136 &self,
137 prompt: &str,
138 model_override: Option<&str>,
139 ) -> Result<String> {
140 let model = model_override.unwrap_or(self.config.smart_model());
141 let provider = self.config.smart_provider();
142 self.complete_with_model(prompt, Some(model), Some(provider))
143 .await
144 }
145
146 pub async fn complete_fast(
149 &self,
150 prompt: &str,
151 model_override: Option<&str>,
152 ) -> Result<String> {
153 let model = model_override.unwrap_or(self.config.fast_model());
154 let provider = self.config.fast_provider();
155 self.complete_with_model(prompt, Some(model), Some(provider))
156 .await
157 }
158
159 pub async fn complete_with_model(
160 &self,
161 prompt: &str,
162 model_override: Option<&str>,
163 provider_override: Option<&str>,
164 ) -> Result<String> {
165 let provider = provider_override.unwrap_or(&self.config.llm.provider);
166 match provider {
167 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
168 "codex" => self.complete_codex_cli(prompt, model_override).await,
169 "cursor" => self.complete_cursor_cli(prompt, model_override).await,
170 "anthropic" => {
171 self.complete_anthropic_with_model(prompt, model_override)
172 .await
173 }
174 "xai" | "openai" | "openrouter" => {
175 self.complete_openai_compatible_with_model(prompt, model_override, provider)
176 .await
177 }
178 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
179 }
180 }
181
182 async fn complete_anthropic_with_model(
183 &self,
184 prompt: &str,
185 model_override: Option<&str>,
186 ) -> Result<String> {
187 let model = model_override.unwrap_or(&self.config.llm.model);
188 let credential = oauth::resolve_anthropic_credential()?;
189
190 let request = AnthropicRequest {
191 model: model.to_string(),
192 max_tokens: self.config.llm.max_tokens,
193 messages: vec![AnthropicMessage {
194 role: "user".to_string(),
195 content: prompt.to_string(),
196 }],
197 };
198
199 let mut request_builder = self
200 .client
201 .post("https://api.anthropic.com/v1/messages")
202 .header("anthropic-version", "2023-06-01")
203 .header("content-type", "application/json");
204
205 match &credential {
206 oauth::ApiCredential::OAuth(token) => {
207 request_builder = request_builder
208 .header("authorization", format!("Bearer {}", token))
209 .header("anthropic-beta", "oauth-2025-04-20")
210 .header("user-agent", "SCUD-CLI/1.0");
211 }
212 oauth::ApiCredential::ApiKey(key) => {
213 request_builder = request_builder.header("x-api-key", key);
214 }
215 }
216
217 let response = request_builder
218 .json(&request)
219 .send()
220 .await
221 .context("Failed to send request to Anthropic API")?;
222
223 if !response.status().is_success() {
224 let status = response.status();
225 let error_text = response.text().await.unwrap_or_default();
226 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
227 }
228
229 let api_response: AnthropicResponse = response
230 .json()
231 .await
232 .context("Failed to parse Anthropic API response")?;
233
234 Ok(api_response
235 .content
236 .first()
237 .map(|c| c.text.clone())
238 .unwrap_or_default())
239 }
240
241 async fn complete_openai_compatible_with_model(
242 &self,
243 prompt: &str,
244 model_override: Option<&str>,
245 provider: &str,
246 ) -> Result<String> {
247 let model = model_override.unwrap_or(&self.config.llm.model);
248 let model_for_api = if provider != "openrouter" {
251 let prefix = format!("{}/", provider);
252 model.strip_prefix(&prefix).unwrap_or(model)
253 } else {
254 model
255 };
256
257 let endpoint = match provider {
259 "xai" => "https://api.x.ai/v1/chat/completions",
260 "openai" => "https://api.openai.com/v1/chat/completions",
261 "openrouter" => "https://openrouter.ai/api/v1/chat/completions",
262 _ => "https://api.x.ai/v1/chat/completions",
263 };
264
265 let env_var = Config::api_key_env_var_for_provider(provider);
267 let api_key = env::var(env_var)
268 .with_context(|| format!("{} environment variable not set", env_var))?;
269
270 let request = OpenAIRequest {
271 model: model_for_api.to_string(),
272 max_tokens: self.config.llm.max_tokens,
273 messages: vec![OpenAIMessage {
274 role: "user".to_string(),
275 content: prompt.to_string(),
276 }],
277 };
278
279 let mut request_builder = self
280 .client
281 .post(endpoint)
282 .header("authorization", format!("Bearer {}", api_key))
283 .header("content-type", "application/json");
284
285 if provider == "openrouter" {
287 request_builder = request_builder
288 .header("HTTP-Referer", "https://github.com/scud-cli")
289 .header("X-Title", "SCUD Task Master");
290 }
291
292 let response = request_builder
293 .json(&request)
294 .send()
295 .await
296 .with_context(|| format!("Failed to send request to {} API", provider))?;
297
298 if !response.status().is_success() {
299 let status = response.status();
300 let error_text = response.text().await.unwrap_or_default();
301 anyhow::bail!("{} API error ({}): {}", provider, status, error_text);
302 }
303
304 let api_response: OpenAIResponse = response
305 .json()
306 .await
307 .with_context(|| format!("Failed to parse {} API response", provider))?;
308
309 Ok(api_response
310 .choices
311 .first()
312 .map(|c| c.message.content.clone())
313 .unwrap_or_default())
314 }
315
316 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
317 where
318 T: serde::de::DeserializeOwned,
319 {
320 self.complete_json_with_model(prompt, None).await
321 }
322
323 pub async fn complete_json_smart<T>(
325 &self,
326 prompt: &str,
327 model_override: Option<&str>,
328 ) -> Result<T>
329 where
330 T: serde::de::DeserializeOwned,
331 {
332 let response_text = self.complete_smart(prompt, model_override).await?;
333 Self::parse_json_response(&response_text)
334 }
335
336 pub async fn complete_json_fast<T>(
338 &self,
339 prompt: &str,
340 model_override: Option<&str>,
341 ) -> Result<T>
342 where
343 T: serde::de::DeserializeOwned,
344 {
345 let response_text = self.complete_fast(prompt, model_override).await?;
346 Self::parse_json_response(&response_text)
347 }
348
349 pub async fn complete_json_with_model<T>(
350 &self,
351 prompt: &str,
352 model_override: Option<&str>,
353 ) -> Result<T>
354 where
355 T: serde::de::DeserializeOwned,
356 {
357 let response_text = self
358 .complete_with_model(prompt, model_override, None)
359 .await?;
360 Self::parse_json_response(&response_text)
361 }
362
363 fn parse_json_response<T>(response_text: &str) -> Result<T>
364 where
365 T: serde::de::DeserializeOwned,
366 {
367 let json_str = Self::extract_json(response_text);
369
370 serde_json::from_str(json_str).with_context(|| {
371 let preview = if json_str.len() > 500 {
373 format!("{}...", &json_str[..500])
374 } else {
375 json_str.to_string()
376 };
377 format!(
378 "Failed to parse JSON from LLM response. Response preview:\n{}",
379 preview
380 )
381 })
382 }
383
384 fn extract_json(response: &str) -> &str {
386 if let Some(start) = response.find("```json") {
388 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
390 return response[content_start..content_start + end].trim();
391 }
392 }
393
394 if let Some(start) = response.find("```") {
396 let content_start = start + 3;
397 let content_start = response[content_start..]
399 .find('\n')
400 .map(|i| content_start + i + 1)
401 .unwrap_or(content_start);
402 if let Some(end) = response[content_start..].find("```") {
403 return response[content_start..content_start + end].trim();
404 }
405 }
406
407 if let Some(start) = response.find('[') {
409 if let Some(end) = response.rfind(']') {
410 if end > start {
411 return &response[start..=end];
412 }
413 }
414 }
415
416 if let Some(start) = response.find('{') {
418 if let Some(end) = response.rfind('}') {
419 if end > start {
420 return &response[start..=end];
421 }
422 }
423 }
424
425 response.trim()
426 }
427
428 async fn complete_claude_cli(
429 &self,
430 prompt: &str,
431 model_override: Option<&str>,
432 ) -> Result<String> {
433 use std::process::Stdio;
434 use tokio::io::AsyncWriteExt;
435 use tokio::process::Command;
436
437 let model = model_override.unwrap_or(&self.config.llm.model);
438
439 let mut cmd = Command::new("claude");
441 cmd.arg("-p") .arg("--output-format")
443 .arg("json")
444 .arg("--model")
445 .arg(model)
446 .stdin(Stdio::piped())
447 .stdout(Stdio::piped())
448 .stderr(Stdio::piped());
449
450 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
452
453 if let Some(mut stdin) = child.stdin.take() {
455 stdin
456 .write_all(prompt.as_bytes())
457 .await
458 .context("Failed to write prompt to claude stdin")?;
459 drop(stdin); }
461
462 let output = child
464 .wait_with_output()
465 .await
466 .context("Failed to wait for claude command")?;
467
468 if !output.status.success() {
469 let stderr = String::from_utf8_lossy(&output.stderr);
470 anyhow::bail!("Claude CLI error: {}", stderr);
471 }
472
473 let stdout =
475 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
476
477 #[derive(Deserialize)]
478 struct ClaudeCliResponse {
479 result: String,
480 }
481
482 let response: ClaudeCliResponse =
483 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
484
485 Ok(response.result)
486 }
487
488 async fn complete_codex_cli(
489 &self,
490 prompt: &str,
491 model_override: Option<&str>,
492 ) -> Result<String> {
493 use std::process::Stdio;
494 use tokio::io::AsyncWriteExt;
495 use tokio::process::Command;
496
497 let model = model_override.unwrap_or(&self.config.llm.model);
498
499 let mut cmd = Command::new("codex");
502 cmd.arg("-p") .arg("--model")
504 .arg(model)
505 .arg("--output-format")
506 .arg("json")
507 .stdin(Stdio::piped())
508 .stdout(Stdio::piped())
509 .stderr(Stdio::piped());
510
511 let mut child = cmd.spawn().context("Failed to spawn 'codex' command. Make sure OpenAI Codex CLI is installed and 'codex' is in your PATH")?;
513
514 if let Some(mut stdin) = child.stdin.take() {
516 stdin
517 .write_all(prompt.as_bytes())
518 .await
519 .context("Failed to write prompt to codex stdin")?;
520 drop(stdin); }
522
523 let output = child
525 .wait_with_output()
526 .await
527 .context("Failed to wait for codex command")?;
528
529 if !output.status.success() {
530 let stderr = String::from_utf8_lossy(&output.stderr);
531 anyhow::bail!("Codex CLI error: {}", stderr);
532 }
533
534 let stdout =
536 String::from_utf8(output.stdout).context("Codex CLI output is not valid UTF-8")?;
537
538 #[derive(Deserialize)]
540 struct CodexCliResponse {
541 result: String,
542 }
543
544 let response: CodexCliResponse =
545 serde_json::from_str(&stdout).context("Failed to parse Codex CLI JSON response")?;
546
547 Ok(response.result)
548 }
549
550 async fn complete_cursor_cli(
551 &self,
552 prompt: &str,
553 model_override: Option<&str>,
554 ) -> Result<String> {
555 use std::process::Stdio;
556 use tokio::io::AsyncWriteExt;
557 use tokio::process::Command;
558
559 let model = model_override.unwrap_or(&self.config.llm.model);
560
561 let mut cmd = Command::new("agent");
563 cmd.arg("-p") .arg("--model")
565 .arg(model)
566 .arg("--output-format")
567 .arg("json")
568 .stdin(Stdio::piped())
569 .stdout(Stdio::piped())
570 .stderr(Stdio::piped());
571
572 let mut child = cmd.spawn().context("Failed to spawn 'agent' command. Make sure Cursor Agent CLI is installed (curl https://cursor.com/install -fsSL | bash)")?;
574
575 if let Some(mut stdin) = child.stdin.take() {
577 stdin
578 .write_all(prompt.as_bytes())
579 .await
580 .context("Failed to write prompt to cursor agent stdin")?;
581 drop(stdin); }
583
584 let output = child
586 .wait_with_output()
587 .await
588 .context("Failed to wait for cursor agent command")?;
589
590 if !output.status.success() {
591 let stderr = String::from_utf8_lossy(&output.stderr);
592 anyhow::bail!("Cursor Agent CLI error: {}", stderr);
593 }
594
595 let stdout = String::from_utf8(output.stdout)
597 .context("Cursor Agent CLI output is not valid UTF-8")?;
598
599 #[derive(Deserialize)]
600 struct CursorCliResponse {
601 result: String,
602 }
603
604 if let Ok(response) = serde_json::from_str::<CursorCliResponse>(&stdout) {
606 return Ok(response.result);
607 }
608
609 Ok(stdout.trim().to_string())
611 }
612}