1use anyhow::{anyhow, Context};
22use serde::{Deserialize, Serialize};
23use std::time::Duration;
24
25#[derive(Debug, Clone, Copy, Default, PartialEq)]
29pub struct LlmUsage {
30 pub input_tokens: u64,
31 pub output_tokens: u64,
32 pub cost_usd: Option<f64>,
33}
34
35impl LlmUsage {
36 pub fn total_tokens(&self) -> u64 {
37 self.input_tokens + self.output_tokens
38 }
39
40 pub fn add(&mut self, other: LlmUsage) {
42 self.input_tokens += other.input_tokens;
43 self.output_tokens += other.output_tokens;
44 self.cost_usd = match (self.cost_usd, other.cost_usd) {
45 (Some(a), Some(b)) => Some(a + b),
46 (a, None) => a,
47 (None, b) => b,
48 };
49 }
50}
51
52pub trait LlmBackend: Send + Sync {
54 fn complete(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<String>;
55 fn name(&self) -> &'static str;
57 fn complete_usage(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<(String, LlmUsage)> {
61 Ok((self.complete(prompt, max_tokens)?, LlmUsage::default()))
62 }
63}
64
65pub fn backend_from_env(explicit: Option<&str>) -> anyhow::Result<Option<Box<dyn LlmBackend>>> {
72 let name = explicit
73 .map(str::to_string)
74 .or_else(|| std::env::var("TJ_BACKEND").ok())
75 .filter(|s| !s.trim().is_empty())
76 .unwrap_or_else(|| "claude-p".to_string());
77
78 match name.trim() {
79 "claude-p" | "claude" | "agent-sdk" => {
80 if crate::classifier::agent_sdk::claude_on_path() {
81 Ok(Some(Box::new(ClaudeCliBackend::from_env())))
82 } else {
83 Ok(None)
84 }
85 }
86 "anthropic" | "api" => match std::env::var("ANTHROPIC_API_KEY") {
87 Ok(key) if !key.is_empty() => Ok(Some(Box::new(AnthropicBackend::new(key)))),
88 _ => Ok(None),
89 },
90 "openai" | "codex" => match std::env::var("OPENAI_API_KEY") {
91 Ok(key) if !key.is_empty() => Ok(Some(Box::new(OpenAiBackend::openai(key)))),
92 _ => Ok(None),
93 },
94 "ollama" => Ok(Some(Box::new(OpenAiBackend::ollama()))),
95 other => Err(anyhow!(
96 "unknown backend '{other}' (expected: claude-p, anthropic, openai, ollama)"
97 )),
98 }
99}
100
101pub struct ClaudeCliBackend {
106 model: String,
107}
108
109impl ClaudeCliBackend {
110 pub fn from_env() -> Self {
111 let model = std::env::var("TJ_CONSOLIDATE_MODEL")
112 .unwrap_or_else(|_| crate::classifier::agent_sdk::DEFAULT_MODEL.to_string());
113 Self { model }
114 }
115}
116
117impl LlmBackend for ClaudeCliBackend {
118 fn complete(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<String> {
119 self.complete_usage(prompt, max_tokens).map(|(t, _)| t)
120 }
121 fn name(&self) -> &'static str {
122 "claude-p"
123 }
124 fn complete_usage(&self, prompt: &str, _max_tokens: u32) -> anyhow::Result<(String, LlmUsage)> {
125 crate::classifier::agent_sdk::run_claude_json_usage(
126 &crate::classifier::agent_sdk::ClaudeBinaryStdinRunner,
127 &self.model,
128 prompt,
129 )
130 }
131}
132
133pub struct AnthropicBackend {
138 api_key: String,
139 model: String,
140 base_url: String,
141 timeout: Duration,
142}
143
144impl AnthropicBackend {
145 pub fn new(api_key: String) -> Self {
146 let model = std::env::var("TJ_CONSOLIDATE_MODEL")
147 .unwrap_or_else(|_| "claude-haiku-4-5-20251001".to_string());
148 let base_url = std::env::var("TJ_CONSOLIDATE_BASE_URL")
149 .unwrap_or_else(|_| "https://api.anthropic.com".to_string());
150 Self {
151 api_key,
152 model,
153 base_url,
154 timeout: Duration::from_secs(60),
155 }
156 }
157}
158
159#[derive(Serialize)]
160struct AnthropicReq<'a> {
161 model: &'a str,
162 max_tokens: u32,
163 messages: Vec<AnthropicMsg<'a>>,
164}
165#[derive(Serialize)]
166struct AnthropicMsg<'a> {
167 role: &'a str,
168 content: &'a str,
169}
170#[derive(Deserialize)]
171struct AnthropicResp {
172 content: Vec<AnthropicBlock>,
173 #[serde(default)]
174 usage: AnthropicUsage,
175}
176#[derive(Deserialize, Default)]
177struct AnthropicUsage {
178 #[serde(default)]
179 input_tokens: u64,
180 #[serde(default)]
181 output_tokens: u64,
182}
183#[derive(Deserialize)]
184struct AnthropicBlock {
185 #[serde(rename = "type")]
186 kind: String,
187 #[serde(default)]
188 text: String,
189}
190
191impl LlmBackend for AnthropicBackend {
192 fn complete(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<String> {
193 self.complete_usage(prompt, max_tokens).map(|(t, _)| t)
194 }
195 fn name(&self) -> &'static str {
196 "anthropic"
197 }
198 fn complete_usage(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<(String, LlmUsage)> {
199 let body = AnthropicReq {
200 model: &self.model,
201 max_tokens,
202 messages: vec![AnthropicMsg {
203 role: "user",
204 content: prompt,
205 }],
206 };
207 let resp: AnthropicResp = ureq::post(&format!("{}/v1/messages", self.base_url))
208 .timeout(self.timeout)
209 .set("x-api-key", &self.api_key)
210 .set("anthropic-version", "2023-06-01")
211 .set("content-type", "application/json")
212 .send_json(serde_json::to_value(&body)?)
213 .context("Anthropic API request failed")?
214 .into_json()
215 .context("decode Anthropic response")?;
216 let usage = LlmUsage {
217 input_tokens: resp.usage.input_tokens,
218 output_tokens: resp.usage.output_tokens,
219 cost_usd: None,
220 };
221 let text = resp
222 .content
223 .iter()
224 .find(|b| b.kind == "text")
225 .map(|b| b.text.clone())
226 .ok_or_else(|| anyhow!("no text content in Anthropic response"))?;
227 Ok((text, usage))
228 }
229}
230
231pub struct OpenAiBackend {
236 api_key: Option<String>,
237 model: String,
238 base_url: String,
239 label: &'static str,
240 timeout: Duration,
241}
242
243impl OpenAiBackend {
244 pub fn openai(api_key: String) -> Self {
245 Self {
246 api_key: Some(api_key),
247 model: std::env::var("TJ_OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o-mini".to_string()),
248 base_url: std::env::var("TJ_OPENAI_BASE_URL")
249 .unwrap_or_else(|_| "https://api.openai.com".to_string()),
250 label: "openai",
251 timeout: Duration::from_secs(60),
252 }
253 }
254
255 pub fn ollama() -> Self {
256 Self {
257 api_key: None, model: std::env::var("TJ_OLLAMA_MODEL").unwrap_or_else(|_| "llama3.1".to_string()),
259 base_url: std::env::var("TJ_OLLAMA_URL")
260 .unwrap_or_else(|_| "http://localhost:11434".to_string()),
261 label: "ollama",
262 timeout: Duration::from_secs(120),
263 }
264 }
265}
266
267#[derive(Serialize)]
268struct OpenAiReq<'a> {
269 model: &'a str,
270 max_tokens: u32,
271 messages: Vec<AnthropicMsg<'a>>,
272}
273#[derive(Deserialize)]
274struct OpenAiResp {
275 choices: Vec<OpenAiChoice>,
276 #[serde(default)]
277 usage: OpenAiUsage,
278}
279#[derive(Deserialize, Default)]
280struct OpenAiUsage {
281 #[serde(default)]
282 prompt_tokens: u64,
283 #[serde(default)]
284 completion_tokens: u64,
285}
286#[derive(Deserialize)]
287struct OpenAiChoice {
288 message: OpenAiMsg,
289}
290#[derive(Deserialize)]
291struct OpenAiMsg {
292 #[serde(default)]
293 content: String,
294}
295
296impl LlmBackend for OpenAiBackend {
297 fn complete(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<String> {
298 self.complete_usage(prompt, max_tokens).map(|(t, _)| t)
299 }
300 fn complete_usage(&self, prompt: &str, max_tokens: u32) -> anyhow::Result<(String, LlmUsage)> {
301 let body = OpenAiReq {
302 model: &self.model,
303 max_tokens,
304 messages: vec![AnthropicMsg {
305 role: "user",
306 content: prompt,
307 }],
308 };
309 let mut req = ureq::post(&format!("{}/v1/chat/completions", self.base_url))
310 .timeout(self.timeout)
311 .set("content-type", "application/json");
312 if let Some(key) = &self.api_key {
313 req = req.set("authorization", &format!("Bearer {key}"));
314 }
315 let resp: OpenAiResp = req
316 .send_json(serde_json::to_value(&body)?)
317 .with_context(|| format!("{} request failed", self.label))?
318 .into_json()
319 .context("decode OpenAI-compatible response")?;
320 let usage = LlmUsage {
321 input_tokens: resp.usage.prompt_tokens,
322 output_tokens: resp.usage.completion_tokens,
323 cost_usd: None,
324 };
325 let text = resp
326 .choices
327 .into_iter()
328 .next()
329 .map(|c| c.message.content)
330 .ok_or_else(|| anyhow!("no choices in {} response", self.label))?;
331 Ok((text, usage))
332 }
333 fn name(&self) -> &'static str {
334 self.label
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 struct EnvGuard(&'static str, Option<String>);
343 impl EnvGuard {
344 fn set(k: &'static str, v: &str) -> Self {
345 let prev = std::env::var(k).ok();
346 std::env::set_var(k, v);
347 Self(k, prev)
348 }
349 fn unset(k: &'static str) -> Self {
350 let prev = std::env::var(k).ok();
351 std::env::remove_var(k);
352 Self(k, prev)
353 }
354 }
355 impl Drop for EnvGuard {
356 fn drop(&mut self) {
357 match &self.1 {
358 Some(v) => std::env::set_var(self.0, v),
359 None => std::env::remove_var(self.0),
360 }
361 }
362 }
363
364 static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
366
367 #[test]
368 fn unknown_backend_errors() {
369 let _l = ENV_LOCK.lock().unwrap();
370 assert!(backend_from_env(Some("nonsense")).is_err());
371 }
372
373 #[test]
374 fn anthropic_unavailable_without_key_is_none() {
375 let _l = ENV_LOCK.lock().unwrap();
376 let _g = EnvGuard::unset("ANTHROPIC_API_KEY");
377 assert!(backend_from_env(Some("anthropic")).unwrap().is_none());
378 }
379
380 #[test]
381 fn anthropic_with_key_resolves() {
382 let _l = ENV_LOCK.lock().unwrap();
383 let _g = EnvGuard::set("ANTHROPIC_API_KEY", "k");
384 let b = backend_from_env(Some("anthropic")).unwrap().unwrap();
385 assert_eq!(b.name(), "anthropic");
386 }
387
388 #[test]
389 fn ollama_always_resolves_no_key() {
390 let _l = ENV_LOCK.lock().unwrap();
391 let b = backend_from_env(Some("ollama")).unwrap().unwrap();
392 assert_eq!(b.name(), "ollama");
393 }
394
395 #[test]
396 fn openai_calls_chat_completions_and_parses() {
397 let mut server = mockito::Server::new();
398 let m = server
399 .mock("POST", "/v1/chat/completions")
400 .with_status(200)
401 .with_header("content-type", "application/json")
402 .with_body(
403 serde_json::json!({
404 "choices": [{"message": {"role": "assistant", "content": "hello from openai"}}]
405 })
406 .to_string(),
407 )
408 .create();
409 let b = OpenAiBackend {
410 api_key: Some("k".into()),
411 model: "gpt-4o-mini".into(),
412 base_url: server.url(),
413 label: "openai",
414 timeout: Duration::from_secs(5),
415 };
416 let out = b.complete("hi", 64).unwrap();
417 m.assert();
418 assert_eq!(out, "hello from openai");
419 }
420
421 #[test]
422 fn anthropic_calls_messages_and_parses() {
423 let mut server = mockito::Server::new();
424 let m = server
425 .mock("POST", "/v1/messages")
426 .with_status(200)
427 .with_header("content-type", "application/json")
428 .with_body(
429 serde_json::json!({
430 "content": [{"type": "text", "text": "hello from anthropic"}]
431 })
432 .to_string(),
433 )
434 .create();
435 let b = AnthropicBackend {
436 api_key: "k".into(),
437 model: "claude-haiku-4-5-20251001".into(),
438 base_url: server.url(),
439 timeout: Duration::from_secs(5),
440 };
441 let out = b.complete("hi", 64).unwrap();
442 m.assert();
443 assert_eq!(out, "hello from anthropic");
444 }
445}