1use std::path::PathBuf;
16use std::str::FromStr;
17use std::time::Duration;
18
19use futures_core::future::BoxFuture;
20use log::LevelFilter;
21use sqlx_core::connection::{ConnectOptions, LogSettings};
22use sqlx_core::error::Error;
23
24use crate::connection::SpgConnection;
25
26#[derive(Debug, Clone)]
35pub struct SpgConnectOptions {
36 pub path: Option<PathBuf>,
38 pub log_settings: LogSettings,
42 pub(crate) shared: std::sync::Arc<tokio::sync::OnceCell<spg_embedded_tokio::AsyncDatabase>>,
48}
49
50impl Default for SpgConnectOptions {
51 fn default() -> Self {
52 Self {
53 path: None,
54 log_settings: LogSettings::default(),
55 shared: std::sync::Arc::new(tokio::sync::OnceCell::new()),
56 }
57 }
58}
59
60impl SpgConnectOptions {
61 #[must_use]
63 pub fn in_memory() -> Self {
64 Self::default()
65 }
66
67 #[must_use]
69 pub fn file(path: impl Into<PathBuf>) -> Self {
70 Self {
71 path: Some(path.into()),
72 log_settings: LogSettings::default(),
73 shared: std::sync::Arc::new(tokio::sync::OnceCell::new()),
74 }
75 }
76}
77
78impl FromStr for SpgConnectOptions {
79 type Err = Error;
80
81 fn from_str(s: &str) -> Result<Self, Error> {
82 let rest = s
85 .strip_prefix("spg://")
86 .or_else(|| s.strip_prefix("spg:"))
87 .unwrap_or(s);
88 if rest.is_empty() || rest.eq_ignore_ascii_case("memory") {
89 return Ok(Self::in_memory());
90 }
91 Ok(Self::file(rest))
92 }
93}
94
95impl ConnectOptions for SpgConnectOptions {
96 type Connection = SpgConnection;
97
98 fn from_url(url: &sqlx_core::Url) -> Result<Self, Error> {
99 if url.scheme() != "spg" {
103 return Err(Error::Configuration(
104 format!("expected spg:// scheme, got {:?}", url.scheme()).into(),
105 ));
106 }
107 let host = url.host_str().unwrap_or("");
108 let path = url.path();
109 let combined = match (host, path) {
110 ("", "") | ("", "/") => String::new(),
111 ("", p) => p.to_string(),
112 (h, "") | (h, "/") => h.to_string(),
113 (h, p) => format!("{h}{p}"),
114 };
115 SpgConnectOptions::from_str(&combined)
116 }
117
118 fn connect(&self) -> BoxFuture<'_, Result<SpgConnection, Error>> {
119 let path = self.path.clone();
120 let shared = std::sync::Arc::clone(&self.shared);
121 Box::pin(async move {
122 let inner = shared
123 .get_or_try_init(|| async {
124 match path {
125 None => Ok::<_, Error>(spg_embedded_tokio::AsyncDatabase::open_in_memory()),
126 Some(p) => spg_embedded_tokio::AsyncDatabase::open_path(p)
127 .await
128 .map_err(crate::error::engine_to_sqlx),
129 }
130 })
131 .await?
132 .clone();
133 Ok(SpgConnection::new(inner))
134 })
135 }
136
137 fn log_statements(mut self, level: LevelFilter) -> Self {
138 self.log_settings.log_statements(level);
139 self
140 }
141
142 fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self {
143 self.log_settings.log_slow_statements(level, duration);
144 self
145 }
146}