1use std::sync::Arc;
4
5use fred::clients::{Client, Pool};
6use fred::interfaces::{ClientLike, EventInterface, PubsubInterface, StreamsInterface};
7#[cfg(feature = "credential-provider")]
8use fred::types::config::CredentialProvider;
9#[cfg(any(
10 feature = "tls-rustls",
11 feature = "tls-rustls-ring",
12 feature = "tls-native-tls"
13))]
14use fred::types::config::TlsConfig;
15use fred::types::config::{Config, ServerConfig};
16use ruststream::{Broker, DescribeServer, ServerSpec, Subscribe};
17use tokio::sync::OnceCell;
18
19use crate::{
20 error::RedisError,
21 list::{RedisList, RedisListPublisher, RedisListSubscriber},
22 publisher::RedisPublisher,
23 pubsub::{PubSubMode, RedisPubSub, RedisPubSubPublisher, RedisPubSubSubscriber},
24 stream::RedisStream,
25 subscriber::RedisSubscriber,
26};
27
28const DEFAULT_POOL_SIZE: usize = 4;
30
31#[derive(Debug, Clone)]
34enum Topology {
35 Standalone(String),
37 Cluster(Vec<String>),
39 Sentinel { service: String, hosts: Vec<String> },
42 Preconnected,
44}
45
46fn parse_server(addr: &str, default_port: u16) -> Result<(String, u16), RedisError> {
50 let trimmed = addr
51 .trim()
52 .trim_start_matches("rediss://")
53 .trim_start_matches("redis://");
54 let (host, port) = match trimmed.rsplit_once(':') {
55 Some((host, port)) => {
56 let port = port.parse::<u16>().map_err(|_| {
57 RedisError::Connect(format!("invalid port in redis address `{addr}`").into())
58 })?;
59 (host, port)
60 }
61 None => (trimmed, default_port),
62 };
63 if host.is_empty() {
64 return Err(RedisError::Connect(
65 format!("missing host in redis address `{addr}`").into(),
66 ));
67 }
68 Ok((host.to_owned(), port))
69}
70
71fn parse_servers(addrs: &[String], default_port: u16) -> Result<Vec<(String, u16)>, RedisError> {
72 if addrs.is_empty() {
73 return Err(RedisError::Connect("no redis addresses provided".into()));
74 }
75 addrs
76 .iter()
77 .map(|addr| parse_server(addr, default_port))
78 .collect()
79}
80
81#[derive(Clone, Default)]
85struct AuthConfig {
86 username: Option<String>,
88 password: Option<String>,
90 #[cfg(feature = "sentinel-auth")]
92 sentinel_username: Option<String>,
93 #[cfg(feature = "sentinel-auth")]
95 sentinel_password: Option<String>,
96 #[cfg(any(
98 feature = "tls-rustls",
99 feature = "tls-rustls-ring",
100 feature = "tls-native-tls"
101 ))]
102 tls: Option<TlsConfig>,
103 #[cfg(feature = "credential-provider")]
105 credential_provider: Option<Arc<dyn CredentialProvider>>,
106}
107
108impl std::fmt::Debug for AuthConfig {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 let mut s = f.debug_struct("AuthConfig");
113 s.field("username", &self.username);
114 s.field("password", &self.password.as_ref().map(|_| "<redacted>"));
115 #[cfg(feature = "sentinel-auth")]
116 {
117 s.field("sentinel_username", &self.sentinel_username);
118 s.field(
119 "sentinel_password",
120 &self.sentinel_password.as_ref().map(|_| "<redacted>"),
121 );
122 }
123 #[cfg(any(
124 feature = "tls-rustls",
125 feature = "tls-rustls-ring",
126 feature = "tls-native-tls"
127 ))]
128 s.field("tls", &self.tls.as_ref().map(|_| "<configured>"));
129 #[cfg(feature = "credential-provider")]
130 s.field(
131 "credential_provider",
132 &self.credential_provider.as_ref().map(|_| "<configured>"),
133 );
134 s.finish()
135 }
136}
137
138#[derive(Clone)]
168pub struct RedisBroker {
169 pool: Arc<OnceCell<Pool>>,
170 topology: Topology,
171 pool_size: usize,
172 default_group: Option<String>,
173 auth: AuthConfig,
174}
175
176impl std::fmt::Debug for RedisBroker {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.debug_struct("RedisBroker")
179 .field("topology", &self.topology)
180 .field("pool_size", &self.pool_size)
181 .field("default_group", &self.default_group)
182 .field("auth", &self.auth)
183 .finish_non_exhaustive()
184 }
185}
186
187impl RedisBroker {
188 #[must_use]
193 pub fn standalone(url: impl Into<String>) -> Self {
194 Self::with_topology(Topology::Standalone(url.into()))
195 }
196
197 #[must_use]
202 pub fn cluster(nodes: impl IntoIterator<Item = impl Into<String>>) -> Self {
203 Self::with_topology(Topology::Cluster(
204 nodes.into_iter().map(Into::into).collect(),
205 ))
206 }
207
208 #[must_use]
213 pub fn sentinel(
214 service: impl Into<String>,
215 sentinels: impl IntoIterator<Item = impl Into<String>>,
216 ) -> Self {
217 Self::with_topology(Topology::Sentinel {
218 service: service.into(),
219 hosts: sentinels.into_iter().map(Into::into).collect(),
220 })
221 }
222
223 fn with_topology(topology: Topology) -> Self {
224 Self {
225 pool: Arc::new(OnceCell::new()),
226 topology,
227 pool_size: DEFAULT_POOL_SIZE,
228 default_group: None,
229 auth: AuthConfig::default(),
230 }
231 }
232
233 #[must_use]
235 pub const fn pool(mut self, size: usize) -> Self {
236 self.pool_size = size;
237 self
238 }
239
240 #[must_use]
245 pub fn default_group(mut self, group: impl Into<String>) -> Self {
246 self.default_group = Some(group.into());
247 self
248 }
249
250 #[must_use]
269 pub fn credentials(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
270 self.auth.username = Some(username.into());
271 self.auth.password = Some(password.into());
272 self
273 }
274
275 #[must_use]
287 pub fn password(mut self, password: impl Into<String>) -> Self {
288 self.auth.password = Some(password.into());
289 self
290 }
291
292 #[cfg(any(
310 feature = "tls-rustls",
311 feature = "tls-rustls-ring",
312 feature = "tls-native-tls"
313 ))]
314 #[must_use]
315 pub fn tls(mut self, tls: impl Into<TlsConfig>) -> Self {
316 self.auth.tls = Some(tls.into());
317 self
318 }
319
320 #[cfg(feature = "sentinel-auth")]
336 #[must_use]
337 pub fn sentinel_credentials(
338 mut self,
339 username: impl Into<String>,
340 password: impl Into<String>,
341 ) -> Self {
342 self.auth.sentinel_username = Some(username.into());
343 self.auth.sentinel_password = Some(password.into());
344 self
345 }
346
347 #[cfg(feature = "sentinel-auth")]
362 #[must_use]
363 pub fn sentinel_password(mut self, password: impl Into<String>) -> Self {
364 self.auth.sentinel_password = Some(password.into());
365 self
366 }
367
368 #[cfg(feature = "credential-provider")]
385 #[must_use]
386 pub fn credential_provider(mut self, provider: Arc<dyn CredentialProvider>) -> Self {
387 self.auth.credential_provider = Some(provider);
388 self
389 }
390
391 pub async fn connect(url: impl Into<String>) -> Result<Self, RedisError> {
398 let broker = Self::standalone(url);
399 Broker::connect(&broker).await?;
400 Ok(broker)
401 }
402
403 #[must_use]
406 pub fn from_pool(pool: Pool) -> Self {
407 Self {
408 pool: Arc::new(OnceCell::new_with(Some(pool))),
409 topology: Topology::Preconnected,
410 pool_size: DEFAULT_POOL_SIZE,
411 default_group: None,
412 auth: AuthConfig::default(),
413 }
414 }
415
416 fn build_config(&self) -> Result<Config, RedisError> {
418 let mut config = match &self.topology {
419 Topology::Standalone(url) => {
420 Config::from_url(url).map_err(|err| RedisError::Connect(Box::new(err)))?
421 }
422 Topology::Cluster(nodes) => {
423 let hosts = parse_servers(nodes, 6379)?;
424 Config {
425 server: ServerConfig::new_clustered(hosts),
426 ..Config::default()
427 }
428 }
429 Topology::Sentinel { service, hosts } => {
430 let hosts = parse_servers(hosts, 26379)?;
431 Config {
432 server: ServerConfig::new_sentinel(hosts, service.clone()),
433 ..Config::default()
434 }
435 }
436 Topology::Preconnected => return Err(RedisError::NotConnected),
438 };
439 self.apply_auth(&mut config);
440 Ok(config)
441 }
442
443 fn apply_auth(&self, config: &mut Config) {
446 if self.auth.username.is_some() {
447 config.username.clone_from(&self.auth.username);
448 }
449 if self.auth.password.is_some() {
450 config.password.clone_from(&self.auth.password);
451 }
452 #[cfg(any(
453 feature = "tls-rustls",
454 feature = "tls-rustls-ring",
455 feature = "tls-native-tls"
456 ))]
457 if self.auth.tls.is_some() {
458 config.tls.clone_from(&self.auth.tls);
459 }
460 #[cfg(feature = "credential-provider")]
461 if self.auth.credential_provider.is_some() {
462 config
463 .credential_provider
464 .clone_from(&self.auth.credential_provider);
465 }
466 #[cfg(feature = "sentinel-auth")]
467 if let ServerConfig::Sentinel {
468 username, password, ..
469 } = &mut config.server
470 {
471 if self.auth.sentinel_username.is_some() {
472 username.clone_from(&self.auth.sentinel_username);
473 }
474 if self.auth.sentinel_password.is_some() {
475 password.clone_from(&self.auth.sentinel_password);
476 }
477 }
478 }
479
480 fn connected(&self) -> Result<Pool, RedisError> {
482 self.pool.get().cloned().ok_or(RedisError::NotConnected)
483 }
484
485 #[must_use]
494 pub fn pool_handle(&self) -> Pool {
495 self.pool
496 .get()
497 .cloned()
498 .expect("RedisBroker::pool_handle() called before connect()")
499 }
500
501 pub async fn subscribe(&self, def: RedisStream) -> Result<RedisSubscriber, RedisError> {
512 let pool = self.connected()?;
513 let group = def.group_or_err()?.to_owned();
514 let consumer = def.consumer_or_auto();
515 ensure_group(&pool, def.key(), &group, def.start().as_id()).await?;
516 Ok(RedisSubscriber::new(
517 pool,
518 def.key().to_owned(),
519 group,
520 consumer,
521 def.count_or_default(),
522 def.block_or_default(),
523 def.mode(),
524 def.poison_policy(),
525 def.delay_config(),
526 ))
527 }
528
529 #[must_use]
534 pub fn publisher(&self) -> RedisPublisher {
535 RedisPublisher::new(Arc::clone(&self.pool), self.supports_transactions())
536 }
537
538 const fn supports_transactions(&self) -> bool {
541 !matches!(self.topology, Topology::Cluster(_))
542 }
543
544 async fn new_client(&self) -> Result<Client, RedisError> {
547 let config = self.build_config()?;
548 let client = Client::new(config, None, None, None);
549 client
550 .init()
551 .await
552 .map_err(|err| RedisError::Connect(Box::new(err)))?;
553 Ok(client)
554 }
555
556 pub async fn subscribe_pubsub(
564 &self,
565 def: RedisPubSub,
566 ) -> Result<RedisPubSubSubscriber, RedisError> {
567 def.validate()?;
568 let codec = def.codec_handle();
569 let client = self.new_client().await?;
570 let channel = def.channel().to_owned();
571 let result = match (def.delivery_mode(), def.is_pattern()) {
572 (PubSubMode::Classic, true) => client.psubscribe(channel).await,
573 (PubSubMode::Classic, false) => client.subscribe(channel).await,
574 (PubSubMode::Sharded, _) => client.ssubscribe(channel).await,
575 };
576 result.map_err(RedisError::subscribe)?;
577 let rx = client.message_rx();
578 Ok(RedisPubSubSubscriber::new(client, rx, codec))
579 }
580
581 #[allow(
588 clippy::unused_async,
589 reason = "async for parity with the other subscribe methods and the SubscriptionSource shape"
590 )]
591 pub async fn subscribe_list(&self, def: RedisList) -> Result<RedisListSubscriber, RedisError> {
592 let pool = self.connected()?;
593 let recovery = def.recovery_config()?;
594 Ok(RedisListSubscriber::new(
595 pool,
596 def.key().to_owned(),
597 def.is_reliable(),
598 def.processing_or_default(),
599 def.block_or_default(),
600 def.codec_handle(),
601 def.poison_policy(),
602 recovery,
603 ))
604 }
605
606 #[must_use]
609 pub fn pubsub_publisher(&self) -> RedisPubSubPublisher {
610 RedisPubSubPublisher::new(Arc::clone(&self.pool), PubSubMode::Classic)
611 }
612
613 #[must_use]
615 pub fn list_publisher(&self) -> RedisListPublisher {
616 RedisListPublisher::new(Arc::clone(&self.pool))
617 }
618
619 pub async fn shutdown_pool(&self) {
621 if let Some(pool) = self.pool.get() {
622 let _ = pool.quit().await;
623 }
624 }
625}
626
627async fn ensure_group(
629 pool: &Pool,
630 key: &str,
631 group: &str,
632 start_id: &str,
633) -> Result<(), RedisError> {
634 let result: Result<String, fred::error::Error> =
635 pool.xgroup_create(key, group, start_id, true).await;
636 match result {
637 Ok(_) => Ok(()),
638 Err(err) if err.details().contains("BUSYGROUP") => Ok(()),
640 Err(err) => Err(RedisError::subscribe(err)),
641 }
642}
643
644impl Broker for RedisBroker {
645 type Error = RedisError;
646
647 async fn connect(&self) -> Result<(), Self::Error> {
648 self.pool
649 .get_or_try_init(|| async {
650 let config = self.build_config()?;
651 let pool = Pool::new(config, None, None, None, self.pool_size)
652 .map_err(|err| RedisError::Connect(Box::new(err)))?;
653 pool.init()
654 .await
655 .map_err(|err| RedisError::Connect(Box::new(err)))?;
656 Ok(pool)
657 })
658 .await?;
659 Ok(())
660 }
661
662 async fn shutdown(&self) -> Result<(), Self::Error> {
663 self.shutdown_pool().await;
664 Ok(())
665 }
666}
667
668#[allow(clippy::use_self)]
671impl Subscribe for RedisBroker {
672 type Subscriber = RedisSubscriber;
673
674 async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
675 let group = self.default_group.clone().ok_or_else(|| {
676 RedisError::InvalidOptions(format!(
677 "bare-string subscription on `{name}` needs a broker-wide default group: \
678 call RedisBroker::default_group(name), or subscribe with \
679 RedisStream::new(name).group(group)"
680 ))
681 })?;
682 RedisBroker::subscribe(self, RedisStream::new(name).group(group)).await
683 }
684}
685
686impl DescribeServer for RedisBroker {
688 fn describe_server(&self) -> ServerSpec {
689 let host = match &self.topology {
690 Topology::Standalone(url) => url
691 .trim_start_matches("rediss://")
692 .trim_start_matches("redis://")
693 .to_owned(),
694 Topology::Cluster(nodes) => nodes.first().cloned().unwrap_or_default(),
695 Topology::Sentinel { hosts, .. } => hosts.first().cloned().unwrap_or_default(),
696 Topology::Preconnected => String::new(),
697 };
698 ServerSpec::new(host, "redis")
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use ruststream::{OutgoingMessage, Publisher};
705
706 use super::*;
707
708 #[tokio::test]
711 async fn standalone_does_not_connect() {
712 let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
713
714 let publish_err = broker
715 .publisher()
716 .publish(OutgoingMessage::new("orders", b"{}".as_slice()))
717 .await
718 .unwrap_err();
719 assert!(matches!(publish_err, RedisError::NotConnected));
720
721 let subscribe_err = broker
722 .subscribe(RedisStream::new("orders").group("g"))
723 .await
724 .unwrap_err();
725 assert!(matches!(subscribe_err, RedisError::NotConnected));
726 }
727
728 #[tokio::test]
729 async fn bare_string_subscription_needs_default_group() {
730 let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
731 let err = Subscribe::subscribe(&broker, "orders").await.unwrap_err();
732 assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("default group")));
733 }
734
735 #[test]
736 fn describe_server_reports_redis() {
737 let broker = RedisBroker::standalone("redis://localhost:6379");
738 let spec = broker.describe_server();
739 assert_eq!(spec.protocol, "redis");
740 assert_eq!(spec.host, "localhost:6379");
741 }
742
743 #[test]
745 fn credentials_apply_to_all_topologies() {
746 let brokers = [
747 RedisBroker::standalone("redis://localhost:6379").credentials("alice", "s3cr3t"),
748 RedisBroker::cluster(["127.0.0.1:7000"]).credentials("alice", "s3cr3t"),
749 RedisBroker::sentinel("mymaster", ["127.0.0.1:26379"]).credentials("alice", "s3cr3t"),
750 ];
751 for broker in brokers {
752 let config = broker.build_config().expect("config builds");
753 assert_eq!(config.username.as_deref(), Some("alice"));
754 assert_eq!(config.password.as_deref(), Some("s3cr3t"));
755 }
756 }
757
758 #[test]
759 fn password_only_sets_password_without_username() {
760 let config = RedisBroker::cluster(["127.0.0.1:7000"])
761 .password("requirepass")
762 .build_config()
763 .expect("config builds");
764 assert_eq!(config.username, None);
765 assert_eq!(config.password.as_deref(), Some("requirepass"));
766 }
767
768 #[test]
770 fn programmatic_credentials_override_standalone_url() {
771 let config = RedisBroker::standalone("redis://urluser:urlpass@localhost:6379")
772 .credentials("acluser", "aclpass")
773 .build_config()
774 .expect("config builds");
775 assert_eq!(config.username.as_deref(), Some("acluser"));
776 assert_eq!(config.password.as_deref(), Some("aclpass"));
777 }
778
779 #[test]
781 fn url_credentials_preserved_without_override() {
782 let config = RedisBroker::standalone("redis://urluser:urlpass@localhost:6379")
783 .build_config()
784 .expect("config builds");
785 assert_eq!(config.username.as_deref(), Some("urluser"));
786 assert_eq!(config.password.as_deref(), Some("urlpass"));
787 }
788
789 #[test]
790 fn debug_redacts_password() {
791 let broker =
792 RedisBroker::standalone("redis://localhost:6379").credentials("alice", "s3cr3t");
793 let rendered = format!("{broker:?}");
794 assert!(
795 !rendered.contains("s3cr3t"),
796 "password must not appear in Debug output: {rendered}"
797 );
798 assert!(
800 rendered.contains("alice"),
801 "expected username in: {rendered}"
802 );
803 }
804
805 #[cfg(feature = "sentinel-auth")]
806 #[test]
807 fn sentinel_credentials_apply_to_sentinel_server() {
808 let config = RedisBroker::sentinel("mymaster", ["127.0.0.1:26379"])
809 .credentials("datauser", "datapass")
810 .sentinel_credentials("sentineluser", "sentinelpass")
811 .build_config()
812 .expect("config builds");
813 assert_eq!(config.username.as_deref(), Some("datauser"));
815 let ServerConfig::Sentinel {
816 username, password, ..
817 } = &config.server
818 else {
819 panic!("expected a sentinel server config");
820 };
821 assert_eq!(username.as_deref(), Some("sentineluser"));
822 assert_eq!(password.as_deref(), Some("sentinelpass"));
823 }
824
825 #[cfg(feature = "credential-provider")]
826 #[derive(Debug)]
827 struct StaticCredentials;
828
829 #[cfg(feature = "credential-provider")]
830 #[async_trait::async_trait]
831 impl CredentialProvider for StaticCredentials {
832 async fn fetch(
833 &self,
834 _server: Option<&fred::types::config::Server>,
835 ) -> Result<(Option<String>, Option<String>), fred::error::Error> {
836 Ok((Some("rotating".into()), Some("token".into())))
837 }
838 }
839
840 #[cfg(feature = "credential-provider")]
841 #[test]
842 fn credential_provider_is_applied() {
843 let provider: Arc<dyn CredentialProvider> = Arc::new(StaticCredentials);
844 let config = RedisBroker::cluster(["127.0.0.1:7000"])
845 .credential_provider(provider)
846 .build_config()
847 .expect("config builds");
848 assert!(config.credential_provider.is_some());
849 }
850}