1use async_stream;
2use async_trait::async_trait;
3use chrono::NaiveDateTime;
4use futures_util::stream::BoxStream;
5use futures_util::TryStreamExt;
6use sqlx::{PgPool, Row};
7use yrs::Doc;
8use yrs_store::doc::ForStore;
9use yrs_store::errors::StoreError;
10use yrs_store::Store;
11
12pub struct PostgresStorage {
13 document_id: i64,
14 table_name: String,
15 pool: PgPool,
16 run_vacuum: bool,
17}
18
19impl PostgresStorage {
20 pub async fn new(
21 document_id: i64,
22 table_name: String,
23 pool: PgPool,
24 run_vacuum: bool,
25 ) -> Result<Self, sqlx::Error> {
26 Ok(PostgresStorage {
27 document_id,
28 table_name,
29 pool,
30 run_vacuum,
31 })
32 }
33}
34
35fn map_sqlx_err(error: sqlx::Error) -> StoreError {
36 StoreError::StorageError(Box::new(error))
37}
38
39#[async_trait]
40impl Store for PostgresStorage {
41 async fn delete(&self) -> Result<(), StoreError> {
42 sqlx::query(&format!(
43 "DELETE FROM {} WHERE document_id = $1",
44 self.table_name
45 ))
46 .bind(self.document_id)
47 .execute(&self.pool)
48 .await
49 .map(|_| ())
50 .map_err(map_sqlx_err)
51 }
52
53 async fn write(&self, update: &Vec<u8>) -> Result<(), StoreError> {
54 let document_id = self.document_id.clone();
55 let now = chrono::Utc::now().naive_utc();
56 let query_result = sqlx::query(&format!(
57 "INSERT INTO {} (document_id, payload, timestamp) VALUES ($1, $2, $3)",
58 self.table_name
59 ))
60 .bind(document_id)
61 .bind(update)
62 .bind(now)
63 .execute(&self.pool)
64 .await
65 .map_err(map_sqlx_err)?;
66
67 let rows_affected = query_result.rows_affected();
68
69 if rows_affected != 1 {
71 return Err(StoreError::WriteError(format!(
72 "Expected 1 row affected for insert, but got {}",
73 rows_affected
74 )));
75 }
76
77 Ok(())
78 }
79
80 async fn read(&self) -> Result<BoxStream<Result<(Vec<u8>, i64), StoreError>>, StoreError> {
81 let document_id = self.document_id;
82 let table_name = self.table_name.clone();
83 let pool = self.pool.clone();
84
85 let stream = async_stream::stream! {
86 let sql = format!(
87 "SELECT payload, timestamp FROM {} WHERE document_id = $1 ORDER BY timestamp",
88 table_name
89 );
90
91 let mut rows = sqlx::query(&sql)
92 .bind(document_id)
93 .fetch(&pool);
94
95 while let Some(row) = rows.try_next().await.map_err(map_sqlx_err)? {
96 let payload: Vec<u8> = row.get("payload");
97 let timestamp_ndt: NaiveDateTime = row.get("timestamp");
98 let timestamp_ms = timestamp_ndt.and_utc().timestamp_millis();
99 yield Ok((payload, timestamp_ms));
100 }
101 };
102
103 Ok(Box::pin(stream))
104 }
105
106 async fn read_payloads(&self) -> Result<BoxStream<Result<Vec<u8>, StoreError>>, StoreError> {
107 let document_id = self.document_id;
108 let table_name = self.table_name.clone();
109 let pool = self.pool.clone();
110
111 let stream = async_stream::stream! {
112 let sql = format!(
113 "SELECT payload FROM {} WHERE document_id = $1 ORDER BY timestamp",
114 table_name
115 );
116
117 let mut rows = sqlx::query(&sql)
118 .bind(document_id)
119 .fetch(&pool);
120
121 while let Some(row) = rows.try_next().await.map_err(map_sqlx_err)? {
122 let payload: Vec<u8> = row.get("payload");
123 yield Ok(payload);
124 }
125 };
126
127 Ok(Box::pin(stream))
128 }
129
130 async fn squash(&self) -> Result<(), StoreError> {
131 let doc = Doc::new();
132 self.load(&doc).await?;
133 let tx = self.pool.begin().await.map_err(map_sqlx_err)?;
134 sqlx::query(&format!(
135 "DELETE FROM {} WHERE document_id = $1",
136 self.table_name
137 ))
138 .bind(self.document_id)
139 .execute(&self.pool)
140 .await
141 .map_err(map_sqlx_err)?;
142
143 let squashed_update = doc.get_update();
144 self.write(&squashed_update).await?;
145
146 tx.commit().await.map_err(map_sqlx_err)?;
148
149 if self.run_vacuum {
150 sqlx::query(&format!("VACUUM {}", self.table_name))
152 .execute(&self.pool)
153 .await
154 .map_err(map_sqlx_err)?;
155 }
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
164 use std::env;
165 use std::str::FromStr;
166 use yrs::Any::{BigInt, Bool, Number};
167 use yrs::{GetString, Map, Out, ReadTxn, Text, Transact, WriteTxn};
168
169 async fn assert_record_count(
171 pool: &PgPool,
172 expected_count: i64,
173 message: &str,
174 document_id: i64,
175 ) -> Result<(), sqlx::Error> {
176 let count = sqlx::query!(
177 "SELECT COUNT(*) as count FROM document_updates WHERE document_id = $1",
178 document_id
179 )
180 .fetch_one(pool)
181 .await?
182 .count
183 .unwrap_or(0);
184 assert_eq!(count, expected_count, "{}", message);
185 Ok(())
186 }
187
188 async fn cleanup_test_data(pool: &PgPool, document_id: i64) -> Result<(), sqlx::Error> {
190 sqlx::query!(
191 "DELETE FROM document_updates WHERE document_id = $1",
192 document_id
193 )
194 .execute(pool)
195 .await
196 .map(|_| ())
197 }
198
199 async fn create_test_pg_poll(document_id: i64) -> Result<PgPool, sqlx::Error> {
200 let url = env::var("DATABASE_URL")
201 .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/test".to_string());
202 let connect_options = PgConnectOptions::from_str(&url).expect("无法解析数据库连接字符串");
204
205 let pool = PgPoolOptions::new()
206 .max_connections(5)
207 .connect_with(connect_options)
208 .await?;
209
210 sqlx::query(
212 "
213 CREATE TABLE IF NOT EXISTS document_updates (
214 document_id BIGINT NOT NULL,
215 payload BYTEA NOT NULL,
216 timestamp TIMESTAMP NOT NULL
217 )
218 ",
219 )
220 .execute(&pool)
221 .await?;
222
223 cleanup_test_data(&pool, document_id).await?;
225
226 Ok(pool)
227 }
228
229 async fn create_test_store(document_id: i64) -> Result<(PostgresStorage, PgPool), sqlx::Error> {
231 let pool = create_test_pg_poll(document_id).await?;
232
233 let store = PostgresStorage::new(
235 document_id,
236 "document_updates".to_string(),
237 pool.clone(),
238 false,
239 )
240 .await?;
241
242 Ok((store, pool))
243 }
244
245 async fn create_doc_with_text(
247 store: &PostgresStorage,
248 text_content: &str,
249 ) -> Result<Doc, StoreError> {
250 let doc = Doc::new();
251 {
252 let mut txn = doc.transact_mut();
253 let text = txn.get_or_insert_text("text");
254 text.insert(&mut txn, 0, text_content);
255 }
256 let update = doc.get_update();
257 store.write(&update).await?;
258 Ok(doc)
259 }
260
261 async fn update_doc_text(
263 doc: &Doc,
264 store: &PostgresStorage,
265 position: u32,
266 text_content: &str,
267 ) -> Result<(), StoreError> {
268 {
269 let mut txn = doc.transact_mut();
270 let text = txn.get_text("text").unwrap();
271 text.insert(&mut txn, position, text_content);
272 }
273 let update = doc.get_update();
274 store.write(&update).await
275 }
276
277 async fn remove_doc_text(
279 doc: &Doc,
280 store: &PostgresStorage,
281 position: u32,
282 length: u32,
283 ) -> Result<(), StoreError> {
284 {
285 let mut txn = doc.transact_mut();
286 let text = txn.get_text("text").unwrap();
287 text.remove_range(&mut txn, position, length);
288 }
289 let update = doc.get_update();
290 store.write(&update).await
291 }
292
293 fn assert_doc_text(doc: &Doc, expected_text: &str, message: &str) {
295 let txn = doc.transact();
296 let text = txn.get_text("text").unwrap();
297 let content = text.get_string(&txn);
298 assert_eq!(content, expected_text, "{}", message);
299 }
300
301 fn assert_doc_map(
303 doc: &Doc,
304 map_name: &str,
305 expected_entries: &[(&str, serde_json::Value)],
306 message: &str,
307 ) {
308 let txn = doc.transact();
309 let map = txn.get_map(map_name).unwrap();
310
311 for (key, expected_value) in expected_entries {
312 match expected_value {
313 serde_json::Value::String(expected_str) => {
314 let value = map.get(&txn, *key).unwrap().to_string(&txn);
315 assert_eq!(value, *expected_str, "{} - key: {}", message, key);
316 }
317 serde_json::Value::Number(expected_num) => {
318 if let Some(expected_i64) = expected_num.as_i64() {
319 let i64_value = match map.get(&txn, *key).unwrap() {
320 Out::Any(BigInt(v)) => v,
321 Out::Any(Number(v)) => v as i64,
322 _ => panic!("Expected Out::Any(BigInt)"),
323 };
324 assert_eq!(i64_value, expected_i64, "{} - key: {}", message, key);
325 } else if let Some(expected_f64) = expected_num.as_f64() {
326 let f64_value = match map.get(&txn, *key).unwrap() {
328 Out::Any(BigInt(v)) => v as f64,
329 Out::Any(Number(v)) => v,
330 _ => panic!("Expected Out::Any(Number)"),
331 };
332 assert!(
333 (f64_value - expected_f64).abs() < f64::EPSILON,
334 "{} - key: {}",
335 message,
336 key
337 );
338 }
339 }
340 serde_json::Value::Bool(expected_bool) => {
341 let bool_value = match map.get(&txn, *key).unwrap() {
343 Out::Any(Bool(v)) => v,
344 _ => panic!("Expected Out::Any(Bool)"),
345 };
346 assert_eq!(bool_value, *expected_bool, "{} - key: {}", message, key);
347 }
348 _ => {} }
350 }
351 }
352
353 async fn create_and_apply_doc(store: &PostgresStorage) -> Result<Doc, StoreError> {
355 let doc = Doc::new();
356 store.load(&doc).await?;
357 Ok(doc)
358 }
359
360 async fn add_map_to_doc(
362 doc: &Doc,
363 store: &PostgresStorage,
364 map_name: &str,
365 entries: &[(&str, serde_json::Value)],
366 ) -> Result<(), StoreError> {
367 {
368 let mut txn = doc.transact_mut();
369 let map = txn.get_or_insert_map(map_name);
370
371 for (key, value) in entries {
372 match value {
373 serde_json::Value::String(s) => {
374 map.insert(&mut txn, *key, s.as_str());
375 }
376 serde_json::Value::Number(n) => {
377 if let Some(i) = n.as_i64() {
378 map.insert(&mut txn, *key, i);
379 } else if let Some(f) = n.as_f64() {
380 map.insert(&mut txn, *key, f);
381 }
382 }
383 serde_json::Value::Bool(b) => {
384 map.insert(&mut txn, *key, *b);
385 }
386 _ => {} }
388 }
389 }
390 let update = doc.get_update();
391 store.write(&update).await
392 }
393
394 #[tokio::test]
395 async fn test_squash_preserves_history() -> Result<(), Box<dyn std::error::Error>> {
396 let document_id = 1;
397
398 let (store, pool) = create_test_store(document_id).await?;
400
401 let doc = create_doc_with_text(&store, "Hello").await?;
403
404 update_doc_text(&doc, &store, 5, ", World").await?;
406
407 assert_record_count(&pool, 2, "数据库中应该有两条更新记录", document_id).await?;
409
410 store.squash().await?;
412
413 assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
415
416 let new_doc = create_and_apply_doc(&store).await?;
418
419 assert_doc_text(
421 &new_doc,
422 "Hello, World",
423 "squash后的文档应该包含所有历史更改",
424 );
425
426 cleanup_test_data(&pool, document_id).await?;
428
429 Ok(())
430 }
431
432 #[tokio::test]
433 async fn test_squash_with_multiple_updates() -> Result<(), Box<dyn std::error::Error>> {
434 let document_id = 2;
435 let (store, pool) = create_test_store(document_id).await?;
437
438 let doc = Doc::new();
440
441 for i in 0..5 {
443 {
444 let mut txn = doc.transact_mut();
445 let text = if i == 0 {
446 txn.get_or_insert_text("text")
447 } else {
448 txn.get_text("text").unwrap()
449 };
450
451 let len = text.len(&txn);
452 text.insert(&mut txn, len, &format!("Part {}", i));
453 }
454 let update = doc.get_update();
456 store.write(&update).await?;
457 }
458
459 assert_record_count(&pool, 5, "数据库中应该有5条更新记录", document_id).await?;
461
462 store.squash().await?;
464
465 assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
467
468 let new_doc = create_and_apply_doc(&store).await?;
470
471 assert_doc_text(
473 &new_doc,
474 "Part 0Part 1Part 2Part 3Part 4",
475 "squash后的文档应该包含所有历史更改",
476 );
477
478 cleanup_test_data(&pool, document_id).await?;
480
481 Ok(())
482 }
483
484 #[tokio::test]
485 async fn test_squash_with_complex_operations() -> Result<(), Box<dyn std::error::Error>> {
486 let document_id = 3;
487 let (store, pool) = create_test_store(document_id).await?;
489
490 let doc = create_doc_with_text(&store, "Initial content").await?;
492
493 update_doc_text(&doc, &store, 15, " with more text").await?;
495
496 remove_doc_text(&doc, &store, 8, 8).await?; store.squash().await?;
501
502 let new_doc = create_and_apply_doc(&store).await?;
504
505 assert_doc_text(
507 &new_doc,
508 "Initial with more text",
509 "squash后的文档应该正确应用所有操作,包括删除",
510 );
511
512 cleanup_test_data(&pool, document_id).await?;
514
515 Ok(())
516 }
517
518 #[tokio::test]
519 async fn test_apply_updates() -> Result<(), Box<dyn std::error::Error>> {
520 let document_id = 4;
521 let (store, pool) = create_test_store(document_id).await?;
523
524 let source_doc = create_doc_with_text(&store, "Hello, World!").await?;
526
527 let map_name = "map".to_string();
528 let entries = [
530 ("key1", serde_json::json!("value1")),
531 ("key2", serde_json::json!(42)),
532 ];
533 add_map_to_doc(&source_doc, &store, &map_name, &entries).await?;
534
535 let target_doc = create_and_apply_doc(&store).await?;
537
538 assert_doc_text(&target_doc, "Hello, World!", "目标文档应该包含文本内容");
542
543 assert_doc_map(&target_doc, &map_name, &entries, "目标文档应该包含Map的值");
545
546 cleanup_test_data(&pool, document_id).await?;
548
549 Ok(())
550 }
551
552 #[tokio::test]
553 async fn test_encode_state_as_update() -> Result<(), Box<dyn std::error::Error>> {
554 let document_id = 5;
555 let (store, pool) = create_test_store(document_id).await?;
557
558 let doc = Doc::new();
560
561 {
563 let mut txn = doc.transact_mut();
564 let text = txn.get_or_insert_text("text");
565 text.insert(&mut txn, 0, "测试文本内容");
566
567 let map = txn.get_or_insert_map("map");
568 map.insert(&mut txn, "key", "value");
569 }
570
571 store.save(doc.clone()).await?;
573
574 assert_record_count(&pool, 1, "数据库中应该有一条记录", document_id).await?;
576
577 let new_doc = Doc::new();
579 store.load(&new_doc).await?;
580
581 let txn = new_doc.transact();
583
584 let text = txn.get_text("text").unwrap();
586 let content = text.get_string(&txn);
587 assert_eq!(content, "测试文本内容", "新文档应该包含文本内容");
588
589 let map = txn.get_map("map").unwrap();
591 let key_value = map.get(&txn, "key").unwrap().to_string(&txn);
592 assert_eq!(key_value, "value", "新文档应该包含Map的key值");
593
594 cleanup_test_data(&pool, document_id).await?;
596
597 Ok(())
598 }
599}