1use std::path::PathBuf;
16use std::str::FromStr;
17use std::time::Duration;
18
19use futures_core::future::BoxFuture;
20use log::LevelFilter;
21use sqlx_core::connection::{ConnectOptions, LogSettings};
22use sqlx_core::error::Error;
23
24use crate::connection::SpgConnection;
25
26#[derive(Debug, Clone)]
35pub struct SpgConnectOptions {
36 pub path: Option<PathBuf>,
38 pub log_settings: LogSettings,
42 pub(crate) shared: std::sync::Arc<tokio::sync::OnceCell<spg_embedded_tokio::AsyncDatabase>>,
48}
49
50impl Default for SpgConnectOptions {
51 fn default() -> Self {
52 Self {
53 path: None,
54 log_settings: LogSettings::default(),
55 shared: std::sync::Arc::new(tokio::sync::OnceCell::new()),
56 }
57 }
58}
59
60impl SpgConnectOptions {
61 #[must_use]
63 pub fn in_memory() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
69 pub fn file(path: impl Into<PathBuf>) -> Self {
70 Self {
71 path: Some(path.into()),
72 log_settings: LogSettings::default(),
73 shared: std::sync::Arc::new(tokio::sync::OnceCell::new()),
74 }
75 }
76}
77
78impl FromStr for SpgConnectOptions {
79 type Err = Error;
80
81 fn from_str(s: &str) -> Result<Self, Error> {
82 let rest = s
85 .strip_prefix("spg://")
86 .or_else(|| s.strip_prefix("spg:"))
87 .unwrap_or(s);
88 if rest.is_empty() || rest.eq_ignore_ascii_case("memory") {
89 return Ok(Self::in_memory());
90 }
91 Ok(Self::file(rest))
92 }
93}
94
95impl ConnectOptions for SpgConnectOptions {
96 type Connection = SpgConnection;
97
98 fn from_url(url: &sqlx_core::Url) -> Result<Self, Error> {
99 if url.scheme() != "spg" {
103 return Err(Error::Configuration(
104 format!("expected spg:// scheme, got {:?}", url.scheme()).into(),
105 ));
106 }
107 let host = url.host_str().unwrap_or("");
108 let path = url.path();
109 let combined = match (host, path) {
110 ("", "") | ("", "/") => String::new(),
111 ("", p) => p.to_string(),
112 (h, "") | (h, "/") => h.to_string(),
113 (h, p) => format!("{h}{p}"),
114 };
115 SpgConnectOptions::from_str(&combined)
116 }
117
118 fn connect(&self) -> BoxFuture<'_, Result<SpgConnection, Error>> {
119 let path = self.path.clone();
120 let shared = std::sync::Arc::clone(&self.shared);
121 Box::pin(async move {
122 let inner = shared
123 .get_or_try_init(|| async {
124 match path {
125 None => Ok::<_, Error>(
126 spg_embedded_tokio::AsyncDatabase::open_in_memory(),
127 ),
128 Some(p) => spg_embedded_tokio::AsyncDatabase::open_path(p)
129 .await
130 .map_err(crate::error::engine_to_sqlx),
131 }
132 })
133 .await?
134 .clone();
135 Ok(SpgConnection::new(inner))
136 })
137 }
138
139 fn log_statements(mut self, level: LevelFilter) -> Self {
140 self.log_settings.log_statements(level);
141 self
142 }
143
144 fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
145 self.log_settings.log_slow_statements(level, duration);
146 self
147 }
148}