strev_postgres/
subscriber.rs1use std::sync::Arc;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use bytes::Bytes;
6use serde_json::Value;
7use sqlx::{PgPool, Row};
8use strev::{CloseError, Message, MessageStream, Metadata, SubscribeError, Topic};
9use tokio::sync::mpsc::Sender;
10
11use crate::schema::ensure_schema;
12
13pub struct PostgresSubscriberConfig {
14 pub pool: PgPool,
15 pub consumer_group: String,
16 pub poll_interval: Duration,
17 pub batch_size: i64,
18 pub buffer_size: usize,
19}
20
21impl PostgresSubscriberConfig {
22 pub fn new(pool: PgPool, consumer_group: impl Into<String>) -> Self {
23 Self {
24 pool,
25 consumer_group: consumer_group.into(),
26 poll_interval: Duration::from_millis(200),
27 batch_size: 100,
28 buffer_size: 64,
29 }
30 }
31}
32
33pub struct PostgresSubscriber {
34 config: Arc<PostgresSubscriberConfig>,
35}
36
37impl PostgresSubscriber {
38 pub fn new(config: PostgresSubscriberConfig) -> Self {
39 Self {
40 config: Arc::new(config),
41 }
42 }
43}
44
45#[async_trait]
46impl strev::Subscriber for PostgresSubscriber {
47 async fn subscribe(&self, topic: &Topic) -> Result<MessageStream, SubscribeError> {
48 let config = self.config.clone();
49 let topic = topic.as_str().to_string();
50
51 ensure_schema(&config.pool)
52 .await
53 .map_err(|e| SubscribeError::Backend(Box::new(e)))?;
54
55 sqlx::query(
56 "INSERT INTO strev_offsets (consumer_group, topic, last_id) VALUES ($1, $2, 0) ON CONFLICT DO NOTHING",
57 )
58 .bind(&config.consumer_group)
59 .bind(&topic)
60 .execute(&config.pool)
61 .await
62 .map_err(|e| SubscribeError::Backend(Box::new(e)))?;
63
64 let (sender, stream) = MessageStream::channel(config.buffer_size);
65
66 tokio::spawn(async move {
67 loop {
68 if sender.is_closed() {
69 break;
70 }
71
72 match poll_once(&config, &topic, &sender).await {
73 Ok(count) if count > 0 => continue,
74 Ok(_) => tokio::time::sleep(config.poll_interval).await,
75 Err(_) => tokio::time::sleep(config.poll_interval).await,
76 }
77 }
78 });
79
80 Ok(stream)
81 }
82
83 async fn close(&mut self) -> Result<(), CloseError> {
84 Ok(())
85 }
86}
87
88async fn poll_once(
89 config: &PostgresSubscriberConfig,
90 topic: &str,
91 sender: &Sender<Message>,
92) -> Result<usize, sqlx::Error> {
93 let mut tx = config.pool.begin().await?;
94
95 let locked = sqlx::query(
96 "SELECT last_id FROM strev_offsets WHERE consumer_group = $1 AND topic = $2 FOR UPDATE SKIP LOCKED",
97 )
98 .bind(&config.consumer_group)
99 .bind(topic)
100 .fetch_optional(&mut *tx)
101 .await?;
102
103 let last_id: i64 = match locked {
104 Some(row) => row.try_get("last_id")?,
105 None => {
106 tx.rollback().await?;
107 return Ok(0);
108 }
109 };
110
111 let rows = sqlx::query(
112 "SELECT id, payload, metadata FROM strev_messages WHERE topic = $1 AND id > $2 ORDER BY id ASC LIMIT $3",
113 )
114 .bind(topic)
115 .bind(last_id)
116 .bind(config.batch_size)
117 .fetch_all(&mut *tx)
118 .await?;
119
120 if rows.is_empty() {
121 tx.rollback().await?;
122 return Ok(0);
123 }
124
125 let mut max_id = last_id;
126 for row in &rows {
127 let id: i64 = row.try_get("id")?;
128 let payload: Vec<u8> = row.try_get("payload")?;
129 let metadata_json: Value = row.try_get("metadata")?;
130
131 let mut metadata = Metadata::new();
132 if let Value::Object(map) = metadata_json {
133 for (key, value) in map {
134 if let Value::String(text) = value {
135 metadata.set(key, text);
136 }
137 }
138 }
139
140 let message = Message::with_metadata(Bytes::from(payload), metadata);
141 if sender.send(message).await.is_err() {
142 tx.rollback().await?;
143 return Ok(0);
144 }
145
146 max_id = id;
147 }
148
149 sqlx::query("UPDATE strev_offsets SET last_id = $1 WHERE consumer_group = $2 AND topic = $3")
150 .bind(max_id)
151 .bind(&config.consumer_group)
152 .bind(topic)
153 .execute(&mut *tx)
154 .await?;
155
156 tx.commit().await?;
157 Ok(rows.len())
158}