1use std::path::Path;
42use std::str::FromStr;
43
44use chrono::{DateTime, NaiveDate, NaiveTime, Utc};
45use ipnetwork::IpNetwork;
46use mac_address::MacAddress;
47use rust_decimal::Decimal;
48use serde::{Deserialize, Serialize};
49use serde_json::{Map, Value};
50use sqlx::Row;
51use uuid::Uuid;
52
53use crate::db::DbPool;
54use crate::migrate::{Column, ModelMeta};
55use crate::orm::{ArrayElement, SqlType, TsVector};
56
57const DUMP_VERSION: &str = "1";
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct Dump {
64 pub umbral_dump_version: String,
65 pub exported_at: String,
66 pub models: Vec<ModelDump>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct ModelDump {
75 pub table: String,
76 pub rows: Vec<Map<String, Value>>,
77}
78
79#[derive(Debug)]
81pub enum BackupError {
82 Io(std::io::Error),
83 Json(serde_json::Error),
84 Sqlx(sqlx::Error),
85 UnsupportedVersion(String),
88 UnknownColumn {
92 table: String,
93 column: String,
94 },
95 TypeMismatch {
100 table: String,
101 column: String,
102 expected: SqlType,
103 got: String,
104 },
105}
106
107impl std::fmt::Display for BackupError {
108 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109 match self {
110 BackupError::Io(e) => write!(f, "umbral backup: io: {e}"),
111 BackupError::Json(e) => write!(f, "umbral backup: json: {e}"),
112 BackupError::Sqlx(e) => write!(f, "umbral backup: sqlx: {e}"),
113 BackupError::UnsupportedVersion(v) => write!(
114 f,
115 "umbral backup: dump version `{v}` is not supported by this build \
116 (this build knows version `{DUMP_VERSION}`)"
117 ),
118 BackupError::UnknownColumn { table, column } => write!(
119 f,
120 "umbral backup: column `{table}.{column}` in the dump isn't in the \
121 current schema; run `umbral-cli migrate` first or update the dump"
122 ),
123 BackupError::TypeMismatch {
124 table,
125 column,
126 expected,
127 got,
128 } => write!(
129 f,
130 "umbral backup: column `{table}.{column}` expects {expected:?} but the \
131 dump has {got}"
132 ),
133 }
134 }
135}
136
137impl std::error::Error for BackupError {}
138
139impl From<std::io::Error> for BackupError {
140 fn from(e: std::io::Error) -> Self {
141 Self::Io(e)
142 }
143}
144
145impl From<serde_json::Error> for BackupError {
146 fn from(e: serde_json::Error) -> Self {
147 Self::Json(e)
148 }
149}
150
151impl From<sqlx::Error> for BackupError {
152 fn from(e: sqlx::Error) -> Self {
153 Self::Sqlx(e)
154 }
155}
156
157pub async fn dump() -> Result<Dump, BackupError> {
160 let pool = crate::db::pool_dispatched();
161 let mut models = crate::migrate::registered_models();
162 models.sort_by(|a, b| a.table.cmp(&b.table));
163
164 let mut out: Vec<ModelDump> = Vec::with_capacity(models.len());
165 for model in models {
166 out.push(dump_one(pool, &model).await?);
167 }
168 Ok(Dump {
169 umbral_dump_version: DUMP_VERSION.to_string(),
170 exported_at: Utc::now().to_rfc3339(),
171 models: out,
172 })
173}
174
175pub async fn dump_to_path(path: &Path) -> Result<(), BackupError> {
177 let dump = dump().await?;
178 let json = serde_json::to_string_pretty(&dump)?;
179 std::fs::write(path, json)?;
180 Ok(())
181}
182
183pub async fn load(dump: &Dump) -> Result<LoadReport, BackupError> {
187 if dump.umbral_dump_version != DUMP_VERSION {
188 return Err(BackupError::UnsupportedVersion(
189 dump.umbral_dump_version.clone(),
190 ));
191 }
192 let pool = crate::db::pool_dispatched();
193 let registered = crate::migrate::registered_models();
194 let mut by_table: std::collections::HashMap<String, ModelMeta> = registered
195 .into_iter()
196 .map(|m| (m.table.clone(), m))
197 .collect();
198
199 let mut report = LoadReport::default();
200 for model in &dump.models {
201 let Some(meta) = by_table.remove(&model.table) else {
202 report.skipped_tables.push(model.table.clone());
206 continue;
207 };
208 let inserted = load_one(pool, &meta, &model.rows).await?;
209 report.rows_loaded += inserted;
210 report.tables_loaded.push(meta.table);
211 }
212 Ok(report)
213}
214
215pub async fn load_from_path(path: &Path) -> Result<LoadReport, BackupError> {
217 let text = std::fs::read_to_string(path)?;
218 let dump: Dump = serde_json::from_str(&text)?;
219 load(&dump).await
220}
221
222#[derive(Debug, Default, Clone)]
226pub struct LoadReport {
227 pub tables_loaded: Vec<String>,
228 pub skipped_tables: Vec<String>,
229 pub rows_loaded: u64,
230}
231
232async fn dump_one(pool: &DbPool, model: &ModelMeta) -> Result<ModelDump, BackupError> {
237 match pool {
238 DbPool::Sqlite(pool) => dump_one_sqlite(pool, model).await,
239 DbPool::Postgres(pool) => dump_one_postgres(pool, model).await,
240 }
241}
242
243async fn dump_one_sqlite(
244 pool: &sqlx::SqlitePool,
245 model: &ModelMeta,
246) -> Result<ModelDump, BackupError> {
247 let sql = format!(
248 "SELECT {} FROM {}",
249 column_list(model),
250 quoted_ident(&model.table)
251 );
252 let rows = sqlx::query(&sql).fetch_all(pool).await?;
253
254 let mut out: Vec<Map<String, Value>> = Vec::with_capacity(rows.len());
255 for row in rows {
256 let mut obj = Map::new();
257 for col in &model.fields {
258 obj.insert(col.name.clone(), column_to_json(&row, col)?);
259 }
260 out.push(obj);
261 }
262 Ok(ModelDump {
263 table: model.table.clone(),
264 rows: out,
265 })
266}
267
268async fn load_one(
269 pool: &DbPool,
270 model: &ModelMeta,
271 rows: &[Map<String, Value>],
272) -> Result<u64, BackupError> {
273 if rows.is_empty() {
274 return Ok(0);
275 }
276 match pool {
277 DbPool::Sqlite(pool) => load_one_sqlite(pool, model, rows).await,
278 DbPool::Postgres(pool) => load_one_postgres(pool, model, rows).await,
279 }
280}
281
282async fn load_one_sqlite(
283 pool: &sqlx::SqlitePool,
284 model: &ModelMeta,
285 rows: &[Map<String, Value>],
286) -> Result<u64, BackupError> {
287 let sql = format!(
288 "INSERT INTO {} ({}) VALUES ({})",
289 quoted_ident(&model.table),
290 column_list(model),
291 sqlite_placeholders(model.fields.len())
292 );
293
294 let mut count: u64 = 0;
295 for row in rows {
296 for k in row.keys() {
299 if !model.fields.iter().any(|c| &c.name == k) {
300 return Err(BackupError::UnknownColumn {
301 table: model.table.clone(),
302 column: k.clone(),
303 });
304 }
305 }
306 let mut q = sqlx::query(&sql);
307 for col in &model.fields {
308 let val = row.get(&col.name).cloned().unwrap_or(Value::Null);
309 q = bind_value(q, &model.table, col, val)?;
310 }
311 q.execute(pool).await?;
312 count += 1;
313 }
314 Ok(count)
315}
316
317async fn dump_one_postgres(
318 pool: &sqlx::PgPool,
319 model: &ModelMeta,
320) -> Result<ModelDump, BackupError> {
321 let sql = format!(
322 "SELECT {} FROM {}",
323 column_list_pg_select(model),
324 quoted_ident(&model.table)
325 );
326 let rows = sqlx::query(&sql).fetch_all(pool).await?;
327
328 let mut out: Vec<Map<String, Value>> = Vec::with_capacity(rows.len());
329 for row in rows {
330 let mut obj = Map::new();
331 for col in &model.fields {
332 obj.insert(col.name.clone(), column_to_json_pg(&row, col)?);
333 }
334 out.push(obj);
335 }
336 Ok(ModelDump {
337 table: model.table.clone(),
338 rows: out,
339 })
340}
341
342async fn load_one_postgres(
343 pool: &sqlx::PgPool,
344 model: &ModelMeta,
345 rows: &[Map<String, Value>],
346) -> Result<u64, BackupError> {
347 let sql = format!(
348 "INSERT INTO {} ({}) VALUES ({})",
349 quoted_ident(&model.table),
350 column_list(model),
351 postgres_placeholders(model.fields.len())
352 );
353
354 let mut count: u64 = 0;
355 for row in rows {
356 for k in row.keys() {
357 if !model.fields.iter().any(|c| &c.name == k) {
358 return Err(BackupError::UnknownColumn {
359 table: model.table.clone(),
360 column: k.clone(),
361 });
362 }
363 }
364 let mut q = sqlx::query(&sql);
365 for col in &model.fields {
366 let val = row.get(&col.name).cloned().unwrap_or(Value::Null);
367 q = bind_value_pg(q, &model.table, col, val)?;
368 }
369 q.execute(pool).await?;
370 count += 1;
371 }
372 Ok(count)
373}
374
375fn quoted_ident(name: &str) -> String {
376 format!("\"{}\"", name.replace('"', "\"\""))
377}
378
379fn column_list(model: &ModelMeta) -> String {
380 model
381 .fields
382 .iter()
383 .map(|c| quoted_ident(&c.name))
384 .collect::<Vec<_>>()
385 .join(", ")
386}
387
388fn column_list_pg_select(model: &ModelMeta) -> String {
395 model
396 .fields
397 .iter()
398 .map(|c| {
399 if matches!(c.ty, SqlType::Xml | SqlType::Ltree | SqlType::Bit) {
400 let q = quoted_ident(&c.name);
401 format!("{q}::text AS {q}")
402 } else {
403 quoted_ident(&c.name)
404 }
405 })
406 .collect::<Vec<_>>()
407 .join(", ")
408}
409
410fn sqlite_placeholders(count: usize) -> String {
411 (0..count).map(|_| "?").collect::<Vec<_>>().join(", ")
412}
413
414fn postgres_placeholders(count: usize) -> String {
415 (1..=count)
416 .map(|idx| format!("${idx}"))
417 .collect::<Vec<_>>()
418 .join(", ")
419}
420
421fn column_to_json(row: &sqlx::sqlite::SqliteRow, col: &Column) -> Result<Value, BackupError> {
427 let name = col.name.as_str();
428 if col.nullable {
431 return Ok(match crate::migrate::fk_effective_type(col) {
432 SqlType::SmallInt | SqlType::Integer => row
433 .try_get::<Option<i32>, _>(name)?
434 .map_or(Value::Null, Value::from),
435 SqlType::BigInt => row
436 .try_get::<Option<i64>, _>(name)?
437 .map_or(Value::Null, Value::from),
438 SqlType::Real => row
439 .try_get::<Option<f32>, _>(name)?
440 .map_or(Value::Null, |v| Value::from(v as f64)),
441 SqlType::Double => row
442 .try_get::<Option<f64>, _>(name)?
443 .map_or(Value::Null, Value::from),
444 SqlType::Boolean => row
445 .try_get::<Option<bool>, _>(name)?
446 .map_or(Value::Null, Value::from),
447 SqlType::Text => row
448 .try_get::<Option<String>, _>(name)?
449 .map_or(Value::Null, Value::from),
450 SqlType::Date => row
451 .try_get::<Option<NaiveDate>, _>(name)?
452 .map_or(Value::Null, |v| Value::from(v.to_string())),
453 SqlType::Time => row
454 .try_get::<Option<NaiveTime>, _>(name)?
455 .map_or(Value::Null, |v| Value::from(v.to_string())),
456 SqlType::Timestamptz => row
457 .try_get::<Option<DateTime<Utc>>, _>(name)?
458 .map_or(Value::Null, |v| Value::from(v.to_rfc3339())),
459 SqlType::Uuid => row
460 .try_get::<Option<Uuid>, _>(name)?
461 .map_or(Value::Null, |v| Value::from(v.to_string())),
462 SqlType::Json => row
467 .try_get::<Option<Value>, _>(name)?
468 .unwrap_or(Value::Null),
469 SqlType::Array(_) => unreachable_array(&col.name),
474 SqlType::Inet | SqlType::Cidr | SqlType::MacAddr => unreachable_network(&col.name),
475 SqlType::FullText => unreachable_pg_only(&col.name, "FullText (tsvector)"),
476 SqlType::Xml => unreachable_pg_only(&col.name, "Xml"),
479 SqlType::Ltree => unreachable_pg_only(&col.name, "Ltree"),
480 SqlType::Bit => unreachable_pg_only(&col.name, "Bit"),
481 SqlType::ForeignKey => row
483 .try_get::<Option<i64>, _>(name)?
484 .map_or(Value::Null, Value::from),
485 SqlType::Bytes => row
489 .try_get::<Option<Vec<u8>>, _>(name)?
490 .map_or(Value::Null, |b| {
491 Value::Array(b.into_iter().map(Value::from).collect())
492 }),
493 SqlType::Decimal => unreachable_pg_only(&col.name, "Decimal"),
495 });
496 }
497 Ok(match crate::migrate::fk_effective_type(col) {
499 SqlType::SmallInt | SqlType::Integer => Value::from(row.try_get::<i32, _>(name)?),
500 SqlType::BigInt => Value::from(row.try_get::<i64, _>(name)?),
501 SqlType::Real => Value::from(row.try_get::<f32, _>(name)? as f64),
502 SqlType::Double => Value::from(row.try_get::<f64, _>(name)?),
503 SqlType::Boolean => Value::from(row.try_get::<bool, _>(name)?),
504 SqlType::Text => Value::from(row.try_get::<String, _>(name)?),
505 SqlType::Date => Value::from(row.try_get::<NaiveDate, _>(name)?.to_string()),
506 SqlType::Time => Value::from(row.try_get::<NaiveTime, _>(name)?.to_string()),
507 SqlType::Timestamptz => Value::from(row.try_get::<DateTime<Utc>, _>(name)?.to_rfc3339()),
508 SqlType::Uuid => Value::from(row.try_get::<Uuid, _>(name)?.to_string()),
509 SqlType::Json => row.try_get::<Value, _>(name)?,
510 SqlType::Array(_) => unreachable_array(&col.name),
511 SqlType::Inet | SqlType::Cidr | SqlType::MacAddr => unreachable_network(&col.name),
512 SqlType::FullText => unreachable_pg_only(&col.name, "FullText (tsvector)"),
513 SqlType::Xml => unreachable_pg_only(&col.name, "Xml"),
514 SqlType::Ltree => unreachable_pg_only(&col.name, "Ltree"),
515 SqlType::Bit => unreachable_pg_only(&col.name, "Bit"),
516 SqlType::ForeignKey => Value::from(row.try_get::<i64, _>(name)?),
518 SqlType::Bytes => {
519 let bytes: Vec<u8> = row.try_get(name)?;
520 Value::Array(bytes.into_iter().map(Value::from).collect())
521 }
522 SqlType::Decimal => unreachable_pg_only(&col.name, "Decimal"),
523 })
524}
525
526fn column_to_json_pg(row: &sqlx::postgres::PgRow, col: &Column) -> Result<Value, BackupError> {
527 let name = col.name.as_str();
528 if col.nullable {
529 return Ok(match crate::migrate::fk_effective_type(col) {
530 SqlType::SmallInt => row
531 .try_get::<Option<i16>, _>(name)?
532 .map_or(Value::Null, Value::from),
533 SqlType::Integer => row
534 .try_get::<Option<i32>, _>(name)?
535 .map_or(Value::Null, Value::from),
536 SqlType::BigInt | SqlType::ForeignKey => row
537 .try_get::<Option<i64>, _>(name)?
538 .map_or(Value::Null, Value::from),
539 SqlType::Real => row
540 .try_get::<Option<f32>, _>(name)?
541 .map_or(Value::Null, |v| Value::from(v as f64)),
542 SqlType::Double => row
543 .try_get::<Option<f64>, _>(name)?
544 .map_or(Value::Null, Value::from),
545 SqlType::Boolean => row
546 .try_get::<Option<bool>, _>(name)?
547 .map_or(Value::Null, Value::from),
548 SqlType::Text => row
549 .try_get::<Option<String>, _>(name)?
550 .map_or(Value::Null, Value::from),
551 SqlType::Date => row
552 .try_get::<Option<NaiveDate>, _>(name)?
553 .map_or(Value::Null, |v| Value::from(v.to_string())),
554 SqlType::Time => row
555 .try_get::<Option<NaiveTime>, _>(name)?
556 .map_or(Value::Null, |v| Value::from(v.to_string())),
557 SqlType::Timestamptz => row
558 .try_get::<Option<DateTime<Utc>>, _>(name)?
559 .map_or(Value::Null, |v| Value::from(v.to_rfc3339())),
560 SqlType::Uuid => row
561 .try_get::<Option<Uuid>, _>(name)?
562 .map_or(Value::Null, |v| Value::from(v.to_string())),
563 SqlType::Json => row
564 .try_get::<Option<Value>, _>(name)?
565 .unwrap_or(Value::Null),
566 SqlType::Array(elem) => pg_array_column_to_json_nullable(row, name, elem)?,
567 SqlType::Inet | SqlType::Cidr => row
568 .try_get::<Option<IpNetwork>, _>(name)?
569 .map_or(Value::Null, |v| Value::from(v.to_string())),
570 SqlType::MacAddr => row
571 .try_get::<Option<MacAddress>, _>(name)?
572 .map_or(Value::Null, |v| Value::from(v.to_string())),
573 SqlType::FullText => row
574 .try_get::<Option<TsVector>, _>(name)?
575 .map_or(Value::Null, |v| Value::from(v.into_inner())),
576 SqlType::Xml | SqlType::Ltree | SqlType::Bit => row
580 .try_get::<Option<String>, _>(name)?
581 .map_or(Value::Null, Value::from),
582 SqlType::Bytes => row
583 .try_get::<Option<Vec<u8>>, _>(name)?
584 .map_or(Value::Null, bytes_to_json),
585 SqlType::Decimal => row
586 .try_get::<Option<Decimal>, _>(name)?
587 .map_or(Value::Null, |v| Value::from(v.to_string())),
588 });
589 }
590 Ok(match crate::migrate::fk_effective_type(col) {
591 SqlType::SmallInt => Value::from(row.try_get::<i16, _>(name)?),
592 SqlType::Integer => Value::from(row.try_get::<i32, _>(name)?),
593 SqlType::BigInt | SqlType::ForeignKey => Value::from(row.try_get::<i64, _>(name)?),
594 SqlType::Real => Value::from(row.try_get::<f32, _>(name)? as f64),
595 SqlType::Double => Value::from(row.try_get::<f64, _>(name)?),
596 SqlType::Boolean => Value::from(row.try_get::<bool, _>(name)?),
597 SqlType::Text => Value::from(row.try_get::<String, _>(name)?),
598 SqlType::Date => Value::from(row.try_get::<NaiveDate, _>(name)?.to_string()),
599 SqlType::Time => Value::from(row.try_get::<NaiveTime, _>(name)?.to_string()),
600 SqlType::Timestamptz => Value::from(row.try_get::<DateTime<Utc>, _>(name)?.to_rfc3339()),
601 SqlType::Uuid => Value::from(row.try_get::<Uuid, _>(name)?.to_string()),
602 SqlType::Json => row.try_get::<Value, _>(name)?,
603 SqlType::Array(elem) => pg_array_column_to_json(row, name, elem)?,
604 SqlType::Inet | SqlType::Cidr => {
605 Value::from(row.try_get::<IpNetwork, _>(name)?.to_string())
606 }
607 SqlType::MacAddr => Value::from(row.try_get::<MacAddress, _>(name)?.to_string()),
608 SqlType::FullText => Value::from(row.try_get::<TsVector, _>(name)?.into_inner()),
609 SqlType::Xml | SqlType::Ltree | SqlType::Bit => {
611 Value::from(row.try_get::<String, _>(name)?)
612 }
613 SqlType::Bytes => bytes_to_json(row.try_get::<Vec<u8>, _>(name)?),
614 SqlType::Decimal => Value::from(row.try_get::<Decimal, _>(name)?.to_string()),
615 })
616}
617
618fn pg_array_column_to_json_nullable(
619 row: &sqlx::postgres::PgRow,
620 name: &str,
621 elem: ArrayElement,
622) -> Result<Value, BackupError> {
623 Ok(match elem {
624 ArrayElement::SmallInt => row
625 .try_get::<Option<Vec<i16>>, _>(name)?
626 .map_or(Value::Null, |values| array_to_json(values, Value::from)),
627 ArrayElement::Integer => row
628 .try_get::<Option<Vec<i32>>, _>(name)?
629 .map_or(Value::Null, |values| array_to_json(values, Value::from)),
630 ArrayElement::BigInt => row
631 .try_get::<Option<Vec<i64>>, _>(name)?
632 .map_or(Value::Null, |values| array_to_json(values, Value::from)),
633 ArrayElement::Real => row
634 .try_get::<Option<Vec<f32>>, _>(name)?
635 .map_or(Value::Null, |values| {
636 array_to_json(values, |v| Value::from(v as f64))
637 }),
638 ArrayElement::Double => row
639 .try_get::<Option<Vec<f64>>, _>(name)?
640 .map_or(Value::Null, |values| array_to_json(values, Value::from)),
641 ArrayElement::Boolean => row
642 .try_get::<Option<Vec<bool>>, _>(name)?
643 .map_or(Value::Null, |values| array_to_json(values, Value::from)),
644 ArrayElement::Text => row
645 .try_get::<Option<Vec<String>>, _>(name)?
646 .map_or(Value::Null, |values| array_to_json(values, Value::from)),
647 ArrayElement::Uuid => row
648 .try_get::<Option<Vec<Uuid>>, _>(name)?
649 .map_or(Value::Null, |values| {
650 array_to_json(values, |v| Value::from(v.to_string()))
651 }),
652 })
653}
654
655fn pg_array_column_to_json(
656 row: &sqlx::postgres::PgRow,
657 name: &str,
658 elem: ArrayElement,
659) -> Result<Value, BackupError> {
660 Ok(match elem {
661 ArrayElement::SmallInt => array_to_json(row.try_get::<Vec<i16>, _>(name)?, Value::from),
662 ArrayElement::Integer => array_to_json(row.try_get::<Vec<i32>, _>(name)?, Value::from),
663 ArrayElement::BigInt => array_to_json(row.try_get::<Vec<i64>, _>(name)?, Value::from),
664 ArrayElement::Real => {
665 array_to_json(row.try_get::<Vec<f32>, _>(name)?, |v| Value::from(v as f64))
666 }
667 ArrayElement::Double => array_to_json(row.try_get::<Vec<f64>, _>(name)?, Value::from),
668 ArrayElement::Boolean => array_to_json(row.try_get::<Vec<bool>, _>(name)?, Value::from),
669 ArrayElement::Text => array_to_json(row.try_get::<Vec<String>, _>(name)?, Value::from),
670 ArrayElement::Uuid => array_to_json(row.try_get::<Vec<Uuid>, _>(name)?, |v| {
671 Value::from(v.to_string())
672 }),
673 })
674}
675
676fn array_to_json<T>(values: Vec<T>, mut item: impl FnMut(T) -> Value) -> Value {
677 Value::Array(values.into_iter().map(&mut item).collect())
678}
679
680fn bytes_to_json(bytes: Vec<u8>) -> Value {
681 Value::Array(bytes.into_iter().map(Value::from).collect())
682}
683
684fn unreachable_array(column: &str) -> ! {
689 panic!(
690 "umbral backup: column `{column}` is a Postgres-only Array; \
691 the field.backend system check should have failed boot. \
692 For portable list storage use SqlType::Json instead."
693 )
694}
695
696fn unreachable_network(column: &str) -> ! {
698 panic!(
699 "umbral backup: column `{column}` is a Postgres-only network \
700 address type (Inet/Cidr/MacAddr); the field.backend system \
701 check should have failed boot."
702 )
703}
704
705fn unreachable_pg_only(column: &str, type_name: &str) -> ! {
707 panic!(
708 "umbral backup: column `{column}` is a Postgres-only {type_name} \
709 type; the field.backend system check should have failed boot."
710 )
711}
712
713fn bind_value<'q>(
714 q: SqliteQuery<'q>,
715 table: &str,
716 col: &Column,
717 val: Value,
718) -> Result<SqliteQuery<'q>, BackupError> {
719 if matches!(val, Value::Null) {
722 return Ok(match crate::migrate::fk_effective_type(col) {
723 SqlType::SmallInt | SqlType::Integer => q.bind(None::<i32>),
724 SqlType::BigInt => q.bind(None::<i64>),
725 SqlType::Real => q.bind(None::<f32>),
726 SqlType::Double => q.bind(None::<f64>),
727 SqlType::Boolean => q.bind(None::<bool>),
728 SqlType::Text => q.bind(None::<String>),
729 SqlType::Date => q.bind(None::<NaiveDate>),
730 SqlType::Time => q.bind(None::<NaiveTime>),
731 SqlType::Timestamptz => q.bind(None::<DateTime<Utc>>),
732 SqlType::Uuid => q.bind(None::<Uuid>),
733 SqlType::Json => q.bind(None::<Value>),
734 SqlType::Array(_) => unreachable_array(&col.name),
735 SqlType::Inet | SqlType::Cidr | SqlType::MacAddr => unreachable_network(&col.name),
736 SqlType::FullText => unreachable_pg_only(&col.name, "FullText (tsvector)"),
737 SqlType::Xml => unreachable_pg_only(&col.name, "Xml"),
740 SqlType::Ltree => unreachable_pg_only(&col.name, "Ltree"),
741 SqlType::Bit => unreachable_pg_only(&col.name, "Bit"),
742 SqlType::ForeignKey => q.bind(None::<i64>),
744 SqlType::Bytes => q.bind(None::<Vec<u8>>),
745 SqlType::Decimal => unreachable_pg_only(&col.name, "Decimal"),
746 });
747 }
748 let mismatch = |got: &str| BackupError::TypeMismatch {
749 table: table.to_string(),
750 column: col.name.clone(),
751 expected: col.ty,
752 got: got.to_string(),
753 };
754 Ok(match crate::migrate::fk_effective_type(col) {
755 SqlType::SmallInt | SqlType::Integer => {
756 q.bind(val.as_i64().ok_or_else(|| mismatch(json_type_name(&val)))? as i32)
757 }
758 SqlType::BigInt => q.bind(val.as_i64().ok_or_else(|| mismatch(json_type_name(&val)))?),
759 SqlType::Real => q.bind(val.as_f64().ok_or_else(|| mismatch(json_type_name(&val)))? as f32),
760 SqlType::Double => q.bind(val.as_f64().ok_or_else(|| mismatch(json_type_name(&val)))?),
761 SqlType::Boolean => q.bind(
762 val.as_bool()
763 .ok_or_else(|| mismatch(json_type_name(&val)))?,
764 ),
765 SqlType::Text => q.bind(
766 val.as_str()
767 .ok_or_else(|| mismatch(json_type_name(&val)))?
768 .to_string(),
769 ),
770 SqlType::Date => {
771 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
772 q.bind(
773 s.parse::<NaiveDate>()
774 .map_err(|_| mismatch("invalid date string"))?,
775 )
776 }
777 SqlType::Time => {
778 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
779 q.bind(
780 s.parse::<NaiveTime>()
781 .map_err(|_| mismatch("invalid time string"))?,
782 )
783 }
784 SqlType::Timestamptz => {
785 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
786 q.bind(
787 DateTime::parse_from_rfc3339(s)
788 .map_err(|_| mismatch("invalid rfc3339 timestamp"))?
789 .with_timezone(&Utc),
790 )
791 }
792 SqlType::Uuid => {
793 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
794 q.bind(Uuid::parse_str(s).map_err(|_| mismatch("invalid uuid string"))?)
795 }
796 SqlType::Json => q.bind(val),
801 SqlType::Array(_) => unreachable_array(&col.name),
802 SqlType::Inet | SqlType::Cidr | SqlType::MacAddr => unreachable_network(&col.name),
803 SqlType::FullText => unreachable_pg_only(&col.name, "FullText (tsvector)"),
804 SqlType::Xml => unreachable_pg_only(&col.name, "Xml"),
805 SqlType::Ltree => unreachable_pg_only(&col.name, "Ltree"),
806 SqlType::Bit => unreachable_pg_only(&col.name, "Bit"),
807 SqlType::ForeignKey => q.bind(val.as_i64().ok_or_else(|| mismatch(json_type_name(&val)))?),
809 SqlType::Bytes => q.bind(bytes_from_json(table, col, &val)?),
812 SqlType::Decimal => unreachable_pg_only(&col.name, "Decimal"),
813 })
814}
815
816type SqliteQuery<'q> = sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>;
817type PgQuery<'q> = sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>;
818
819fn bind_value_pg<'q>(
820 q: PgQuery<'q>,
821 table: &str,
822 col: &Column,
823 val: Value,
824) -> Result<PgQuery<'q>, BackupError> {
825 if matches!(val, Value::Null) {
826 return Ok(match crate::migrate::fk_effective_type(col) {
827 SqlType::SmallInt => q.bind(None::<i16>),
828 SqlType::Integer => q.bind(None::<i32>),
829 SqlType::BigInt | SqlType::ForeignKey => q.bind(None::<i64>),
830 SqlType::Real => q.bind(None::<f32>),
831 SqlType::Double => q.bind(None::<f64>),
832 SqlType::Boolean => q.bind(None::<bool>),
833 SqlType::Text => q.bind(None::<String>),
834 SqlType::Date => q.bind(None::<NaiveDate>),
835 SqlType::Time => q.bind(None::<NaiveTime>),
836 SqlType::Timestamptz => q.bind(None::<DateTime<Utc>>),
837 SqlType::Uuid => q.bind(None::<Uuid>),
838 SqlType::Json => q.bind(None::<Value>),
839 SqlType::Array(elem) => bind_null_array_pg(q, elem),
840 SqlType::Inet | SqlType::Cidr => q.bind(None::<IpNetwork>),
841 SqlType::MacAddr => q.bind(None::<MacAddress>),
842 SqlType::FullText => q.bind(None::<TsVector>),
843 SqlType::Xml | SqlType::Ltree | SqlType::Bit => q.bind(None::<String>),
846 SqlType::Bytes => q.bind(None::<Vec<u8>>),
847 SqlType::Decimal => q.bind(None::<Decimal>),
848 });
849 }
850 let mismatch = |got: &str| BackupError::TypeMismatch {
851 table: table.to_string(),
852 column: col.name.clone(),
853 expected: col.ty,
854 got: got.to_string(),
855 };
856 Ok(match crate::migrate::fk_effective_type(col) {
857 SqlType::SmallInt => q.bind(
858 i16::try_from(val.as_i64().ok_or_else(|| mismatch(json_type_name(&val)))?)
859 .map_err(|_| mismatch("number out of i16 range"))?,
860 ),
861 SqlType::Integer => q.bind(
862 i32::try_from(val.as_i64().ok_or_else(|| mismatch(json_type_name(&val)))?)
863 .map_err(|_| mismatch("number out of i32 range"))?,
864 ),
865 SqlType::BigInt | SqlType::ForeignKey => {
866 q.bind(val.as_i64().ok_or_else(|| mismatch(json_type_name(&val)))?)
867 }
868 SqlType::Real => q.bind(val.as_f64().ok_or_else(|| mismatch(json_type_name(&val)))? as f32),
869 SqlType::Double => q.bind(val.as_f64().ok_or_else(|| mismatch(json_type_name(&val)))?),
870 SqlType::Boolean => q.bind(
871 val.as_bool()
872 .ok_or_else(|| mismatch(json_type_name(&val)))?,
873 ),
874 SqlType::Text => q.bind(
875 val.as_str()
876 .ok_or_else(|| mismatch(json_type_name(&val)))?
877 .to_string(),
878 ),
879 SqlType::Date => {
880 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
881 q.bind(
882 s.parse::<NaiveDate>()
883 .map_err(|_| mismatch("invalid date string"))?,
884 )
885 }
886 SqlType::Time => {
887 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
888 q.bind(
889 s.parse::<NaiveTime>()
890 .map_err(|_| mismatch("invalid time string"))?,
891 )
892 }
893 SqlType::Timestamptz => {
894 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
895 q.bind(
896 DateTime::parse_from_rfc3339(s)
897 .map_err(|_| mismatch("invalid rfc3339 timestamp"))?
898 .with_timezone(&Utc),
899 )
900 }
901 SqlType::Uuid => {
902 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
903 q.bind(Uuid::parse_str(s).map_err(|_| mismatch("invalid uuid string"))?)
904 }
905 SqlType::Json => q.bind(val),
906 SqlType::Array(elem) => bind_array_pg(q, table, col, elem, &val)?,
907 SqlType::Inet | SqlType::Cidr => {
908 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
909 q.bind(IpNetwork::from_str(s).map_err(|_| mismatch("invalid network string"))?)
910 }
911 SqlType::MacAddr => {
912 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
913 q.bind(MacAddress::from_str(s).map_err(|_| mismatch("invalid macaddr string"))?)
914 }
915 SqlType::FullText => {
916 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
917 q.bind(TsVector::from(s))
918 }
919 SqlType::Xml | SqlType::Ltree | SqlType::Bit => {
923 let s = val.as_str().ok_or_else(|| mismatch(json_type_name(&val)))?;
924 q.bind(s.to_string())
925 }
926 SqlType::Bytes => q.bind(bytes_from_json(table, col, &val)?),
927 SqlType::Decimal => {
928 let parsed = match &val {
929 Value::String(s) => Decimal::from_str(s).ok(),
930 Value::Number(n) => Decimal::from_str(&n.to_string()).ok(),
931 _ => None,
932 };
933 q.bind(parsed.ok_or_else(|| mismatch(json_type_name(&val)))?)
934 }
935 })
936}
937
938fn bind_null_array_pg<'q>(q: PgQuery<'q>, elem: ArrayElement) -> PgQuery<'q> {
939 match elem {
940 ArrayElement::SmallInt => q.bind(None::<Vec<i16>>),
941 ArrayElement::Integer => q.bind(None::<Vec<i32>>),
942 ArrayElement::BigInt => q.bind(None::<Vec<i64>>),
943 ArrayElement::Real => q.bind(None::<Vec<f32>>),
944 ArrayElement::Double => q.bind(None::<Vec<f64>>),
945 ArrayElement::Boolean => q.bind(None::<Vec<bool>>),
946 ArrayElement::Text => q.bind(None::<Vec<String>>),
947 ArrayElement::Uuid => q.bind(None::<Vec<Uuid>>),
948 }
949}
950
951fn bind_array_pg<'q>(
952 q: PgQuery<'q>,
953 table: &str,
954 col: &Column,
955 elem: ArrayElement,
956 val: &Value,
957) -> Result<PgQuery<'q>, BackupError> {
958 Ok(match elem {
959 ArrayElement::SmallInt => q.bind(
960 int_array_from_json(table, col, val)?
961 .into_iter()
962 .map(|n| {
963 i16::try_from(n)
964 .map_err(|_| type_mismatch(table, col, "element out of i16 range"))
965 })
966 .collect::<Result<Vec<_>, _>>()?,
967 ),
968 ArrayElement::Integer => q.bind(
969 int_array_from_json(table, col, val)?
970 .into_iter()
971 .map(|n| {
972 i32::try_from(n)
973 .map_err(|_| type_mismatch(table, col, "element out of i32 range"))
974 })
975 .collect::<Result<Vec<_>, _>>()?,
976 ),
977 ArrayElement::BigInt => q.bind(int_array_from_json(table, col, val)?),
978 ArrayElement::Real => q.bind(
979 float_array_from_json(table, col, val)?
980 .into_iter()
981 .map(|n| n as f32)
982 .collect::<Vec<_>>(),
983 ),
984 ArrayElement::Double => q.bind(float_array_from_json(table, col, val)?),
985 ArrayElement::Boolean => q.bind(
986 array_values(table, col, val)?
987 .iter()
988 .map(|v| {
989 v.as_bool()
990 .ok_or_else(|| type_mismatch(table, col, "non-boolean in array"))
991 })
992 .collect::<Result<Vec<_>, _>>()?,
993 ),
994 ArrayElement::Text => q.bind(
995 array_values(table, col, val)?
996 .iter()
997 .map(|v| {
998 v.as_str()
999 .map(ToString::to_string)
1000 .ok_or_else(|| type_mismatch(table, col, "non-string in array"))
1001 })
1002 .collect::<Result<Vec<_>, _>>()?,
1003 ),
1004 ArrayElement::Uuid => q.bind(
1005 array_values(table, col, val)?
1006 .iter()
1007 .map(|v| {
1008 let s = v
1009 .as_str()
1010 .ok_or_else(|| type_mismatch(table, col, "non-string uuid in array"))?;
1011 Uuid::parse_str(s)
1012 .map_err(|_| type_mismatch(table, col, "invalid uuid string in array"))
1013 })
1014 .collect::<Result<Vec<_>, _>>()?,
1015 ),
1016 })
1017}
1018
1019fn array_values<'a>(
1020 table: &str,
1021 col: &Column,
1022 val: &'a Value,
1023) -> Result<&'a Vec<Value>, BackupError> {
1024 val.as_array()
1025 .ok_or_else(|| type_mismatch(table, col, json_type_name(val)))
1026}
1027
1028fn int_array_from_json(table: &str, col: &Column, val: &Value) -> Result<Vec<i64>, BackupError> {
1029 array_values(table, col, val)?
1030 .iter()
1031 .map(|v| {
1032 v.as_i64()
1033 .ok_or_else(|| type_mismatch(table, col, "non-integer in array"))
1034 })
1035 .collect()
1036}
1037
1038fn float_array_from_json(table: &str, col: &Column, val: &Value) -> Result<Vec<f64>, BackupError> {
1039 array_values(table, col, val)?
1040 .iter()
1041 .map(|v| {
1042 v.as_f64()
1043 .ok_or_else(|| type_mismatch(table, col, "non-number in array"))
1044 })
1045 .collect()
1046}
1047
1048fn bytes_from_json(table: &str, col: &Column, val: &Value) -> Result<Vec<u8>, BackupError> {
1049 let arr = val
1050 .as_array()
1051 .ok_or_else(|| type_mismatch(table, col, json_type_name(val)))?;
1052 let mut bytes: Vec<u8> = Vec::with_capacity(arr.len());
1053 for v in arr {
1054 let n = v
1055 .as_u64()
1056 .ok_or_else(|| type_mismatch(table, col, "non-number in bytes array"))?;
1057 if n > 255 {
1058 return Err(type_mismatch(table, col, "element out of u8 range"));
1059 }
1060 bytes.push(n as u8);
1061 }
1062 Ok(bytes)
1063}
1064
1065fn type_mismatch(table: &str, col: &Column, got: impl Into<String>) -> BackupError {
1066 BackupError::TypeMismatch {
1067 table: table.to_string(),
1068 column: col.name.clone(),
1069 expected: col.ty,
1070 got: got.into(),
1071 }
1072}
1073
1074fn json_type_name(v: &Value) -> &'static str {
1075 match v {
1076 Value::Null => "null",
1077 Value::Bool(_) => "boolean",
1078 Value::Number(_) => "number",
1079 Value::String(_) => "string",
1080 Value::Array(_) => "array",
1081 Value::Object(_) => "object",
1082 }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088
1089 #[test]
1090 fn placeholder_generation_matches_backend_syntax() {
1091 assert_eq!(sqlite_placeholders(3), "?, ?, ?");
1092 assert_eq!(postgres_placeholders(3), "$1, $2, $3");
1093 }
1094
1095 #[test]
1096 fn quoted_ident_escapes_double_quotes() {
1097 assert_eq!(quoted_ident("plain"), "\"plain\"");
1098 assert_eq!(quoted_ident("weird\"name"), "\"weird\"\"name\"");
1099 }
1100}