yrs_postgres/
lib.rs

1use async_stream;
2use async_trait::async_trait;
3use bytes::Bytes;
4use chrono::NaiveDateTime;
5use futures_util::stream::BoxStream;
6use futures_util::TryStreamExt;
7use sqlx::{PgPool, Row};
8use yrs::Doc;
9use yrs_store::doc::ForStore;
10use yrs_store::errors::StoreError;
11use yrs_store::Store;
12
13pub struct PostgresStorage {
14    document_id: i64,
15    table_name: String,
16    pool: PgPool,
17    run_vacuum: bool,
18}
19
20impl PostgresStorage {
21    pub async fn new(
22        document_id: i64,
23        table_name: String,
24        pool: PgPool,
25        run_vacuum: bool,
26    ) -> Result<Self, sqlx::Error> {
27        Ok(PostgresStorage {
28            document_id,
29            table_name,
30            pool,
31            run_vacuum,
32        })
33    }
34}
35
36#[async_trait]
37impl Store for PostgresStorage {
38    async fn delete(&self) -> Result<(), StoreError> {
39        sqlx::query(&format!(
40            "DELETE FROM {} WHERE document_id = $1",
41            self.table_name
42        ))
43        .bind(self.document_id)
44        .execute(&self.pool)
45        .await
46        .map(|_| ())
47        .map_err(StoreError::SqlxError)
48    }
49
50    async fn write(&self, update: &Bytes) -> Result<(), StoreError> {
51        let document_id = self.document_id.clone();
52        let now = chrono::Utc::now().naive_utc();
53        let query_result = sqlx::query(&format!(
54            "INSERT INTO {} (document_id, payload, timestamp) VALUES ($1, $2, $3)",
55            self.table_name
56        ))
57        .bind(document_id)
58        .bind(update.as_ref())
59        .bind(now)
60        .execute(&self.pool)
61        .await
62        .map_err(StoreError::SqlxError)?;
63
64        let rows_affected = query_result.rows_affected();
65
66        // 检查影响行数是否等于 1
67        if rows_affected != 1 {
68            return Err(StoreError::WriteError(format!(
69                "Expected 1 row affected for insert, but got {}",
70                rows_affected
71            )));
72        }
73
74        Ok(())
75    }
76
77    async fn read(&self) -> Result<BoxStream<Result<(Bytes, i64), StoreError>>, StoreError> {
78        let document_id = self.document_id;
79        let table_name = self.table_name.clone();
80        let pool = self.pool.clone();
81
82        let stream = async_stream::stream! {
83            let sql = format!(
84                "SELECT payload, timestamp FROM {} WHERE document_id = $1 ORDER BY timestamp",
85                table_name
86            );
87
88            let mut rows = sqlx::query(&sql)
89                .bind(document_id)
90                .fetch(&pool);
91
92            while let Some(row) = rows.try_next().await.map_err(StoreError::SqlxError)? {
93                let payload_vec: Vec<u8> = row.get("payload");
94                let payload = payload_vec.into();
95                let timestamp_ndt: NaiveDateTime = row.get("timestamp");
96                let timestamp_ms = timestamp_ndt.and_utc().timestamp_millis();
97                yield Ok((payload, timestamp_ms));
98            }
99        };
100
101        Ok(Box::pin(stream))
102    }
103
104    async fn read_payloads(&self) -> Result<BoxStream<Result<Bytes, StoreError>>, StoreError> {
105        let document_id = self.document_id;
106        let table_name = self.table_name.clone();
107        let pool = self.pool.clone();
108
109        let stream = async_stream::stream! {
110            let sql = format!(
111                "SELECT payload FROM {} WHERE document_id = $1 ORDER BY timestamp",
112                table_name
113            );
114
115            let mut rows = sqlx::query(&sql)
116                .bind(document_id)
117                .fetch(&pool);
118
119            while let Some(row) = rows.try_next().await.map_err(StoreError::SqlxError)? {
120                let payload_vec: Vec<u8> = row.get("payload");
121                let payload = payload_vec.into();
122                yield Ok(payload);
123            }
124        };
125
126        Ok(Box::pin(stream))
127    }
128
129    async fn squash(&self) -> Result<(), StoreError> {
130        let doc = Doc::new();
131        self.load(&doc).await?;
132        let tx = self.pool.begin().await.map_err(StoreError::SqlxError)?;
133        sqlx::query(&format!(
134            "DELETE FROM {} WHERE document_id = $1",
135            self.table_name
136        ))
137        .bind(self.document_id)
138        .execute(&self.pool)
139        .await
140        .map_err(StoreError::SqlxError)?;
141
142        let squashed_update = doc.get_update();
143        self.write(&squashed_update).await?;
144
145        // 如果在事务超出范围之前都未调用,rollback则自动调用
146        tx.commit().await.map_err(StoreError::SqlxError)?;
147
148        if self.run_vacuum {
149            // 回收死行占据的存储空间
150            sqlx::query(&format!("VACUUM {}", self.table_name))
151                .execute(&self.pool)
152                .await
153                .map_err(StoreError::SqlxError)?;
154        }
155        Ok(())
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
163    use std::env;
164    use std::str::FromStr;
165    use yrs::Any::{BigInt, Bool, Number};
166    use yrs::{GetString, Map, Out, ReadTxn, Text, Transact, WriteTxn};
167
168    // 验证数据库中的记录数量
169    async fn assert_record_count(
170        pool: &PgPool,
171        expected_count: i64,
172        message: &str,
173        document_id: i64,
174    ) -> Result<(), sqlx::Error> {
175        let count = sqlx::query!(
176            "SELECT COUNT(*) as count FROM document_updates WHERE document_id = $1",
177            document_id
178        )
179        .fetch_one(pool)
180        .await?
181        .count
182        .unwrap_or(0);
183        assert_eq!(count, expected_count, "{}", message);
184        Ok(())
185    }
186
187    // 清理测试数据
188    async fn cleanup_test_data(pool: &PgPool, document_id: i64) -> Result<(), sqlx::Error> {
189        sqlx::query!(
190            "DELETE FROM document_updates WHERE document_id = $1",
191            document_id
192        )
193        .execute(pool)
194        .await
195        .map(|_| ())
196    }
197
198    async fn create_test_pg_poll(document_id: i64) -> Result<PgPool, sqlx::Error> {
199        let url = env::var("DATABASE_URL")
200            .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/test".to_string());
201        // 连接到测试数据库
202        let connect_options = PgConnectOptions::from_str(&url).expect("无法解析数据库连接字符串");
203
204        let pool = PgPoolOptions::new()
205            .max_connections(5)
206            .connect_with(connect_options)
207            .await?;
208
209        // 确保测试表存在
210        sqlx::query(
211            "
212            CREATE TABLE IF NOT EXISTS document_updates (
213                document_id BIGINT NOT NULL,
214                payload BYTEA NOT NULL,
215                timestamp TIMESTAMP NOT NULL
216            )
217        ",
218        )
219        .execute(&pool)
220        .await?;
221
222        // 清理可能存在的测试数据
223        cleanup_test_data(&pool, document_id).await?;
224
225        Ok(pool)
226    }
227
228    // 创建一个测试用的PostgresStore
229    async fn create_test_store(document_id: i64) -> Result<(PostgresStorage, PgPool), sqlx::Error> {
230        let pool = create_test_pg_poll(document_id).await?;
231
232        // 创建测试用的PostgresStore
233        let store = PostgresStorage::new(
234            document_id,
235            "document_updates".to_string(),
236            pool.clone(),
237            false,
238        )
239        .await?;
240
241        Ok((store, pool))
242    }
243
244    // 创建文档并添加文本内容
245    async fn create_doc_with_text(
246        store: &PostgresStorage,
247        text_content: &str,
248    ) -> Result<Doc, StoreError> {
249        let doc = Doc::new();
250        {
251            let mut txn = doc.transact_mut();
252            let text = txn.get_or_insert_text("text");
253            text.insert(&mut txn, 0, text_content);
254        }
255        let update = doc.get_update();
256        store.write(&update).await?;
257        Ok(doc)
258    }
259
260    // 向文档添加更新
261    async fn update_doc_text(
262        doc: &Doc,
263        store: &PostgresStorage,
264        position: u32,
265        text_content: &str,
266    ) -> Result<(), StoreError> {
267        {
268            let mut txn = doc.transact_mut();
269            let text = txn.get_text("text").unwrap();
270            text.insert(&mut txn, position, text_content);
271        }
272        let update = doc.get_update();
273        store.write(&update).await
274    }
275
276    // 从文档中删除文本
277    async fn remove_doc_text(
278        doc: &Doc,
279        store: &PostgresStorage,
280        position: u32,
281        length: u32,
282    ) -> Result<(), StoreError> {
283        {
284            let mut txn = doc.transact_mut();
285            let text = txn.get_text("text").unwrap();
286            text.remove_range(&mut txn, position, length);
287        }
288        let update = doc.get_update();
289        store.write(&update).await
290    }
291
292    // 验证文档文本内容
293    fn assert_doc_text(doc: &Doc, expected_text: &str, message: &str) {
294        let txn = doc.transact();
295        let text = txn.get_text("text").unwrap();
296        let content = text.get_string(&txn);
297        assert_eq!(content, expected_text, "{}", message);
298    }
299
300    // 验证文档Map内容
301    fn assert_doc_map(
302        doc: &Doc,
303        map_name: &str,
304        expected_entries: &[(&str, serde_json::Value)],
305        message: &str,
306    ) {
307        let txn = doc.transact();
308        let map = txn.get_map(map_name).unwrap();
309
310        for (key, expected_value) in expected_entries {
311            match expected_value {
312                serde_json::Value::String(expected_str) => {
313                    let value = map.get(&txn, *key).unwrap().to_string(&txn);
314                    assert_eq!(value, *expected_str, "{} - key: {}", message, key);
315                }
316                serde_json::Value::Number(expected_num) => {
317                    if let Some(expected_i64) = expected_num.as_i64() {
318                        let i64_value = match map.get(&txn, *key).unwrap() {
319                            Out::Any(BigInt(v)) => v,
320                            Out::Any(Number(v)) => v as i64,
321                            _ => panic!("Expected Out::Any(BigInt)"),
322                        };
323                        assert_eq!(i64_value, expected_i64, "{} - key: {}", message, key);
324                    } else if let Some(expected_f64) = expected_num.as_f64() {
325                        // 将值转换为字符串并解析为f64进行比较
326                        let f64_value = match map.get(&txn, *key).unwrap() {
327                            Out::Any(BigInt(v)) => v as f64,
328                            Out::Any(Number(v)) => v,
329                            _ => panic!("Expected Out::Any(Number)"),
330                        };
331                        assert!(
332                            (f64_value - expected_f64).abs() < f64::EPSILON,
333                            "{} - key: {}",
334                            message,
335                            key
336                        );
337                    }
338                }
339                serde_json::Value::Bool(expected_bool) => {
340                    // 将值转换为字符串并解析为布尔值进行比较
341                    let bool_value = match map.get(&txn, *key).unwrap() {
342                        Out::Any(Bool(v)) => v,
343                        _ => panic!("Expected Out::Any(Bool)"),
344                    };
345                    assert_eq!(bool_value, *expected_bool, "{} - key: {}", message, key);
346                }
347                _ => {} // 忽略其他类型
348            }
349        }
350    }
351
352    // 创建新文档并应用更新
353    async fn create_and_apply_doc(store: &PostgresStorage) -> Result<Doc, StoreError> {
354        let doc = Doc::new();
355        store.load(&doc).await?;
356        Ok(doc)
357    }
358
359    // 向文档添加Map
360    async fn add_map_to_doc(
361        doc: &Doc,
362        store: &PostgresStorage,
363        map_name: &str,
364        entries: &[(&str, serde_json::Value)],
365    ) -> Result<(), StoreError> {
366        {
367            let mut txn = doc.transact_mut();
368            let map = txn.get_or_insert_map(map_name);
369
370            for (key, value) in entries {
371                match value {
372                    serde_json::Value::String(s) => {
373                        map.insert(&mut txn, *key, s.as_str());
374                    }
375                    serde_json::Value::Number(n) => {
376                        if let Some(i) = n.as_i64() {
377                            map.insert(&mut txn, *key, i);
378                        } else if let Some(f) = n.as_f64() {
379                            map.insert(&mut txn, *key, f);
380                        }
381                    }
382                    serde_json::Value::Bool(b) => {
383                        map.insert(&mut txn, *key, *b);
384                    }
385                    _ => {} // 忽略其他类型
386                }
387            }
388        }
389        let update = doc.get_update();
390        store.write(&update).await
391    }
392
393    #[tokio::test]
394    async fn test_squash_preserves_history() -> Result<(), Box<dyn std::error::Error>> {
395        let document_id = 1;
396
397        // 创建测试存储
398        let (store, pool) = create_test_store(document_id).await?;
399
400        // 创建一个YDoc并添加初始文本
401        let doc = create_doc_with_text(&store, "Hello").await?;
402
403        // 添加更多文本
404        update_doc_text(&doc, &store, 5, ", World").await?;
405
406        // 验证数据库中有两条记录
407        assert_record_count(&pool, 2, "数据库中应该有两条更新记录", document_id).await?;
408
409        // 执行squash操作
410        store.squash().await?;
411
412        // 验证数据库中现在只有一条记录
413        assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
414
415        // 创建一个新的YDoc并应用squash后的更新
416        let new_doc = create_and_apply_doc(&store).await?;
417
418        // 验证新文档包含所有历史更改
419        assert_doc_text(
420            &new_doc,
421            "Hello, World",
422            "squash后的文档应该包含所有历史更改",
423        );
424
425        // 清理测试数据
426        cleanup_test_data(&pool, document_id).await?;
427
428        Ok(())
429    }
430
431    #[tokio::test]
432    async fn test_squash_with_multiple_updates() -> Result<(), Box<dyn std::error::Error>> {
433        let document_id = 2;
434        // 创建测试存储
435        let (store, pool) = create_test_store(document_id).await?;
436
437        // 创建一个YDoc并添加多次更新
438        let doc = Doc::new();
439
440        // 进行多次更新
441        for i in 0..5 {
442            {
443                let mut txn = doc.transact_mut();
444                let text = if i == 0 {
445                    txn.get_or_insert_text("text")
446                } else {
447                    txn.get_text("text").unwrap()
448                };
449
450                let len = text.len(&txn);
451                text.insert(&mut txn, len, &format!("Part {}", i));
452            }
453            // 存储更新
454            let update = doc.get_update();
455            store.write(&update).await?;
456        }
457
458        // 验证数据库中有5条记录
459        assert_record_count(&pool, 5, "数据库中应该有5条更新记录", document_id).await?;
460
461        // 执行squash操作
462        store.squash().await?;
463
464        // 验证数据库中现在只有一条记录
465        assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
466
467        // 创建一个新的YDoc并应用squash后的更新
468        let new_doc = create_and_apply_doc(&store).await?;
469
470        // 验证新文档包含所有历史更改
471        assert_doc_text(
472            &new_doc,
473            "Part 0Part 1Part 2Part 3Part 4",
474            "squash后的文档应该包含所有历史更改",
475        );
476
477        // 清理测试数据
478        cleanup_test_data(&pool, document_id).await?;
479
480        Ok(())
481    }
482
483    #[tokio::test]
484    async fn test_squash_with_complex_operations() -> Result<(), Box<dyn std::error::Error>> {
485        let document_id = 3;
486        // 创建测试存储
487        let (store, pool) = create_test_store(document_id).await?;
488
489        // 创建一个YDoc并添加初始文本
490        let doc = create_doc_with_text(&store, "Initial content").await?;
491
492        // 添加更多文本
493        update_doc_text(&doc, &store, 15, " with more text").await?;
494
495        // 第三次更新:删除部分内容
496        remove_doc_text(&doc, &store, 8, 8).await?; // 删除"content "
497
498        // 执行squash操作
499        store.squash().await?;
500
501        // 创建一个新的YDoc并应用squash后的更新
502        let new_doc = create_and_apply_doc(&store).await?;
503
504        // 验证新文档包含所有历史更改,包括删除操作
505        assert_doc_text(
506            &new_doc,
507            "Initial with more text",
508            "squash后的文档应该正确应用所有操作,包括删除",
509        );
510
511        // 清理测试数据
512        cleanup_test_data(&pool, document_id).await?;
513
514        Ok(())
515    }
516
517    #[tokio::test]
518    async fn test_apply_updates() -> Result<(), Box<dyn std::error::Error>> {
519        let document_id = 4;
520        // 创建测试存储
521        let (store, pool) = create_test_store(document_id).await?;
522
523        // 创建源文档并添加文本内容
524        let source_doc = create_doc_with_text(&store, "Hello, World!").await?;
525
526        let map_name = "map".to_string();
527        // 添加 Map
528        let entries = [
529            ("key1", serde_json::json!("value1")),
530            ("key2", serde_json::json!(42)),
531        ];
532        add_map_to_doc(&source_doc, &store, &map_name, &entries).await?;
533
534        // 创建目标文档并应用更新
535        let target_doc = create_and_apply_doc(&store).await?;
536
537        // 验证目标文档包含所有更新
538
539        // 验证文本
540        assert_doc_text(&target_doc, "Hello, World!", "目标文档应该包含文本内容");
541
542        // 验证 Map
543        assert_doc_map(&target_doc, &map_name, &entries, "目标文档应该包含Map的值");
544
545        // 清理测试数据
546        cleanup_test_data(&pool, document_id).await?;
547
548        Ok(())
549    }
550
551    #[tokio::test]
552    async fn test_encode_state_as_update() -> Result<(), Box<dyn std::error::Error>> {
553        let document_id = 5;
554        // 创建测试存储
555        let (store, pool) = create_test_store(document_id).await?;
556
557        // 创建文档并添加内容
558        let doc = Doc::new();
559
560        // 添加内容
561        {
562            let mut txn = doc.transact_mut();
563            let text = txn.get_or_insert_text("text");
564            text.insert(&mut txn, 0, "测试文本内容");
565
566            let map = txn.get_or_insert_map("map");
567            map.insert(&mut txn, "key", "value");
568        }
569
570        // 使用encode_state_as_update存储状态
571        store.save(doc.clone()).await?;
572
573        // 验证数据已写入数据库
574        assert_record_count(&pool, 1, "数据库中应该有一条记录", document_id).await?;
575
576        // 创建新文档并应用更新
577        let new_doc = Doc::new();
578        store.load(&new_doc).await?;
579
580        // 验证新文档包含所有内容
581        let txn = new_doc.transact();
582
583        // 验证文本
584        let text = txn.get_text("text").unwrap();
585        let content = text.get_string(&txn);
586        assert_eq!(content, "测试文本内容", "新文档应该包含文本内容");
587
588        // 验证map
589        let map = txn.get_map("map").unwrap();
590        let key_value = map.get(&txn, "key").unwrap().to_string(&txn);
591        assert_eq!(key_value, "value", "新文档应该包含Map的key值");
592
593        // 清理测试数据
594        cleanup_test_data(&pool, document_id).await?;
595
596        Ok(())
597    }
598}