Skip to main content

uv_auth/
cache.rs

1use std::fmt::Display;
2use std::fmt::Formatter;
3use std::hash::BuildHasherDefault;
4use std::sync::Arc;
5use std::sync::RwLock;
6
7use rustc_hash::{FxHashMap, FxHasher};
8use tracing::trace;
9use url::Url;
10
11use uv_once_map::OnceMap;
12use uv_redacted::DisplaySafeUrl;
13
14use crate::credentials::{Authentication, Username};
15use crate::{Credentials, Realm};
16
17type FxOnceMap<K, V> = OnceMap<K, V, BuildHasherDefault<FxHasher>>;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
20pub(crate) enum FetchUrl {
21    /// A full index URL
22    Index(DisplaySafeUrl),
23    /// A realm URL
24    Realm(Realm),
25}
26
27impl Display for FetchUrl {
28    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
29        match self {
30            Self::Index(index) => Display::fmt(index, f),
31            Self::Realm(realm) => Display::fmt(realm, f),
32        }
33    }
34}
35
36#[derive(Debug)] // All internal types are redacted.
37pub struct CredentialsCache {
38    /// A cache per realm and username
39    realms: RwLock<FxHashMap<(Realm, Username), Arc<Authentication>>>,
40    /// A cache tracking the result of realm or index URL fetches from external services
41    pub(crate) fetches: FxOnceMap<(FetchUrl, Username), Option<Arc<Authentication>>>,
42    /// A cache per URL, uses a trie for efficient prefix queries.
43    urls: RwLock<UrlTrie<Arc<Authentication>>>,
44}
45
46impl Default for CredentialsCache {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl CredentialsCache {
53    /// Create a new cache.
54    pub fn new() -> Self {
55        Self {
56            fetches: FxOnceMap::default(),
57            realms: RwLock::new(FxHashMap::default()),
58            urls: RwLock::new(UrlTrie::new()),
59        }
60    }
61
62    /// Populate the global authentication store with credentials on a URL, if there are any.
63    ///
64    /// Returns `true` if the store was updated.
65    pub fn store_credentials_from_url(&self, url: &DisplaySafeUrl) -> bool {
66        if let Some(credentials) = Credentials::from_url(url) {
67            trace!("Caching credentials for {url}");
68            self.insert(url, Arc::new(Authentication::from(credentials)));
69            true
70        } else {
71            false
72        }
73    }
74
75    /// Populate the global authentication store with credentials on a URL, if there are any.
76    ///
77    /// Returns `true` if the store was updated.
78    pub fn store_credentials(&self, url: &DisplaySafeUrl, credentials: Credentials) {
79        trace!("Caching credentials for {url}");
80        self.insert(url, Arc::new(Authentication::from(credentials)));
81    }
82
83    /// Return the credentials that should be used for a realm and username, if any.
84    pub(crate) fn get_realm(
85        &self,
86        realm: Realm,
87        username: Username,
88    ) -> Option<Arc<Authentication>> {
89        let realms = self.realms.read().unwrap();
90        let given_username = username.is_some();
91        let key = (realm, username);
92
93        let Some(credentials) = realms.get(&key).cloned() else {
94            trace!(
95                "No credentials in cache for realm {}",
96                RealmUsername::from(key)
97            );
98            return None;
99        };
100
101        if given_username && credentials.password().is_none() {
102            // If given a username, don't return password-less credentials
103            trace!(
104                "No password in cache for realm {}",
105                RealmUsername::from(key)
106            );
107            return None;
108        }
109
110        trace!(
111            "Found cached credentials for realm {}",
112            RealmUsername::from(key)
113        );
114        Some(credentials)
115    }
116
117    /// Return the cached credentials for a URL and username, if any.
118    ///
119    /// Note we do not cache per username, but if a username is passed we will confirm that the
120    /// cached credentials have a username equal to the provided one — otherwise `None` is returned.
121    /// If multiple usernames are used per URL, the realm cache should be queried instead.
122    pub(crate) fn get_url(
123        &self,
124        url: &DisplaySafeUrl,
125        username: &Username,
126    ) -> Option<Arc<Authentication>> {
127        let urls = self.urls.read().unwrap();
128        let credentials = urls.get(url);
129        if let Some(credentials) = credentials {
130            if username.is_none() || username.as_deref() == credentials.username() {
131                if username.is_some() && credentials.password().is_none() {
132                    // If given a username, don't return password-less credentials
133                    trace!("No password in cache for URL {url}");
134                    return None;
135                }
136                trace!("Found cached credentials for URL {url}");
137                return Some(credentials.clone());
138            }
139        }
140        trace!("No credentials in cache for URL {url}");
141        None
142    }
143
144    /// Update the cache with the given credentials.
145    pub(crate) fn insert(&self, url: &DisplaySafeUrl, credentials: Arc<Authentication>) {
146        // Do not cache empty credentials
147        if credentials.is_empty() {
148            return;
149        }
150
151        // Insert an entry for requests including the username
152        let username = credentials.to_username();
153        if username.is_some() {
154            let realm = (Realm::from(url), username);
155            self.insert_realm(realm, &credentials);
156        }
157
158        // Insert an entry for requests with no username
159        self.insert_realm((Realm::from(url), Username::none()), &credentials);
160
161        // Insert an entry for the URL
162        let mut urls = self.urls.write().unwrap();
163        urls.insert(url, credentials);
164    }
165
166    /// Private interface to update a realm cache entry.
167    ///
168    /// Returns replaced credentials, if any.
169    fn insert_realm(
170        &self,
171        key: (Realm, Username),
172        credentials: &Arc<Authentication>,
173    ) -> Option<Arc<Authentication>> {
174        // Do not cache empty credentials
175        if credentials.is_empty() {
176            return None;
177        }
178
179        let mut realms = self.realms.write().unwrap();
180
181        // Always replace existing entries if we have a password or token
182        if credentials.is_authenticated() {
183            return realms.insert(key, credentials.clone());
184        }
185
186        // If we only have a username, add a new entry or replace an existing entry if it doesn't have a password
187        let existing = realms.get(&key);
188        if existing.is_none()
189            || existing.is_some_and(|credentials| credentials.password().is_none())
190        {
191            return realms.insert(key, credentials.clone());
192        }
193
194        None
195    }
196}
197
198#[derive(Debug)]
199struct UrlTrie<T> {
200    states: Vec<TrieState<T>>,
201}
202
203#[derive(Debug)]
204struct TrieState<T> {
205    children: Vec<(String, usize)>,
206    value: Option<T>,
207}
208
209impl<T> Default for TrieState<T> {
210    fn default() -> Self {
211        Self {
212            children: vec![],
213            value: None,
214        }
215    }
216}
217
218impl<T> UrlTrie<T> {
219    fn new() -> Self {
220        let mut trie = Self { states: vec![] };
221        trie.alloc();
222        trie
223    }
224
225    fn get(&self, url: &Url) -> Option<&T> {
226        let mut state = 0;
227        let realm = Realm::from(url).to_string();
228        for component in [realm.as_str()]
229            .into_iter()
230            .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
231        {
232            state = self.states[state].get(component)?;
233            if let Some(ref value) = self.states[state].value {
234                return Some(value);
235            }
236        }
237        self.states[state].value.as_ref()
238    }
239
240    fn insert(&mut self, url: &Url, value: T) {
241        let mut state = 0;
242        let realm = Realm::from(url).to_string();
243        for component in [realm.as_str()]
244            .into_iter()
245            .chain(url.path_segments().unwrap().filter(|item| !item.is_empty()))
246        {
247            match self.states[state].index(component) {
248                Ok(i) => state = self.states[state].children[i].1,
249                Err(i) => {
250                    let new_state = self.alloc();
251                    self.states[state]
252                        .children
253                        .insert(i, (component.to_string(), new_state));
254                    state = new_state;
255                }
256            }
257        }
258        self.states[state].value = Some(value);
259    }
260
261    fn alloc(&mut self) -> usize {
262        let id = self.states.len();
263        self.states.push(TrieState::default());
264        id
265    }
266}
267
268impl<T> TrieState<T> {
269    fn get(&self, component: &str) -> Option<usize> {
270        let i = self.index(component).ok()?;
271        Some(self.children[i].1)
272    }
273
274    fn index(&self, component: &str) -> Result<usize, usize> {
275        self.children
276            .binary_search_by(|(label, _)| label.as_str().cmp(component))
277    }
278}
279
280#[derive(Debug)]
281struct RealmUsername(Realm, Username);
282
283impl std::fmt::Display for RealmUsername {
284    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
285        let Self(realm, username) = self;
286        if let Some(username) = username.as_deref() {
287            write!(f, "{username}@{realm}")
288        } else {
289            write!(f, "{realm}")
290        }
291    }
292}
293
294impl From<(Realm, Username)> for RealmUsername {
295    fn from((realm, username): (Realm, Username)) -> Self {
296        Self(realm, username)
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use crate::Credentials;
303    use crate::credentials::Password;
304
305    use super::*;
306
307    #[test]
308    fn test_trie() {
309        let credentials1 =
310            Credentials::basic(Some("username1".to_string()), Some("password1".to_string()));
311        let credentials2 =
312            Credentials::basic(Some("username2".to_string()), Some("password2".to_string()));
313        let credentials3 =
314            Credentials::basic(Some("username3".to_string()), Some("password3".to_string()));
315        let credentials4 =
316            Credentials::basic(Some("username4".to_string()), Some("password4".to_string()));
317
318        let mut trie = UrlTrie::new();
319        trie.insert(
320            &Url::parse("https://burntsushi.net").unwrap(),
321            credentials1.clone(),
322        );
323        trie.insert(
324            &Url::parse("https://astral.sh").unwrap(),
325            credentials2.clone(),
326        );
327        trie.insert(
328            &Url::parse("https://example.com/foo").unwrap(),
329            credentials3.clone(),
330        );
331        trie.insert(
332            &Url::parse("https://example.com/bar").unwrap(),
333            credentials4.clone(),
334        );
335
336        let url = Url::parse("https://burntsushi.net/regex-internals").unwrap();
337        assert_eq!(trie.get(&url), Some(&credentials1));
338
339        let url = Url::parse("https://burntsushi.net/").unwrap();
340        assert_eq!(trie.get(&url), Some(&credentials1));
341
342        let url = Url::parse("https://astral.sh/about").unwrap();
343        assert_eq!(trie.get(&url), Some(&credentials2));
344
345        let url = Url::parse("https://example.com/foo").unwrap();
346        assert_eq!(trie.get(&url), Some(&credentials3));
347
348        let url = Url::parse("https://example.com/foo/").unwrap();
349        assert_eq!(trie.get(&url), Some(&credentials3));
350
351        let url = Url::parse("https://example.com/foo/bar").unwrap();
352        assert_eq!(trie.get(&url), Some(&credentials3));
353
354        let url = Url::parse("https://example.com/bar").unwrap();
355        assert_eq!(trie.get(&url), Some(&credentials4));
356
357        let url = Url::parse("https://example.com/bar/").unwrap();
358        assert_eq!(trie.get(&url), Some(&credentials4));
359
360        let url = Url::parse("https://example.com/bar/foo").unwrap();
361        assert_eq!(trie.get(&url), Some(&credentials4));
362
363        let url = Url::parse("https://example.com/about").unwrap();
364        assert_eq!(trie.get(&url), None);
365
366        let url = Url::parse("https://example.com/foobar").unwrap();
367        assert_eq!(trie.get(&url), None);
368    }
369
370    #[test]
371    fn test_url_with_credentials() {
372        let username = Username::new(Some(String::from("username")));
373        let password = Password::new(String::from("password"));
374        let credentials = Arc::new(Authentication::from(Credentials::Basic {
375            username: username.clone(),
376            password: Some(password),
377        }));
378        let cache = CredentialsCache::default();
379        // Insert with URL with credentials and get with redacted URL.
380        let url = DisplaySafeUrl::parse("https://username:password@example.com/foobar").unwrap();
381        cache.insert(&url, credentials.clone());
382        assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
383        // Insert with redacted URL and get with URL with credentials.
384        let url =
385            DisplaySafeUrl::parse("https://username:password@second-example.com/foobar").unwrap();
386        cache.insert(&url, credentials.clone());
387        assert_eq!(cache.get_url(&url, &username), Some(credentials.clone()));
388    }
389}