1use std::collections::VecDeque;
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use runtime::{
5 load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest,
6 OAuthTokenExchangeRequest,
7};
8use serde::Deserialize;
9
10use crate::error::ApiError;
11
12use super::{Provider, ProviderFuture};
13use crate::sse::SseParser;
14use crate::types::{MessageRequest, MessageResponse, StreamEvent};
15
16pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com";
17const ANTHROPIC_VERSION: &str = "2023-06-01";
18const REQUEST_ID_HEADER: &str = "request-id";
19const ALT_REQUEST_ID_HEADER: &str = "x-request-id";
20const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
21const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2);
22const DEFAULT_MAX_RETRIES: u32 = 2;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum AuthSource {
26 None,
27 ApiKey(String),
28 BearerToken(String),
29 ApiKeyAndBearer {
30 api_key: String,
31 bearer_token: String,
32 },
33}
34
35impl AuthSource {
36 pub fn from_env() -> Result<Self, ApiError> {
37 let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?;
38 let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?;
39 match (api_key, auth_token) {
40 (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer {
41 api_key,
42 bearer_token,
43 }),
44 (Some(api_key), None) => Ok(Self::ApiKey(api_key)),
45 (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)),
46 (None, None) => Err(ApiError::missing_credentials(
47 "Wraith",
48 &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
49 )),
50 }
51 }
52
53 #[must_use]
54 pub fn api_key(&self) -> Option<&str> {
55 match self {
56 Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key),
57 Self::None | Self::BearerToken(_) => None,
58 }
59 }
60
61 #[must_use]
62 pub fn bearer_token(&self) -> Option<&str> {
63 match self {
64 Self::BearerToken(token)
65 | Self::ApiKeyAndBearer {
66 bearer_token: token,
67 ..
68 } => Some(token),
69 Self::None | Self::ApiKey(_) => None,
70 }
71 }
72
73 #[must_use]
74 pub fn masked_authorization_header(&self) -> &'static str {
75 if self.bearer_token().is_some() {
76 "Bearer [REDACTED]"
77 } else {
78 "<absent>"
79 }
80 }
81
82 pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
83 if let Some(api_key) = self.api_key() {
84 request_builder = request_builder.header("x-api-key", api_key);
85 }
86 if let Some(token) = self.bearer_token() {
87 request_builder = request_builder.bearer_auth(token);
88 }
89 request_builder
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Deserialize)]
94pub struct OAuthTokenSet {
95 pub access_token: String,
96 pub refresh_token: Option<String>,
97 pub expires_at: Option<u64>,
98 #[serde(default)]
99 pub scopes: Vec<String>,
100}
101
102impl From<OAuthTokenSet> for AuthSource {
103 fn from(value: OAuthTokenSet) -> Self {
104 Self::BearerToken(value.access_token)
105 }
106}
107
108#[derive(Debug, Clone)]
109pub struct AnthropicClient {
110 http: reqwest::Client,
111 auth: AuthSource,
112 base_url: String,
113 max_retries: u32,
114 initial_backoff: Duration,
115 max_backoff: Duration,
116}
117
118impl AnthropicClient {
119 #[must_use]
120 pub fn new(api_key: impl Into<String>) -> Self {
121 Self {
122 http: reqwest::Client::new(),
123 auth: AuthSource::ApiKey(api_key.into()),
124 base_url: DEFAULT_BASE_URL.to_string(),
125 max_retries: DEFAULT_MAX_RETRIES,
126 initial_backoff: DEFAULT_INITIAL_BACKOFF,
127 max_backoff: DEFAULT_MAX_BACKOFF,
128 }
129 }
130
131 #[must_use]
132 pub fn from_auth(auth: AuthSource) -> Self {
133 Self {
134 http: reqwest::Client::new(),
135 auth,
136 base_url: DEFAULT_BASE_URL.to_string(),
137 max_retries: DEFAULT_MAX_RETRIES,
138 initial_backoff: DEFAULT_INITIAL_BACKOFF,
139 max_backoff: DEFAULT_MAX_BACKOFF,
140 }
141 }
142
143 pub fn from_env() -> Result<Self, ApiError> {
144 Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url()))
145 }
146
147 #[must_use]
148 pub fn with_auth_source(mut self, auth: AuthSource) -> Self {
149 self.auth = auth;
150 self
151 }
152
153 #[must_use]
154 pub fn with_auth_token(mut self, auth_token: Option<String>) -> Self {
155 match (
156 self.auth.api_key().map(ToOwned::to_owned),
157 auth_token.filter(|token| !token.is_empty()),
158 ) {
159 (Some(api_key), Some(bearer_token)) => {
160 self.auth = AuthSource::ApiKeyAndBearer {
161 api_key,
162 bearer_token,
163 };
164 }
165 (Some(api_key), None) => {
166 self.auth = AuthSource::ApiKey(api_key);
167 }
168 (None, Some(bearer_token)) => {
169 self.auth = AuthSource::BearerToken(bearer_token);
170 }
171 (None, None) => {
172 self.auth = AuthSource::None;
173 }
174 }
175 self
176 }
177
178 #[must_use]
179 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
180 self.base_url = base_url.into();
181 self
182 }
183
184 #[must_use]
185 pub fn with_retry_policy(
186 mut self,
187 max_retries: u32,
188 initial_backoff: Duration,
189 max_backoff: Duration,
190 ) -> Self {
191 self.max_retries = max_retries;
192 self.initial_backoff = initial_backoff;
193 self.max_backoff = max_backoff;
194 self
195 }
196
197 #[must_use]
198 pub fn auth_source(&self) -> &AuthSource {
199 &self.auth
200 }
201
202 pub async fn send_message(
203 &self,
204 request: &MessageRequest,
205 ) -> Result<MessageResponse, ApiError> {
206 let request = MessageRequest {
207 stream: false,
208 ..request.clone()
209 };
210 let response = self.send_with_retry(&request).await?;
211 let request_id = request_id_from_headers(response.headers());
212 let mut response = response
213 .json::<MessageResponse>()
214 .await
215 .map_err(ApiError::from)?;
216 if response.request_id.is_none() {
217 response.request_id = request_id;
218 }
219 Ok(response)
220 }
221
222 pub async fn stream_message(
223 &self,
224 request: &MessageRequest,
225 ) -> Result<MessageStream, ApiError> {
226 let response = self
227 .send_with_retry(&request.clone().with_streaming())
228 .await?;
229 Ok(MessageStream {
230 request_id: request_id_from_headers(response.headers()),
231 response,
232 parser: SseParser::new(),
233 pending: VecDeque::new(),
234 done: false,
235 })
236 }
237
238 pub async fn exchange_oauth_code(
239 &self,
240 config: &OAuthConfig,
241 request: &OAuthTokenExchangeRequest,
242 ) -> Result<OAuthTokenSet, ApiError> {
243 let response = self
244 .http
245 .post(&config.token_url)
246 .header("content-type", "application/x-www-form-urlencoded")
247 .form(&request.form_params())
248 .send()
249 .await
250 .map_err(ApiError::from)?;
251 let response = expect_success(response).await?;
252 response
253 .json::<OAuthTokenSet>()
254 .await
255 .map_err(ApiError::from)
256 }
257
258 pub async fn refresh_oauth_token(
259 &self,
260 config: &OAuthConfig,
261 request: &OAuthRefreshRequest,
262 ) -> Result<OAuthTokenSet, ApiError> {
263 let response = self
264 .http
265 .post(&config.token_url)
266 .header("content-type", "application/x-www-form-urlencoded")
267 .form(&request.form_params())
268 .send()
269 .await
270 .map_err(ApiError::from)?;
271 let response = expect_success(response).await?;
272 response
273 .json::<OAuthTokenSet>()
274 .await
275 .map_err(ApiError::from)
276 }
277
278 async fn send_with_retry(
279 &self,
280 request: &MessageRequest,
281 ) -> Result<reqwest::Response, ApiError> {
282 let mut attempts = 0;
283 let mut last_error: Option<ApiError>;
284
285 loop {
286 attempts += 1;
287 match self.send_raw_request(request).await {
288 Ok(response) => match expect_success(response).await {
289 Ok(response) => return Ok(response),
290 Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
291 last_error = Some(error);
292 }
293 Err(error) => return Err(error),
294 },
295 Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => {
296 last_error = Some(error);
297 }
298 Err(error) => return Err(error),
299 }
300
301 if attempts > self.max_retries {
302 break;
303 }
304
305 tokio::time::sleep(self.backoff_for_attempt(attempts)?).await;
306 }
307
308 Err(ApiError::RetriesExhausted {
309 attempts,
310 last_error: Box::new(last_error.expect("retry loop must capture an error")),
311 })
312 }
313
314 async fn send_raw_request(
315 &self,
316 request: &MessageRequest,
317 ) -> Result<reqwest::Response, ApiError> {
318 let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/'));
319 let request_builder = self
320 .http
321 .post(&request_url)
322 .header("anthropic-version", ANTHROPIC_VERSION)
323 .header("content-type", "application/json");
324 let mut request_builder = self.auth.apply(request_builder);
325
326 request_builder = request_builder.json(request);
327 request_builder.send().await.map_err(ApiError::from)
328 }
329
330 fn backoff_for_attempt(&self, attempt: u32) -> Result<Duration, ApiError> {
331 let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else {
332 return Err(ApiError::BackoffOverflow {
333 attempt,
334 base_delay: self.initial_backoff,
335 });
336 };
337 Ok(self
338 .initial_backoff
339 .checked_mul(multiplier)
340 .map_or(self.max_backoff, |delay| delay.min(self.max_backoff)))
341 }
342}
343
344impl AuthSource {
345 pub fn from_env_or_saved() -> Result<Self, ApiError> {
346 if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
347 return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
348 Some(bearer_token) => Ok(Self::ApiKeyAndBearer {
349 api_key,
350 bearer_token,
351 }),
352 None => Ok(Self::ApiKey(api_key)),
353 };
354 }
355 if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
356 return Ok(Self::BearerToken(bearer_token));
357 }
358 match load_saved_oauth_token() {
359 Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => {
360 if token_set.refresh_token.is_some() {
361 Err(ApiError::Auth(
362 "saved OAuth token is expired; load runtime OAuth config to refresh it"
363 .to_string(),
364 ))
365 } else {
366 Err(ApiError::ExpiredOAuthToken)
367 }
368 }
369 Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)),
370 Ok(None) => Err(ApiError::missing_credentials(
371 "Wraith",
372 &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
373 )),
374 Err(error) => Err(error),
375 }
376 }
377}
378
379#[must_use]
380pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool {
381 token_set
382 .expires_at
383 .is_some_and(|expires_at| expires_at <= now_unix_timestamp())
384}
385
386pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result<Option<OAuthTokenSet>, ApiError> {
387 let Some(token_set) = load_saved_oauth_token()? else {
388 return Ok(None);
389 };
390 resolve_saved_oauth_token_set(config, token_set).map(Some)
391}
392
393pub fn has_auth_from_env_or_saved() -> Result<bool, ApiError> {
394 Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some()
395 || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some()
396 || load_saved_oauth_token()?.is_some())
397}
398
399pub fn resolve_startup_auth_source<F>(load_oauth_config: F) -> Result<AuthSource, ApiError>
400where
401 F: FnOnce() -> Result<Option<OAuthConfig>, ApiError>,
402{
403 if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? {
404 return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
405 Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer {
406 api_key,
407 bearer_token,
408 }),
409 None => Ok(AuthSource::ApiKey(api_key)),
410 };
411 }
412 if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? {
413 return Ok(AuthSource::BearerToken(bearer_token));
414 }
415
416 let Some(token_set) = load_saved_oauth_token()? else {
417 return Err(ApiError::missing_credentials(
418 "Wraith",
419 &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
420 ));
421 };
422 if !oauth_token_is_expired(&token_set) {
423 return Ok(AuthSource::BearerToken(token_set.access_token));
424 }
425 if token_set.refresh_token.is_none() {
426 return Err(ApiError::ExpiredOAuthToken);
427 }
428
429 let Some(config) = load_oauth_config()? else {
430 return Err(ApiError::Auth(
431 "saved OAuth token is expired; runtime OAuth config is missing".to_string(),
432 ));
433 };
434 Ok(AuthSource::from(resolve_saved_oauth_token_set(
435 &config, token_set,
436 )?))
437}
438
439fn resolve_saved_oauth_token_set(
440 config: &OAuthConfig,
441 token_set: OAuthTokenSet,
442) -> Result<OAuthTokenSet, ApiError> {
443 if !oauth_token_is_expired(&token_set) {
444 return Ok(token_set);
445 }
446 let Some(refresh_token) = token_set.refresh_token.clone() else {
447 return Err(ApiError::ExpiredOAuthToken);
448 };
449 let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url());
450 let refreshed = client_runtime_block_on(async {
451 client
452 .refresh_oauth_token(
453 config,
454 &OAuthRefreshRequest::from_config(
455 config,
456 refresh_token,
457 Some(token_set.scopes.clone()),
458 ),
459 )
460 .await
461 })?;
462 let resolved = OAuthTokenSet {
463 access_token: refreshed.access_token,
464 refresh_token: refreshed.refresh_token.or(token_set.refresh_token),
465 expires_at: refreshed.expires_at,
466 scopes: refreshed.scopes,
467 };
468 save_oauth_credentials(&runtime::OAuthTokenSet {
469 access_token: resolved.access_token.clone(),
470 refresh_token: resolved.refresh_token.clone(),
471 expires_at: resolved.expires_at,
472 scopes: resolved.scopes.clone(),
473 })
474 .map_err(ApiError::from)?;
475 Ok(resolved)
476}
477
478fn client_runtime_block_on<F, T>(future: F) -> Result<T, ApiError>
479where
480 F: std::future::Future<Output = Result<T, ApiError>>,
481{
482 tokio::runtime::Runtime::new()
483 .map_err(ApiError::from)?
484 .block_on(future)
485}
486
487fn load_saved_oauth_token() -> Result<Option<OAuthTokenSet>, ApiError> {
488 let token_set = load_oauth_credentials().map_err(ApiError::from)?;
489 Ok(token_set.map(|token_set| OAuthTokenSet {
490 access_token: token_set.access_token,
491 refresh_token: token_set.refresh_token,
492 expires_at: token_set.expires_at,
493 scopes: token_set.scopes,
494 }))
495}
496
497fn now_unix_timestamp() -> u64 {
498 SystemTime::now()
499 .duration_since(UNIX_EPOCH)
500 .map_or(0, |duration| duration.as_secs())
501}
502
503fn read_env_non_empty(key: &str) -> Result<Option<String>, ApiError> {
504 match std::env::var(key) {
505 Ok(value) if !value.is_empty() => Ok(Some(value)),
506 Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None),
507 Err(error) => Err(ApiError::from(error)),
508 }
509}
510
511#[cfg(test)]
512fn read_api_key() -> Result<String, ApiError> {
513 let auth = AuthSource::from_env_or_saved()?;
514 auth.api_key()
515 .or_else(|| auth.bearer_token())
516 .map(ToOwned::to_owned)
517 .ok_or(ApiError::missing_credentials(
518 "Wraith",
519 &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"],
520 ))
521}
522
523#[cfg(test)]
524fn read_auth_token() -> Option<String> {
525 read_env_non_empty("ANTHROPIC_AUTH_TOKEN")
526 .ok()
527 .and_then(std::convert::identity)
528}
529
530#[must_use]
531pub fn read_base_url() -> String {
532 std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string())
533}
534
535fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option<String> {
536 headers
537 .get(REQUEST_ID_HEADER)
538 .or_else(|| headers.get(ALT_REQUEST_ID_HEADER))
539 .and_then(|value| value.to_str().ok())
540 .map(ToOwned::to_owned)
541}
542
543impl Provider for AnthropicClient {
544 type Stream = MessageStream;
545
546 fn send_message<'a>(
547 &'a self,
548 request: &'a MessageRequest,
549 ) -> ProviderFuture<'a, MessageResponse> {
550 Box::pin(async move { self.send_message(request).await })
551 }
552
553 fn stream_message<'a>(
554 &'a self,
555 request: &'a MessageRequest,
556 ) -> ProviderFuture<'a, Self::Stream> {
557 Box::pin(async move { self.stream_message(request).await })
558 }
559}
560
561#[derive(Debug)]
562pub struct MessageStream {
563 request_id: Option<String>,
564 response: reqwest::Response,
565 parser: SseParser,
566 pending: VecDeque<StreamEvent>,
567 done: bool,
568}
569
570impl MessageStream {
571 #[must_use]
572 pub fn request_id(&self) -> Option<&str> {
573 self.request_id.as_deref()
574 }
575
576 pub async fn next_event(&mut self) -> Result<Option<StreamEvent>, ApiError> {
577 loop {
578 if let Some(event) = self.pending.pop_front() {
579 return Ok(Some(event));
580 }
581
582 if self.done {
583 let remaining = self.parser.finish()?;
584 self.pending.extend(remaining);
585 if let Some(event) = self.pending.pop_front() {
586 return Ok(Some(event));
587 }
588 return Ok(None);
589 }
590
591 match self.response.chunk().await? {
592 Some(chunk) => {
593 self.pending.extend(self.parser.push(&chunk)?);
594 }
595 None => {
596 self.done = true;
597 }
598 }
599 }
600 }
601}
602
603async fn expect_success(response: reqwest::Response) -> Result<reqwest::Response, ApiError> {
604 let status = response.status();
605 if status.is_success() {
606 return Ok(response);
607 }
608
609 let body = response.text().await.unwrap_or_else(|_| String::new());
610 let parsed_error = serde_json::from_str::<ApiErrorEnvelope>(&body).ok();
611 let retryable = is_retryable_status(status);
612
613 Err(ApiError::Api {
614 status,
615 error_type: parsed_error
616 .as_ref()
617 .map(|error| error.error.error_type.clone()),
618 message: parsed_error
619 .as_ref()
620 .map(|error| error.error.message.clone()),
621 body,
622 retryable,
623 })
624}
625
626const fn is_retryable_status(status: reqwest::StatusCode) -> bool {
627 matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504)
628}
629
630#[derive(Debug, Deserialize)]
631struct ApiErrorEnvelope {
632 error: ApiErrorBody,
633}
634
635#[derive(Debug, Deserialize)]
636struct ApiErrorBody {
637 #[serde(rename = "type")]
638 error_type: String,
639 message: String,
640}
641
642#[cfg(test)]
643mod tests {
644 use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER};
645 use std::io::{Read, Write};
646 use std::net::TcpListener;
647 use std::sync::{Mutex, OnceLock};
648 use std::thread;
649 use std::time::{Duration, SystemTime, UNIX_EPOCH};
650
651 use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig};
652
653 use super::{
654 now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token,
655 resolve_startup_auth_source, AuthSource, AnthropicClient, OAuthTokenSet,
656 };
657 use crate::types::{ContentBlockDelta, MessageRequest};
658
659 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
660 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
661 LOCK.get_or_init(|| Mutex::new(()))
662 .lock()
663 .unwrap_or_else(std::sync::PoisonError::into_inner)
664 }
665
666 fn temp_config_home() -> std::path::PathBuf {
667 std::env::temp_dir().join(format!(
668 "api-oauth-test-{}-{}",
669 std::process::id(),
670 SystemTime::now()
671 .duration_since(UNIX_EPOCH)
672 .expect("time")
673 .as_nanos()
674 ))
675 }
676
677 fn cleanup_temp_config_home(config_home: &std::path::Path) {
678 match std::fs::remove_dir_all(config_home) {
679 Ok(()) => {}
680 Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
681 Err(error) => panic!("cleanup temp dir: {error}"),
682 }
683 }
684
685 fn sample_oauth_config(token_url: String) -> OAuthConfig {
686 OAuthConfig {
687 client_id: "runtime-client".to_string(),
688 authorize_url: "https://console.test/oauth/authorize".to_string(),
689 token_url,
690 callback_port: Some(4545),
691 manual_redirect_url: Some("https://console.test/oauth/callback".to_string()),
692 scopes: vec!["org:read".to_string(), "user:write".to_string()],
693 }
694 }
695
696 fn spawn_token_server(response_body: &'static str) -> String {
697 let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener");
698 let address = listener.local_addr().expect("local addr");
699 thread::spawn(move || {
700 let (mut stream, _) = listener.accept().expect("accept connection");
701 let mut buffer = [0_u8; 4096];
702 let _ = stream.read(&mut buffer).expect("read request");
703 let response = format!(
704 "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}",
705 response_body.len(),
706 response_body
707 );
708 stream
709 .write_all(response.as_bytes())
710 .expect("write response");
711 });
712 format!("http://{address}/oauth/token")
713 }
714
715 #[test]
716 fn read_api_key_requires_presence() {
717 let _guard = env_lock();
718 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
719 std::env::remove_var("ANTHROPIC_API_KEY");
720 std::env::remove_var("WRAITH_CONFIG_HOME");
721 let error = super::read_api_key().expect_err("missing key should error");
722 assert!(matches!(
723 error,
724 crate::error::ApiError::MissingCredentials { .. }
725 ));
726 }
727
728 #[test]
729 fn read_api_key_requires_non_empty_value() {
730 let _guard = env_lock();
731 std::env::set_var("ANTHROPIC_AUTH_TOKEN", "");
732 std::env::remove_var("ANTHROPIC_API_KEY");
733 let error = super::read_api_key().expect_err("empty key should error");
734 assert!(matches!(
735 error,
736 crate::error::ApiError::MissingCredentials { .. }
737 ));
738 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
739 }
740
741 #[test]
742 fn read_api_key_prefers_api_key_env() {
743 let _guard = env_lock();
744 std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
745 std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
746 assert_eq!(
747 super::read_api_key().expect("api key should load"),
748 "legacy-key"
749 );
750 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
751 std::env::remove_var("ANTHROPIC_API_KEY");
752 }
753
754 #[test]
755 fn read_auth_token_reads_auth_token_env() {
756 let _guard = env_lock();
757 std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
758 assert_eq!(super::read_auth_token().as_deref(), Some("auth-token"));
759 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
760 }
761
762 #[test]
763 fn oauth_token_maps_to_bearer_auth_source() {
764 let auth = AuthSource::from(OAuthTokenSet {
765 access_token: "access-token".to_string(),
766 refresh_token: Some("refresh".to_string()),
767 expires_at: Some(123),
768 scopes: vec!["scope:a".to_string()],
769 });
770 assert_eq!(auth.bearer_token(), Some("access-token"));
771 assert_eq!(auth.api_key(), None);
772 }
773
774 #[test]
775 fn auth_source_from_env_combines_api_key_and_bearer_token() {
776 let _guard = env_lock();
777 std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token");
778 std::env::set_var("ANTHROPIC_API_KEY", "legacy-key");
779 let auth = AuthSource::from_env().expect("env auth");
780 assert_eq!(auth.api_key(), Some("legacy-key"));
781 assert_eq!(auth.bearer_token(), Some("auth-token"));
782 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
783 std::env::remove_var("ANTHROPIC_API_KEY");
784 }
785
786 #[test]
787 fn auth_source_from_saved_oauth_when_env_absent() {
788 let _guard = env_lock();
789 let config_home = temp_config_home();
790 std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
791 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
792 std::env::remove_var("ANTHROPIC_API_KEY");
793 save_oauth_credentials(&runtime::OAuthTokenSet {
794 access_token: "saved-access-token".to_string(),
795 refresh_token: Some("refresh".to_string()),
796 expires_at: Some(now_unix_timestamp() + 300),
797 scopes: vec!["scope:a".to_string()],
798 })
799 .expect("save oauth credentials");
800
801 let auth = AuthSource::from_env_or_saved().expect("saved auth");
802 assert_eq!(auth.bearer_token(), Some("saved-access-token"));
803
804 clear_oauth_credentials().expect("clear credentials");
805 std::env::remove_var("WRAITH_CONFIG_HOME");
806 cleanup_temp_config_home(&config_home);
807 }
808
809 #[test]
810 fn oauth_token_expiry_uses_expires_at_timestamp() {
811 assert!(oauth_token_is_expired(&OAuthTokenSet {
812 access_token: "access-token".to_string(),
813 refresh_token: None,
814 expires_at: Some(1),
815 scopes: Vec::new(),
816 }));
817 assert!(!oauth_token_is_expired(&OAuthTokenSet {
818 access_token: "access-token".to_string(),
819 refresh_token: None,
820 expires_at: Some(now_unix_timestamp() + 60),
821 scopes: Vec::new(),
822 }));
823 }
824
825 #[test]
826 fn resolve_saved_oauth_token_refreshes_expired_credentials() {
827 let _guard = env_lock();
828 let config_home = temp_config_home();
829 std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
830 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
831 std::env::remove_var("ANTHROPIC_API_KEY");
832 save_oauth_credentials(&runtime::OAuthTokenSet {
833 access_token: "expired-access-token".to_string(),
834 refresh_token: Some("refresh-token".to_string()),
835 expires_at: Some(1),
836 scopes: vec!["scope:a".to_string()],
837 })
838 .expect("save expired oauth credentials");
839
840 let token_url = spawn_token_server(
841 "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
842 );
843 let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
844 .expect("resolve refreshed token")
845 .expect("token set present");
846 assert_eq!(resolved.access_token, "refreshed-token");
847 let stored = runtime::load_oauth_credentials()
848 .expect("load stored credentials")
849 .expect("stored token set");
850 assert_eq!(stored.access_token, "refreshed-token");
851
852 clear_oauth_credentials().expect("clear credentials");
853 std::env::remove_var("WRAITH_CONFIG_HOME");
854 cleanup_temp_config_home(&config_home);
855 }
856
857 #[test]
858 fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() {
859 let _guard = env_lock();
860 let config_home = temp_config_home();
861 std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
862 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
863 std::env::remove_var("ANTHROPIC_API_KEY");
864 save_oauth_credentials(&runtime::OAuthTokenSet {
865 access_token: "saved-access-token".to_string(),
866 refresh_token: Some("refresh".to_string()),
867 expires_at: Some(now_unix_timestamp() + 300),
868 scopes: vec!["scope:a".to_string()],
869 })
870 .expect("save oauth credentials");
871
872 let auth = resolve_startup_auth_source(|| panic!("config should not be loaded"))
873 .expect("startup auth");
874 assert_eq!(auth.bearer_token(), Some("saved-access-token"));
875
876 clear_oauth_credentials().expect("clear credentials");
877 std::env::remove_var("WRAITH_CONFIG_HOME");
878 cleanup_temp_config_home(&config_home);
879 }
880
881 #[test]
882 fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() {
883 let _guard = env_lock();
884 let config_home = temp_config_home();
885 std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
886 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
887 std::env::remove_var("ANTHROPIC_API_KEY");
888 save_oauth_credentials(&runtime::OAuthTokenSet {
889 access_token: "expired-access-token".to_string(),
890 refresh_token: Some("refresh-token".to_string()),
891 expires_at: Some(1),
892 scopes: vec!["scope:a".to_string()],
893 })
894 .expect("save expired oauth credentials");
895
896 let error =
897 resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error");
898 assert!(
899 matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing"))
900 );
901
902 let stored = runtime::load_oauth_credentials()
903 .expect("load stored credentials")
904 .expect("stored token set");
905 assert_eq!(stored.access_token, "expired-access-token");
906 assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
907
908 clear_oauth_credentials().expect("clear credentials");
909 std::env::remove_var("WRAITH_CONFIG_HOME");
910 cleanup_temp_config_home(&config_home);
911 }
912
913 #[test]
914 fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() {
915 let _guard = env_lock();
916 let config_home = temp_config_home();
917 std::env::set_var("WRAITH_CONFIG_HOME", &config_home);
918 std::env::remove_var("ANTHROPIC_AUTH_TOKEN");
919 std::env::remove_var("ANTHROPIC_API_KEY");
920 save_oauth_credentials(&runtime::OAuthTokenSet {
921 access_token: "expired-access-token".to_string(),
922 refresh_token: Some("refresh-token".to_string()),
923 expires_at: Some(1),
924 scopes: vec!["scope:a".to_string()],
925 })
926 .expect("save expired oauth credentials");
927
928 let token_url = spawn_token_server(
929 "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}",
930 );
931 let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url))
932 .expect("resolve refreshed token")
933 .expect("token set present");
934 assert_eq!(resolved.access_token, "refreshed-token");
935 assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token"));
936 let stored = runtime::load_oauth_credentials()
937 .expect("load stored credentials")
938 .expect("stored token set");
939 assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
940
941 clear_oauth_credentials().expect("clear credentials");
942 std::env::remove_var("WRAITH_CONFIG_HOME");
943 cleanup_temp_config_home(&config_home);
944 }
945
946 #[test]
947 fn message_request_stream_helper_sets_stream_true() {
948 let request = MessageRequest {
949 model: "claude-opus-4-6".to_string(),
950 max_tokens: 64,
951 messages: vec![],
952 system: None,
953 tools: None,
954 tool_choice: None,
955 stream: false,
956 };
957
958 assert!(request.with_streaming().stream);
959 }
960
961 #[test]
962 fn backoff_doubles_until_maximum() {
963 let client = AnthropicClient::new("test-key").with_retry_policy(
964 3,
965 Duration::from_millis(10),
966 Duration::from_millis(25),
967 );
968 assert_eq!(
969 client.backoff_for_attempt(1).expect("attempt 1"),
970 Duration::from_millis(10)
971 );
972 assert_eq!(
973 client.backoff_for_attempt(2).expect("attempt 2"),
974 Duration::from_millis(20)
975 );
976 assert_eq!(
977 client.backoff_for_attempt(3).expect("attempt 3"),
978 Duration::from_millis(25)
979 );
980 }
981
982 #[test]
983 fn retryable_statuses_are_detected() {
984 assert!(super::is_retryable_status(
985 reqwest::StatusCode::TOO_MANY_REQUESTS
986 ));
987 assert!(super::is_retryable_status(
988 reqwest::StatusCode::INTERNAL_SERVER_ERROR
989 ));
990 assert!(!super::is_retryable_status(
991 reqwest::StatusCode::UNAUTHORIZED
992 ));
993 }
994
995 #[test]
996 fn tool_delta_variant_round_trips() {
997 let delta = ContentBlockDelta::InputJsonDelta {
998 partial_json: "{\"city\":\"Paris\"}".to_string(),
999 };
1000 let encoded = serde_json::to_string(&delta).expect("delta should serialize");
1001 let decoded: ContentBlockDelta =
1002 serde_json::from_str(&encoded).expect("delta should deserialize");
1003 assert_eq!(decoded, delta);
1004 }
1005
1006 #[test]
1007 fn request_id_uses_primary_or_fallback_header() {
1008 let mut headers = reqwest::header::HeaderMap::new();
1009 headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header"));
1010 assert_eq!(
1011 super::request_id_from_headers(&headers).as_deref(),
1012 Some("req_primary")
1013 );
1014
1015 headers.clear();
1016 headers.insert(
1017 ALT_REQUEST_ID_HEADER,
1018 "req_fallback".parse().expect("header"),
1019 );
1020 assert_eq!(
1021 super::request_id_from_headers(&headers).as_deref(),
1022 Some("req_fallback")
1023 );
1024 }
1025
1026 #[test]
1027 fn auth_source_applies_headers() {
1028 let auth = AuthSource::ApiKeyAndBearer {
1029 api_key: "test-key".to_string(),
1030 bearer_token: "proxy-token".to_string(),
1031 };
1032 let request = auth
1033 .apply(reqwest::Client::new().post("https://example.test"))
1034 .build()
1035 .expect("request build");
1036 let headers = request.headers();
1037 assert_eq!(
1038 headers.get("x-api-key").and_then(|v| v.to_str().ok()),
1039 Some("test-key")
1040 );
1041 assert_eq!(
1042 headers.get("authorization").and_then(|v| v.to_str().ok()),
1043 Some("Bearer proxy-token")
1044 );
1045 }
1046}