sql_mel/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![doc = include_str!("../README.md")]
3
4use async_std::stream::StreamExt;
5use async_std::sync::{Arc as AsyncArc, RwLock as AsyncRwLock};
6use core::time::Duration;
7use melodium_core::{common::executive::ResultStatus, *};
8use melodium_macro::{check, mel_model, mel_package, mel_treatment};
9use sqlx::any::{AnyArguments, AnyRow};
10use sqlx::postgres::any::AnyTypeInfoKind;
11use sqlx::query::Query;
12use sqlx::Any;
13use sqlx::{any::AnyPoolOptions, AnyPool, Column, QueryBuilder, Row};
14use std::{
15    collections::HashMap,
16    sync::{Arc, Weak},
17};
18use std_mel::data::*;
19
20fn postgres_bind_replace(mut sql_to_bind: String, bind_symbol: &str) -> String {
21    let bind_num = sql_to_bind.matches(bind_symbol).count();
22
23    for i in 1..=bind_num {
24        sql_to_bind = sql_to_bind
25            .replacen(bind_symbol, &format!("${i}"), 1)
26            .to_string();
27    }
28
29    sql_to_bind
30}
31
32fn bind_value<'q>(
33    query: Query<'q, Any, AnyArguments<'q>>,
34    value: &Value,
35) -> Query<'q, Any, AnyArguments<'q>> {
36    match value {
37        Value::Void(_) => query.bind(None::<bool>),
38        Value::I8(n) => query.bind(*n as i16),
39        Value::I16(n) => query.bind(*n),
40        Value::I32(n) => query.bind(*n as i32),
41        Value::I64(n) => query.bind(*n as i64),
42        Value::I128(n) => query.bind(*n as f64),
43        Value::U8(n) => query.bind(*n as i16),
44        Value::U16(n) => query.bind(*n as i32),
45        Value::U32(n) => query.bind(*n as i64),
46        Value::U64(n) => query.bind(*n as f64),
47        Value::U128(n) => query.bind(*n as f64),
48        Value::F32(n) => query.bind(*n),
49        Value::F64(n) => query.bind(*n),
50        Value::Bool(b) => query.bind(*b),
51        Value::Byte(n) => query.bind(vec![*n]),
52        Value::Char(c) => query.bind(c.to_string()),
53        Value::String(s) => query.bind(s.clone()),
54        Value::Vec(_) => query.bind(None::<bool>),
55        Value::Option(o) => match o {
56            None => query.bind(None::<bool>),
57            Some(v) => bind_value(query, v),
58        },
59        Value::Data(d) => {
60            if value
61                .datatype()
62                .implements(&melodium_core::common::descriptor::DataTrait::ToString)
63            {
64                query.bind(d.to_string())
65            } else {
66                query.bind(None::<bool>)
67            }
68        }
69    }
70}
71
72fn get_row_as_map(row: &AnyRow) -> Map {
73    let mut map = HashMap::with_capacity(row.len());
74    for column in row.columns() {
75        map.insert(
76            column.name().to_string(),
77            match column.type_info().kind() {
78                AnyTypeInfoKind::Null => Value::Option(None),
79                AnyTypeInfoKind::Bool => row
80                    .try_get::<bool, _>(column.ordinal())
81                    .map(|b| Value::Bool(b))
82                    .unwrap_or_else(|_| Value::Option(None)),
83                AnyTypeInfoKind::SmallInt => row
84                    .try_get::<i16, _>(column.ordinal())
85                    .map(|n| Value::I16(n))
86                    .unwrap_or_else(|_| Value::Option(None)),
87                AnyTypeInfoKind::Integer => row
88                    .try_get::<i32, _>(column.ordinal())
89                    .map(|n| Value::I32(n))
90                    .unwrap_or_else(|_| Value::Option(None)),
91                AnyTypeInfoKind::BigInt => row
92                    .try_get::<i64, _>(column.ordinal())
93                    .map(|n| Value::I64(n))
94                    .unwrap_or_else(|_| Value::Option(None)),
95                AnyTypeInfoKind::Real => row
96                    .try_get::<f32, _>(column.ordinal())
97                    .map(|n| Value::F32(n))
98                    .unwrap_or_else(|_| Value::Option(None)),
99                AnyTypeInfoKind::Double => row
100                    .try_get::<f64, _>(column.ordinal())
101                    .map(|n| Value::F64(n))
102                    .unwrap_or_else(|_| Value::Option(None)),
103                AnyTypeInfoKind::Text => row
104                    .try_get::<String, _>(column.ordinal())
105                    .map(|s| Value::String(s))
106                    .unwrap_or_else(|_| Value::Option(None)),
107                AnyTypeInfoKind::Blob => row
108                    .try_get::<Vec<u8>, _>(column.ordinal())
109                    .map(|d| Value::Vec(d.into_iter().map(|v| Value::Byte(v)).collect()))
110                    .unwrap_or_else(|_| Value::Option(None)),
111            },
112        );
113    }
114    Map::new_with(map)
115}
116
117#[derive(Debug)]
118#[mel_model(
119    param url string none
120    param max_connections u32 10
121    param min_connections u32 0
122    param acquire_timeout u64 10000
123    param idle_timeout Option<u64> 600000
124    param max_lifetime Option<u64> 1800000
125    source failure () () (
126        failure Block<string>
127    )
128    initialize initialize
129    continuous (continuous)
130    shutdown shutdown
131)]
132pub struct SqlPool {
133    model: Weak<SqlPoolModel>,
134    pool: AsyncRwLock<Option<AsyncArc<AnyPool>>>,
135    error: AsyncRwLock<Option<sqlx::Error>>,
136}
137
138impl SqlPool {
139    fn new(model: Weak<SqlPoolModel>) -> Self {
140        Self {
141            model,
142            pool: AsyncRwLock::new(None),
143            error: AsyncRwLock::new(None),
144        }
145    }
146
147    fn initialize(&self) {
148        sqlx::any::install_default_drivers();
149
150        let model = self.model.upgrade().unwrap();
151
152        match AnyPoolOptions::new()
153            .max_connections(model.get_max_connections())
154            .min_connections(model.get_min_connections())
155            .acquire_timeout(Duration::from_millis(model.get_acquire_timeout()))
156            .idle_timeout(
157                model
158                    .get_idle_timeout()
159                    .map(|millis| Duration::from_millis(millis)),
160            )
161            .max_lifetime(
162                model
163                    .get_max_lifetime()
164                    .map(|millis| Duration::from_millis(millis)),
165            )
166            .connect_lazy(&model.get_url())
167        {
168            Ok(pool) => async_std::task::block_on(async {
169                *self.pool.write().await = Some(AsyncArc::new(pool));
170            }),
171            Err(error) => async_std::task::block_on(async {
172                *self.error.write().await = Some(error);
173            }),
174        }
175    }
176
177    async fn continuous(&self) {
178        if let Some(error) = self.error.read().await.as_ref() {
179            let model = self.model.upgrade().unwrap();
180            let error = error.to_string();
181            model
182                .new_failure(
183                    None,
184                    &HashMap::new(),
185                    Some(Box::new(move |mut outputs| {
186                        let failure = outputs.get("failure");
187                        vec![Box::new(Box::pin(async move {
188                            let _ = failure.send_one(Value::String(error)).await;
189                            failure.close().await;
190                            ResultStatus::Ok
191                        }))]
192                    })),
193                )
194                .await;
195        }
196    }
197
198    fn shutdown(&self) {
199        async_std::task::block_on(async {
200            if let Some(pool) = self.pool.read().await.as_ref() {
201                pool.close().await;
202            }
203        });
204    }
205
206    fn invoke_source(&self, _source: &str, _params: HashMap<String, Value>) {}
207
208    pub(crate) async fn pool(&self) -> Result<AsyncArc<AnyPool>, sqlx::Error> {
209        match self.pool.read().await.as_ref() {
210            Some(pool) => Ok(AsyncArc::clone(pool)),
211            None => Err(sqlx::Error::PoolClosed),
212        }
213    }
214}
215
216#[mel_treatment(
217    input trigger Block<void>
218    output affected Block<u64>
219    output failure Block<string>
220    model pool SqlPool
221)]
222pub async fn execute_raw(sql: string) {
223    match SqlPoolModel::into(pool).inner().pool().await {
224        Ok(pool) => match sqlx::raw_sql(&sql).execute(&*pool).await {
225            Ok(result) => {
226                let _ = affected.send_one(Value::U64(result.rows_affected())).await;
227            }
228            Err(error) => {
229                let _ = failure.send_one(error.to_string().into()).await;
230            }
231        },
232        Err(error) => {
233            let _ = failure.send_one(error.to_string().into()).await;
234        }
235    }
236}
237
238#[mel_treatment(
239    input bind Block<Map>
240    output affected Block<u64>
241    output failure Block<string>
242    default bind_symbol "?"
243    model pool SqlPool
244)]
245pub async fn execute(sql: string, bindings: Vec<string>, bind_symbol: string) {
246    if let Ok(bind) = bind.recv_one().await.map(|val| {
247        GetData::<Arc<dyn Data>>::try_data(val)
248            .unwrap()
249            .downcast_arc::<Map>()
250            .unwrap()
251    }) {
252        match SqlPoolModel::into(pool).inner().pool().await {
253            Ok(pool) => {
254                let sql = match pool.connect_options().database_url.scheme() {
255                    "postgres" => postgres_bind_replace(sql, &bind_symbol),
256                    _ => sql,
257                };
258                let mut query = sqlx::query(&sql);
259
260                for binding in &bindings {
261                    if let Some(val) = bind.map.get(binding) {
262                        query = bind_value(query, val);
263                    } else {
264                        query = query.bind(None::<bool>);
265                    }
266                }
267
268                match query.execute(&*pool).await {
269                    Ok(result) => {
270                        let _ = affected.send_one(Value::U64(result.rows_affected())).await;
271                    }
272                    Err(error) => {
273                        let _ = failure.send_one(error.to_string().into()).await;
274                    }
275                }
276            }
277            Err(error) => {
278                let _ = failure.send_one(error.to_string().into()).await;
279            }
280        }
281    }
282}
283
284#[mel_treatment(
285    input bind Stream<Map>
286    output affected Stream<u64>
287    output failure Stream<string>
288    default bind_symbol "?"
289    default stop_on_failure true
290    model pool SqlPool
291)]
292pub async fn execute_each(
293    sql: string,
294    bindings: Vec<string>,
295    bind_symbol: string,
296    stop_on_failure: bool,
297) {
298    match SqlPoolModel::into(pool).inner().pool().await {
299        Ok(pool) => {
300            while let Ok(bind) = bind.recv_one().await.map(|val| {
301                GetData::<Arc<dyn Data>>::try_data(val)
302                    .unwrap()
303                    .downcast_arc::<Map>()
304                    .unwrap()
305            }) {
306                let sql = match pool.connect_options().database_url.scheme() {
307                    "postgres" => postgres_bind_replace(sql.clone(), &bind_symbol),
308                    _ => sql.clone(),
309                };
310                let mut query = sqlx::query(&sql);
311
312                for binding in &bindings {
313                    if let Some(val) = bind.map.get(binding) {
314                        query = bind_value(query, val);
315                    } else {
316                        query = query.bind(None::<bool>);
317                    }
318                }
319
320                match query.execute(&*pool).await {
321                    Ok(result) => {
322                        let _ = affected.send_one(Value::U64(result.rows_affected())).await;
323                    }
324                    Err(error) => {
325                        let _ = failure.send_one(error.to_string().into()).await;
326                        if stop_on_failure {
327                            break;
328                        }
329                    }
330                }
331            }
332        }
333        Err(error) => {
334            let _ = failure.send_one(error.to_string().into()).await;
335        }
336    }
337}
338
339#[mel_treatment(
340    default separator ", "
341    default stop_on_failure true
342    default bind_limit 65535
343    default bind_symbol "?"
344    input bind Stream<Map>
345    output affected Stream<u64>
346    output failure Stream<string>
347    model pool SqlPool
348)]
349pub async fn execute_batch(
350    base: string,
351    batch: string,
352    bindings: Vec<string>,
353    bind_symbol: string,
354    bind_limit: u64,
355    separator: string,
356    stop_on_failure: bool,
357) {
358    let limit = bind_limit.min(65535);
359    let batch_max = limit / bindings.len() as u64;
360
361    match SqlPoolModel::into(pool).inner().pool().await {
362        Ok(pool) => 'main: loop {
363            let mut query_builder = QueryBuilder::new(base.as_str());
364
365            let mut full_batch = Vec::with_capacity(batch_max as usize);
366            for _ in 0..batch_max {
367                if let Ok(bind) = bind.recv_one().await.map(|val| {
368                    GetData::<Arc<dyn Data>>::try_data(val)
369                        .unwrap()
370                        .downcast_arc::<Map>()
371                        .unwrap()
372                }) {
373                    full_batch.push(bind);
374                } else {
375                    break;
376                }
377            }
378
379            if full_batch.is_empty() {
380                break;
381            }
382
383            let mut query = query_builder
384                .push({
385                    let batch = std::iter::repeat(batch.as_str())
386                        .take(full_batch.len())
387                        .collect::<Vec<_>>()
388                        .join(&separator);
389                    match pool.connect_options().database_url.scheme() {
390                        "postgres" => postgres_bind_replace(batch, &bind_symbol),
391                        _ => batch,
392                    }
393                })
394                .build();
395
396            for b in full_batch {
397                for binding in &bindings {
398                    if let Some(val) = b.map.get(binding) {
399                        query = bind_value(query, val);
400                    } else {
401                        query = query.bind(None::<bool>);
402                    }
403                }
404            }
405
406            match query.execute(&*pool).await {
407                Ok(result) => {
408                    let _ = affected.send_one(Value::U64(result.rows_affected())).await;
409                }
410                Err(error) => {
411                    let _ = failure.send_one(error.to_string().into()).await;
412                    if stop_on_failure {
413                        break 'main;
414                    }
415                }
416            }
417        },
418        Err(error) => {
419            let _ = failure.send_one(error.to_string().into()).await;
420        }
421    }
422}
423
424#[mel_treatment(
425    input bind Block<Map>
426    output data Stream<Map>
427    output failure Block<string>
428    default bind_symbol "?"
429    model pool SqlPool
430)]
431pub async fn fetch(sql: string, bindings: Vec<string>, bind_symbol: string) {
432    if let Ok(bind) = bind.recv_one().await.map(|val| {
433        GetData::<Arc<dyn Data>>::try_data(val)
434            .unwrap()
435            .downcast_arc::<Map>()
436            .unwrap()
437    }) {
438        match SqlPoolModel::into(pool).inner().pool().await {
439            Ok(pool) => {
440                let sql = match pool.connect_options().database_url.scheme() {
441                    "postgres" => postgres_bind_replace(sql, &bind_symbol),
442                    _ => sql,
443                };
444                let mut query = sqlx::query(&sql);
445
446                for binding in &bindings {
447                    if let Some(val) = bind.map.get(binding) {
448                        query = bind_value(query, val);
449                    } else {
450                        query = query.bind(None::<bool>);
451                    }
452                }
453
454                let mut stream = query.fetch(&*pool);
455                while let Some(row) = stream.next().await {
456                    match row {
457                        Ok(row) => {
458                            let map = get_row_as_map(&row);
459                            check!(
460                                data.send_one(Value::Data(Arc::new(map) as Arc<dyn Data>))
461                                    .await
462                            )
463                        }
464                        Err(error) => {
465                            let _ = failure.send_one(error.to_string().into()).await;
466                            break;
467                        }
468                    }
469                }
470            }
471            Err(error) => {
472                let _ = failure.send_one(error.to_string().into()).await;
473            }
474        }
475    }
476}
477
478#[mel_treatment(
479    default separator ", "
480    default stop_on_failure true
481    default bind_limit 65535
482    default bind_symbol "?"
483    input bind Stream<Map>
484    output data Stream<Map>
485    output failure Stream<string>
486    model pool SqlPool
487)]
488pub async fn fetch_batch(
489    base: string,
490    batch: string,
491    bindings: Vec<string>,
492    bind_limit: u64,
493    bind_symbol: string,
494    separator: string,
495    stop_on_failure: bool,
496) {
497    let limit = bind_limit.min(65535);
498    let batch_max = limit / bindings.len() as u64;
499
500    match SqlPoolModel::into(pool).inner().pool().await {
501        Ok(pool) => 'main: loop {
502            let mut query_builder = QueryBuilder::new(base.as_str());
503
504            let mut full_batch = Vec::with_capacity(batch_max as usize);
505            for _ in 0..batch_max {
506                if let Ok(bind) = bind.recv_one().await.map(|val| {
507                    GetData::<Arc<dyn Data>>::try_data(val)
508                        .unwrap()
509                        .downcast_arc::<Map>()
510                        .unwrap()
511                }) {
512                    full_batch.push(bind);
513                } else {
514                    break;
515                }
516            }
517
518            if full_batch.is_empty() {
519                break;
520            }
521
522            let mut query = query_builder
523                .push({
524                    let batch = std::iter::repeat(batch.as_str())
525                        .take(full_batch.len())
526                        .collect::<Vec<_>>()
527                        .join(&separator);
528                    match pool.connect_options().database_url.scheme() {
529                        "postgres" => postgres_bind_replace(batch, &bind_symbol),
530                        _ => batch,
531                    }
532                })
533                .build();
534
535            for b in full_batch {
536                for binding in &bindings {
537                    if let Some(val) = b.map.get(binding) {
538                        query = bind_value(query, val);
539                    } else {
540                        query = query.bind(None::<bool>);
541                    }
542                }
543            }
544
545            let mut stream = query.fetch(&*pool);
546            'result: while let Some(row) = stream.next().await {
547                match row {
548                    Ok(row) => {
549                        let map = get_row_as_map(&row);
550
551                        let _ = data
552                            .send_one(Value::Data(Arc::new(map) as Arc<dyn Data>))
553                            .await;
554                    }
555                    Err(error) => {
556                        let _ = failure.send_one(error.to_string().into()).await;
557                        if stop_on_failure {
558                            break 'main;
559                        } else {
560                            break 'result;
561                        }
562                    }
563                }
564            }
565        },
566        Err(error) => {
567            let _ = failure.send_one(error.to_string().into()).await;
568        }
569    }
570}
571
572mel_package!();