watermelon_proto/headers/
map.rs

1use alloc::{
2    collections::{BTreeMap, btree_map::Entry},
3    vec,
4    vec::Vec,
5};
6use core::{
7    fmt::{self, Debug},
8    mem,
9};
10
11use super::{HeaderName, HeaderValue};
12
13static EMPTY_HEADERS: OneOrMany = OneOrMany::Many(Vec::new());
14
15/// A set of NATS headers
16///
17/// [`HeaderMap`] is a multimap of [`HeaderName`].
18#[derive(Clone, PartialEq, Eq)]
19pub struct HeaderMap {
20    headers: BTreeMap<HeaderName, OneOrMany>,
21    len: usize,
22}
23
24#[derive(Clone, PartialEq, Eq)]
25enum OneOrMany {
26    One(HeaderValue),
27    Many(Vec<HeaderValue>),
28}
29
30impl HeaderMap {
31    /// Create an empty `HeaderMap`
32    ///
33    /// The map will be created without any capacity. This function will not allocate.
34    ///
35    /// Consider using the [`FromIterator`], [`Extend`] implementations if the final
36    /// length is known upfront.
37    #[must_use]
38    pub const fn new() -> Self {
39        Self {
40            headers: BTreeMap::new(),
41            len: 0,
42        }
43    }
44
45    pub fn get(&self, name: &HeaderName) -> Option<&HeaderValue> {
46        self.get_all(name).next()
47    }
48
49    pub fn get_all<'a>(
50        &'a self,
51        name: &HeaderName,
52    ) -> impl DoubleEndedIterator<Item = &'a HeaderValue> + use<'a> {
53        self.headers.get(name).unwrap_or(&EMPTY_HEADERS).iter()
54    }
55
56    pub fn insert(&mut self, name: HeaderName, value: HeaderValue) {
57        if let Some(prev) = self.headers.insert(name, OneOrMany::One(value)) {
58            self.len -= prev.len();
59        }
60        self.len += 1;
61    }
62
63    pub fn append(&mut self, name: HeaderName, value: HeaderValue) {
64        match self.headers.entry(name) {
65            Entry::Vacant(vacant) => {
66                vacant.insert(OneOrMany::One(value));
67            }
68            Entry::Occupied(mut occupied) => {
69                occupied.get_mut().push(value);
70            }
71        }
72        self.len += 1;
73    }
74
75    pub fn remove(&mut self, name: &HeaderName) {
76        if let Some(prev) = self.headers.remove(name) {
77            self.len -= prev.len();
78        }
79    }
80
81    /// Returns the number of keys stored in the map
82    ///
83    /// This number will be less than or equal to [`HeaderMap::len`].
84    #[must_use]
85    pub fn keys_len(&self) -> usize {
86        self.headers.len()
87    }
88
89    /// Returns the number of headers stored in the map
90    ///
91    /// This number represents the total number of **values** stored in the map.
92    /// This number can be greater than or equal to the number of **keys** stored.
93    #[must_use]
94    pub fn len(&self) -> usize {
95        self.len
96    }
97
98    /// Returns true if the map contains no elements
99    #[must_use]
100    pub fn is_empty(&self) -> bool {
101        self.headers.is_empty()
102    }
103
104    /// Clear the map, removing all key-value pairs. Keeps the allocated memory for reuse
105    pub fn clear(&mut self) {
106        self.headers.clear();
107        self.len = 0;
108    }
109
110    #[cfg(test)]
111    fn keys(&self) -> impl Iterator<Item = &'_ HeaderName> {
112        self.headers.keys()
113    }
114
115    pub(crate) fn iter(
116        &self,
117    ) -> impl DoubleEndedIterator<Item = (&'_ HeaderName, impl Iterator<Item = &'_ HeaderValue>)>
118    {
119        self.headers
120            .iter()
121            .map(|(name, value)| (name, value.iter()))
122    }
123}
124
125impl Debug for HeaderMap {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_tuple("HeaderMap")
128            .field(&self.headers)
129            // FIXME: switch to `finish_non_exhaustive`
130            .finish()
131    }
132}
133
134impl FromIterator<(HeaderName, HeaderValue)> for HeaderMap {
135    fn from_iter<I: IntoIterator<Item = (HeaderName, HeaderValue)>>(iter: I) -> Self {
136        let mut this = Self::new();
137        this.extend(iter);
138        this
139    }
140}
141
142impl Extend<(HeaderName, HeaderValue)> for HeaderMap {
143    fn extend<T: IntoIterator<Item = (HeaderName, HeaderValue)>>(&mut self, iter: T) {
144        iter.into_iter().for_each(|(name, value)| {
145            self.append(name, value);
146        });
147    }
148}
149
150impl Default for HeaderMap {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156impl Debug for OneOrMany {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        f.debug_set().entries(self.iter()).finish()
159    }
160}
161
162impl OneOrMany {
163    fn len(&self) -> usize {
164        match self {
165            Self::One(_) => 1,
166            Self::Many(vec) => vec.len(),
167        }
168    }
169
170    fn push(&mut self, item: HeaderValue) {
171        match self {
172            Self::One(current_item) => {
173                let current_item =
174                    mem::replace(current_item, HeaderValue::from_static("replacing"));
175                *self = Self::Many(vec![current_item, item]);
176            }
177            Self::Many(vec) => {
178                debug_assert!(!vec.is_empty(), "OneOrMany can't be empty");
179                vec.push(item);
180            }
181        }
182    }
183
184    fn iter(&self) -> impl DoubleEndedIterator<Item = &'_ HeaderValue> {
185        // This implementation may look odd, but it implements `TrustedLen`,
186        // so the Iterator is efficient to collect.
187        match self {
188            Self::One(one) => Iterator::chain(Some(one).into_iter(), &[]),
189            Self::Many(many) => Iterator::chain(None.into_iter(), many),
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use alloc::{vec, vec::Vec};
197
198    use crate::headers::{HeaderName, HeaderValue};
199
200    use super::HeaderMap;
201
202    #[test]
203    fn manual() {
204        let mut headers = HeaderMap::new();
205        headers.append(
206            HeaderName::from_static("Nats-Message-Id"),
207            HeaderValue::from_static("abcd"),
208        );
209        headers.append(
210            HeaderName::from_static("Nats-Sequence"),
211            HeaderValue::from_static("1"),
212        );
213        headers.append(
214            HeaderName::from_static("Nats-Message-Id"),
215            HeaderValue::from_static("1234"),
216        );
217        headers.append(
218            HeaderName::from_static("Nats-Time-Stamp"),
219            HeaderValue::from_static("0"),
220        );
221        headers.remove(&HeaderName::from_static("Nats-Time-Stamp"));
222
223        verify_header_map(&headers);
224    }
225
226    #[test]
227    fn collect() {
228        let headers = [
229            (
230                HeaderName::from_static("Nats-Message-Id"),
231                HeaderValue::from_static("abcd"),
232            ),
233            (
234                HeaderName::from_static("Nats-Sequence"),
235                HeaderValue::from_static("1"),
236            ),
237            (
238                HeaderName::from_static("Nats-Message-Id"),
239                HeaderValue::from_static("1234"),
240            ),
241        ]
242        .into_iter()
243        .collect::<HeaderMap>();
244
245        verify_header_map(&headers);
246    }
247
248    fn verify_header_map(headers: &HeaderMap) {
249        assert_eq!(
250            [
251                HeaderName::from_static("Nats-Message-Id"),
252                HeaderName::from_static("Nats-Sequence")
253            ]
254            .as_slice(),
255            headers.keys().cloned().collect::<Vec<_>>().as_slice()
256        );
257
258        let raw_headers = headers
259            .iter()
260            .map(|(name, values)| (name.clone(), values.cloned().collect::<Vec<_>>()))
261            .collect::<Vec<_>>();
262        assert_eq!(
263            [
264                (
265                    HeaderName::from_static("Nats-Message-Id"),
266                    vec![
267                        HeaderValue::from_static("abcd"),
268                        HeaderValue::from_static("1234")
269                    ]
270                ),
271                (
272                    HeaderName::from_static("Nats-Sequence"),
273                    vec![HeaderValue::from_static("1")]
274                ),
275            ]
276            .as_slice(),
277            raw_headers.as_slice(),
278        );
279    }
280}