1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3
4use async_std::stream::StreamExt;
5use async_std::sync::{Arc as AsyncArc, RwLock as AsyncRwLock};
6use core::time::Duration;
7use melodium_core::{common::executive::ResultStatus, *};
8use melodium_macro::{check, mel_model, mel_package, mel_treatment};
9use sqlx::any::{AnyArguments, AnyRow};
10use sqlx::postgres::any::AnyTypeInfoKind;
11use sqlx::query::Query;
12use sqlx::Any;
13use sqlx::{any::AnyPoolOptions, AnyPool, Column, QueryBuilder, Row};
14use std::{
15 collections::HashMap,
16 sync::{Arc, Weak},
17};
18use std_mel::data::*;
19
20fn postgres_bind_replace(mut sql_to_bind: String, bind_symbol: &str) -> String {
21 let bind_num = sql_to_bind.matches(bind_symbol).count();
22
23 for i in 1..=bind_num {
24 sql_to_bind = sql_to_bind
25 .replacen(bind_symbol, &format!("${i}"), 1)
26 .to_string();
27 }
28
29 sql_to_bind
30}
31
32fn bind_value<'q>(
33 query: Query<'q, Any, AnyArguments<'q>>,
34 value: &Value,
35) -> Query<'q, Any, AnyArguments<'q>> {
36 match value {
37 Value::Void(_) => query.bind(None::<bool>),
38 Value::I8(n) => query.bind(*n as i16),
39 Value::I16(n) => query.bind(*n),
40 Value::I32(n) => query.bind(*n as i32),
41 Value::I64(n) => query.bind(*n as i64),
42 Value::I128(n) => query.bind(*n as f64),
43 Value::U8(n) => query.bind(*n as i16),
44 Value::U16(n) => query.bind(*n as i32),
45 Value::U32(n) => query.bind(*n as i64),
46 Value::U64(n) => query.bind(*n as f64),
47 Value::U128(n) => query.bind(*n as f64),
48 Value::F32(n) => query.bind(*n),
49 Value::F64(n) => query.bind(*n),
50 Value::Bool(b) => query.bind(*b),
51 Value::Byte(n) => query.bind(vec![*n]),
52 Value::Char(c) => query.bind(c.to_string()),
53 Value::String(s) => query.bind(s.clone()),
54 Value::Vec(_) => query.bind(None::<bool>),
55 Value::Option(o) => match o {
56 None => query.bind(None::<bool>),
57 Some(v) => bind_value(query, v),
58 },
59 Value::Data(d) => {
60 if value
61 .datatype()
62 .implements(&melodium_core::common::descriptor::DataTrait::ToString)
63 {
64 query.bind(d.to_string())
65 } else {
66 query.bind(None::<bool>)
67 }
68 }
69 }
70}
71
72fn get_row_as_map(row: &AnyRow) -> Map {
73 let mut map = HashMap::with_capacity(row.len());
74 for column in row.columns() {
75 map.insert(
76 column.name().to_string(),
77 match column.type_info().kind() {
78 AnyTypeInfoKind::Null => Value::Option(None),
79 AnyTypeInfoKind::Bool => row
80 .try_get::<bool, _>(column.ordinal())
81 .map(|b| Value::Bool(b))
82 .unwrap_or_else(|_| Value::Option(None)),
83 AnyTypeInfoKind::SmallInt => row
84 .try_get::<i16, _>(column.ordinal())
85 .map(|n| Value::I16(n))
86 .unwrap_or_else(|_| Value::Option(None)),
87 AnyTypeInfoKind::Integer => row
88 .try_get::<i32, _>(column.ordinal())
89 .map(|n| Value::I32(n))
90 .unwrap_or_else(|_| Value::Option(None)),
91 AnyTypeInfoKind::BigInt => row
92 .try_get::<i64, _>(column.ordinal())
93 .map(|n| Value::I64(n))
94 .unwrap_or_else(|_| Value::Option(None)),
95 AnyTypeInfoKind::Real => row
96 .try_get::<f32, _>(column.ordinal())
97 .map(|n| Value::F32(n))
98 .unwrap_or_else(|_| Value::Option(None)),
99 AnyTypeInfoKind::Double => row
100 .try_get::<f64, _>(column.ordinal())
101 .map(|n| Value::F64(n))
102 .unwrap_or_else(|_| Value::Option(None)),
103 AnyTypeInfoKind::Text => row
104 .try_get::<String, _>(column.ordinal())
105 .map(|s| Value::String(s))
106 .unwrap_or_else(|_| Value::Option(None)),
107 AnyTypeInfoKind::Blob => row
108 .try_get::<Vec<u8>, _>(column.ordinal())
109 .map(|d| Value::Vec(d.into_iter().map(|v| Value::Byte(v)).collect()))
110 .unwrap_or_else(|_| Value::Option(None)),
111 },
112 );
113 }
114 Map::new_with(map)
115}
116
117#[derive(Debug)]
118#[mel_model(
119 param url string none
120 param max_connections u32 10
121 param min_connections u32 0
122 param acquire_timeout u64 10000
123 param idle_timeout Option<u64> 600000
124 param max_lifetime Option<u64> 1800000
125 source failure () () (
126 failure Block<string>
127 )
128 initialize initialize
129 continuous (continuous)
130 shutdown shutdown
131)]
132pub struct SqlPool {
133 model: Weak<SqlPoolModel>,
134 pool: AsyncRwLock<Option<AsyncArc<AnyPool>>>,
135 error: AsyncRwLock<Option<sqlx::Error>>,
136}
137
138impl SqlPool {
139 fn new(model: Weak<SqlPoolModel>) -> Self {
140 Self {
141 model,
142 pool: AsyncRwLock::new(None),
143 error: AsyncRwLock::new(None),
144 }
145 }
146
147 fn initialize(&self) {
148 sqlx::any::install_default_drivers();
149
150 let model = self.model.upgrade().unwrap();
151
152 match AnyPoolOptions::new()
153 .max_connections(model.get_max_connections())
154 .min_connections(model.get_min_connections())
155 .acquire_timeout(Duration::from_millis(model.get_acquire_timeout()))
156 .idle_timeout(
157 model
158 .get_idle_timeout()
159 .map(|millis| Duration::from_millis(millis)),
160 )
161 .max_lifetime(
162 model
163 .get_max_lifetime()
164 .map(|millis| Duration::from_millis(millis)),
165 )
166 .connect_lazy(&model.get_url())
167 {
168 Ok(pool) => async_std::task::block_on(async {
169 *self.pool.write().await = Some(AsyncArc::new(pool));
170 }),
171 Err(error) => async_std::task::block_on(async {
172 *self.error.write().await = Some(error);
173 }),
174 }
175 }
176
177 async fn continuous(&self) {
178 if let Some(error) = self.error.read().await.as_ref() {
179 let model = self.model.upgrade().unwrap();
180 let error = error.to_string();
181 model
182 .new_failure(
183 None,
184 &HashMap::new(),
185 Some(Box::new(move |mut outputs| {
186 let failure = outputs.get("failure");
187 vec![Box::new(Box::pin(async move {
188 let _ = failure.send_one(Value::String(error)).await;
189 failure.close().await;
190 ResultStatus::Ok
191 }))]
192 })),
193 )
194 .await;
195 }
196 }
197
198 fn shutdown(&self) {
199 async_std::task::block_on(async {
200 if let Some(pool) = self.pool.read().await.as_ref() {
201 pool.close().await;
202 }
203 });
204 }
205
206 fn invoke_source(&self, _source: &str, _params: HashMap<String, Value>) {}
207
208 pub(crate) async fn pool(&self) -> Result<AsyncArc<AnyPool>, sqlx::Error> {
209 match self.pool.read().await.as_ref() {
210 Some(pool) => Ok(AsyncArc::clone(pool)),
211 None => Err(sqlx::Error::PoolClosed),
212 }
213 }
214}
215
216#[mel_treatment(
217 input trigger Block<void>
218 output affected Block<u64>
219 output failure Block<string>
220 model pool SqlPool
221)]
222pub async fn execute_raw(sql: string) {
223 match SqlPoolModel::into(pool).inner().pool().await {
224 Ok(pool) => match sqlx::raw_sql(&sql).execute(&*pool).await {
225 Ok(result) => {
226 let _ = affected.send_one(Value::U64(result.rows_affected())).await;
227 }
228 Err(error) => {
229 let _ = failure.send_one(error.to_string().into()).await;
230 }
231 },
232 Err(error) => {
233 let _ = failure.send_one(error.to_string().into()).await;
234 }
235 }
236}
237
238#[mel_treatment(
239 input bind Block<Map>
240 output affected Block<u64>
241 output failure Block<string>
242 default bind_symbol "?"
243 model pool SqlPool
244)]
245pub async fn execute(sql: string, bindings: Vec<string>, bind_symbol: string) {
246 if let Ok(bind) = bind.recv_one().await.map(|val| {
247 GetData::<Arc<dyn Data>>::try_data(val)
248 .unwrap()
249 .downcast_arc::<Map>()
250 .unwrap()
251 }) {
252 match SqlPoolModel::into(pool).inner().pool().await {
253 Ok(pool) => {
254 let sql = match pool.connect_options().database_url.scheme() {
255 "postgres" => postgres_bind_replace(sql, &bind_symbol),
256 _ => sql,
257 };
258 let mut query = sqlx::query(&sql);
259
260 for binding in &bindings {
261 if let Some(val) = bind.map.get(binding) {
262 query = bind_value(query, val);
263 } else {
264 query = query.bind(None::<bool>);
265 }
266 }
267
268 match query.execute(&*pool).await {
269 Ok(result) => {
270 let _ = affected.send_one(Value::U64(result.rows_affected())).await;
271 }
272 Err(error) => {
273 let _ = failure.send_one(error.to_string().into()).await;
274 }
275 }
276 }
277 Err(error) => {
278 let _ = failure.send_one(error.to_string().into()).await;
279 }
280 }
281 }
282}
283
284#[mel_treatment(
285 input bind Stream<Map>
286 output affected Stream<u64>
287 output failure Stream<string>
288 default bind_symbol "?"
289 default stop_on_failure true
290 model pool SqlPool
291)]
292pub async fn execute_each(
293 sql: string,
294 bindings: Vec<string>,
295 bind_symbol: string,
296 stop_on_failure: bool,
297) {
298 match SqlPoolModel::into(pool).inner().pool().await {
299 Ok(pool) => {
300 while let Ok(bind) = bind.recv_one().await.map(|val| {
301 GetData::<Arc<dyn Data>>::try_data(val)
302 .unwrap()
303 .downcast_arc::<Map>()
304 .unwrap()
305 }) {
306 let sql = match pool.connect_options().database_url.scheme() {
307 "postgres" => postgres_bind_replace(sql.clone(), &bind_symbol),
308 _ => sql.clone(),
309 };
310 let mut query = sqlx::query(&sql);
311
312 for binding in &bindings {
313 if let Some(val) = bind.map.get(binding) {
314 query = bind_value(query, val);
315 } else {
316 query = query.bind(None::<bool>);
317 }
318 }
319
320 match query.execute(&*pool).await {
321 Ok(result) => {
322 let _ = affected.send_one(Value::U64(result.rows_affected())).await;
323 }
324 Err(error) => {
325 let _ = failure.send_one(error.to_string().into()).await;
326 if stop_on_failure {
327 break;
328 }
329 }
330 }
331 }
332 }
333 Err(error) => {
334 let _ = failure.send_one(error.to_string().into()).await;
335 }
336 }
337}
338
339#[mel_treatment(
340 default separator ", "
341 default stop_on_failure true
342 default bind_limit 65535
343 default bind_symbol "?"
344 input bind Stream<Map>
345 output affected Stream<u64>
346 output failure Stream<string>
347 model pool SqlPool
348)]
349pub async fn execute_batch(
350 base: string,
351 batch: string,
352 bindings: Vec<string>,
353 bind_symbol: string,
354 bind_limit: u64,
355 separator: string,
356 stop_on_failure: bool,
357) {
358 let limit = bind_limit.min(65535);
359 let batch_max = limit / bindings.len() as u64;
360
361 match SqlPoolModel::into(pool).inner().pool().await {
362 Ok(pool) => 'main: loop {
363 let mut query_builder = QueryBuilder::new(base.as_str());
364
365 let mut full_batch = Vec::with_capacity(batch_max as usize);
366 for _ in 0..batch_max {
367 if let Ok(bind) = bind.recv_one().await.map(|val| {
368 GetData::<Arc<dyn Data>>::try_data(val)
369 .unwrap()
370 .downcast_arc::<Map>()
371 .unwrap()
372 }) {
373 full_batch.push(bind);
374 } else {
375 break;
376 }
377 }
378
379 if full_batch.is_empty() {
380 break;
381 }
382
383 let mut query = query_builder
384 .push({
385 let batch = std::iter::repeat(batch.as_str())
386 .take(full_batch.len())
387 .collect::<Vec<_>>()
388 .join(&separator);
389 match pool.connect_options().database_url.scheme() {
390 "postgres" => postgres_bind_replace(batch, &bind_symbol),
391 _ => batch,
392 }
393 })
394 .build();
395
396 for b in full_batch {
397 for binding in &bindings {
398 if let Some(val) = b.map.get(binding) {
399 query = bind_value(query, val);
400 } else {
401 query = query.bind(None::<bool>);
402 }
403 }
404 }
405
406 match query.execute(&*pool).await {
407 Ok(result) => {
408 let _ = affected.send_one(Value::U64(result.rows_affected())).await;
409 }
410 Err(error) => {
411 let _ = failure.send_one(error.to_string().into()).await;
412 if stop_on_failure {
413 break 'main;
414 }
415 }
416 }
417 },
418 Err(error) => {
419 let _ = failure.send_one(error.to_string().into()).await;
420 }
421 }
422}
423
424#[mel_treatment(
425 input bind Block<Map>
426 output data Stream<Map>
427 output failure Block<string>
428 default bind_symbol "?"
429 model pool SqlPool
430)]
431pub async fn fetch(sql: string, bindings: Vec<string>, bind_symbol: string) {
432 if let Ok(bind) = bind.recv_one().await.map(|val| {
433 GetData::<Arc<dyn Data>>::try_data(val)
434 .unwrap()
435 .downcast_arc::<Map>()
436 .unwrap()
437 }) {
438 match SqlPoolModel::into(pool).inner().pool().await {
439 Ok(pool) => {
440 let sql = match pool.connect_options().database_url.scheme() {
441 "postgres" => postgres_bind_replace(sql, &bind_symbol),
442 _ => sql,
443 };
444 let mut query = sqlx::query(&sql);
445
446 for binding in &bindings {
447 if let Some(val) = bind.map.get(binding) {
448 query = bind_value(query, val);
449 } else {
450 query = query.bind(None::<bool>);
451 }
452 }
453
454 let mut stream = query.fetch(&*pool);
455 while let Some(row) = stream.next().await {
456 match row {
457 Ok(row) => {
458 let map = get_row_as_map(&row);
459 check!(
460 data.send_one(Value::Data(Arc::new(map) as Arc<dyn Data>))
461 .await
462 )
463 }
464 Err(error) => {
465 let _ = failure.send_one(error.to_string().into()).await;
466 break;
467 }
468 }
469 }
470 }
471 Err(error) => {
472 let _ = failure.send_one(error.to_string().into()).await;
473 }
474 }
475 }
476}
477
478#[mel_treatment(
479 default separator ", "
480 default stop_on_failure true
481 default bind_limit 65535
482 default bind_symbol "?"
483 input bind Stream<Map>
484 output data Stream<Map>
485 output failure Stream<string>
486 model pool SqlPool
487)]
488pub async fn fetch_batch(
489 base: string,
490 batch: string,
491 bindings: Vec<string>,
492 bind_limit: u64,
493 bind_symbol: string,
494 separator: string,
495 stop_on_failure: bool,
496) {
497 let limit = bind_limit.min(65535);
498 let batch_max = limit / bindings.len() as u64;
499
500 match SqlPoolModel::into(pool).inner().pool().await {
501 Ok(pool) => 'main: loop {
502 let mut query_builder = QueryBuilder::new(base.as_str());
503
504 let mut full_batch = Vec::with_capacity(batch_max as usize);
505 for _ in 0..batch_max {
506 if let Ok(bind) = bind.recv_one().await.map(|val| {
507 GetData::<Arc<dyn Data>>::try_data(val)
508 .unwrap()
509 .downcast_arc::<Map>()
510 .unwrap()
511 }) {
512 full_batch.push(bind);
513 } else {
514 break;
515 }
516 }
517
518 if full_batch.is_empty() {
519 break;
520 }
521
522 let mut query = query_builder
523 .push({
524 let batch = std::iter::repeat(batch.as_str())
525 .take(full_batch.len())
526 .collect::<Vec<_>>()
527 .join(&separator);
528 match pool.connect_options().database_url.scheme() {
529 "postgres" => postgres_bind_replace(batch, &bind_symbol),
530 _ => batch,
531 }
532 })
533 .build();
534
535 for b in full_batch {
536 for binding in &bindings {
537 if let Some(val) = b.map.get(binding) {
538 query = bind_value(query, val);
539 } else {
540 query = query.bind(None::<bool>);
541 }
542 }
543 }
544
545 let mut stream = query.fetch(&*pool);
546 'result: while let Some(row) = stream.next().await {
547 match row {
548 Ok(row) => {
549 let map = get_row_as_map(&row);
550
551 let _ = data
552 .send_one(Value::Data(Arc::new(map) as Arc<dyn Data>))
553 .await;
554 }
555 Err(error) => {
556 let _ = failure.send_one(error.to_string().into()).await;
557 if stop_on_failure {
558 break 'main;
559 } else {
560 break 'result;
561 }
562 }
563 }
564 }
565 },
566 Err(error) => {
567 let _ = failure.send_one(error.to_string().into()).await;
568 }
569 }
570}
571
572mel_package!();