1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::hash::BuildHasherDefault;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use rustc_hash::{FxHashMap, FxHasher};
8use tracing::trace;
9use url::Url;
10
11use uv_once_map::OnceMap;
12use uv_redacted::DisplaySafeUrl;
13
14use crate::credentials::{Authentication, CredentialsFromUrlError, Username};
15use crate::{Credentials, Realm};
16
17type FxOnceMap<K, V> = OnceMap<K, V, BuildHasherDefault<FxHasher>>;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub(crate) enum FetchUrl {
21 Index(DisplaySafeUrl),
23 Realm(Realm),
25}
26
27impl Display for FetchUrl {
28 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
29 match self {
30 Self::Index(index) => Display::fmt(index, f),
31 Self::Realm(realm) => Display::fmt(realm, f),
32 }
33 }
34}
35
36#[derive(Debug)] pub struct CredentialsCache {
38 realms: RwLock<FxHashMap<(Realm, Username), Arc<Authentication>>>,
40 pub(crate) fetches: FxOnceMap<(FetchUrl, Username), Option<Arc<Authentication>>>,
42 urls: RwLock<UrlTrie<Arc<Authentication>>>,
44}
45
46impl Default for CredentialsCache {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl CredentialsCache {
53 pub fn new() -> Self {
55 Self {
56 fetches: FxOnceMap::default(),
57 realms: RwLock::new(FxHashMap::default()),
58 urls: RwLock::new(UrlTrie::new()),
59 }
60 }
61
62 pub fn store_credentials_from_url(
66 &self,
67 url: &DisplaySafeUrl,
68 ) -> Result<bool, CredentialsFromUrlError> {
69 if let Some(credentials) = Credentials::from_url(url)? {
70 trace!("Caching credentials for {url}");
71 self.insert(url, Arc::new(Authentication::from(credentials)));
72 Ok(true)
73 } else {
74 Ok(false)
75 }
76 }
77
78 pub fn store_credentials(&self, url: &DisplaySafeUrl, credentials: Credentials) {
82 trace!("Caching credentials for {url}");
83 self.insert(url, Arc::new(Authentication::from(credentials)));
84 }
85
86 pub(crate) fn get_realm(
88 &self,
89 realm: Realm,
90 username: Username,
91 ) -> Option<Arc<Authentication>> {
92 let realms = self.realms.read().unwrap();
93 let given_username = username.is_some();
94 let key = (realm, username);
95
96 let Some(credentials) = realms.get(&key).cloned() else {
97 trace!(
98 "No credentials in cache for realm {}",
99 RealmUsername::from(key)
100 );
101 return None;
102 };
103
104 if given_username && credentials.password().is_none() {
105 trace!(
107 "No password in cache for realm {}",
108 RealmUsername::from(key)
109 );
110 return None;
111 }
112
113 trace!(
114 "Found cached credentials for realm {}",
115 RealmUsername::from(key)
116 );
117 Some(credentials)
118 }
119
120 pub(crate) fn get_url(
126 &self,
127 url: &DisplaySafeUrl,
128 username: &Username,
129 ) -> Option<Arc<Authentication>> {
130 let urls = self.urls.read().unwrap();
131 let credentials = urls.get(url);
132 if let Some(credentials) = credentials {
133 if username.is_none() || username.as_deref() == credentials.username() {
134 if username.is_some() && credentials.password().is_none() {
135 trace!("No password in cache for URL {url}");
137 return None;
138 }
139 trace!("Found cached credentials for URL {url}");
140 return Some(credentials.clone());
141 }
142 }
143 trace!("No credentials in cache for URL {url}");
144 None
145 }
146
147 pub(crate) fn insert(&self, url: &DisplaySafeUrl, credentials: Arc<Authentication>) {
149 if credentials.is_empty() {
151 return;
152 }
153
154 let username = credentials.to_username();
156 if username.is_some() {
157 let realm = (Realm::from(url), username);
158 self.insert_realm(realm, &credentials);
159 }
160
161 self.insert_realm((Realm::from(url), Username::none()), &credentials);
163
164 let mut urls = self.urls.write().unwrap();
166 urls.insert(url, credentials);
167 }
168
169 fn insert_realm(
173 &self,
174 key: (Realm, Username),
175 credentials: &Arc<Authentication>,
176 ) -> Option<Arc<Authentication>> {
177 if credentials.is_empty() {
179 return None;
180 }
181
182 let mut realms = self.realms.write().unwrap();
183
184 if credentials.is_authenticated() {
186 return realms.insert(key, credentials.clone());
187 }
188
189 let existing = realms.get(&key);
191 if existing.is_none_or(|credentials| credentials.password().is_none()) {
192 return realms.insert(key, credentials.clone());
193 }
194
195 None
196 }
197}
198
199#[derive(Debug)]
200struct UrlTrie<T> {
201 states: Vec<TrieState<T>>,
202}
203
204#[derive(Debug)]
205struct TrieState<T> {
206 children: Vec<(String, usize)>,
207 value: Option<T>,
208}
209
210impl<T> Default for TrieState<T> {
211 fn default() -> Self {
212 Self {
213 children: vec![],
214 value: None,
215 }
216 }
217}
218
219impl<T> UrlTrie<T> {
220 fn new() -> Self {
221 let mut trie = Self { states: vec![] };
222 trie.alloc();
223 trie
224 }
225
226 fn get(&self, url: &Url) -> Option<&T> {
227 let mut state = 0;
228 let realm = Realm::from(url).to_string();
229 for component in [realm.as_str()]
230 .into_iter()
231 .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
232 {
233 state = self.states[state].get(component)?;
234 if let Some(ref value) = self.states[state].value {
235 return Some(value);
236 }
237 }
238 self.states[state].value.as_ref()
239 }
240
241 fn insert(&mut self, url: &Url, value: T) {
242 let mut state = 0;
243 let realm = Realm::from(url).to_string();
244 for component in [realm.as_str()]
245 .into_iter()
246 .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
247 {
248 match self.states[state].index(component) {
249 Ok(i) => state = self.states[state].children[i].1,
250 Err(i) => {
251 let new_state = self.alloc();
252 self.states[state]
253 .children
254 .insert(i, (component.to_string(), new_state));
255 state = new_state;
256 }
257 }
258 }
259 self.states[state].value = Some(value);
260 }
261
262 fn alloc(&mut self) -> usize {
263 let id = self.states.len();
264 self.states.push(TrieState::default());
265 id
266 }
267}
268
269impl<T> TrieState<T> {
270 fn get(&self, component: &str) -> Option<usize> {
271 let i = self.index(component).ok()?;
272 Some(self.children[i].1)
273 }
274
275 fn index(&self, component: &str) -> Result<usize, usize> {
276 self.children
277 .binary_search_by(|(label, _)| label.as_str().cmp(component))
278 }
279}
280
281#[derive(Debug)]
282struct RealmUsername(Realm, Username);
283
284impl std::fmt::Display for RealmUsername {
285 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286 let Self(realm, username) = self;
287 if let Some(username) = username.as_deref() {
288 write!(f, "{username}@{realm}")
289 } else {
290 write!(f, "{realm}")
291 }
292 }
293}
294
295impl From<(Realm, Username)> for RealmUsername {
296 fn from((realm, username): (Realm, Username)) -> Self {
297 Self(realm, username)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use crate::Credentials;
304 use crate::credentials::Password;
305
306 use super::*;
307
308 #[test]
309 fn test_trie() {
310 let credentials1 =
311 Credentials::basic(Some("username1".to_string()), Some("password1".to_string()));
312 let credentials2 =
313 Credentials::basic(Some("username2".to_string()), Some("password2".to_string()));
314 let credentials3 =
315 Credentials::basic(Some("username3".to_string()), Some("password3".to_string()));
316 let credentials4 =
317 Credentials::basic(Some("username4".to_string()), Some("password4".to_string()));
318
319 let mut trie = UrlTrie::new();
320 trie.insert(
321 &Url::parse("https://burntsushi.net").unwrap(),
322 credentials1.clone(),
323 );
324 trie.insert(
325 &Url::parse("https://astral.sh").unwrap(),
326 credentials2.clone(),
327 );
328 trie.insert(
329 &Url::parse("https://example.com/foo").unwrap(),
330 credentials3.clone(),
331 );
332 trie.insert(
333 &Url::parse("https://example.com/bar").unwrap(),
334 credentials4.clone(),
335 );
336
337 let url = Url::parse("https://burntsushi.net/regex-internals").unwrap();
338 assert_eq!(trie.get(&url), Some(&credentials1));
339
340 let url = Url::parse("https://burntsushi.net/").unwrap();
341 assert_eq!(trie.get(&url), Some(&credentials1));
342
343 let url = Url::parse("https://astral.sh/about").unwrap();
344 assert_eq!(trie.get(&url), Some(&credentials2));
345
346 let url = Url::parse("https://example.com/foo").unwrap();
347 assert_eq!(trie.get(&url), Some(&credentials3));
348
349 let url = Url::parse("https://example.com/foo/").unwrap();
350 assert_eq!(trie.get(&url), Some(&credentials3));
351
352 let url = Url::parse("https://example.com/foo/bar").unwrap();
353 assert_eq!(trie.get(&url), Some(&credentials3));
354
355 let url = Url::parse("https://example.com/bar").unwrap();
356 assert_eq!(trie.get(&url), Some(&credentials4));
357
358 let url = Url::parse("https://example.com/bar/").unwrap();
359 assert_eq!(trie.get(&url), Some(&credentials4));
360
361 let url = Url::parse("https://example.com/bar/foo").unwrap();
362 assert_eq!(trie.get(&url), Some(&credentials4));
363
364 let url = Url::parse("https://example.com/about").unwrap();
365 assert_eq!(trie.get(&url), None);
366
367 let url = Url::parse("https://example.com/foobar").unwrap();
368 assert_eq!(trie.get(&url), None);
369 }
370
371 #[test]
372 fn test_url_with_credentials() {
373 let username = Username::new(Some(String::from("username")));
374 let password = Password::new(String::from("password"));
375 let credentials = Arc::new(Authentication::from(Credentials::Basic {
376 username: username.clone(),
377 password: Some(password),
378 }));
379 let cache = CredentialsCache::default();
380 let url = DisplaySafeUrl::parse("https://username:password@example.com/foobar").unwrap();
382 cache.insert(&url, credentials.clone());
383 assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
384 let url =
386 DisplaySafeUrl::parse("https://username:password@second-example.com/foobar").unwrap();
387 cache.insert(&url, credentials.clone());
388 assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
389 }
390}