sqlrite_ask/provider/
anthropic.rs1use serde::{Deserialize, Serialize};
33
34use super::{Provider, Request, Response, Usage};
35use crate::AskError;
36use crate::prompt::{SystemBlock, UserMessage};
37
38const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
39const ANTHROPIC_VERSION: &str = "2023-06-01";
40const MESSAGES_PATH: &str = "/v1/messages";
41
42pub struct AnthropicProvider {
47 api_key: String,
48 base_url: String,
49 agent: ureq::Agent,
50}
51
52impl AnthropicProvider {
53 pub fn new(api_key: impl Into<String>) -> Self {
55 Self::with_base_url(api_key, DEFAULT_BASE_URL)
56 }
57
58 pub fn with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
63 let agent = ureq::AgentBuilder::new()
64 .timeout_connect(std::time::Duration::from_secs(10))
65 .timeout(std::time::Duration::from_secs(90))
71 .build();
72 Self {
73 api_key: api_key.into(),
74 base_url: base_url.into(),
75 agent,
76 }
77 }
78}
79
80#[derive(Serialize)]
81struct MessagesRequestBody<'a> {
82 model: &'a str,
83 max_tokens: u32,
84 system: &'a [SystemBlock],
85 messages: &'a [UserMessage],
86}
87
88#[derive(Deserialize)]
89struct MessagesResponseBody {
90 content: Vec<ContentBlock>,
91 #[serde(default)]
92 usage: ResponseUsage,
93}
94
95#[derive(Deserialize)]
96struct ContentBlock {
97 #[serde(rename = "type")]
98 kind: String,
99 #[serde(default)]
100 text: String,
101}
102
103#[derive(Deserialize, Default)]
104struct ResponseUsage {
105 #[serde(default)]
106 input_tokens: u64,
107 #[serde(default)]
108 output_tokens: u64,
109 #[serde(default)]
110 cache_creation_input_tokens: u64,
111 #[serde(default)]
112 cache_read_input_tokens: u64,
113}
114
115#[derive(Deserialize)]
116struct ApiErrorBody {
117 error: ApiErrorInner,
118}
119
120#[derive(Deserialize)]
121struct ApiErrorInner {
122 #[serde(rename = "type")]
123 kind: String,
124 message: String,
125}
126
127impl Provider for AnthropicProvider {
128 fn complete(&self, req: Request<'_>) -> Result<Response, AskError> {
129 let body = MessagesRequestBody {
130 model: req.model,
131 max_tokens: req.max_tokens,
132 system: req.system,
133 messages: req.messages,
134 };
135
136 let url = format!("{}{}", self.base_url, MESSAGES_PATH);
137
138 let result = self
143 .agent
144 .post(&url)
145 .set("x-api-key", &self.api_key)
146 .set("anthropic-version", ANTHROPIC_VERSION)
147 .set("content-type", "application/json")
148 .send_json(serde_json::to_value(&body).map_err(AskError::Json)?);
149
150 let resp = match result {
151 Ok(r) => r,
152 Err(ureq::Error::Status(code, response)) => {
153 let body_text = response
154 .into_string()
155 .unwrap_or_else(|_| "<unreadable response body>".to_string());
156 let detail = serde_json::from_str::<ApiErrorBody>(&body_text)
157 .map(|e| format!("{}: {}", e.error.kind, e.error.message))
158 .unwrap_or_else(|_| body_text);
159 return Err(AskError::ApiStatus {
160 status: code,
161 detail,
162 });
163 }
164 Err(ureq::Error::Transport(t)) => {
165 return Err(AskError::Http(t.to_string()));
166 }
167 };
168
169 let parsed: MessagesResponseBody = resp
170 .into_json()
171 .map_err(|e| AskError::Http(e.to_string()))?;
172
173 let text = parsed
180 .content
181 .iter()
182 .filter(|b| b.kind == "text")
183 .map(|b| b.text.as_str())
184 .collect::<Vec<_>>()
185 .join("");
186
187 if text.is_empty() {
188 return Err(AskError::EmptyResponse);
189 }
190
191 Ok(Response {
192 text,
193 usage: Usage {
194 input_tokens: parsed.usage.input_tokens,
195 output_tokens: parsed.usage.output_tokens,
196 cache_creation_input_tokens: parsed.usage.cache_creation_input_tokens,
197 cache_read_input_tokens: parsed.usage.cache_read_input_tokens,
198 },
199 })
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use crate::prompt::{CacheControl, UserMessage, build_system};
207
208 #[test]
209 fn request_body_serializes_to_expected_shape() {
210 let system = build_system(
214 "CREATE TABLE users (id INTEGER PRIMARY KEY);\n",
215 Some(CacheControl::ephemeral()),
216 );
217 let messages = vec![UserMessage::new("count users")];
218 let body = MessagesRequestBody {
219 model: "claude-sonnet-4-6",
220 max_tokens: 1024,
221 system: &system,
222 messages: &messages,
223 };
224 let json = serde_json::to_value(&body).unwrap();
225 assert_eq!(json["model"], "claude-sonnet-4-6");
226 assert_eq!(json["max_tokens"], 1024);
227 assert_eq!(json["system"][0]["type"], "text");
228 assert_eq!(json["system"][1]["cache_control"]["type"], "ephemeral");
229 assert_eq!(json["messages"][0]["role"], "user");
230 assert_eq!(json["messages"][0]["content"], "count users");
231 }
232}