Skip to main content

tj_core/classifier/
http.rs

1//! Anthropic API HTTP client implementing Classifier.
2
3use super::*;
4use anyhow::{anyhow, Context};
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7
8/// Default upper bound on a single classification round-trip. Hooks wrap calls
9/// in `|| true` so a timeout never breaks Claude Code, but without a bound the
10/// hook would still hang the chat turn.
11pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(15);
12
13/// Default model when `TJ_CLASSIFIER_MODEL` is not set.
14pub const DEFAULT_MODEL: &str = "claude-haiku-4-5-20251001";
15
16pub struct AnthropicClassifier {
17    pub api_key: String,
18    pub model: String,
19    pub base_url: String, // overridable for tests
20    pub timeout: Duration,
21}
22
23impl AnthropicClassifier {
24    pub fn from_env() -> anyhow::Result<Self> {
25        let api_key =
26            std::env::var("ANTHROPIC_API_KEY").context("ANTHROPIC_API_KEY env var not set")?;
27        let model = std::env::var("TJ_CLASSIFIER_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.into());
28        Ok(Self {
29            api_key,
30            model,
31            base_url: "https://api.anthropic.com".into(),
32            timeout: DEFAULT_TIMEOUT,
33        })
34    }
35}
36
37#[derive(Serialize)]
38struct MessagesRequest<'a> {
39    model: &'a str,
40    max_tokens: u32,
41    messages: Vec<MessageIn<'a>>,
42}
43#[derive(Serialize)]
44struct MessageIn<'a> {
45    role: &'a str,
46    content: &'a str,
47}
48#[derive(Deserialize)]
49struct MessagesResponse {
50    content: Vec<ContentBlock>,
51}
52#[derive(Deserialize)]
53struct ContentBlock {
54    #[serde(rename = "type")]
55    kind: String,
56    #[serde(default)]
57    text: String,
58}
59
60impl Classifier for AnthropicClassifier {
61    fn classify(&self, input: &ClassifyInput) -> anyhow::Result<ClassifyOutput> {
62        let prompt = crate::classifier::prompt::build(input);
63        let body = MessagesRequest {
64            model: &self.model,
65            max_tokens: 256,
66            messages: vec![MessageIn {
67                role: "user",
68                content: &prompt,
69            }],
70        };
71
72        let url = format!("{}/v1/messages", self.base_url);
73        let resp: MessagesResponse = ureq::post(&url)
74            .timeout(self.timeout)
75            .set("x-api-key", &self.api_key)
76            .set("anthropic-version", "2023-06-01")
77            .set("content-type", "application/json")
78            .send_json(serde_json::to_value(&body)?)
79            .context("Anthropic API request failed")?
80            .into_json()
81            .context("decode Anthropic response")?;
82
83        let text = resp
84            .content
85            .iter()
86            .find(|b| b.kind == "text")
87            .map(|b| b.text.clone())
88            .ok_or_else(|| anyhow!("no text content in response"))?;
89
90        let json_str = text
91            .trim()
92            .trim_start_matches("```json")
93            .trim_start_matches("```")
94            .trim_end_matches("```")
95            .trim();
96        let out: ClassifyOutput = serde_json::from_str(json_str)
97            .with_context(|| format!("classifier JSON parse failed; got: {json_str}"))?;
98        Ok(out)
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::event::EventType;
106
107    #[test]
108    fn classifier_parses_anthropic_response() {
109        let mut server = mockito::Server::new();
110        let url = server.url();
111
112        let body = serde_json::json!({
113            "id": "msg_test",
114            "type": "message",
115            "role": "assistant",
116            "model": "claude-haiku-4-5-20251001",
117            "content": [
118                { "type": "text", "text": "{\"event_type\":\"decision\",\"task_id_guess\":\"tj-x\",\"confidence\":0.93,\"evidence_strength\":null,\"suggested_text\":\"Adopt Rust.\"}" }
119            ],
120            "stop_reason": "end_turn"
121        });
122
123        let mock = server
124            .mock("POST", "/v1/messages")
125            .with_status(200)
126            .with_header("content-type", "application/json")
127            .with_body(body.to_string())
128            .create();
129
130        let c = AnthropicClassifier {
131            api_key: "test".into(),
132            model: "claude-haiku-4-5-20251001".into(),
133            base_url: url,
134            timeout: DEFAULT_TIMEOUT,
135        };
136        let out = c
137            .classify(&ClassifyInput {
138                text: "We adopted Rust.".into(),
139                author_hint: "assistant".into(),
140                recent_tasks: vec![],
141            })
142            .unwrap();
143
144        assert_eq!(out.event_type, EventType::Decision);
145        assert_eq!(out.task_id_guess.as_deref(), Some("tj-x"));
146        assert!((out.confidence - 0.93).abs() < 1e-6);
147        mock.assert();
148    }
149
150    #[test]
151    fn classifier_times_out_on_unresponsive_server() {
152        use std::net::TcpListener;
153        use std::time::Instant;
154
155        // Bind a TCP socket but never accept — the kernel completes the
156        // 3-way handshake from the backlog so connect() succeeds, but no
157        // bytes are ever read or written. Read timeout must fire.
158        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
159        let addr = listener.local_addr().unwrap();
160        let url = format!("http://{addr}");
161
162        let c = AnthropicClassifier {
163            api_key: "test".into(),
164            model: "test-model".into(),
165            base_url: url,
166            timeout: Duration::from_millis(300),
167        };
168
169        let start = Instant::now();
170        let res = c.classify(&ClassifyInput {
171            text: "x".into(),
172            author_hint: "user".into(),
173            recent_tasks: vec![],
174        });
175        let elapsed = start.elapsed();
176
177        assert!(res.is_err(), "expected a timeout error, got Ok");
178        assert!(
179            elapsed < Duration::from_secs(3),
180            "expected timeout near 300ms, got {elapsed:?}"
181        );
182
183        // Keep the listener alive until after the request to avoid races.
184        drop(listener);
185    }
186}