1use std::sync::Arc;
4
5use fred::clients::{Client, Pool};
6use fred::interfaces::{ClientLike, EventInterface, PubsubInterface, StreamsInterface};
7use fred::types::config::{Config, ServerConfig};
8use ruststream::{Broker, DescribeServer, ServerSpec, Subscribe};
9use tokio::sync::OnceCell;
10
11use crate::{
12 error::RedisError,
13 list::{RedisList, RedisListPublisher, RedisListSubscriber},
14 publisher::RedisPublisher,
15 pubsub::{PubSubMode, RedisPubSub, RedisPubSubPublisher, RedisPubSubSubscriber},
16 stream::RedisStream,
17 subscriber::RedisSubscriber,
18};
19
20const DEFAULT_POOL_SIZE: usize = 4;
22
23#[derive(Debug, Clone)]
26enum Topology {
27 Standalone(String),
29 Cluster(Vec<String>),
31 Sentinel { service: String, hosts: Vec<String> },
34 Preconnected,
36}
37
38fn parse_server(addr: &str, default_port: u16) -> Result<(String, u16), RedisError> {
42 let trimmed = addr
43 .trim()
44 .trim_start_matches("rediss://")
45 .trim_start_matches("redis://");
46 let (host, port) = match trimmed.rsplit_once(':') {
47 Some((host, port)) => {
48 let port = port.parse::<u16>().map_err(|_| {
49 RedisError::Connect(format!("invalid port in redis address `{addr}`").into())
50 })?;
51 (host, port)
52 }
53 None => (trimmed, default_port),
54 };
55 if host.is_empty() {
56 return Err(RedisError::Connect(
57 format!("missing host in redis address `{addr}`").into(),
58 ));
59 }
60 Ok((host.to_owned(), port))
61}
62
63fn parse_servers(addrs: &[String], default_port: u16) -> Result<Vec<(String, u16)>, RedisError> {
64 if addrs.is_empty() {
65 return Err(RedisError::Connect("no redis addresses provided".into()));
66 }
67 addrs
68 .iter()
69 .map(|addr| parse_server(addr, default_port))
70 .collect()
71}
72
73#[derive(Clone)]
103pub struct RedisBroker {
104 pool: Arc<OnceCell<Pool>>,
105 topology: Topology,
106 pool_size: usize,
107 default_group: Option<String>,
108}
109
110impl std::fmt::Debug for RedisBroker {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 f.debug_struct("RedisBroker")
113 .field("topology", &self.topology)
114 .field("pool_size", &self.pool_size)
115 .field("default_group", &self.default_group)
116 .finish_non_exhaustive()
117 }
118}
119
120impl RedisBroker {
121 #[must_use]
126 pub fn standalone(url: impl Into<String>) -> Self {
127 Self::with_topology(Topology::Standalone(url.into()))
128 }
129
130 #[must_use]
135 pub fn cluster(nodes: impl IntoIterator<Item = impl Into<String>>) -> Self {
136 Self::with_topology(Topology::Cluster(
137 nodes.into_iter().map(Into::into).collect(),
138 ))
139 }
140
141 #[must_use]
146 pub fn sentinel(
147 service: impl Into<String>,
148 sentinels: impl IntoIterator<Item = impl Into<String>>,
149 ) -> Self {
150 Self::with_topology(Topology::Sentinel {
151 service: service.into(),
152 hosts: sentinels.into_iter().map(Into::into).collect(),
153 })
154 }
155
156 fn with_topology(topology: Topology) -> Self {
157 Self {
158 pool: Arc::new(OnceCell::new()),
159 topology,
160 pool_size: DEFAULT_POOL_SIZE,
161 default_group: None,
162 }
163 }
164
165 #[must_use]
167 pub const fn pool(mut self, size: usize) -> Self {
168 self.pool_size = size;
169 self
170 }
171
172 #[must_use]
177 pub fn default_group(mut self, group: impl Into<String>) -> Self {
178 self.default_group = Some(group.into());
179 self
180 }
181
182 pub async fn connect(url: impl Into<String>) -> Result<Self, RedisError> {
189 let broker = Self::standalone(url);
190 Broker::connect(&broker).await?;
191 Ok(broker)
192 }
193
194 #[must_use]
197 pub fn from_pool(pool: Pool) -> Self {
198 Self {
199 pool: Arc::new(OnceCell::new_with(Some(pool))),
200 topology: Topology::Preconnected,
201 pool_size: DEFAULT_POOL_SIZE,
202 default_group: None,
203 }
204 }
205
206 fn build_config(&self) -> Result<Config, RedisError> {
208 match &self.topology {
209 Topology::Standalone(url) => {
210 Config::from_url(url).map_err(|err| RedisError::Connect(Box::new(err)))
211 }
212 Topology::Cluster(nodes) => {
213 let hosts = parse_servers(nodes, 6379)?;
214 Ok(Config {
215 server: ServerConfig::new_clustered(hosts),
216 ..Config::default()
217 })
218 }
219 Topology::Sentinel { service, hosts } => {
220 let hosts = parse_servers(hosts, 26379)?;
221 Ok(Config {
222 server: ServerConfig::new_sentinel(hosts, service.clone()),
223 ..Config::default()
224 })
225 }
226 Topology::Preconnected => Err(RedisError::NotConnected),
228 }
229 }
230
231 fn connected(&self) -> Result<Pool, RedisError> {
233 self.pool.get().cloned().ok_or(RedisError::NotConnected)
234 }
235
236 #[must_use]
245 pub fn pool_handle(&self) -> Pool {
246 self.pool
247 .get()
248 .cloned()
249 .expect("RedisBroker::pool_handle() called before connect()")
250 }
251
252 pub async fn subscribe(&self, def: RedisStream) -> Result<RedisSubscriber, RedisError> {
263 let pool = self.connected()?;
264 let group = def.group_or_err()?.to_owned();
265 let consumer = def.consumer_or_auto();
266 ensure_group(&pool, def.key(), &group, def.start().as_id()).await?;
267 Ok(RedisSubscriber::new(
268 pool,
269 def.key().to_owned(),
270 group,
271 consumer,
272 def.count_or_default(),
273 def.block_or_default(),
274 def.mode(),
275 ))
276 }
277
278 #[must_use]
283 pub fn publisher(&self) -> RedisPublisher {
284 RedisPublisher::new(Arc::clone(&self.pool), self.supports_transactions())
285 }
286
287 const fn supports_transactions(&self) -> bool {
290 !matches!(self.topology, Topology::Cluster(_))
291 }
292
293 async fn new_client(&self) -> Result<Client, RedisError> {
296 let config = self.build_config()?;
297 let client = Client::new(config, None, None, None);
298 client
299 .init()
300 .await
301 .map_err(|err| RedisError::Connect(Box::new(err)))?;
302 Ok(client)
303 }
304
305 pub async fn subscribe_pubsub(
313 &self,
314 def: RedisPubSub,
315 ) -> Result<RedisPubSubSubscriber, RedisError> {
316 def.validate()?;
317 let codec = def.codec_handle();
318 let client = self.new_client().await?;
319 let channel = def.channel().to_owned();
320 let result = match (def.delivery_mode(), def.is_pattern()) {
321 (PubSubMode::Classic, true) => client.psubscribe(channel).await,
322 (PubSubMode::Classic, false) => client.subscribe(channel).await,
323 (PubSubMode::Sharded, _) => client.ssubscribe(channel).await,
324 };
325 result.map_err(RedisError::subscribe)?;
326 let rx = client.message_rx();
327 Ok(RedisPubSubSubscriber::new(client, rx, codec))
328 }
329
330 #[allow(
336 clippy::unused_async,
337 reason = "async for parity with the other subscribe methods and the SubscriptionSource shape"
338 )]
339 pub async fn subscribe_list(&self, def: RedisList) -> Result<RedisListSubscriber, RedisError> {
340 let pool = self.connected()?;
341 Ok(RedisListSubscriber::new(
342 pool,
343 def.key().to_owned(),
344 def.is_reliable(),
345 def.processing_or_default(),
346 def.block_or_default(),
347 def.codec_handle(),
348 ))
349 }
350
351 #[must_use]
354 pub fn pubsub_publisher(&self) -> RedisPubSubPublisher {
355 RedisPubSubPublisher::new(Arc::clone(&self.pool), PubSubMode::Classic)
356 }
357
358 #[must_use]
360 pub fn list_publisher(&self) -> RedisListPublisher {
361 RedisListPublisher::new(Arc::clone(&self.pool))
362 }
363
364 pub async fn shutdown_pool(&self) {
366 if let Some(pool) = self.pool.get() {
367 let _ = pool.quit().await;
368 }
369 }
370}
371
372async fn ensure_group(
374 pool: &Pool,
375 key: &str,
376 group: &str,
377 start_id: &str,
378) -> Result<(), RedisError> {
379 let result: Result<String, fred::error::Error> =
380 pool.xgroup_create(key, group, start_id, true).await;
381 match result {
382 Ok(_) => Ok(()),
383 Err(err) if err.details().contains("BUSYGROUP") => Ok(()),
385 Err(err) => Err(RedisError::subscribe(err)),
386 }
387}
388
389impl Broker for RedisBroker {
390 type Error = RedisError;
391
392 async fn connect(&self) -> Result<(), Self::Error> {
393 self.pool
394 .get_or_try_init(|| async {
395 let config = self.build_config()?;
396 let pool = Pool::new(config, None, None, None, self.pool_size)
397 .map_err(|err| RedisError::Connect(Box::new(err)))?;
398 pool.init()
399 .await
400 .map_err(|err| RedisError::Connect(Box::new(err)))?;
401 Ok(pool)
402 })
403 .await?;
404 Ok(())
405 }
406
407 async fn shutdown(&self) -> Result<(), Self::Error> {
408 self.shutdown_pool().await;
409 Ok(())
410 }
411}
412
413#[allow(clippy::use_self)]
416impl Subscribe for RedisBroker {
417 type Subscriber = RedisSubscriber;
418
419 async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
420 let group = self.default_group.clone().ok_or_else(|| {
421 RedisError::InvalidOptions(format!(
422 "bare-string subscription on `{name}` needs a broker-wide default group: \
423 call RedisBroker::default_group(name), or subscribe with \
424 RedisStream::new(name).group(group)"
425 ))
426 })?;
427 RedisBroker::subscribe(self, RedisStream::new(name).group(group)).await
428 }
429}
430
431impl DescribeServer for RedisBroker {
433 fn describe_server(&self) -> ServerSpec {
434 let host = match &self.topology {
435 Topology::Standalone(url) => url
436 .trim_start_matches("rediss://")
437 .trim_start_matches("redis://")
438 .to_owned(),
439 Topology::Cluster(nodes) => nodes.first().cloned().unwrap_or_default(),
440 Topology::Sentinel { hosts, .. } => hosts.first().cloned().unwrap_or_default(),
441 Topology::Preconnected => String::new(),
442 };
443 ServerSpec::new(host, "redis")
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use ruststream::{OutgoingMessage, Publisher};
450
451 use super::*;
452
453 #[tokio::test]
456 async fn standalone_does_not_connect() {
457 let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
458
459 let publish_err = broker
460 .publisher()
461 .publish(OutgoingMessage::new("orders", b"{}".as_slice()))
462 .await
463 .unwrap_err();
464 assert!(matches!(publish_err, RedisError::NotConnected));
465
466 let subscribe_err = broker
467 .subscribe(RedisStream::new("orders").group("g"))
468 .await
469 .unwrap_err();
470 assert!(matches!(subscribe_err, RedisError::NotConnected));
471 }
472
473 #[tokio::test]
474 async fn bare_string_subscription_needs_default_group() {
475 let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
476 let err = Subscribe::subscribe(&broker, "orders").await.unwrap_err();
477 assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("default group")));
478 }
479
480 #[test]
481 fn describe_server_reports_redis() {
482 let broker = RedisBroker::standalone("redis://localhost:6379");
483 let spec = broker.describe_server();
484 assert_eq!(spec.protocol, "redis");
485 assert_eq!(spec.host, "localhost:6379");
486 }
487}