Skip to main content

uv_auth/
cache.rs

1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::hash::BuildHasherDefault;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use rustc_hash::{FxHashMap, FxHasher};
8use tracing::trace;
9use url::Url;
10
11use uv_once_map::OnceMap;
12use uv_redacted::DisplaySafeUrl;
13
14use crate::credentials::{Authentication, CredentialsFromUrlError, Username};
15use crate::{Credentials, Realm};
16
17type FxOnceMap<K, V> = OnceMap<K, V, BuildHasherDefault<FxHasher>>;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub(crate) enum FetchUrl {
21    /// A full index URL
22    Index(DisplaySafeUrl),
23    /// A realm URL
24    Realm(Realm),
25}
26
27impl Display for FetchUrl {
28    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
29        match self {
30            Self::Index(index) => Display::fmt(index, f),
31            Self::Realm(realm) => Display::fmt(realm, f),
32        }
33    }
34}
35
36#[derive(Debug)] // All internal types are redacted.
37pub struct CredentialsCache {
38    /// A cache per realm and username
39    realms: RwLock<FxHashMap<(Realm, Username), Arc<Authentication>>>,
40    /// A cache tracking the result of realm or index URL fetches from external services
41    pub(crate) fetches: FxOnceMap<(FetchUrl, Username), Option<Arc<Authentication>>>,
42    /// A cache per URL, uses a trie for efficient prefix queries.
43    urls: RwLock<UrlTrie<Arc<Authentication>>>,
44}
45
46impl Default for CredentialsCache {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl CredentialsCache {
53    /// Create a new cache.
54    pub fn new() -> Self {
55        Self {
56            fetches: FxOnceMap::default(),
57            realms: RwLock::new(FxHashMap::default()),
58            urls: RwLock::new(UrlTrie::new()),
59        }
60    }
61
62    /// Populate the global authentication store with credentials on a URL, if there are any.
63    ///
64    /// Returns `true` if the store was updated.
65    pub fn store_credentials_from_url(
66        &self,
67        url: &DisplaySafeUrl,
68    ) -> Result<bool, CredentialsFromUrlError> {
69        if let Some(credentials) = Credentials::from_url(url)? {
70            trace!("Caching credentials for {url}");
71            self.insert(url, Arc::new(Authentication::from(credentials)));
72            Ok(true)
73        } else {
74            Ok(false)
75        }
76    }
77
78    /// Populate the global authentication store with credentials on a URL, if there are any.
79    ///
80    /// Returns `true` if the store was updated.
81    pub fn store_credentials(&self, url: &DisplaySafeUrl, credentials: Credentials) {
82        trace!("Caching credentials for {url}");
83        self.insert(url, Arc::new(Authentication::from(credentials)));
84    }
85
86    /// Return the credentials that should be used for a realm and username, if any.
87    pub(crate) fn get_realm(
88        &self,
89        realm: Realm,
90        username: Username,
91    ) -> Option<Arc<Authentication>> {
92        let realms = self.realms.read().unwrap();
93        let given_username = username.is_some();
94        let key = (realm, username);
95
96        let Some(credentials) = realms.get(&key).cloned() else {
97            trace!(
98                "No credentials in cache for realm {}",
99                RealmUsername::from(key)
100            );
101            return None;
102        };
103
104        if given_username && credentials.password().is_none() {
105            // If given a username, don't return password-less credentials
106            trace!(
107                "No password in cache for realm {}",
108                RealmUsername::from(key)
109            );
110            return None;
111        }
112
113        trace!(
114            "Found cached credentials for realm {}",
115            RealmUsername::from(key)
116        );
117        Some(credentials)
118    }
119
120    /// Return the cached credentials for a URL and username, if any.
121    ///
122    /// Note we do not cache per username, but if a username is passed we will confirm that the
123    /// cached credentials have a username equal to the provided one — otherwise `None` is returned.
124    /// If multiple usernames are used per URL, the realm cache should be queried instead.
125    pub(crate) fn get_url(
126        &self,
127        url: &DisplaySafeUrl,
128        username: &Username,
129    ) -> Option<Arc<Authentication>> {
130        let urls = self.urls.read().unwrap();
131        let credentials = urls.get(url);
132        if let Some(credentials) = credentials {
133            if username.is_none() || username.as_deref() == credentials.username() {
134                if username.is_some() && credentials.password().is_none() {
135                    // If given a username, don't return password-less credentials
136                    trace!("No password in cache for URL {url}");
137                    return None;
138                }
139                trace!("Found cached credentials for URL {url}");
140                return Some(credentials.clone());
141            }
142        }
143        trace!("No credentials in cache for URL {url}");
144        None
145    }
146
147    /// Update the cache with the given credentials.
148    pub(crate) fn insert(&self, url: &DisplaySafeUrl, credentials: Arc<Authentication>) {
149        // Do not cache empty credentials
150        if credentials.is_empty() {
151            return;
152        }
153
154        // Insert an entry for requests including the username
155        let username = credentials.to_username();
156        if username.is_some() {
157            let realm = (Realm::from(url), username);
158            self.insert_realm(realm, &credentials);
159        }
160
161        // Insert an entry for requests with no username
162        self.insert_realm((Realm::from(url), Username::none()), &credentials);
163
164        // Insert an entry for the URL
165        let mut urls = self.urls.write().unwrap();
166        urls.insert(url, credentials);
167    }
168
169    /// Private interface to update a realm cache entry.
170    ///
171    /// Returns replaced credentials, if any.
172    fn insert_realm(
173        &self,
174        key: (Realm, Username),
175        credentials: &Arc<Authentication>,
176    ) -> Option<Arc<Authentication>> {
177        // Do not cache empty credentials
178        if credentials.is_empty() {
179            return None;
180        }
181
182        let mut realms = self.realms.write().unwrap();
183
184        // Always replace existing entries if we have a password or token
185        if credentials.is_authenticated() {
186            return realms.insert(key, credentials.clone());
187        }
188
189        // If we only have a username, add a new entry or replace an existing entry if it doesn't have a password
190        let existing = realms.get(&key);
191        if existing.is_none_or(|credentials| credentials.password().is_none()) {
192            return realms.insert(key, credentials.clone());
193        }
194
195        None
196    }
197}
198
199#[derive(Debug)]
200struct UrlTrie<T> {
201    states: Vec<TrieState<T>>,
202}
203
204#[derive(Debug)]
205struct TrieState<T> {
206    children: Vec<(String, usize)>,
207    value: Option<T>,
208}
209
210impl<T> Default for TrieState<T> {
211    fn default() -> Self {
212        Self {
213            children: vec![],
214            value: None,
215        }
216    }
217}
218
219impl<T> UrlTrie<T> {
220    fn new() -> Self {
221        let mut trie = Self { states: vec![] };
222        trie.alloc();
223        trie
224    }
225
226    fn get(&self, url: &Url) -> Option<&T> {
227        let mut state = 0;
228        let realm = Realm::from(url).to_string();
229        for component in [realm.as_str()]
230            .into_iter()
231            .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
232        {
233            state = self.states[state].get(component)?;
234            if let Some(ref value) = self.states[state].value {
235                return Some(value);
236            }
237        }
238        self.states[state].value.as_ref()
239    }
240
241    fn insert(&mut self, url: &Url, value: T) {
242        let mut state = 0;
243        let realm = Realm::from(url).to_string();
244        for component in [realm.as_str()]
245            .into_iter()
246            .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
247        {
248            match self.states[state].index(component) {
249                Ok(i) => state = self.states[state].children[i].1,
250                Err(i) => {
251                    let new_state = self.alloc();
252                    self.states[state]
253                        .children
254                        .insert(i, (component.to_string(), new_state));
255                    state = new_state;
256                }
257            }
258        }
259        self.states[state].value = Some(value);
260    }
261
262    fn alloc(&mut self) -> usize {
263        let id = self.states.len();
264        self.states.push(TrieState::default());
265        id
266    }
267}
268
269impl<T> TrieState<T> {
270    fn get(&self, component: &str) -> Option<usize> {
271        let i = self.index(component).ok()?;
272        Some(self.children[i].1)
273    }
274
275    fn index(&self, component: &str) -> Result<usize, usize> {
276        self.children
277            .binary_search_by(|(label, _)| label.as_str().cmp(component))
278    }
279}
280
281#[derive(Debug)]
282struct RealmUsername(Realm, Username);
283
284impl std::fmt::Display for RealmUsername {
285    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
286        let Self(realm, username) = self;
287        if let Some(username) = username.as_deref() {
288            write!(f, "{username}@{realm}")
289        } else {
290            write!(f, "{realm}")
291        }
292    }
293}
294
295impl From<(Realm, Username)> for RealmUsername {
296    fn from((realm, username): (Realm, Username)) -> Self {
297        Self(realm, username)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use crate::Credentials;
304    use crate::credentials::Password;
305
306    use super::*;
307
308    #[test]
309    fn test_trie() {
310        let credentials1 =
311            Credentials::basic(Some("username1".to_string()), Some("password1".to_string()));
312        let credentials2 =
313            Credentials::basic(Some("username2".to_string()), Some("password2".to_string()));
314        let credentials3 =
315            Credentials::basic(Some("username3".to_string()), Some("password3".to_string()));
316        let credentials4 =
317            Credentials::basic(Some("username4".to_string()), Some("password4".to_string()));
318
319        let mut trie = UrlTrie::new();
320        trie.insert(
321            &Url::parse("https://burntsushi.net").unwrap(),
322            credentials1.clone(),
323        );
324        trie.insert(
325            &Url::parse("https://astral.sh").unwrap(),
326            credentials2.clone(),
327        );
328        trie.insert(
329            &Url::parse("https://example.com/foo").unwrap(),
330            credentials3.clone(),
331        );
332        trie.insert(
333            &Url::parse("https://example.com/bar").unwrap(),
334            credentials4.clone(),
335        );
336
337        let url = Url::parse("https://burntsushi.net/regex-internals").unwrap();
338        assert_eq!(trie.get(&url), Some(&credentials1));
339
340        let url = Url::parse("https://burntsushi.net/").unwrap();
341        assert_eq!(trie.get(&url), Some(&credentials1));
342
343        let url = Url::parse("https://astral.sh/about").unwrap();
344        assert_eq!(trie.get(&url), Some(&credentials2));
345
346        let url = Url::parse("https://example.com/foo").unwrap();
347        assert_eq!(trie.get(&url), Some(&credentials3));
348
349        let url = Url::parse("https://example.com/foo/").unwrap();
350        assert_eq!(trie.get(&url), Some(&credentials3));
351
352        let url = Url::parse("https://example.com/foo/bar").unwrap();
353        assert_eq!(trie.get(&url), Some(&credentials3));
354
355        let url = Url::parse("https://example.com/bar").unwrap();
356        assert_eq!(trie.get(&url), Some(&credentials4));
357
358        let url = Url::parse("https://example.com/bar/").unwrap();
359        assert_eq!(trie.get(&url), Some(&credentials4));
360
361        let url = Url::parse("https://example.com/bar/foo").unwrap();
362        assert_eq!(trie.get(&url), Some(&credentials4));
363
364        let url = Url::parse("https://example.com/about").unwrap();
365        assert_eq!(trie.get(&url), None);
366
367        let url = Url::parse("https://example.com/foobar").unwrap();
368        assert_eq!(trie.get(&url), None);
369    }
370
371    #[test]
372    fn test_url_with_credentials() {
373        let username = Username::new(Some(String::from("username")));
374        let password = Password::new(String::from("password"));
375        let credentials = Arc::new(Authentication::from(Credentials::Basic {
376            username: username.clone(),
377            password: Some(password),
378        }));
379        let cache = CredentialsCache::default();
380        // Insert with URL with credentials and get with redacted URL.
381        let url = DisplaySafeUrl::parse("https://username:password@example.com/foobar").unwrap();
382        cache.insert(&url, credentials.clone());
383        assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
384        // Insert with redacted URL and get with URL with credentials.
385        let url =
386            DisplaySafeUrl::parse("https://username:password@second-example.com/foobar").unwrap();
387        cache.insert(&url, credentials.clone());
388        assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
389    }
390}