1use std::collections::HashMap;
8use std::sync::Arc;
9
10use serde::{Deserialize, Serialize};
11use tokio::sync::RwLock;
12
13use crate::barrier::Barrier;
14use crate::error::DatabaseError;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct DatabaseConfig {
19 pub name: String,
21 pub plugin: String,
23 pub connection_url: String,
25 pub max_open_connections: u32,
27 pub allowed_roles: Vec<String>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct DatabaseRole {
34 pub name: String,
36 pub db_name: String,
38 pub creation_statements: Vec<String>,
40 pub revocation_statements: Vec<String>,
42 pub default_ttl_secs: i64,
44 pub max_ttl_secs: i64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct DatabaseCredentials {
51 pub username: String,
53 pub password: String,
55}
56
57pub struct DatabaseEngine {
63 barrier: Arc<Barrier>,
64 prefix: String,
65 configs: RwLock<HashMap<String, DatabaseConfig>>,
67 roles: RwLock<HashMap<String, DatabaseRole>>,
69}
70
71impl DatabaseEngine {
72 pub fn new(barrier: Arc<Barrier>, prefix: String) -> Self {
78 Self {
79 barrier,
80 prefix,
81 configs: RwLock::new(HashMap::new()),
82 roles: RwLock::new(HashMap::new()),
83 }
84 }
85
86 fn config_key(&self, name: &str) -> String {
87 format!("{}config/{}", self.prefix, name)
88 }
89
90 fn role_key(&self, name: &str) -> String {
91 format!("{}roles/{}", self.prefix, name)
92 }
93
94 pub async fn configure(&self, config: DatabaseConfig) -> Result<(), DatabaseError> {
101 if config.name.is_empty() {
102 return Err(DatabaseError::InvalidConfig {
103 reason: "connection name is required".to_owned(),
104 });
105 }
106 if config.connection_url.is_empty() {
107 return Err(DatabaseError::InvalidConfig {
108 reason: "connection_url is required".to_owned(),
109 });
110 }
111 if config.plugin != "postgresql" && config.plugin != "mysql" {
112 return Err(DatabaseError::InvalidConfig {
113 reason: format!("unsupported plugin '{}', expected 'postgresql' or 'mysql'", config.plugin),
114 });
115 }
116
117 let data = serde_json::to_vec(&config).map_err(|e| DatabaseError::Internal {
118 reason: format!("serialization failed: {e}"),
119 })?;
120 self.barrier
121 .put(&self.config_key(&config.name), &data)
122 .await?;
123 self.configs.write().await.insert(config.name.clone(), config);
124 Ok(())
125 }
126
127 pub async fn get_config(&self, name: &str) -> Result<DatabaseConfig, DatabaseError> {
133 if let Some(cfg) = self.configs.read().await.get(name) {
135 return Ok(cfg.clone());
136 }
137 let data = self
138 .barrier
139 .get(&self.config_key(name))
140 .await?
141 .ok_or_else(|| DatabaseError::NotFound {
142 name: name.to_owned(),
143 })?;
144 let config: DatabaseConfig =
145 serde_json::from_slice(&data).map_err(|e| DatabaseError::Internal {
146 reason: format!("deserialization failed: {e}"),
147 })?;
148 self.configs.write().await.insert(name.to_owned(), config.clone());
149 Ok(config)
150 }
151
152 pub async fn delete_config(&self, name: &str) -> Result<(), DatabaseError> {
158 self.barrier.delete(&self.config_key(name)).await?;
159 self.configs.write().await.remove(name);
160 Ok(())
161 }
162
163 pub async fn list_configs(&self) -> Result<Vec<String>, DatabaseError> {
169 let prefix = format!("{}config/", self.prefix);
170 let keys = self.barrier.list(&prefix).await?;
171 Ok(keys
172 .into_iter()
173 .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
174 .collect())
175 }
176
177 pub async fn create_role(&self, role: DatabaseRole) -> Result<(), DatabaseError> {
183 if role.name.is_empty() {
184 return Err(DatabaseError::InvalidConfig {
185 reason: "role name is required".to_owned(),
186 });
187 }
188 if role.db_name.is_empty() {
189 return Err(DatabaseError::InvalidConfig {
190 reason: "db_name is required".to_owned(),
191 });
192 }
193 if role.creation_statements.is_empty() {
194 return Err(DatabaseError::InvalidConfig {
195 reason: "creation_statements is required".to_owned(),
196 });
197 }
198 self.get_config(&role.db_name).await?;
200
201 let data = serde_json::to_vec(&role).map_err(|e| DatabaseError::Internal {
202 reason: format!("serialization failed: {e}"),
203 })?;
204 self.barrier.put(&self.role_key(&role.name), &data).await?;
205 self.roles.write().await.insert(role.name.clone(), role);
206 Ok(())
207 }
208
209 pub async fn get_role(&self, name: &str) -> Result<DatabaseRole, DatabaseError> {
215 if let Some(role) = self.roles.read().await.get(name) {
216 return Ok(role.clone());
217 }
218 let data = self
219 .barrier
220 .get(&self.role_key(name))
221 .await?
222 .ok_or_else(|| DatabaseError::RoleNotFound {
223 name: name.to_owned(),
224 })?;
225 let role: DatabaseRole =
226 serde_json::from_slice(&data).map_err(|e| DatabaseError::Internal {
227 reason: format!("deserialization failed: {e}"),
228 })?;
229 self.roles.write().await.insert(name.to_owned(), role.clone());
230 Ok(role)
231 }
232
233 pub async fn delete_role(&self, name: &str) -> Result<(), DatabaseError> {
239 self.barrier.delete(&self.role_key(name)).await?;
240 self.roles.write().await.remove(name);
241 Ok(())
242 }
243
244 pub async fn list_roles(&self) -> Result<Vec<String>, DatabaseError> {
250 let prefix = format!("{}roles/", self.prefix);
251 let keys = self.barrier.list(&prefix).await?;
252 Ok(keys
253 .into_iter()
254 .filter_map(|k| k.strip_prefix(&prefix).map(String::from))
255 .collect())
256 }
257
258 pub async fn generate_credentials(
270 &self,
271 role_name: &str,
272 ) -> Result<(DatabaseCredentials, DatabaseRole), DatabaseError> {
273 let role = self.get_role(role_name).await?;
274 let _config = self.get_config(&role.db_name).await?;
276
277 let username = format!("v-{}-{}", role_name, &uuid::Uuid::new_v4().to_string()[..8]);
278 let password = uuid::Uuid::new_v4().to_string().replace('-', "");
279
280 let creds = DatabaseCredentials { username, password };
281 Ok((creds, role))
282 }
283}