1use async_stream;
2use async_trait::async_trait;
3use bytes::Bytes;
4use chrono::NaiveDateTime;
5use futures_util::stream::BoxStream;
6use futures_util::TryStreamExt;
7use sqlx::{PgPool, Row};
8use yrs::Doc;
9use yrs_store::doc::ForStore;
10use yrs_store::errors::StoreError;
11use yrs_store::Store;
12
13pub struct PostgresStorage {
14 document_id: i64,
15 table_name: String,
16 pool: PgPool,
17 run_vacuum: bool,
18}
19
20impl PostgresStorage {
21 pub async fn new(
22 document_id: i64,
23 table_name: String,
24 pool: PgPool,
25 run_vacuum: bool,
26 ) -> Result<Self, sqlx::Error> {
27 Ok(PostgresStorage {
28 document_id,
29 table_name,
30 pool,
31 run_vacuum,
32 })
33 }
34}
35
36#[async_trait]
37impl Store for PostgresStorage {
38 async fn delete(&self) -> Result<(), StoreError> {
39 sqlx::query(&format!(
40 "DELETE FROM {} WHERE document_id = $1",
41 self.table_name
42 ))
43 .bind(self.document_id)
44 .execute(&self.pool)
45 .await
46 .map(|_| ())
47 .map_err(StoreError::SqlxError)
48 }
49
50 async fn write(&self, update: &Bytes) -> Result<(), StoreError> {
51 let document_id = self.document_id.clone();
52 let now = chrono::Utc::now().naive_utc();
53 let query_result = sqlx::query(&format!(
54 "INSERT INTO {} (document_id, payload, timestamp) VALUES ($1, $2, $3)",
55 self.table_name
56 ))
57 .bind(document_id)
58 .bind(update.as_ref())
59 .bind(now)
60 .execute(&self.pool)
61 .await
62 .map_err(StoreError::SqlxError)?;
63
64 let rows_affected = query_result.rows_affected();
65
66 if rows_affected != 1 {
68 return Err(StoreError::WriteError(format!(
69 "Expected 1 row affected for insert, but got {}",
70 rows_affected
71 )));
72 }
73
74 Ok(())
75 }
76
77 async fn read(&self) -> Result<BoxStream<Result<(Bytes, i64), StoreError>>, StoreError> {
78 let document_id = self.document_id;
79 let table_name = self.table_name.clone();
80 let pool = self.pool.clone();
81
82 let stream = async_stream::stream! {
83 let sql = format!(
84 "SELECT payload, timestamp FROM {} WHERE document_id = $1 ORDER BY timestamp",
85 table_name
86 );
87
88 let mut rows = sqlx::query(&sql)
89 .bind(document_id)
90 .fetch(&pool);
91
92 while let Some(row) = rows.try_next().await.map_err(StoreError::SqlxError)? {
93 let payload_vec: Vec<u8> = row.get("payload");
94 let payload = payload_vec.into();
95 let timestamp_ndt: NaiveDateTime = row.get("timestamp");
96 let timestamp_ms = timestamp_ndt.and_utc().timestamp_millis();
97 yield Ok((payload, timestamp_ms));
98 }
99 };
100
101 Ok(Box::pin(stream))
102 }
103
104 async fn read_payloads(&self) -> Result<BoxStream<Result<Bytes, StoreError>>, StoreError> {
105 let document_id = self.document_id;
106 let table_name = self.table_name.clone();
107 let pool = self.pool.clone();
108
109 let stream = async_stream::stream! {
110 let sql = format!(
111 "SELECT payload FROM {} WHERE document_id = $1 ORDER BY timestamp",
112 table_name
113 );
114
115 let mut rows = sqlx::query(&sql)
116 .bind(document_id)
117 .fetch(&pool);
118
119 while let Some(row) = rows.try_next().await.map_err(StoreError::SqlxError)? {
120 let payload_vec: Vec<u8> = row.get("payload");
121 let payload = payload_vec.into();
122 yield Ok(payload);
123 }
124 };
125
126 Ok(Box::pin(stream))
127 }
128
129 async fn squash(&self) -> Result<(), StoreError> {
130 let doc = Doc::new();
131 self.load(&doc).await?;
132 let tx = self.pool.begin().await.map_err(StoreError::SqlxError)?;
133 sqlx::query(&format!(
134 "DELETE FROM {} WHERE document_id = $1",
135 self.table_name
136 ))
137 .bind(self.document_id)
138 .execute(&self.pool)
139 .await
140 .map_err(StoreError::SqlxError)?;
141
142 let squashed_update = doc.get_update();
143 self.write(&squashed_update).await?;
144
145 tx.commit().await.map_err(StoreError::SqlxError)?;
147
148 if self.run_vacuum {
149 sqlx::query(&format!("VACUUM {}", self.table_name))
151 .execute(&self.pool)
152 .await
153 .map_err(StoreError::SqlxError)?;
154 }
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
163 use std::env;
164 use std::str::FromStr;
165 use yrs::Any::{BigInt, Bool, Number};
166 use yrs::{GetString, Map, Out, ReadTxn, Text, Transact, WriteTxn};
167
168 async fn assert_record_count(
170 pool: &PgPool,
171 expected_count: i64,
172 message: &str,
173 document_id: i64,
174 ) -> Result<(), sqlx::Error> {
175 let count = sqlx::query!(
176 "SELECT COUNT(*) as count FROM document_updates WHERE document_id = $1",
177 document_id
178 )
179 .fetch_one(pool)
180 .await?
181 .count
182 .unwrap_or(0);
183 assert_eq!(count, expected_count, "{}", message);
184 Ok(())
185 }
186
187 async fn cleanup_test_data(pool: &PgPool, document_id: i64) -> Result<(), sqlx::Error> {
189 sqlx::query!(
190 "DELETE FROM document_updates WHERE document_id = $1",
191 document_id
192 )
193 .execute(pool)
194 .await
195 .map(|_| ())
196 }
197
198 async fn create_test_pg_poll(document_id: i64) -> Result<PgPool, sqlx::Error> {
199 let url = env::var("DATABASE_URL")
200 .unwrap_or_else(|_| "postgres://postgres:postgres@localhost:5432/test".to_string());
201 let connect_options = PgConnectOptions::from_str(&url).expect("无法解析数据库连接字符串");
203
204 let pool = PgPoolOptions::new()
205 .max_connections(5)
206 .connect_with(connect_options)
207 .await?;
208
209 sqlx::query(
211 "
212 CREATE TABLE IF NOT EXISTS document_updates (
213 document_id BIGINT NOT NULL,
214 payload BYTEA NOT NULL,
215 timestamp TIMESTAMP NOT NULL
216 )
217 ",
218 )
219 .execute(&pool)
220 .await?;
221
222 cleanup_test_data(&pool, document_id).await?;
224
225 Ok(pool)
226 }
227
228 async fn create_test_store(document_id: i64) -> Result<(PostgresStorage, PgPool), sqlx::Error> {
230 let pool = create_test_pg_poll(document_id).await?;
231
232 let store = PostgresStorage::new(
234 document_id,
235 "document_updates".to_string(),
236 pool.clone(),
237 false,
238 )
239 .await?;
240
241 Ok((store, pool))
242 }
243
244 async fn create_doc_with_text(
246 store: &PostgresStorage,
247 text_content: &str,
248 ) -> Result<Doc, StoreError> {
249 let doc = Doc::new();
250 {
251 let mut txn = doc.transact_mut();
252 let text = txn.get_or_insert_text("text");
253 text.insert(&mut txn, 0, text_content);
254 }
255 let update = doc.get_update();
256 store.write(&update).await?;
257 Ok(doc)
258 }
259
260 async fn update_doc_text(
262 doc: &Doc,
263 store: &PostgresStorage,
264 position: u32,
265 text_content: &str,
266 ) -> Result<(), StoreError> {
267 {
268 let mut txn = doc.transact_mut();
269 let text = txn.get_text("text").unwrap();
270 text.insert(&mut txn, position, text_content);
271 }
272 let update = doc.get_update();
273 store.write(&update).await
274 }
275
276 async fn remove_doc_text(
278 doc: &Doc,
279 store: &PostgresStorage,
280 position: u32,
281 length: u32,
282 ) -> Result<(), StoreError> {
283 {
284 let mut txn = doc.transact_mut();
285 let text = txn.get_text("text").unwrap();
286 text.remove_range(&mut txn, position, length);
287 }
288 let update = doc.get_update();
289 store.write(&update).await
290 }
291
292 fn assert_doc_text(doc: &Doc, expected_text: &str, message: &str) {
294 let txn = doc.transact();
295 let text = txn.get_text("text").unwrap();
296 let content = text.get_string(&txn);
297 assert_eq!(content, expected_text, "{}", message);
298 }
299
300 fn assert_doc_map(
302 doc: &Doc,
303 map_name: &str,
304 expected_entries: &[(&str, serde_json::Value)],
305 message: &str,
306 ) {
307 let txn = doc.transact();
308 let map = txn.get_map(map_name).unwrap();
309
310 for (key, expected_value) in expected_entries {
311 match expected_value {
312 serde_json::Value::String(expected_str) => {
313 let value = map.get(&txn, *key).unwrap().to_string(&txn);
314 assert_eq!(value, *expected_str, "{} - key: {}", message, key);
315 }
316 serde_json::Value::Number(expected_num) => {
317 if let Some(expected_i64) = expected_num.as_i64() {
318 let i64_value = match map.get(&txn, *key).unwrap() {
319 Out::Any(BigInt(v)) => v,
320 Out::Any(Number(v)) => v as i64,
321 _ => panic!("Expected Out::Any(BigInt)"),
322 };
323 assert_eq!(i64_value, expected_i64, "{} - key: {}", message, key);
324 } else if let Some(expected_f64) = expected_num.as_f64() {
325 let f64_value = match map.get(&txn, *key).unwrap() {
327 Out::Any(BigInt(v)) => v as f64,
328 Out::Any(Number(v)) => v,
329 _ => panic!("Expected Out::Any(Number)"),
330 };
331 assert!(
332 (f64_value - expected_f64).abs() < f64::EPSILON,
333 "{} - key: {}",
334 message,
335 key
336 );
337 }
338 }
339 serde_json::Value::Bool(expected_bool) => {
340 let bool_value = match map.get(&txn, *key).unwrap() {
342 Out::Any(Bool(v)) => v,
343 _ => panic!("Expected Out::Any(Bool)"),
344 };
345 assert_eq!(bool_value, *expected_bool, "{} - key: {}", message, key);
346 }
347 _ => {} }
349 }
350 }
351
352 async fn create_and_apply_doc(store: &PostgresStorage) -> Result<Doc, StoreError> {
354 let doc = Doc::new();
355 store.load(&doc).await?;
356 Ok(doc)
357 }
358
359 async fn add_map_to_doc(
361 doc: &Doc,
362 store: &PostgresStorage,
363 map_name: &str,
364 entries: &[(&str, serde_json::Value)],
365 ) -> Result<(), StoreError> {
366 {
367 let mut txn = doc.transact_mut();
368 let map = txn.get_or_insert_map(map_name);
369
370 for (key, value) in entries {
371 match value {
372 serde_json::Value::String(s) => {
373 map.insert(&mut txn, *key, s.as_str());
374 }
375 serde_json::Value::Number(n) => {
376 if let Some(i) = n.as_i64() {
377 map.insert(&mut txn, *key, i);
378 } else if let Some(f) = n.as_f64() {
379 map.insert(&mut txn, *key, f);
380 }
381 }
382 serde_json::Value::Bool(b) => {
383 map.insert(&mut txn, *key, *b);
384 }
385 _ => {} }
387 }
388 }
389 let update = doc.get_update();
390 store.write(&update).await
391 }
392
393 #[tokio::test]
394 async fn test_squash_preserves_history() -> Result<(), Box<dyn std::error::Error>> {
395 let document_id = 1;
396
397 let (store, pool) = create_test_store(document_id).await?;
399
400 let doc = create_doc_with_text(&store, "Hello").await?;
402
403 update_doc_text(&doc, &store, 5, ", World").await?;
405
406 assert_record_count(&pool, 2, "数据库中应该有两条更新记录", document_id).await?;
408
409 store.squash().await?;
411
412 assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
414
415 let new_doc = create_and_apply_doc(&store).await?;
417
418 assert_doc_text(
420 &new_doc,
421 "Hello, World",
422 "squash后的文档应该包含所有历史更改",
423 );
424
425 cleanup_test_data(&pool, document_id).await?;
427
428 Ok(())
429 }
430
431 #[tokio::test]
432 async fn test_squash_with_multiple_updates() -> Result<(), Box<dyn std::error::Error>> {
433 let document_id = 2;
434 let (store, pool) = create_test_store(document_id).await?;
436
437 let doc = Doc::new();
439
440 for i in 0..5 {
442 {
443 let mut txn = doc.transact_mut();
444 let text = if i == 0 {
445 txn.get_or_insert_text("text")
446 } else {
447 txn.get_text("text").unwrap()
448 };
449
450 let len = text.len(&txn);
451 text.insert(&mut txn, len, &format!("Part {}", i));
452 }
453 let update = doc.get_update();
455 store.write(&update).await?;
456 }
457
458 assert_record_count(&pool, 5, "数据库中应该有5条更新记录", document_id).await?;
460
461 store.squash().await?;
463
464 assert_record_count(&pool, 1, "squash后数据库中应该只有一条记录", document_id).await?;
466
467 let new_doc = create_and_apply_doc(&store).await?;
469
470 assert_doc_text(
472 &new_doc,
473 "Part 0Part 1Part 2Part 3Part 4",
474 "squash后的文档应该包含所有历史更改",
475 );
476
477 cleanup_test_data(&pool, document_id).await?;
479
480 Ok(())
481 }
482
483 #[tokio::test]
484 async fn test_squash_with_complex_operations() -> Result<(), Box<dyn std::error::Error>> {
485 let document_id = 3;
486 let (store, pool) = create_test_store(document_id).await?;
488
489 let doc = create_doc_with_text(&store, "Initial content").await?;
491
492 update_doc_text(&doc, &store, 15, " with more text").await?;
494
495 remove_doc_text(&doc, &store, 8, 8).await?; store.squash().await?;
500
501 let new_doc = create_and_apply_doc(&store).await?;
503
504 assert_doc_text(
506 &new_doc,
507 "Initial with more text",
508 "squash后的文档应该正确应用所有操作,包括删除",
509 );
510
511 cleanup_test_data(&pool, document_id).await?;
513
514 Ok(())
515 }
516
517 #[tokio::test]
518 async fn test_apply_updates() -> Result<(), Box<dyn std::error::Error>> {
519 let document_id = 4;
520 let (store, pool) = create_test_store(document_id).await?;
522
523 let source_doc = create_doc_with_text(&store, "Hello, World!").await?;
525
526 let map_name = "map".to_string();
527 let entries = [
529 ("key1", serde_json::json!("value1")),
530 ("key2", serde_json::json!(42)),
531 ];
532 add_map_to_doc(&source_doc, &store, &map_name, &entries).await?;
533
534 let target_doc = create_and_apply_doc(&store).await?;
536
537 assert_doc_text(&target_doc, "Hello, World!", "目标文档应该包含文本内容");
541
542 assert_doc_map(&target_doc, &map_name, &entries, "目标文档应该包含Map的值");
544
545 cleanup_test_data(&pool, document_id).await?;
547
548 Ok(())
549 }
550
551 #[tokio::test]
552 async fn test_encode_state_as_update() -> Result<(), Box<dyn std::error::Error>> {
553 let document_id = 5;
554 let (store, pool) = create_test_store(document_id).await?;
556
557 let doc = Doc::new();
559
560 {
562 let mut txn = doc.transact_mut();
563 let text = txn.get_or_insert_text("text");
564 text.insert(&mut txn, 0, "测试文本内容");
565
566 let map = txn.get_or_insert_map("map");
567 map.insert(&mut txn, "key", "value");
568 }
569
570 store.save(doc.clone()).await?;
572
573 assert_record_count(&pool, 1, "数据库中应该有一条记录", document_id).await?;
575
576 let new_doc = Doc::new();
578 store.load(&new_doc).await?;
579
580 let txn = new_doc.transact();
582
583 let text = txn.get_text("text").unwrap();
585 let content = text.get_string(&txn);
586 assert_eq!(content, "测试文本内容", "新文档应该包含文本内容");
587
588 let map = txn.get_map("map").unwrap();
590 let key_value = map.get(&txn, "key").unwrap().to_string(&txn);
591 assert_eq!(key_value, "value", "新文档应该包含Map的key值");
592
593 cleanup_test_data(&pool, document_id).await?;
595
596 Ok(())
597 }
598}