Skip to main content

tank_core/
util.rs

1use crate::{AsValue, DynQuery, Value};
2use proc_macro2::TokenStream;
3use quote::{ToTokens, TokenStreamExt, quote};
4use rust_decimal::prelude::ToPrimitive;
5use serde_json::{Map, Number, Value as JsonValue};
6use std::{
7    borrow::Cow,
8    cmp::min,
9    collections::BTreeMap,
10    ffi::{CStr, CString},
11    fmt::Write,
12    ptr,
13};
14use syn::Path;
15use time::{Date, Time};
16
17#[derive(Clone)]
18/// Polymorphic iterator adapter returning items from either variant.
19pub enum EitherIterator<A, B>
20where
21    A: Iterator,
22    B: Iterator<Item = A::Item>,
23{
24    Left(A),
25    Right(B),
26}
27
28impl<A, B> Iterator for EitherIterator<A, B>
29where
30    A: Iterator,
31    B: Iterator<Item = A::Item>,
32{
33    type Item = A::Item;
34    fn next(&mut self) -> Option<Self::Item> {
35        match self {
36            Self::Left(a) => a.next(),
37            Self::Right(b) => b.next(),
38        }
39    }
40}
41
42pub fn value_to_json(v: &Value) -> Option<JsonValue> {
43    Some(match v {
44        _ if v.is_null() => JsonValue::Null,
45        Value::Boolean(Some(v), ..) => JsonValue::Bool(*v),
46        Value::Int8(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
47        Value::Int16(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
48        Value::Int32(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
49        Value::Int64(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
50        Value::Int128(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
51        Value::UInt8(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
52        Value::UInt16(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
53        Value::UInt32(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
54        Value::UInt64(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
55        Value::UInt128(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
56        Value::Float32(Some(v), ..) => JsonValue::Number(Number::from_f64(*v as _)?),
57        Value::Float64(Some(v), ..) => JsonValue::Number(Number::from_f64(*v as _)?),
58        Value::Decimal(Some(v), ..) => JsonValue::Number(Number::from_f64(v.to_f64()?)?),
59        Value::Char(Some(v), ..) => JsonValue::String(v.to_string()),
60        Value::Varchar(Some(v), ..) => JsonValue::String(v.to_string()),
61        Value::Blob(Some(v), ..) => JsonValue::Array(
62            v.iter()
63                .map(|v| Number::from_u128(*v as _).map(JsonValue::Number))
64                .collect::<Option<_>>()?,
65        ),
66        Value::Date(Some(v), ..) => {
67            JsonValue::String(format!("{:04}-{:02}-{:02}", v.year(), v.month(), v.day()))
68        }
69        Value::Time(Some(v), ..) => {
70            let mut out = String::new();
71            print_timer(
72                &mut out,
73                "",
74                v.hour() as _,
75                v.minute(),
76                v.second(),
77                v.nanosecond(),
78            );
79            JsonValue::String(out)
80        }
81        Value::Timestamp(Some(v), ..) => {
82            let date = v.date();
83            let time = v.time();
84            let mut out = String::new();
85            print_date(&mut out, "", &date);
86            out.push(' ');
87            print_timer(
88                &mut out,
89                "",
90                time.hour() as _,
91                time.minute(),
92                time.second(),
93                time.nanosecond(),
94            );
95            JsonValue::String(out)
96        }
97        Value::TimestampWithTimezone(Some(v), ..) => {
98            let date = v.date();
99            let time = v.time();
100            let mut out = String::new();
101            print_date(&mut out, "", &date);
102            out.push(' ');
103            print_timer(
104                &mut out,
105                "",
106                time.hour() as _,
107                time.minute(),
108                time.second(),
109                time.nanosecond(),
110            );
111            let (h, m, s) = v.offset().as_hms();
112            out.push(' ');
113            if h >= 0 {
114                out.push('+');
115            } else {
116                out.push('-');
117            }
118            let offset = Time::from_hms(h.abs() as _, m.abs() as _, s.abs() as _).ok()?;
119            print_timer(
120                &mut out,
121                "",
122                offset.hour() as _,
123                offset.minute(),
124                offset.second(),
125                offset.nanosecond(),
126            );
127            JsonValue::String(out)
128        }
129        Value::Interval(Some(_v), ..) => {
130            return None;
131        }
132        Value::Uuid(Some(v), ..) => JsonValue::String(v.to_string()),
133        Value::Array(Some(v), ..) => {
134            JsonValue::Array(v.iter().map(value_to_json).collect::<Option<_>>()?)
135        }
136        Value::List(Some(v), ..) => {
137            JsonValue::Array(v.iter().map(value_to_json).collect::<Option<_>>()?)
138        }
139        Value::Map(Some(v), ..) => {
140            let mut map = Map::new();
141            for (k, v) in v.iter() {
142                let Ok(k) = String::try_from_value(k.clone()) else {
143                    return None;
144                };
145                let Some(v) = value_to_json(v) else {
146                    return None;
147                };
148                map.insert(k, v)?;
149            }
150            JsonValue::Object(map)
151        }
152        Value::Json(Some(v), ..) => v.clone(),
153        Value::Struct(Some(v), ..) => {
154            let mut map = Map::new();
155            for (k, v) in v.iter() {
156                let Some(v) = value_to_json(v) else {
157                    return None;
158                };
159                map.insert(k.clone(), v)?;
160            }
161            JsonValue::Object(map)
162        }
163        Value::Unknown(Some(v), ..) => JsonValue::String(v.clone()),
164        _ => {
165            return None;
166        }
167    })
168}
169
170/// Quote a `BTreeMap<K, V>` into tokens.
171pub fn quote_btree_map<K: ToTokens, V: ToTokens>(value: &BTreeMap<K, V>) -> TokenStream {
172    let mut tokens = TokenStream::new();
173    for (k, v) in value {
174        let ks = k.to_token_stream();
175        let vs = v.to_token_stream();
176        tokens.append_all(quote! {
177            (#ks, #vs),
178        });
179    }
180    quote! {
181        ::std::collections::BTreeMap::from([
182            #tokens
183        ])
184    }
185}
186
187/// Quote a `Cow<T>` preserving borrowed vs owned status for generated code.
188pub fn quote_cow<T: ToOwned + ToTokens + ?Sized>(value: &Cow<T>) -> TokenStream
189where
190    <T as ToOwned>::Owned: ToTokens,
191{
192    match value {
193        Cow::Borrowed(v) => quote! { ::std::borrow::Cow::Borrowed(#v) },
194        Cow::Owned(v) => quote! { ::std::borrow::Cow::Borrowed(#v) },
195    }
196}
197
198/// Quote an `Option<T>` into tokens.
199pub fn quote_option<T: ToTokens>(value: &Option<T>) -> TokenStream {
200    match value {
201        None => quote! { None },
202        Some(v) => quote! { Some(#v) },
203    }
204}
205
206/// Determine if the trailing segments of a `syn::Path` match the expected identifiers.
207pub fn matches_path(path: &Path, expect: &[&str]) -> bool {
208    let len = min(path.segments.len(), expect.len());
209    path.segments
210        .iter()
211        .rev()
212        .take(len)
213        .map(|v| &v.ident)
214        .eq(expect.iter().rev().take(len))
215}
216
217/// Write an iterator of items separated by a delimiter into a string.
218pub fn separated_by<T, F>(
219    out: &mut DynQuery,
220    values: impl IntoIterator<Item = T>,
221    mut f: F,
222    separator: &str,
223) where
224    F: FnMut(&mut DynQuery, T),
225{
226    let mut len = out.len();
227    for v in values {
228        if out.len() > len {
229            out.push_str(separator);
230        }
231        len = out.len();
232        f(out, v);
233    }
234}
235
236/// Write, escaping occurrences of `search` char with `replace` while copying into buffer.
237pub fn write_escaped(out: &mut DynQuery, value: &str, search: char, replace: &str) {
238    let mut position = 0;
239    for (i, c) in value.char_indices() {
240        if c == search {
241            out.push_str(&value[position..i]);
242            out.push_str(replace);
243            position = i + 1;
244        }
245    }
246    out.push_str(&value[position..]);
247}
248
249/// Convenience wrapper converting into a `CString`.
250pub fn as_c_string(str: impl Into<Vec<u8>>) -> CString {
251    CString::new(
252        str.into()
253            .into_iter()
254            .map(|b| if b == 0 { b'?' } else { b })
255            .collect::<Vec<u8>>(),
256    )
257    .unwrap_or_default()
258}
259
260pub fn error_message_from_ptr<'a>(ptr: &'a *const i8) -> Cow<'a, str> {
261    unsafe {
262        if *ptr != ptr::null() {
263            CStr::from_ptr(*ptr).to_string_lossy()
264        } else {
265            Cow::Borrowed("Unknown error: could not extract the error message")
266        }
267    }
268}
269
270/// Consume a prefix of `input` while the predicate returns true, returning that slice.
271pub fn consume_while<'s>(input: &mut &'s str, predicate: impl FnMut(&char) -> bool) -> &'s str {
272    let len = input.chars().take_while(predicate).count();
273    if len == 0 {
274        return "";
275    }
276    let result = &input[..len];
277    *input = &input[len..];
278    result
279}
280
281pub fn extract_number<'s, const SIGNED: bool>(input: &mut &'s str) -> &'s str {
282    let mut end = 0;
283    let mut chars = input.chars().peekable();
284    if SIGNED && matches!(chars.peek(), Some('+') | Some('-')) {
285        chars.next();
286        end += 1;
287    }
288    for _ in chars.take_while(char::is_ascii_digit) {
289        end += 1;
290    }
291    let result = &input[..end];
292    *input = &input[end..];
293    result
294}
295
296pub fn print_date(out: &mut impl Write, quote: &str, date: &Date) {
297    let _ = write!(
298        out,
299        "{quote}{:04}-{:02}-{:02}{quote}",
300        date.year(),
301        date.month() as u8,
302        date.day(),
303    );
304}
305
306pub fn print_timer(out: &mut impl Write, quote: &str, h: i64, m: u8, s: u8, ns: u32) {
307    let mut subsecond = ns;
308    let mut width = 9;
309    while width > 1 && subsecond % 10 == 0 {
310        subsecond /= 10;
311        width -= 1;
312    }
313    let _ = write!(
314        out,
315        "{quote}{h:02}:{m:02}:{s:02}.{subsecond:0width$}{quote}",
316    );
317}
318
319#[macro_export]
320macro_rules! number_to_month {
321    ($month:expr, $throw:expr $(,)?) => {
322        match $month {
323            1 => Month::January,
324            2 => Month::February,
325            3 => Month::March,
326            4 => Month::April,
327            5 => Month::May,
328            6 => Month::June,
329            7 => Month::July,
330            8 => Month::August,
331            9 => Month::September,
332            10 => Month::October,
333            11 => Month::November,
334            12 => Month::December,
335            _ => $throw,
336        }
337    };
338}
339
340#[macro_export]
341macro_rules! month_to_number {
342    ($month:expr $(,)?) => {
343        match $month {
344            Month::January => 1,
345            Month::February => 2,
346            Month::March => 3,
347            Month::April => 4,
348            Month::May => 5,
349            Month::June => 6,
350            Month::July => 7,
351            Month::August => 8,
352            Month::September => 9,
353            Month::October => 10,
354            Month::November => 11,
355            Month::December => 12,
356        }
357    };
358}
359
360#[macro_export]
361/// Conditionally wrap a generated fragment in parentheses.
362macro_rules! possibly_parenthesized {
363    ($out:ident, $cond:expr, $v:expr) => {
364        if $cond {
365            $out.push('(');
366            $v;
367            $out.push(')');
368        } else {
369            $v;
370        }
371    };
372}
373
374#[macro_export]
375/// Truncate long strings for logging and error messages purpose.
376///
377/// Returns a `format_args!` that yields at most 497 characters from the start
378/// of the input followed by `...` when truncation occurred. Minimal overhead.
379///
380/// If true is the second argument, it evaluates the first argument just once.
381///
382/// # Examples
383/// ```ignore
384/// use tank_core::truncate_long;
385/// let short = "SELECT 1";
386/// assert_eq!(format!("{}", truncate_long!(short)), "SELECT 1\n");
387/// let long = format!("SELECT {}", "X".repeat(600));
388/// let logged = format!("{}", truncate_long!(long));
389/// assert!(logged.starts_with("SELECT XXXXXX"));
390/// assert!(logged.ends_with("...\n"));
391/// ```
392macro_rules! truncate_long {
393    ($query:expr) => {
394        format_args!(
395            "{}{}",
396            &$query[..::std::cmp::min($query.len(), 497)].trim(),
397            if $query.len() > 497 { "...\n" } else { "" },
398        )
399    };
400    ($query:expr,true) => {{
401        let query = $query;
402        format!(
403            "{}{}",
404            &query[..::std::cmp::min(query.len(), 497)].trim(),
405            if query.len() > 497 { "...\n" } else { "" },
406        )
407    }};
408}
409
410/// Sends the value through the channel and logs in case of error.
411///
412/// Parameters:
413/// * `$tx`: sender channel
414/// * `$value`: value to be sent
415///
416/// *Example*:
417/// ```ignore
418/// send_value!(tx, Ok(QueryResult::Row(row)));
419/// ```
420
421#[macro_export]
422macro_rules! send_value {
423    ($tx:ident, $value:expr) => {{
424        if let Err(e) = $tx.send($value) {
425            log::error!("{e:#}");
426        }
427    }};
428}
429
430/// Incrementally accumulates tokens from a speculative parse stream until one
431/// of the supplied parsers succeeds.
432///
433/// Returns `(accumulated_tokens, (parser1_option, parser2_option, ...))` with
434/// exactly one `Some(T)`: the first successful parser.
435#[doc(hidden)]
436#[macro_export]
437macro_rules! take_until {
438    ($original:expr, $($parser:expr),+ $(,)?) => {{
439        let macro_local_input = $original.fork();
440        let mut macro_local_result = (
441            TokenStream::new(),
442            ($({
443                let _ = $parser;
444                None
445            }),+),
446        );
447        loop {
448            if macro_local_input.is_empty() {
449                break;
450            }
451            let mut parsed = false;
452            let produced = ($({
453                let attempt = macro_local_input.fork();
454                if let Ok(content) = ($parser)(&attempt) {
455                    macro_local_input.advance_to(&attempt);
456                    parsed = true;
457                    Some(content)
458                } else {
459                    None
460                }
461            }),+);
462            if parsed {
463                macro_local_result.1 = produced;
464                break;
465            }
466            macro_local_result.0.append(macro_local_input.parse::<TokenTree>()?);
467        }
468        $original.advance_to(&macro_local_input);
469        macro_local_result
470    }};
471}
472
473#[macro_export]
474/// Implement the `Executor` trait for a transaction wrapper type by
475/// delegating each operation to an underlying connection object.
476///
477/// This reduces boilerplate across driver implementations. The macro expands
478/// into an `impl Executor for $transaction<'c>` with forwarding methods for
479/// `prepare`, `run`, `fetch`, `execute`, and `append`.
480///
481/// Parameters:
482/// * `$driver`: concrete driver type.
483/// * `$transaction`: transaction wrapper type (generic over lifetime `'c`).
484/// * `$connection`: field name on the transaction pointing to the connection.
485///
486/// # Examples
487/// ```ignore
488/// use crate::{YourDBConnection, YourDBDriver};
489/// use tank_core::{Error, Result, Transaction, impl_executor_transaction};
490///
491/// pub struct YourDBTransaction<'c> {
492///     connection: &'c mut YourDBConnection,
493/// }
494///
495/// impl_executor_transaction!(YourDBDriver, YourDBTransaction<'c>, connection);
496///
497/// impl<'c> Transaction<'c> for YourDBTransaction<'c> { ... }
498/// ```
499macro_rules! impl_executor_transaction {
500    // Case 1: Lifetime is present (necessary for transactions)
501    ($driver:ty, $transaction:ident $(< $lt:lifetime >)?, $connection:ident) => {
502       impl $(<$lt>)? ::tank_core::Executor for $transaction $(<$lt>)? {
503            type Driver = $driver;
504
505            fn accepts_multiple_statements(&self) -> bool {
506                self.$connection.accepts_multiple_statements()
507            }
508
509            fn do_prepare(
510                &mut self,
511                sql: String,
512            ) -> impl Future<Output = ::tank_core::Result<::tank_core::Query<Self::Driver>>> + Send
513            {
514                self.$connection.do_prepare(sql)
515            }
516
517            fn run<'s>(
518                &'s mut self,
519                query: impl ::tank_core::AsQuery<Self::Driver> + 's,
520            ) -> impl ::tank_core::stream::Stream<
521                Item = ::tank_core::Result<::tank_core::QueryResult>,
522            > + Send {
523                self.$connection.run(query)
524            }
525
526            fn fetch<'s>(
527                &'s mut self,
528                query: impl ::tank_core::AsQuery<Self::Driver> + 's,
529            ) -> impl ::tank_core::stream::Stream<
530                Item = ::tank_core::Result<::tank_core::RowLabeled>,
531            > + Send
532            + 's {
533                self.$connection.fetch(query)
534            }
535
536            fn execute<'s>(
537                &'s mut self,
538                query: impl ::tank_core::AsQuery<Self::Driver> + 's,
539            ) -> impl Future<Output = ::tank_core::Result<::tank_core::RowsAffected>> + Send {
540                self.$connection.execute(query)
541            }
542
543            fn append<'a, E, It>(
544                &mut self,
545                entities: It,
546            ) -> impl Future<Output = ::tank_core::Result<::tank_core::RowsAffected>> + Send
547            where
548                E: ::tank_core::Entity + 'a,
549                It: IntoIterator<Item = &'a E> + Send,
550                <It as IntoIterator>::IntoIter: Send,
551            {
552                self.$connection.append(entities)
553            }
554        }
555    }
556}