1use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
2
3use crate::base::{Record, RecordError};
4
5pub trait CounterCache: Record {
7 fn counter_cache_columns() -> &'static [(&'static str, &'static str)];
9
10 async fn increment_counter(
12 column: &str,
13 id: i64,
14 db: &DatabaseConnection,
15 ) -> Result<(), RecordError> {
16 update_counter::<Self>(column, id, 1, db).await
17 }
18
19 async fn decrement_counter(
21 column: &str,
22 id: i64,
23 db: &DatabaseConnection,
24 ) -> Result<(), RecordError> {
25 update_counter::<Self>(column, id, -1, db).await
26 }
27}
28
29async fn update_counter<T: CounterCache>(
30 column: &str,
31 id: i64,
32 delta: i64,
33 db: &DatabaseConnection,
34) -> Result<(), RecordError> {
35 validate_counter_column::<T>(column)?;
36
37 let sql = format!(
38 "UPDATE {table} SET {column} = COALESCE({column}, 0) + ? WHERE {primary_key} = ?",
39 table = T::table_name(),
40 primary_key = T::primary_key_name(),
41 );
42 let result = db
43 .execute_raw(Statement::from_sql_and_values(
44 db.get_database_backend(),
45 sql,
46 [delta.into(), id.into()],
47 ))
48 .await?;
49
50 if result.rows_affected() == 0 {
51 return Err(RecordError::NotFound);
52 }
53
54 Ok(())
55}
56
57fn validate_counter_column<T: CounterCache>(column: &str) -> Result<(), RecordError> {
58 if T::counter_cache_columns()
59 .iter()
60 .any(|(candidate, _)| *candidate == column)
61 {
62 Ok(())
63 } else {
64 Err(RecordError::Invalid(format!(
65 "unknown counter cache column: {column}"
66 )))
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use sea_orm::{
73 ActiveModelTrait, ActiveValue::NotSet, ActiveValue::Set, ConnectionTrait, Database,
74 EntityTrait, Schema,
75 };
76
77 use super::CounterCache;
78 use crate::base::{Record, RecordState};
79
80 mod counter_record {
81 use sea_orm::entity::prelude::*;
82
83 #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
84 #[sea_orm(table_name = "counter_records")]
85 pub struct Model {
86 #[sea_orm(primary_key)]
87 pub id: i32,
88 pub comments_count: i32,
89 pub views_count: Option<i32>,
90 }
91
92 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
93 pub enum Relation {}
94
95 impl ActiveModelBehavior for ActiveModel {}
96 }
97
98 const COUNTER_COLUMNS: [(&str, &str); 2] =
99 [("comments_count", "comments"), ("views_count", "views")];
100
101 #[derive(Clone, Debug, Default, PartialEq, Eq)]
102 struct CounterRecord {
103 id: Option<i64>,
104 comments_count: i32,
105 views_count: Option<i32>,
106 state: RecordState,
107 }
108
109 impl Record for CounterRecord {
110 type Entity = counter_record::Entity;
111
112 fn table_name() -> &'static str {
113 "counter_records"
114 }
115
116 fn id(&self) -> Option<i64> {
117 self.id
118 }
119
120 fn record_state(&self) -> RecordState {
121 self.state
122 }
123
124 fn set_record_state(&mut self, state: RecordState) {
125 self.state = state;
126 }
127
128 fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
129 Self {
130 id: Some(i64::from(model.id)),
131 comments_count: model.comments_count,
132 views_count: model.views_count,
133 state: RecordState::Persisted,
134 }
135 }
136
137 fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
138 counter_record::ActiveModel {
139 id: match self.id.and_then(|value| i32::try_from(value).ok()) {
140 Some(value) => Set(value),
141 None => NotSet,
142 },
143 comments_count: Set(self.comments_count),
144 views_count: Set(self.views_count),
145 }
146 }
147 }
148
149 impl CounterCache for CounterRecord {
150 fn counter_cache_columns() -> &'static [(&'static str, &'static str)] {
151 &COUNTER_COLUMNS
152 }
153 }
154
155 async fn setup_db() -> sea_orm::DatabaseConnection {
156 let db = Database::connect("sqlite::memory:")
157 .await
158 .expect("in-memory sqlite connection should succeed");
159 let backend = db.get_database_backend();
160 let schema = Schema::new(backend);
161 db.execute(&schema.create_table_from_entity(counter_record::Entity))
162 .await
163 .expect("counter_records table should be created");
164 counter_record::ActiveModel {
165 comments_count: Set(2),
166 ..Default::default()
167 }
168 .insert(&db)
169 .await
170 .expect("seed row should insert");
171 db
172 }
173
174 #[tokio::test]
175 async fn increment_counter_updates_column() {
176 let db = setup_db().await;
177
178 CounterRecord::increment_counter("comments_count", 1, &db)
179 .await
180 .expect("increment should succeed");
181
182 let row = counter_record::Entity::find_by_id(1)
183 .one(&db)
184 .await
185 .expect("query should succeed")
186 .expect("row should exist");
187 assert_eq!(row.comments_count, 3);
188 }
189
190 #[tokio::test]
191 async fn decrement_counter_updates_column() {
192 let db = setup_db().await;
193
194 CounterRecord::decrement_counter("comments_count", 1, &db)
195 .await
196 .expect("decrement should succeed");
197
198 let row = counter_record::Entity::find_by_id(1)
199 .one(&db)
200 .await
201 .expect("query should succeed")
202 .expect("row should exist");
203 assert_eq!(row.comments_count, 1);
204 }
205
206 #[tokio::test]
207 async fn unknown_counter_column_is_rejected() {
208 let db = setup_db().await;
209
210 let error = CounterRecord::increment_counter("likes_count", 1, &db)
211 .await
212 .expect_err("unknown counter should fail");
213
214 assert!(
215 matches!(error, crate::RecordError::Invalid(message) if message.contains("likes_count"))
216 );
217 }
218
219 #[tokio::test]
220 async fn increment_counter_returns_not_found_for_missing_id() {
221 let db = setup_db().await;
222
223 let error = CounterRecord::increment_counter("comments_count", 404, &db)
224 .await
225 .expect_err("missing row should fail");
226
227 assert!(matches!(error, crate::RecordError::NotFound));
228 }
229
230 #[tokio::test]
231 async fn decrement_counter_returns_not_found_for_missing_id() {
232 let db = setup_db().await;
233
234 let error = CounterRecord::decrement_counter("comments_count", 404, &db)
235 .await
236 .expect_err("missing row should fail");
237
238 assert!(matches!(error, crate::RecordError::NotFound));
239 }
240
241 #[tokio::test]
242 async fn repeated_increments_accumulate_on_existing_counter() {
243 let db = setup_db().await;
244
245 CounterRecord::increment_counter("comments_count", 1, &db)
246 .await
247 .expect("first increment should succeed");
248 CounterRecord::increment_counter("comments_count", 1, &db)
249 .await
250 .expect("second increment should succeed");
251
252 let row = counter_record::Entity::find_by_id(1)
253 .one(&db)
254 .await
255 .expect("query should succeed")
256 .expect("row should exist");
257 assert_eq!(row.comments_count, 4);
258 }
259
260 #[tokio::test]
261 async fn decrement_counter_can_cross_zero() {
262 let db = setup_db().await;
263
264 CounterRecord::decrement_counter("comments_count", 1, &db)
265 .await
266 .expect("first decrement should succeed");
267 CounterRecord::decrement_counter("comments_count", 1, &db)
268 .await
269 .expect("second decrement should succeed");
270 CounterRecord::decrement_counter("comments_count", 1, &db)
271 .await
272 .expect("third decrement should succeed");
273
274 let row = counter_record::Entity::find_by_id(1)
275 .one(&db)
276 .await
277 .expect("query should succeed")
278 .expect("row should exist");
279 assert_eq!(row.comments_count, -1);
280 }
281
282 #[tokio::test]
283 async fn increment_counter_coalesces_null_counters_to_zero() {
284 let db = setup_db().await;
285
286 CounterRecord::increment_counter("views_count", 1, &db)
287 .await
288 .expect("increment should succeed for null counter");
289
290 let row = counter_record::Entity::find_by_id(1)
291 .one(&db)
292 .await
293 .expect("query should succeed")
294 .expect("row should exist");
295 assert_eq!(row.views_count, Some(1));
296 }
297}