shunt/credential.rs
1//! Credential abstraction — supports OAuth (with refresh) and static API keys.
2//!
3//! All provider-specific auth is gated behind this enum so the rest of the
4//! codebase stays credential-type-agnostic.
5
6use serde::{Deserialize, Serialize};
7
8use crate::oauth::OAuthCredential;
9
10// ---------------------------------------------------------------------------
11// Credential enum
12// ---------------------------------------------------------------------------
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(tag = "type", rename_all = "lowercase")]
16pub enum Credential {
17 /// OAuth credential with access + refresh tokens and an expiry.
18 /// Used by Anthropic (claude.ai) and OpenAI (chatgpt.com) accounts.
19 Oauth(OAuthCredential),
20 /// Static API key — no expiry, no refresh.
21 /// Used by Groq, Mistral, OpenRouter, Gemini, Ollama Cloud, etc.
22 Apikey { key: String },
23}
24
25impl Credential {
26 /// The bearer token to send in `Authorization: Bearer <token>`.
27 ///
28 /// For OAuth accounts: prefers `id_token` over `access_token` when
29 /// present (required by chatgpt.com / Codex). Falls back to
30 /// `access_token` for standard Anthropic OAuth.
31 ///
32 /// For API-key accounts: returns the raw key directly.
33 pub fn bearer_token(&self) -> &str {
34 match self {
35 Credential::Oauth(c) => c.id_token.as_deref().unwrap_or(&c.access_token),
36 Credential::Apikey { key } => key,
37 }
38 }
39
40 /// The raw `access_token` string.
41 ///
42 /// Used when you need the access_token specifically (e.g. token-rotation
43 /// comparison in the 401 handler, Anthropic auth headers).
44 ///
45 /// For API-key accounts returns the key (same as `bearer_token`).
46 pub fn access_token(&self) -> &str {
47 match self {
48 Credential::Oauth(c) => &c.access_token,
49 Credential::Apikey { key } => key,
50 }
51 }
52
53 /// True if the credential should be refreshed before use.
54 /// Always false for API-key credentials.
55 pub fn needs_refresh(&self) -> bool {
56 match self {
57 Credential::Oauth(c) => c.needs_refresh(),
58 Credential::Apikey { .. } => false,
59 }
60 }
61
62 /// Account email, if known. None for API-key credentials.
63 pub fn email(&self) -> Option<&str> {
64 match self {
65 Credential::Oauth(c) => c.email.as_deref(),
66 Credential::Apikey { .. } => None,
67 }
68 }
69
70 /// True when a refresh_token is available to attempt recovery.
71 /// Always false for API-key credentials.
72 pub fn has_refresh_token(&self) -> bool {
73 match self {
74 Credential::Oauth(c) => !c.refresh_token.is_empty(),
75 Credential::Apikey { .. } => false,
76 }
77 }
78
79 /// Borrow the inner OAuthCredential, if this is an OAuth credential.
80 pub fn as_oauth(&self) -> Option<&OAuthCredential> {
81 match self {
82 Credential::Oauth(c) => Some(c),
83 Credential::Apikey { .. } => None,
84 }
85 }
86
87 /// Mutably borrow the inner OAuthCredential.
88 pub fn as_oauth_mut(&mut self) -> Option<&mut OAuthCredential> {
89 match self {
90 Credential::Oauth(c) => Some(c),
91 Credential::Apikey { .. } => None,
92 }
93 }
94
95 /// Display string for status/monitor output.
96 /// Shows email for OAuth accounts, masked key for API-key accounts.
97 pub fn masked_display(&self) -> String {
98 match self {
99 Credential::Oauth(c) => c.email.clone().unwrap_or_else(|| "oauth".to_owned()),
100 Credential::Apikey { key } => {
101 let suffix = &key[key.len().saturating_sub(4)..];
102 format!("···{suffix}")
103 }
104 }
105 }
106}
107
108impl From<OAuthCredential> for Credential {
109 fn from(c: OAuthCredential) -> Self {
110 Credential::Oauth(c)
111 }
112}
113
114// ---------------------------------------------------------------------------
115// Backwards-compatible deserialization for CredentialsStore
116// ---------------------------------------------------------------------------
117
118/// Deserialize a `HashMap<String, Credential>` that may contain old-format
119/// entries (written before the `"type"` tag was introduced).
120///
121/// Old format: `{ "access_token": "...", "refresh_token": "...", ... }`
122/// New format: `{ "type": "oauth", "access_token": "...", ... }`
123/// `{ "type": "apikey", "key": "..." }`
124pub fn deserialize_credential_map<'de, D>(
125 deserializer: D,
126) -> Result<std::collections::HashMap<String, Credential>, D::Error>
127where
128 D: serde::Deserializer<'de>,
129{
130 use std::collections::HashMap;
131 let raw: HashMap<String, serde_json::Value> = HashMap::deserialize(deserializer)?;
132 let mut out = HashMap::with_capacity(raw.len());
133 for (k, v) in raw {
134 let cred = if v.get("type").is_some() {
135 // New tagged format — deserialize directly.
136 serde_json::from_value::<Credential>(v).map_err(serde::de::Error::custom)?
137 } else {
138 // Legacy format — treat as OAuth.
139 serde_json::from_value::<OAuthCredential>(v)
140 .map(Credential::Oauth)
141 .map_err(serde::de::Error::custom)?
142 };
143 out.insert(k, cred);
144 }
145 Ok(out)
146}