1use anyhow::{Context, Result};
2use serde::{Deserialize, Serialize};
3use std::env;
4use std::path::PathBuf;
5
6use crate::config::Config;
7use crate::storage::Storage;
8
9#[derive(Debug, Serialize)]
11struct AnthropicRequest {
12 model: String,
13 max_tokens: u32,
14 messages: Vec<AnthropicMessage>,
15}
16
17#[derive(Debug, Serialize)]
18struct AnthropicMessage {
19 role: String,
20 content: String,
21}
22
23#[derive(Debug, Deserialize)]
24struct AnthropicResponse {
25 content: Vec<AnthropicContent>,
26}
27
28#[derive(Debug, Deserialize)]
29struct AnthropicContent {
30 text: String,
31}
32
33#[derive(Debug, Serialize)]
35struct OpenAIRequest {
36 model: String,
37 max_tokens: u32,
38 messages: Vec<OpenAIMessage>,
39}
40
41#[derive(Debug, Serialize)]
42struct OpenAIMessage {
43 role: String,
44 content: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OpenAIResponse {
49 choices: Vec<OpenAIChoice>,
50}
51
52#[derive(Debug, Deserialize)]
53struct OpenAIChoice {
54 message: OpenAIMessageResponse,
55}
56
57#[derive(Debug, Deserialize)]
58struct OpenAIMessageResponse {
59 content: String,
60}
61
62#[derive(Debug, Clone, Default)]
64pub struct ToolConfig {
65 pub allowed_tools: Vec<String>,
67 pub max_turns: Option<u32>,
69}
70
71impl ToolConfig {
72 pub fn new(allowed_tools: Vec<String>, max_turns: Option<u32>) -> Self {
73 Self {
74 allowed_tools,
75 max_turns,
76 }
77 }
78}
79
80pub struct LLMClient {
81 config: Config,
82 api_key: String,
83 client: reqwest::Client,
84}
85
86impl LLMClient {
87 pub fn new() -> Result<Self> {
88 let storage = Storage::new(None);
89 let config = storage.load_config()?;
90
91 let api_key = if config.requires_api_key() {
92 env::var(config.api_key_env_var()).with_context(|| {
93 format!("{} environment variable not set", config.api_key_env_var())
94 })?
95 } else {
96 String::new() };
98
99 Ok(LLMClient {
100 config,
101 api_key,
102 client: reqwest::Client::new(),
103 })
104 }
105
106 pub fn new_with_project_root(project_root: PathBuf) -> Result<Self> {
107 let storage = Storage::new(Some(project_root));
108 let config = storage.load_config()?;
109
110 let api_key = if config.requires_api_key() {
111 env::var(config.api_key_env_var()).with_context(|| {
112 format!("{} environment variable not set", config.api_key_env_var())
113 })?
114 } else {
115 String::new() };
117
118 Ok(LLMClient {
119 config,
120 api_key,
121 client: reqwest::Client::new(),
122 })
123 }
124
125 pub async fn complete(&self, prompt: &str) -> Result<String> {
126 self.complete_with_model(prompt, None).await
127 }
128
129 pub async fn complete_with_model(
130 &self,
131 prompt: &str,
132 model_override: Option<&str>,
133 ) -> Result<String> {
134 match self.config.llm.provider.as_str() {
135 "claude-cli" => self.complete_claude_cli(prompt, model_override).await,
136 "anthropic" => {
137 self.complete_anthropic_with_model(prompt, model_override)
138 .await
139 }
140 "xai" | "openai" | "openrouter" => {
141 self.complete_openai_compatible_with_model(prompt, model_override)
142 .await
143 }
144 _ => anyhow::bail!("Unsupported provider: {}", self.config.llm.provider),
145 }
146 }
147
148 async fn complete_anthropic_with_model(
149 &self,
150 prompt: &str,
151 model_override: Option<&str>,
152 ) -> Result<String> {
153 let model = model_override.unwrap_or(&self.config.llm.model);
154 let request = AnthropicRequest {
155 model: model.to_string(),
156 max_tokens: self.config.llm.max_tokens,
157 messages: vec![AnthropicMessage {
158 role: "user".to_string(),
159 content: prompt.to_string(),
160 }],
161 };
162
163 let response = self
164 .client
165 .post(self.config.api_endpoint())
166 .header("x-api-key", &self.api_key)
167 .header("anthropic-version", "2023-06-01")
168 .header("content-type", "application/json")
169 .json(&request)
170 .send()
171 .await
172 .context("Failed to send request to Anthropic API")?;
173
174 if !response.status().is_success() {
175 let status = response.status();
176 let error_text = response.text().await.unwrap_or_default();
177 anyhow::bail!("Anthropic API error ({}): {}", status, error_text);
178 }
179
180 let api_response: AnthropicResponse = response
181 .json()
182 .await
183 .context("Failed to parse Anthropic API response")?;
184
185 Ok(api_response
186 .content
187 .first()
188 .map(|c| c.text.clone())
189 .unwrap_or_default())
190 }
191
192 async fn complete_openai_compatible_with_model(
193 &self,
194 prompt: &str,
195 model_override: Option<&str>,
196 ) -> Result<String> {
197 let model = model_override.unwrap_or(&self.config.llm.model);
198 let request = OpenAIRequest {
199 model: model.to_string(),
200 max_tokens: self.config.llm.max_tokens,
201 messages: vec![OpenAIMessage {
202 role: "user".to_string(),
203 content: prompt.to_string(),
204 }],
205 };
206
207 let mut request_builder = self
208 .client
209 .post(self.config.api_endpoint())
210 .header("authorization", format!("Bearer {}", self.api_key))
211 .header("content-type", "application/json");
212
213 if self.config.llm.provider == "openrouter" {
215 request_builder = request_builder
216 .header("HTTP-Referer", "https://github.com/scud-cli")
217 .header("X-Title", "SCUD Task Master");
218 }
219
220 let response = request_builder
221 .json(&request)
222 .send()
223 .await
224 .with_context(|| {
225 format!("Failed to send request to {} API", self.config.llm.provider)
226 })?;
227
228 if !response.status().is_success() {
229 let status = response.status();
230 let error_text = response.text().await.unwrap_or_default();
231 anyhow::bail!(
232 "{} API error ({}): {}",
233 self.config.llm.provider,
234 status,
235 error_text
236 );
237 }
238
239 let api_response: OpenAIResponse = response.json().await.with_context(|| {
240 format!("Failed to parse {} API response", self.config.llm.provider)
241 })?;
242
243 Ok(api_response
244 .choices
245 .first()
246 .map(|c| c.message.content.clone())
247 .unwrap_or_default())
248 }
249
250 pub async fn complete_json<T>(&self, prompt: &str) -> Result<T>
251 where
252 T: serde::de::DeserializeOwned,
253 {
254 let response_text = self.complete(prompt).await?;
255
256 let json_str = Self::extract_json(&response_text);
258
259 serde_json::from_str(json_str).with_context(|| {
260 let preview = if json_str.len() > 500 {
262 format!("{}...", &json_str[..500])
263 } else {
264 json_str.to_string()
265 };
266 format!(
267 "Failed to parse JSON from LLM response. Response preview:\n{}",
268 preview
269 )
270 })
271 }
272
273 pub fn is_claude_cli(&self) -> bool {
275 self.config.llm.provider == "claude-cli"
276 }
277
278 pub async fn complete_with_tools(&self, prompt: &str, tools: &ToolConfig) -> Result<String> {
280 if !self.is_claude_cli() {
281 return self.complete(prompt).await;
283 }
284 self.complete_claude_cli_with_tools(prompt, None, Some(tools))
285 .await
286 }
287
288 pub async fn complete_json_with_tools<T>(&self, prompt: &str, tools: &ToolConfig) -> Result<T>
290 where
291 T: serde::de::DeserializeOwned,
292 {
293 let response_text = self.complete_with_tools(prompt, tools).await?;
294
295 let json_str = Self::extract_json(&response_text);
297
298 serde_json::from_str(json_str).with_context(|| {
299 let preview = if json_str.len() > 500 {
300 format!("{}...", &json_str[..500])
301 } else {
302 json_str.to_string()
303 };
304 format!(
305 "Failed to parse JSON from LLM response. Response preview:\n{}",
306 preview
307 )
308 })
309 }
310
311 fn extract_json(response: &str) -> &str {
313 if let Some(start) = response.find("```json") {
315 let content_start = start + 7; if let Some(end) = response[content_start..].find("```") {
317 return response[content_start..content_start + end].trim();
318 }
319 }
320
321 if let Some(start) = response.find("```") {
323 let content_start = start + 3;
324 let content_start = response[content_start..]
326 .find('\n')
327 .map(|i| content_start + i + 1)
328 .unwrap_or(content_start);
329 if let Some(end) = response[content_start..].find("```") {
330 return response[content_start..content_start + end].trim();
331 }
332 }
333
334 if let Some(start) = response.find('[') {
336 if let Some(end) = response.rfind(']') {
337 if end > start {
338 return &response[start..=end];
339 }
340 }
341 }
342
343 if let Some(start) = response.find('{') {
345 if let Some(end) = response.rfind('}') {
346 if end > start {
347 return &response[start..=end];
348 }
349 }
350 }
351
352 response.trim()
353 }
354
355 async fn complete_claude_cli(
356 &self,
357 prompt: &str,
358 model_override: Option<&str>,
359 ) -> Result<String> {
360 self.complete_claude_cli_with_tools(prompt, model_override, None)
361 .await
362 }
363
364 async fn complete_claude_cli_with_tools(
365 &self,
366 prompt: &str,
367 model_override: Option<&str>,
368 tool_config: Option<&ToolConfig>,
369 ) -> Result<String> {
370 use std::process::Stdio;
371 use tokio::io::AsyncWriteExt;
372 use tokio::process::Command;
373
374 let model = model_override.unwrap_or(&self.config.llm.model);
375
376 let mut cmd = Command::new("claude");
378 cmd.arg("-p") .arg("--output-format")
380 .arg("json")
381 .arg("--model")
382 .arg(model);
383
384 if let Some(tools) = tool_config {
386 if !tools.allowed_tools.is_empty() {
387 cmd.arg("--allowedTools")
388 .arg(tools.allowed_tools.join(","));
389 }
390 if let Some(max_turns) = tools.max_turns {
391 cmd.arg("--max-turns").arg(max_turns.to_string());
392 }
393 }
394
395 cmd.stdin(Stdio::piped())
396 .stdout(Stdio::piped())
397 .stderr(Stdio::piped());
398
399 let mut child = cmd.spawn().context("Failed to spawn 'claude' command. Make sure Claude Code is installed and 'claude' is in your PATH")?;
401
402 if let Some(mut stdin) = child.stdin.take() {
404 stdin
405 .write_all(prompt.as_bytes())
406 .await
407 .context("Failed to write prompt to claude stdin")?;
408 drop(stdin); }
410
411 let output = child
413 .wait_with_output()
414 .await
415 .context("Failed to wait for claude command")?;
416
417 if !output.status.success() {
418 let stderr = String::from_utf8_lossy(&output.stderr);
419 anyhow::bail!("Claude CLI error: {}", stderr);
420 }
421
422 let stdout =
424 String::from_utf8(output.stdout).context("Claude CLI output is not valid UTF-8")?;
425
426 #[derive(Deserialize)]
427 struct ClaudeCliResponse {
428 result: String,
429 }
430
431 let response: ClaudeCliResponse =
432 serde_json::from_str(&stdout).context("Failed to parse Claude CLI JSON response")?;
433
434 Ok(response.result)
435 }
436}