Skip to main content

rustrails_record/
counter_cache.rs

1use sea_orm::{ConnectionTrait, DatabaseConnection, Statement};
2
3use crate::base::{Record, RecordError};
4
5/// Counter cache helpers for record types.
6pub trait CounterCache: Record {
7    /// Returns the available counter cache columns as `(column, association)` pairs.
8    fn counter_cache_columns() -> &'static [(&'static str, &'static str)];
9
10    /// Increments a counter cache column for the record identified by `id`.
11    async fn increment_counter(
12        column: &str,
13        id: i64,
14        db: &DatabaseConnection,
15    ) -> Result<(), RecordError> {
16        update_counter::<Self>(column, id, 1, db).await
17    }
18
19    /// Decrements a counter cache column for the record identified by `id`.
20    async fn decrement_counter(
21        column: &str,
22        id: i64,
23        db: &DatabaseConnection,
24    ) -> Result<(), RecordError> {
25        update_counter::<Self>(column, id, -1, db).await
26    }
27}
28
29async fn update_counter<T: CounterCache>(
30    column: &str,
31    id: i64,
32    delta: i64,
33    db: &DatabaseConnection,
34) -> Result<(), RecordError> {
35    validate_counter_column::<T>(column)?;
36
37    let sql = format!(
38        "UPDATE {table} SET {column} = COALESCE({column}, 0) + ? WHERE {primary_key} = ?",
39        table = T::table_name(),
40        primary_key = T::primary_key_name(),
41    );
42    let result = db
43        .execute_raw(Statement::from_sql_and_values(
44            db.get_database_backend(),
45            sql,
46            [delta.into(), id.into()],
47        ))
48        .await?;
49
50    if result.rows_affected() == 0 {
51        return Err(RecordError::NotFound);
52    }
53
54    Ok(())
55}
56
57fn validate_counter_column<T: CounterCache>(column: &str) -> Result<(), RecordError> {
58    if T::counter_cache_columns()
59        .iter()
60        .any(|(candidate, _)| *candidate == column)
61    {
62        Ok(())
63    } else {
64        Err(RecordError::Invalid(format!(
65            "unknown counter cache column: {column}"
66        )))
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use sea_orm::{
73        ActiveModelTrait, ActiveValue::NotSet, ActiveValue::Set, ConnectionTrait, Database,
74        EntityTrait, Schema,
75    };
76
77    use super::CounterCache;
78    use crate::base::{Record, RecordState};
79
80    mod counter_record {
81        use sea_orm::entity::prelude::*;
82
83        #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
84        #[sea_orm(table_name = "counter_records")]
85        pub struct Model {
86            #[sea_orm(primary_key)]
87            pub id: i32,
88            pub comments_count: i32,
89            pub views_count: Option<i32>,
90        }
91
92        #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
93        pub enum Relation {}
94
95        impl ActiveModelBehavior for ActiveModel {}
96    }
97
98    const COUNTER_COLUMNS: [(&str, &str); 2] =
99        [("comments_count", "comments"), ("views_count", "views")];
100
101    #[derive(Clone, Debug, Default, PartialEq, Eq)]
102    struct CounterRecord {
103        id: Option<i64>,
104        comments_count: i32,
105        views_count: Option<i32>,
106        state: RecordState,
107    }
108
109    impl Record for CounterRecord {
110        type Entity = counter_record::Entity;
111
112        fn table_name() -> &'static str {
113            "counter_records"
114        }
115
116        fn id(&self) -> Option<i64> {
117            self.id
118        }
119
120        fn record_state(&self) -> RecordState {
121            self.state
122        }
123
124        fn set_record_state(&mut self, state: RecordState) {
125            self.state = state;
126        }
127
128        fn from_sea_model(model: <Self::Entity as EntityTrait>::Model) -> Self {
129            Self {
130                id: Some(i64::from(model.id)),
131                comments_count: model.comments_count,
132                views_count: model.views_count,
133                state: RecordState::Persisted,
134            }
135        }
136
137        fn to_active_model(&self) -> <Self::Entity as EntityTrait>::ActiveModel {
138            counter_record::ActiveModel {
139                id: match self.id.and_then(|value| i32::try_from(value).ok()) {
140                    Some(value) => Set(value),
141                    None => NotSet,
142                },
143                comments_count: Set(self.comments_count),
144                views_count: Set(self.views_count),
145            }
146        }
147    }
148
149    impl CounterCache for CounterRecord {
150        fn counter_cache_columns() -> &'static [(&'static str, &'static str)] {
151            &COUNTER_COLUMNS
152        }
153    }
154
155    async fn setup_db() -> sea_orm::DatabaseConnection {
156        let db = Database::connect("sqlite::memory:")
157            .await
158            .expect("in-memory sqlite connection should succeed");
159        let backend = db.get_database_backend();
160        let schema = Schema::new(backend);
161        db.execute(&schema.create_table_from_entity(counter_record::Entity))
162            .await
163            .expect("counter_records table should be created");
164        counter_record::ActiveModel {
165            comments_count: Set(2),
166            ..Default::default()
167        }
168        .insert(&db)
169        .await
170        .expect("seed row should insert");
171        db
172    }
173
174    #[tokio::test]
175    async fn increment_counter_updates_column() {
176        let db = setup_db().await;
177
178        CounterRecord::increment_counter("comments_count", 1, &db)
179            .await
180            .expect("increment should succeed");
181
182        let row = counter_record::Entity::find_by_id(1)
183            .one(&db)
184            .await
185            .expect("query should succeed")
186            .expect("row should exist");
187        assert_eq!(row.comments_count, 3);
188    }
189
190    #[tokio::test]
191    async fn decrement_counter_updates_column() {
192        let db = setup_db().await;
193
194        CounterRecord::decrement_counter("comments_count", 1, &db)
195            .await
196            .expect("decrement should succeed");
197
198        let row = counter_record::Entity::find_by_id(1)
199            .one(&db)
200            .await
201            .expect("query should succeed")
202            .expect("row should exist");
203        assert_eq!(row.comments_count, 1);
204    }
205
206    #[tokio::test]
207    async fn unknown_counter_column_is_rejected() {
208        let db = setup_db().await;
209
210        let error = CounterRecord::increment_counter("likes_count", 1, &db)
211            .await
212            .expect_err("unknown counter should fail");
213
214        assert!(
215            matches!(error, crate::RecordError::Invalid(message) if message.contains("likes_count"))
216        );
217    }
218
219    #[tokio::test]
220    async fn increment_counter_returns_not_found_for_missing_id() {
221        let db = setup_db().await;
222
223        let error = CounterRecord::increment_counter("comments_count", 404, &db)
224            .await
225            .expect_err("missing row should fail");
226
227        assert!(matches!(error, crate::RecordError::NotFound));
228    }
229
230    #[tokio::test]
231    async fn decrement_counter_returns_not_found_for_missing_id() {
232        let db = setup_db().await;
233
234        let error = CounterRecord::decrement_counter("comments_count", 404, &db)
235            .await
236            .expect_err("missing row should fail");
237
238        assert!(matches!(error, crate::RecordError::NotFound));
239    }
240
241    #[tokio::test]
242    async fn repeated_increments_accumulate_on_existing_counter() {
243        let db = setup_db().await;
244
245        CounterRecord::increment_counter("comments_count", 1, &db)
246            .await
247            .expect("first increment should succeed");
248        CounterRecord::increment_counter("comments_count", 1, &db)
249            .await
250            .expect("second increment should succeed");
251
252        let row = counter_record::Entity::find_by_id(1)
253            .one(&db)
254            .await
255            .expect("query should succeed")
256            .expect("row should exist");
257        assert_eq!(row.comments_count, 4);
258    }
259
260    #[tokio::test]
261    async fn decrement_counter_can_cross_zero() {
262        let db = setup_db().await;
263
264        CounterRecord::decrement_counter("comments_count", 1, &db)
265            .await
266            .expect("first decrement should succeed");
267        CounterRecord::decrement_counter("comments_count", 1, &db)
268            .await
269            .expect("second decrement should succeed");
270        CounterRecord::decrement_counter("comments_count", 1, &db)
271            .await
272            .expect("third decrement should succeed");
273
274        let row = counter_record::Entity::find_by_id(1)
275            .one(&db)
276            .await
277            .expect("query should succeed")
278            .expect("row should exist");
279        assert_eq!(row.comments_count, -1);
280    }
281
282    #[tokio::test]
283    async fn increment_counter_coalesces_null_counters_to_zero() {
284        let db = setup_db().await;
285
286        CounterRecord::increment_counter("views_count", 1, &db)
287            .await
288            .expect("increment should succeed for null counter");
289
290        let row = counter_record::Entity::find_by_id(1)
291            .one(&db)
292            .await
293            .expect("query should succeed")
294            .expect("row should exist");
295        assert_eq!(row.views_count, Some(1));
296    }
297}