1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::hash::BuildHasherDefault;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use rustc_hash::{FxHashMap, FxHasher};
8use tracing::trace;
9use url::Url;
10
11use uv_once_map::OnceMap;
12use uv_redacted::DisplaySafeUrl;
13
14use crate::credentials::{Authentication, Username};
15use crate::{Credentials, Realm};
16
17type FxOnceMap<K, V> = OnceMap<K, V, BuildHasherDefault<FxHasher>>;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub(crate) enum FetchUrl {
21 Index(DisplaySafeUrl),
23 Realm(Realm),
25}
26
27impl Display for FetchUrl {
28 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
29 match self {
30 Self::Index(index) => Display::fmt(index, f),
31 Self::Realm(realm) => Display::fmt(realm, f),
32 }
33 }
34}
35
36#[derive(Debug)] pub struct CredentialsCache {
38 realms: RwLock<FxHashMap<(Realm, Username), Arc<Authentication>>>,
40 pub(crate) fetches: FxOnceMap<(FetchUrl, Username), Option<Arc<Authentication>>>,
42 urls: RwLock<UrlTrie<Arc<Authentication>>>,
44}
45
46impl Default for CredentialsCache {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl CredentialsCache {
53 pub fn new() -> Self {
55 Self {
56 fetches: FxOnceMap::default(),
57 realms: RwLock::new(FxHashMap::default()),
58 urls: RwLock::new(UrlTrie::new()),
59 }
60 }
61
62 pub fn store_credentials_from_url(&self, url: &DisplaySafeUrl) -> bool {
66 if let Some(credentials) = Credentials::from_url(url) {
67 trace!("Caching credentials for {url}");
68 self.insert(url, Arc::new(Authentication::from(credentials)));
69 true
70 } else {
71 false
72 }
73 }
74
75 pub fn store_credentials(&self, url: &DisplaySafeUrl, credentials: Credentials) {
79 trace!("Caching credentials for {url}");
80 self.insert(url, Arc::new(Authentication::from(credentials)));
81 }
82
83 pub(crate) fn get_realm(
85 &self,
86 realm: Realm,
87 username: Username,
88 ) -> Option<Arc<Authentication>> {
89 let realms = self.realms.read().unwrap();
90 let given_username = username.is_some();
91 let key = (realm, username);
92
93 let Some(credentials) = realms.get(&key).cloned() else {
94 trace!(
95 "No credentials in cache for realm {}",
96 RealmUsername::from(key)
97 );
98 return None;
99 };
100
101 if given_username && credentials.password().is_none() {
102 trace!(
104 "No password in cache for realm {}",
105 RealmUsername::from(key)
106 );
107 return None;
108 }
109
110 trace!(
111 "Found cached credentials for realm {}",
112 RealmUsername::from(key)
113 );
114 Some(credentials)
115 }
116
117 pub(crate) fn get_url(
123 &self,
124 url: &DisplaySafeUrl,
125 username: &Username,
126 ) -> Option<Arc<Authentication>> {
127 let urls = self.urls.read().unwrap();
128 let credentials = urls.get(url);
129 if let Some(credentials) = credentials {
130 if username.is_none() || username.as_deref() == credentials.username() {
131 if username.is_some() && credentials.password().is_none() {
132 trace!("No password in cache for URL {url}");
134 return None;
135 }
136 trace!("Found cached credentials for URL {url}");
137 return Some(credentials.clone());
138 }
139 }
140 trace!("No credentials in cache for URL {url}");
141 None
142 }
143
144 pub(crate) fn insert(&self, url: &DisplaySafeUrl, credentials: Arc<Authentication>) {
146 if credentials.is_empty() {
148 return;
149 }
150
151 let username = credentials.to_username();
153 if username.is_some() {
154 let realm = (Realm::from(url), username);
155 self.insert_realm(realm, &credentials);
156 }
157
158 self.insert_realm((Realm::from(url), Username::none()), &credentials);
160
161 let mut urls = self.urls.write().unwrap();
163 urls.insert(url, credentials);
164 }
165
166 fn insert_realm(
170 &self,
171 key: (Realm, Username),
172 credentials: &Arc<Authentication>,
173 ) -> Option<Arc<Authentication>> {
174 if credentials.is_empty() {
176 return None;
177 }
178
179 let mut realms = self.realms.write().unwrap();
180
181 if credentials.is_authenticated() {
183 return realms.insert(key, credentials.clone());
184 }
185
186 let existing = realms.get(&key);
188 if existing.is_none()
189 || existing.is_some_and(|credentials| credentials.password().is_none())
190 {
191 return realms.insert(key, credentials.clone());
192 }
193
194 None
195 }
196}
197
198#[derive(Debug)]
199struct UrlTrie<T> {
200 states: Vec<TrieState<T>>,
201}
202
203#[derive(Debug)]
204struct TrieState<T> {
205 children: Vec<(String, usize)>,
206 value: Option<T>,
207}
208
209impl<T> Default for TrieState<T> {
210 fn default() -> Self {
211 Self {
212 children: vec![],
213 value: None,
214 }
215 }
216}
217
218impl<T> UrlTrie<T> {
219 fn new() -> Self {
220 let mut trie = Self { states: vec![] };
221 trie.alloc();
222 trie
223 }
224
225 fn get(&self, url: &Url) -> Option<&T> {
226 let mut state = 0;
227 let realm = Realm::from(url).to_string();
228 for component in [realm.as_str()]
229 .into_iter()
230 .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
231 {
232 state = self.states[state].get(component)?;
233 if let Some(ref value) = self.states[state].value {
234 return Some(value);
235 }
236 }
237 self.states[state].value.as_ref()
238 }
239
240 fn insert(&mut self, url: &Url, value: T) {
241 let mut state = 0;
242 let realm = Realm::from(url).to_string();
243 for component in [realm.as_str()]
244 .into_iter()
245 .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
246 {
247 match self.states[state].index(component) {
248 Ok(i) => state = self.states[state].children[i].1,
249 Err(i) => {
250 let new_state = self.alloc();
251 self.states[state]
252 .children
253 .insert(i, (component.to_string(), new_state));
254 state = new_state;
255 }
256 }
257 }
258 self.states[state].value = Some(value);
259 }
260
261 fn alloc(&mut self) -> usize {
262 let id = self.states.len();
263 self.states.push(TrieState::default());
264 id
265 }
266}
267
268impl<T> TrieState<T> {
269 fn get(&self, component: &str) -> Option<usize> {
270 let i = self.index(component).ok()?;
271 Some(self.children[i].1)
272 }
273
274 fn index(&self, component: &str) -> Result<usize, usize> {
275 self.children
276 .binary_search_by(|(label, _)| label.as_str().cmp(component))
277 }
278}
279
280#[derive(Debug)]
281struct RealmUsername(Realm, Username);
282
283impl std::fmt::Display for RealmUsername {
284 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
285 let Self(realm, username) = self;
286 if let Some(username) = username.as_deref() {
287 write!(f, "{username}@{realm}")
288 } else {
289 write!(f, "{realm}")
290 }
291 }
292}
293
294impl From<(Realm, Username)> for RealmUsername {
295 fn from((realm, username): (Realm, Username)) -> Self {
296 Self(realm, username)
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use crate::Credentials;
303 use crate::credentials::Password;
304
305 use super::*;
306
307 #[test]
308 fn test_trie() {
309 let credentials1 =
310 Credentials::basic(Some("username1".to_string()), Some("password1".to_string()));
311 let credentials2 =
312 Credentials::basic(Some("username2".to_string()), Some("password2".to_string()));
313 let credentials3 =
314 Credentials::basic(Some("username3".to_string()), Some("password3".to_string()));
315 let credentials4 =
316 Credentials::basic(Some("username4".to_string()), Some("password4".to_string()));
317
318 let mut trie = UrlTrie::new();
319 trie.insert(
320 &Url::parse("https://burntsushi.net").unwrap(),
321 credentials1.clone(),
322 );
323 trie.insert(
324 &Url::parse("https://astral.sh").unwrap(),
325 credentials2.clone(),
326 );
327 trie.insert(
328 &Url::parse("https://example.com/foo").unwrap(),
329 credentials3.clone(),
330 );
331 trie.insert(
332 &Url::parse("https://example.com/bar").unwrap(),
333 credentials4.clone(),
334 );
335
336 let url = Url::parse("https://burntsushi.net/regex-internals").unwrap();
337 assert_eq!(trie.get(&url), Some(&credentials1));
338
339 let url = Url::parse("https://burntsushi.net/").unwrap();
340 assert_eq!(trie.get(&url), Some(&credentials1));
341
342 let url = Url::parse("https://astral.sh/about").unwrap();
343 assert_eq!(trie.get(&url), Some(&credentials2));
344
345 let url = Url::parse("https://example.com/foo").unwrap();
346 assert_eq!(trie.get(&url), Some(&credentials3));
347
348 let url = Url::parse("https://example.com/foo/").unwrap();
349 assert_eq!(trie.get(&url), Some(&credentials3));
350
351 let url = Url::parse("https://example.com/foo/bar").unwrap();
352 assert_eq!(trie.get(&url), Some(&credentials3));
353
354 let url = Url::parse("https://example.com/bar").unwrap();
355 assert_eq!(trie.get(&url), Some(&credentials4));
356
357 let url = Url::parse("https://example.com/bar/").unwrap();
358 assert_eq!(trie.get(&url), Some(&credentials4));
359
360 let url = Url::parse("https://example.com/bar/foo").unwrap();
361 assert_eq!(trie.get(&url), Some(&credentials4));
362
363 let url = Url::parse("https://example.com/about").unwrap();
364 assert_eq!(trie.get(&url), None);
365
366 let url = Url::parse("https://example.com/foobar").unwrap();
367 assert_eq!(trie.get(&url), None);
368 }
369
370 #[test]
371 fn test_url_with_credentials() {
372 let username = Username::new(Some(String::from("username")));
373 let password = Password::new(String::from("password"));
374 let credentials = Arc::new(Authentication::from(Credentials::Basic {
375 username: username.clone(),
376 password: Some(password),
377 }));
378 let cache = CredentialsCache::default();
379 let url = DisplaySafeUrl::parse("https://username:password@example.com/foobar").unwrap();
381 cache.insert(&url, credentials.clone());
382 assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
383 let url =
385 DisplaySafeUrl::parse("https://username:password@second-example.com/foobar").unwrap();
386 cache.insert(&url, credentials.clone());
387 assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
388 }
389}