yrs_postgres/
lib.rs

1use async_stream;
2use async_trait::async_trait;
3use chrono::NaiveDateTime;
4use futures_util::stream::BoxStream;
5use futures_util::TryStreamExt;
6use sqlx::{PgPool, Row};
7use yrs::Doc;
8use yrs_store::doc::ForStore;
9use yrs_store::errors::StoreError;
10use yrs_store::Store;
11
12pub struct PostgresStorage {
13    document_id: i64,
14    table_name: String,
15    pool: PgPool,
16    run_vacuum: bool,
17}
18
19impl PostgresStorage {
20    pub async fn new(
21        document_id: i64,
22        table_name: String,
23        pool: PgPool,
24        run_vacuum: bool,
25    ) -> Result<Self, sqlx::Error> {
26        Ok(PostgresStorage {
27            document_id,
28            table_name,
29            pool,
30            run_vacuum,
31        })
32    }
33}
34
35fn map_sqlx_err(error: sqlx::Error) -> StoreError {
36    StoreError::StorageError(Box::new(error))
37}
38
39#[async_trait]
40impl Store for PostgresStorage {
41    async fn delete(&self) -> Result<(), StoreError> {
42        sqlx::query(&format!(
43            "DELETE FROM {} WHERE document_id = $1",
44            self.table_name
45        ))
46        .bind(self.document_id)
47        .execute(&self.pool)
48        .await
49        .map(|_| ())
50        .map_err(map_sqlx_err)
51    }
52
53    async fn write(&self, update: &Vec<u8>) -> Result<(), StoreError> {
54        let document_id = self.document_id.clone();
55        let now = chrono::Utc::now().naive_utc();
56        let query_result = sqlx::query(&format!(
57            "INSERT INTO {} (document_id, payload, timestamp) VALUES ($1, $2, $3)",
58            self.table_name
59        ))
60        .bind(document_id)
61        .bind(update)
62        .bind(now)
63        .execute(&self.pool)
64        .await
65        .map_err(map_sqlx_err)?;
66
67        let rows_affected = query_result.rows_affected();
68
69        // 检查影响行数是否等于 1
70        if rows_affected != 1 {
71            return Err(StoreError::WriteError(format!(
72                "Expected 1 row affected for insert, but got {}",
73                rows_affected
74            )));
75        }
76
77        Ok(())
78    }
79
80    async fn read(&self) -> Result<BoxStream<Result<(Vec<u8>, i64), StoreError>>, StoreError> {
81        let document_id = self.document_id;
82        let table_name = self.table_name.clone();
83        let pool = self.pool.clone();
84
85        let stream = async_stream::stream! {
86            let sql = format!(
87                "SELECT payload, timestamp FROM {} WHERE document_id = $1 ORDER BY timestamp",
88                table_name
89            );
90
91            let mut rows = sqlx::query(&sql)
92                .bind(document_id)
93                .fetch(&pool);
94
95            while let Some(row) = rows.try_next().await.map_err(map_sqlx_err)? {
96                let payload: Vec<u8> = row.get("payload");
97                let timestamp_ndt: NaiveDateTime = row.get("timestamp");
98                let timestamp_ms = timestamp_ndt.and_utc().timestamp_millis();
99                yield Ok((payload, timestamp_ms));
100            }
101        };
102
103        Ok(Box::pin(stream))
104    }
105
106    async fn read_payloads(&self) -> Result<BoxStream<Result<Vec<u8>, StoreError>>, StoreError> {
107        let document_id = self.document_id;
108        let table_name = self.table_name.clone();
109        let pool = self.pool.clone();
110
111        let stream = async_stream::stream! {
112            let sql = format!(
113                "SELECT payload FROM {} WHERE document_id = $1 ORDER BY timestamp",
114                table_name
115            );
116
117            let mut rows = sqlx::query(&sql)
118                .bind(document_id)
119                .fetch(&pool);
120
121            while let Some(row) = rows.try_next().await.map_err(map_sqlx_err)? {
122                let payload: Vec<u8> = row.get("payload");
123                yield Ok(payload);
124            }
125        };
126
127        Ok(Box::pin(stream))
128    }
129
130    async fn squash(&self) -> Result<(), StoreError> {
131        let doc = Doc::new();
132        self.load(&doc).await?;
133        let tx = self.pool.begin().await.map_err(map_sqlx_err)?;
134        sqlx::query(&format!(
135            "DELETE FROM {} WHERE document_id = $1",
136            self.table_name
137        ))
138        .bind(self.document_id)
139        .execute(&self.pool)
140        .await
141        .map_err(map_sqlx_err)?;
142
143        let squashed_update = doc.get_update();
144        self.write(&squashed_update).await?;
145
146        // 如果在事务超出范围之前都未调用,rollback则自动调用
147        tx.commit().await.map_err(map_sqlx_err)?;
148
149        if self.run_vacuum {
150            // 回收死行占据的存储空间
151            sqlx::query(&format!("VACUUM {}", self.table_name))
152                .execute(&self.pool)
153                .await
154                .map_err(map_sqlx_err)?;
155        }
156        Ok(())
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163    use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
164    use std::env;
165    use std::str::FromStr;
166    use yrs::Any::{BigInt, Bool, Number};
167    use yrs::{GetString, Map, Out, ReadTxn, Text, Transact, WriteTxn};
168
169    // 验证数据库中的记录数量
170    async fn assert_record_count(
171        pool: &PgPool,
172        expected_count: i64,
173        message: &str,
174        document_id: i64,
175    ) -> Result<(), sqlx::Error> {
176        let count = sqlx::query!(
177            "SELECT COUNT(*) as count FROM document_updates WHERE document_id = $1",
178            document_id
179        )
180        .fetch_one(pool)
181        .await?
182        .count
183        .unwrap_or(0);
184        assert_eq!(count, expected_count, "{}", message);
185        Ok(())
186    }
187
188    // 清理测试数据
189    async fn cleanup_test_data(pool: &PgPool, document_id: i64) -> Result<(), sqlx::Error> {
190        sqlx::query!(
191            "DELETE FROM document_updates WHERE document_id = $1",
192            document_id
193        )
194        .execute(pool)
195        .await
196        .map(|_| ())
197    }
198
199    async fn create_test_pg_poll(document_id: i64) -> Result<PgPool, sqlx::Error> {
200        let url = env::var("DATABASE_URL")
201            .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/test".to_string());
202        // 连接到测试数据库
203        let connect_options = PgConnectOptions::from_str(&url).expect("无法解析数据库连接字符串");
204
205        let pool = PgPoolOptions::new()
206            .max_connections(5)
207            .connect_with(connect_options)
208            .await?;
209
210        // 确保测试表存在
211        sqlx::query(
212            "
213            CREATE TABLE IF NOT EXISTS document_updates (
214                document_id BIGINT NOT NULL,
215                payload BYTEA NOT NULL,
216                timestamp TIMESTAMP NOT NULL
217            )
218        ",
219        )
220        .execute(&pool)
221        .await?;
222
223        // 清理可能存在的测试数据
224        cleanup_test_data(&pool, document_id).await?;
225
226        Ok(pool)
227    }
228
229    // 创建一个测试用的PostgresStore
230    async fn create_test_store(document_id: i64) -> Result<(PostgresStorage, PgPool), sqlx::Error> {
231        let pool = create_test_pg_poll(document_id).await?;
232
233        // 创建测试用的PostgresStore
234        let store = PostgresStorage::new(
235            document_id,
236            "document_updates".to_string(),
237            pool.clone(),
238            false,
239        )
240        .await?;
241
242        Ok((store, pool))
243    }
244
245    // 创建文档并添加文本内容
246    async fn create_doc_with_text(
247        store: &PostgresStorage,
248        text_content: &str,
249    ) -> Result<Doc, StoreError> {
250        let doc = Doc::new();
251        {
252            let mut txn = doc.transact_mut();
253            let text = txn.get_or_insert_text("text");
254            text.insert(&mut txn, 0, text_content);
255        }
256        let update = doc.get_update();
257        store.write(&update).await?;
258        Ok(doc)
259    }
260
261    // 向文档添加更新
262    async fn update_doc_text(
263        doc: &Doc,
264        store: &PostgresStorage,
265        position: u32,
266        text_content: &str,
267    ) -> Result<(), StoreError> {
268        {
269            let mut txn = doc.transact_mut();
270            let text = txn.get_text("text").unwrap();
271            text.insert(&mut txn, position, text_content);
272        }
273        let update = doc.get_update();
274        store.write(&update).await
275    }
276
277    // 从文档中删除文本
278    async fn remove_doc_text(
279        doc: &Doc,
280        store: &PostgresStorage,
281        position: u32,
282        length: u32,
283    ) -> Result<(), StoreError> {
284        {
285            let mut txn = doc.transact_mut();
286            let text = txn.get_text("text").unwrap();
287            text.remove_range(&mut txn, position, length);
288        }
289        let update = doc.get_update();
290        store.write(&update).await
291    }
292
293    // 验证文档文本内容
294    fn assert_doc_text(doc: &Doc, expected_text: &str, message: &str) {
295        let txn = doc.transact();
296        let text = txn.get_text("text").unwrap();
297        let content = text.get_string(&txn);
298        assert_eq!(content, expected_text, "{}", message);
299    }
300
301    // 验证文档Map内容
302    fn assert_doc_map(
303        doc: &Doc,
304        map_name: &str,
305        expected_entries: &[(&str, serde_json::Value)],
306        message: &str,
307    ) {
308        let txn = doc.transact();
309        let map = txn.get_map(map_name).unwrap();
310
311        for (key, expected_value) in expected_entries {
312            match expected_value {
313                serde_json::Value::String(expected_str) => {
314                    let value = map.get(&txn, *key).unwrap().to_string(&txn);
315                    assert_eq!(value, *expected_str, "{} - key: {}", message, key);
316                }
317                serde_json::Value::Number(expected_num) => {
318                    if let Some(expected_i64) = expected_num.as_i64() {
319                        let i64_value = match map.get(&txn, *key).unwrap() {
320                            Out::Any(BigInt(v)) => v,
321                            Out::Any(Number(v)) => v as i64,
322                            _ => panic!("Expected Out::Any(BigInt)"),
323                        };
324                        assert_eq!(i64_value, expected_i64, "{} - key: {}", message, key);
325                    } else if let Some(expected_f64) = expected_num.as_f64() {
326                        // 将值转换为字符串并解析为f64进行比较
327                        let f64_value = match map.get(&txn, *key).unwrap() {
328                            Out::Any(BigInt(v)) => v as f64,
329                            Out::Any(Number(v)) => v,
330                            _ => panic!("Expected Out::Any(Number)"),
331                        };
332                        assert!(
333                            (f64_value - expected_f64).abs() < f64::EPSILON,
334                            "{} - key: {}",
335                            message,
336                            key
337                        );
338                    }
339                }
340                serde_json::Value::Bool(expected_bool) => {
341                    // 将值转换为字符串并解析为布尔值进行比较
342                    let bool_value = match map.get(&txn, *key).unwrap() {
343                        Out::Any(Bool(v)) => v,
344                        _ => panic!("Expected Out::Any(Bool)"),
345                    };
346                    assert_eq!(bool_value, *expected_bool, "{} - key: {}", message, key);
347                }
348                _ => {} // 忽略其他类型
349            }
350        }
351    }
352
353    // 创建新文档并应用更新
354    async fn create_and_apply_doc(store: &PostgresStorage) -> Result<Doc, StoreError> {
355        let doc = Doc::new();
356        store.load(&doc).await?;
357        Ok(doc)
358    }
359
360    // 向文档添加Map
361    async fn add_map_to_doc(
362        doc: &Doc,
363        store: &PostgresStorage,
364        map_name: &str,
365        entries: &[(&str, serde_json::Value)],
366    ) -> Result<(), StoreError> {
367        {
368            let mut txn = doc.transact_mut();
369            let map = txn.get_or_insert_map(map_name);
370
371            for (key, value) in entries {
372                match value {
373                    serde_json::Value::String(s) => {
374                        map.insert(&mut txn, *key, s.as_str());
375                    }
376                    serde_json::Value::Number(n) => {
377                        if let Some(i) = n.as_i64() {
378                            map.insert(&mut txn, *key, i);
379                        } else if let Some(f) = n.as_f64() {
380                            map.insert(&mut txn, *key, f);
381                        }
382                    }
383                    serde_json::Value::Bool(b) => {
384                        map.insert(&mut txn, *key, *b);
385                    }
386                    _ => {} // 忽略其他类型
387                }
388            }
389        }
390        let update = doc.get_update();
391        store.write(&update).await
392    }
393
394    #[tokio::test]
395    async fn test_squash_preserves_history() -> Result<(), Box<dyn std::error::Error>> {
396        let document_id = 1;
397
398        // 创建测试存储
399        let (store, pool) = create_test_store(document_id).await?;
400
401        // 创建一个YDoc并添加初始文本
402        let doc = create_doc_with_text(&store, "Hello").await?;
403
404        // 添加更多文本
405        update_doc_text(&doc, &store, 5, ", World").await?;
406
407        // 验证数据库中有两条记录
408        assert_record_count(&pool, 2, "数据库中应该有两条更新记录", document_id).await?;
409
410        // 执行squash操作
411        store.squash().await?;
412
413        // 验证数据库中现在只有一条记录
414        assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
415
416        // 创建一个新的YDoc并应用squash后的更新
417        let new_doc = create_and_apply_doc(&store).await?;
418
419        // 验证新文档包含所有历史更改
420        assert_doc_text(
421            &new_doc,
422            "Hello, World",
423            "squash后的文档应该包含所有历史更改",
424        );
425
426        // 清理测试数据
427        cleanup_test_data(&pool, document_id).await?;
428
429        Ok(())
430    }
431
432    #[tokio::test]
433    async fn test_squash_with_multiple_updates() -> Result<(), Box<dyn std::error::Error>> {
434        let document_id = 2;
435        // 创建测试存储
436        let (store, pool) = create_test_store(document_id).await?;
437
438        // 创建一个YDoc并添加多次更新
439        let doc = Doc::new();
440
441        // 进行多次更新
442        for i in 0..5 {
443            {
444                let mut txn = doc.transact_mut();
445                let text = if i == 0 {
446                    txn.get_or_insert_text("text")
447                } else {
448                    txn.get_text("text").unwrap()
449                };
450
451                let len = text.len(&txn);
452                text.insert(&mut txn, len, &format!("Part {}", i));
453            }
454            // 存储更新
455            let update = doc.get_update();
456            store.write(&update).await?;
457        }
458
459        // 验证数据库中有5条记录
460        assert_record_count(&pool, 5, "数据库中应该有5条更新记录", document_id).await?;
461
462        // 执行squash操作
463        store.squash().await?;
464
465        // 验证数据库中现在只有一条记录
466        assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
467
468        // 创建一个新的YDoc并应用squash后的更新
469        let new_doc = create_and_apply_doc(&store).await?;
470
471        // 验证新文档包含所有历史更改
472        assert_doc_text(
473            &new_doc,
474            "Part 0Part 1Part 2Part 3Part 4",
475            "squash后的文档应该包含所有历史更改",
476        );
477
478        // 清理测试数据
479        cleanup_test_data(&pool, document_id).await?;
480
481        Ok(())
482    }
483
484    #[tokio::test]
485    async fn test_squash_with_complex_operations() -> Result<(), Box<dyn std::error::Error>> {
486        let document_id = 3;
487        // 创建测试存储
488        let (store, pool) = create_test_store(document_id).await?;
489
490        // 创建一个YDoc并添加初始文本
491        let doc = create_doc_with_text(&store, "Initial content").await?;
492
493        // 添加更多文本
494        update_doc_text(&doc, &store, 15, " with more text").await?;
495
496        // 第三次更新:删除部分内容
497        remove_doc_text(&doc, &store, 8, 8).await?; // 删除"content "
498
499        // 执行squash操作
500        store.squash().await?;
501
502        // 创建一个新的YDoc并应用squash后的更新
503        let new_doc = create_and_apply_doc(&store).await?;
504
505        // 验证新文档包含所有历史更改,包括删除操作
506        assert_doc_text(
507            &new_doc,
508            "Initial with more text",
509            "squash后的文档应该正确应用所有操作,包括删除",
510        );
511
512        // 清理测试数据
513        cleanup_test_data(&pool, document_id).await?;
514
515        Ok(())
516    }
517
518    #[tokio::test]
519    async fn test_apply_updates() -> Result<(), Box<dyn std::error::Error>> {
520        let document_id = 4;
521        // 创建测试存储
522        let (store, pool) = create_test_store(document_id).await?;
523
524        // 创建源文档并添加文本内容
525        let source_doc = create_doc_with_text(&store, "Hello, World!").await?;
526
527        let map_name = "map".to_string();
528        // 添加 Map
529        let entries = [
530            ("key1", serde_json::json!("value1")),
531            ("key2", serde_json::json!(42)),
532        ];
533        add_map_to_doc(&source_doc, &store, &map_name, &entries).await?;
534
535        // 创建目标文档并应用更新
536        let target_doc = create_and_apply_doc(&store).await?;
537
538        // 验证目标文档包含所有更新
539
540        // 验证文本
541        assert_doc_text(&target_doc, "Hello, World!", "目标文档应该包含文本内容");
542
543        // 验证 Map
544        assert_doc_map(&target_doc, &map_name, &entries, "目标文档应该包含Map的值");
545
546        // 清理测试数据
547        cleanup_test_data(&pool, document_id).await?;
548
549        Ok(())
550    }
551
552    #[tokio::test]
553    async fn test_encode_state_as_update() -> Result<(), Box<dyn std::error::Error>> {
554        let document_id = 5;
555        // 创建测试存储
556        let (store, pool) = create_test_store(document_id).await?;
557
558        // 创建文档并添加内容
559        let doc = Doc::new();
560
561        // 添加内容
562        {
563            let mut txn = doc.transact_mut();
564            let text = txn.get_or_insert_text("text");
565            text.insert(&mut txn, 0, "测试文本内容");
566
567            let map = txn.get_or_insert_map("map");
568            map.insert(&mut txn, "key", "value");
569        }
570
571        // 使用encode_state_as_update存储状态
572        store.save(doc.clone()).await?;
573
574        // 验证数据已写入数据库
575        assert_record_count(&pool, 1, "数据库中应该有一条记录", document_id).await?;
576
577        // 创建新文档并应用更新
578        let new_doc = Doc::new();
579        store.load(&new_doc).await?;
580
581        // 验证新文档包含所有内容
582        let txn = new_doc.transact();
583
584        // 验证文本
585        let text = txn.get_text("text").unwrap();
586        let content = text.get_string(&txn);
587        assert_eq!(content, "测试文本内容", "新文档应该包含文本内容");
588
589        // 验证map
590        let map = txn.get_map("map").unwrap();
591        let key_value = map.get(&txn, "key").unwrap().to_string(&txn);
592        assert_eq!(key_value, "value", "新文档应该包含Map的key值");
593
594        // 清理测试数据
595        cleanup_test_data(&pool, document_id).await?;
596
597        Ok(())
598    }
599}