strut_rabbitmq/
handle.rs

1use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
2use secure_string::SecureString;
3use serde::de::{DeserializeSeed, Error, MapAccess, Visitor};
4use serde::{Deserialize, Deserializer};
5use std::any::type_name;
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::fmt::{Debug, Display, Formatter};
9use std::sync::Arc;
10use strut_deserialize::{Slug, SlugMap};
11use strut_factory::impl_deserialize_field;
12use strut_util::BackoffConfig;
13
14const VHOST_ENCODE_SET: &AsciiSet = &CONTROLS
15    .add(b'/') // Encode '/' as %2F
16    .add(b'?') // Encode '?' as %3F
17    .add(b'#') // Encode '#' as %23
18    .add(b'%'); // Encode '%' as %25 (to avoid ambiguity)
19
20/// Represents a collection of uniquely named [`Handle`]s.
21#[derive(Debug, Default, Clone, PartialEq)]
22pub struct HandleCollection {
23    handles: SlugMap<Handle>,
24}
25
26/// Defines a connection handle for a RabbitMQ cluster, consisting primarily of
27/// a set of credentials, along with a bit of metadata for logging/debugging
28/// purposes.
29///
30/// This handle by itself does not implement any connection logic.
31#[derive(Clone, PartialEq)]
32pub struct Handle {
33    name: Arc<str>,
34    identifier: Arc<str>,
35    dsn: SecureString,
36    backoff: BackoffConfig,
37}
38
39/// Groups the pieces of a RabbitMQ DSN for convenient passing into
40/// [`Handle::new`].
41pub struct DsnChunks<H, U, P, VH>
42where
43    H: AsRef<str>,
44    U: AsRef<str>,
45    P: Into<SecureString>,
46    VH: AsRef<str>,
47{
48    /// The `localhost` part of `amqp://user:pass@localhost:5672/%2F`.
49    pub host: H,
50    /// The `5672` part of `amqp://user:pass@localhost:5672/%2F`.
51    pub port: u16,
52    /// The `user` part of `amqp://user:pass@localhost:5672/%2F`.
53    pub user: U,
54    /// The `pass` part of `amqp://user:pass@localhost:5672/%2F`.
55    ///
56    /// This has to be represented with anything that implements
57    /// [`Into<SecureString>`], which includes `&str`.
58    pub password: P,
59    /// The `%2F` part of `amqp://user:pass@localhost:5672/%2F`.
60    ///
61    /// This does **not** need to be percent-encoded. [`Handle`] takes
62    /// care of percent-encoding. In the example above, the equivalent
63    /// human-readable string `"/"` will work just fine.
64    pub vhost: VH,
65}
66
67impl Handle {
68    /// Creates a new handle with the given name and composes the DSN from the
69    /// given [`chunks`](DsnChunks).
70    ///
71    /// Takes care of securing the password against _accidental_ debug-printing.
72    /// Ensures proper percent-encoding of the `vhost`; there is no need to
73    /// pre-encode it.
74    pub fn new<H, U, P, VH>(name: impl AsRef<str>, chunks: DsnChunks<H, U, P, VH>) -> Self
75    where
76        H: AsRef<str>,
77        U: AsRef<str>,
78        P: Into<SecureString>,
79        VH: AsRef<str>,
80    {
81        let name = Arc::from(name.as_ref());
82
83        let vhost = Self::ensure_encoded_vhost(chunks.vhost.as_ref());
84        let identifier = Self::compose_identifier(
85            chunks.host.as_ref(),
86            chunks.port,
87            chunks.user.as_ref(),
88            vhost.as_ref(),
89        );
90
91        let password = chunks.password.into();
92        let dsn = Self::compose_dsn(
93            chunks.host.as_ref(),
94            chunks.port,
95            chunks.user.as_ref(),
96            &password,
97            vhost.as_ref(),
98        );
99
100        let backoff = BackoffConfig::default();
101
102        Self {
103            name,
104            identifier,
105            dsn,
106            backoff,
107        }
108    }
109
110    /// Re-create this [`Handle`] with the given [`BackoffConfig`].
111    pub fn with_backoff(self, backoff: BackoffConfig) -> Self {
112        Self { backoff, ..self }
113    }
114
115    /// Ensures that the given `vhost` value is correctly percent-encoded to be
116    /// included in a DSN.
117    fn ensure_encoded_vhost(vhost: &str) -> Cow<'_, str> {
118        utf8_percent_encode(vhost, VHOST_ENCODE_SET).into()
119    }
120
121    /// Composes a non-sensitive identifier useful for debug-printing a handle.
122    fn compose_identifier(host: &str, port: u16, user: &str, vhost: &str) -> Arc<str> {
123        Arc::from(format!("{}@{}:{}/{}", user, host, port, vhost))
124    }
125
126    /// Composes a sensitive DSN to be used for connecting to the RabbitMQ cluster.
127    fn compose_dsn(
128        host: &str,
129        port: u16,
130        user: &str,
131        password: &SecureString,
132        vhost: &str,
133    ) -> SecureString {
134        SecureString::from(format!(
135            "amqp://{}:{}@{}:{}/{}",
136            user,
137            password.unsecure(),
138            host,
139            port,
140            vhost,
141        ))
142    }
143}
144
145impl HandleCollection {
146    /// Reports whether this collection contains a [`Handle`] with the
147    /// given unique name.
148    pub fn contains(&self, name: &str) -> bool {
149        self.handles.contains_key(name)
150    }
151
152    /// Retrieves `Some` reference to a [`Handle`] from this collection
153    /// under the given name, or `None`, if the name is not present in the
154    /// collection.
155    pub fn get(&self, name: &str) -> Option<&Handle> {
156        self.handles.get(name)
157    }
158
159    /// Retrieves a reference to a [`Handle`] from this collection under
160    /// the given name. Panics if the name is not present in the collection.
161    pub fn expect(&self, name: &str) -> &Handle {
162        self.get(name)
163            .unwrap_or_else(|| panic!("requested an undefined RabbitMQ handle '{}'", name))
164    }
165}
166
167impl Handle {
168    /// Reports the handle name.
169    pub fn name(&self) -> &str {
170        &self.name
171    }
172
173    /// Reports the handle identifier, which is the normal connection DSN, but
174    /// with the password obscured. This identifier is generally safe for debug
175    /// logging.
176    pub fn identifier(&self) -> &str {
177        &self.identifier
178    }
179
180    /// Reports the handle DSN.
181    pub fn dsn(&self) -> &SecureString {
182        &self.dsn
183    }
184
185    /// Exposes the exponential [`Backoff`](strut_util::Backoff) configuration
186    /// for this handle.
187    pub fn backoff(&self) -> &BackoffConfig {
188        &self.backoff
189    }
190}
191
192/// Convenience implementation for providing partially hard-coding chunks.
193impl Default for DsnChunks<&str, &str, &str, &str> {
194    fn default() -> Self {
195        Self {
196            host: Handle::default_host(),
197            port: Handle::default_port(),
198            user: Handle::default_user(),
199            password: Handle::default_password(),
200            vhost: Handle::default_vhost(),
201        }
202    }
203}
204
205impl Handle {
206    fn default_name() -> &'static str {
207        "default"
208    }
209
210    fn default_host() -> &'static str {
211        "localhost"
212    }
213
214    fn default_port() -> u16 {
215        5672
216    }
217
218    fn default_user() -> &'static str {
219        "guest"
220    }
221
222    fn default_password() -> &'static str {
223        "guest"
224    }
225
226    fn default_vhost() -> &'static str {
227        "/"
228    }
229}
230
231impl Default for Handle {
232    fn default() -> Self {
233        Self::new(Self::default_name(), DsnChunks::default())
234    }
235}
236
237/// Omits `dsn` from debug representation. DSN is largely safe (it’s a [`SecureString`]),
238/// but its inclusion adds no valuable debug information.
239impl Debug for Handle {
240    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
241        f.debug_struct(type_name::<Self>())
242            .field("name", &self.name)
243            .field("identifier", &self.identifier)
244            .field("backoff", &self.backoff)
245            .finish()
246    }
247}
248
249impl Display for Handle {
250    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
251        f.write_str(&self.identifier)
252    }
253}
254
255impl AsRef<Handle> for Handle {
256    fn as_ref(&self) -> &Handle {
257        self
258    }
259}
260
261impl AsRef<HandleCollection> for HandleCollection {
262    fn as_ref(&self) -> &HandleCollection {
263        self
264    }
265}
266
267const _: () = {
268    impl<S> FromIterator<(S, Handle)> for HandleCollection
269    where
270        S: Into<Slug>,
271    {
272        fn from_iter<T: IntoIterator<Item = (S, Handle)>>(iter: T) -> Self {
273            let handles = iter.into_iter().collect();
274
275            Self { handles }
276        }
277    }
278
279    impl<const N: usize, S> From<[(S, Handle); N]> for HandleCollection
280    where
281        S: Into<Slug>,
282    {
283        fn from(value: [(S, Handle); N]) -> Self {
284            value.into_iter().collect()
285        }
286    }
287};
288
289const _: () = {
290    impl<'de> Deserialize<'de> for HandleCollection {
291        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
292        where
293            D: Deserializer<'de>,
294        {
295            deserializer.deserialize_map(HandleCollectionVisitor)
296        }
297    }
298
299    struct HandleCollectionVisitor;
300
301    impl<'de> Visitor<'de> for HandleCollectionVisitor {
302        type Value = HandleCollection;
303
304        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
305            formatter.write_str("a map of RabbitMQ handles")
306        }
307
308        fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
309        where
310            A: MapAccess<'de>,
311        {
312            let grouped = Slug::group_map(map)?;
313            let mut handles = HashMap::with_capacity(grouped.len());
314
315            for (key, value) in grouped {
316                let seed = HandleSeed {
317                    name: key.original(),
318                };
319                let handle = seed.deserialize(value).map_err(Error::custom)?;
320                handles.insert(key, handle);
321            }
322
323            Ok(HandleCollection {
324                handles: SlugMap::new(handles),
325            })
326        }
327    }
328
329    struct HandleSeed<'a> {
330        name: &'a str,
331    }
332
333    impl<'de> DeserializeSeed<'de> for HandleSeed<'_> {
334        type Value = Handle;
335
336        fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
337        where
338            D: Deserializer<'de>,
339        {
340            deserializer.deserialize_map(HandleSeedVisitor { name: self.name })
341        }
342    }
343
344    struct HandleSeedVisitor<'a> {
345        name: &'a str,
346    }
347
348    impl<'de> Visitor<'de> for HandleSeedVisitor<'_> {
349        type Value = Handle;
350
351        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
352            formatter.write_str("a map of RabbitMQ handle")
353        }
354
355        fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
356        where
357            A: MapAccess<'de>,
358        {
359            visit_handle(map, Some(self.name))
360        }
361    }
362
363    impl<'de> Deserialize<'de> for Handle {
364        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
365        where
366            D: Deserializer<'de>,
367        {
368            deserializer.deserialize_map(HandleVisitor)
369        }
370    }
371
372    struct HandleVisitor;
373
374    impl<'de> Visitor<'de> for HandleVisitor {
375        type Value = Handle;
376
377        fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
378            formatter.write_str("a map of RabbitMQ handle")
379        }
380
381        fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
382        where
383            A: MapAccess<'de>,
384        {
385            visit_handle(map, None)
386        }
387    }
388
389    fn visit_handle<'de, A>(mut map: A, known_name: Option<&str>) -> Result<Handle, A::Error>
390    where
391        A: MapAccess<'de>,
392    {
393        // Type hints are needed on `String`s to avoid deserializer expecting a
394        // borrowed string, which not all deserializers support.
395        let mut name: Option<String> = None;
396        let mut host: Option<String> = None;
397        let mut port = None;
398        let mut user: Option<String> = None;
399        let mut password: Option<SecureString> = None;
400        let mut vhost: Option<String> = None;
401
402        while let Some(key) = map.next_key()? {
403            match key {
404                HandleField::name => key.poll(&mut map, &mut name)?,
405                HandleField::host => key.poll(&mut map, &mut host)?,
406                HandleField::port => key.poll(&mut map, &mut port)?,
407                HandleField::user => key.poll(&mut map, &mut user)?,
408                HandleField::password => key.poll(&mut map, &mut password)?,
409                HandleField::vhost => key.poll(&mut map, &mut vhost)?,
410                HandleField::__ignore => map.next_value()?,
411            };
412        }
413
414        let name = match known_name {
415            Some(known_name) => known_name,
416            None => name.as_deref().unwrap_or_else(|| Handle::default_name()),
417        };
418
419        // “Useless” closures are needed to avoid lifetime issues
420        let chunks = DsnChunks {
421            host: host.as_deref().unwrap_or_else(|| Handle::default_host()),
422            port: port.unwrap_or_else(Handle::default_port),
423            user: user.as_deref().unwrap_or_else(|| Handle::default_user()),
424            password: password.unwrap_or_else(|| Handle::default_password().into()),
425            vhost: vhost.as_deref().unwrap_or_else(|| Handle::default_vhost()),
426        };
427
428        Ok(Handle::new(name, chunks))
429    }
430
431    impl_deserialize_field!(
432        HandleField,
433        strut_deserialize::Slug::eq_as_slugs,
434        name,
435        host | hostname,
436        port,
437        user | username,
438        password,
439        vhost,
440    );
441};
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use pretty_assertions::assert_eq;
447
448    #[test]
449    fn deserialize_from_empty() {
450        // Given
451        let input = "";
452        let expected_output = Handle::default();
453
454        // When
455        let actual_output = serde_yml::from_str::<Handle>(input).unwrap();
456
457        // Then
458        assert_eq!(expected_output, actual_output);
459    }
460
461    #[test]
462    fn deserialize_from_full() {
463        // Given
464        let input = r#"
465name: test_handle
466host: test_host
467port: 8080
468user: test_user
469password: test_password
470vhost: test_vhost
471"#;
472        let expected_output = Handle::new(
473            "test_handle",
474            DsnChunks {
475                host: "test_host",
476                port: 8080,
477                user: "test_user",
478                password: "test_password",
479                vhost: "test_vhost",
480            },
481        );
482
483        // When
484        let actual_output = serde_yml::from_str::<Handle>(input).unwrap();
485
486        // Then
487        assert_eq!(expected_output, actual_output);
488    }
489
490    #[test]
491    fn deserialize_collection_from_empty() {
492        // Given
493        let input = "";
494        let expected_output = HandleCollection::default();
495
496        // When
497        let actual_output = serde_yml::from_str::<HandleCollection>(input).unwrap();
498
499        // Then
500        assert_eq!(expected_output, actual_output);
501    }
502
503    #[test]
504    fn deserialize_collection_from_full() {
505        // Given
506        let input = r#"
507test_handle_a: {}
508test_handle_b:
509  host: test_host
510  port: 8080
511"#;
512        let expected_output = HandleCollection::from([
513            (
514                "test_handle_a",
515                Handle::new("test_handle_a", DsnChunks::default()),
516            ),
517            (
518                "test_handle_b",
519                Handle::new(
520                    "test_handle_b",
521                    DsnChunks {
522                        host: "test_host",
523                        port: 8080,
524                        ..Default::default()
525                    },
526                ),
527            ),
528        ]);
529
530        // When
531        let actual_output = serde_yml::from_str::<HandleCollection>(input).unwrap();
532
533        // Then
534        assert_eq!(expected_output, actual_output);
535    }
536}