Skip to main content

tank_core/
util.rs

1use crate::{AsValue, ColumnDef, DynQuery, TableRef, Value};
2use proc_macro2::TokenStream;
3use quote::{ToTokens, TokenStreamExt, quote};
4use rust_decimal::prelude::ToPrimitive;
5use serde_json::{Map, Number, Value as JsonValue};
6use std::{
7    borrow::Cow,
8    cmp::min,
9    collections::BTreeMap,
10    ffi::{CStr, CString},
11    ptr,
12};
13use syn::Path;
14
15#[derive(Clone)]
16/// Iterator adapter for two types.
17pub enum EitherIterator<A, B>
18where
19    A: Iterator,
20    B: Iterator<Item = A::Item>,
21{
22    Left(A),
23    Right(B),
24}
25
26impl<A, B> Iterator for EitherIterator<A, B>
27where
28    A: Iterator,
29    B: Iterator<Item = A::Item>,
30{
31    type Item = A::Item;
32    fn next(&mut self) -> Option<Self::Item> {
33        match self {
34            Self::Left(a) => a.next(),
35            Self::Right(b) => b.next(),
36        }
37    }
38}
39
40pub fn value_to_json(v: &Value) -> Option<JsonValue> {
41    Some(match v {
42        _ if v.is_null() => JsonValue::Null,
43        Value::Boolean(Some(v), ..) => JsonValue::Bool(*v),
44        Value::Int8(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
45        Value::Int16(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
46        Value::Int32(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
47        Value::Int64(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
48        Value::Int128(Some(v), ..) => JsonValue::Number(Number::from_i128(*v as _)?),
49        Value::UInt8(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
50        Value::UInt16(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
51        Value::UInt32(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
52        Value::UInt64(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
53        Value::UInt128(Some(v), ..) => JsonValue::Number(Number::from_u128(*v as _)?),
54        Value::Float32(Some(v), ..) => JsonValue::Number(Number::from_f64(*v as _)?),
55        Value::Float64(Some(v), ..) => JsonValue::Number(Number::from_f64(*v as _)?),
56        Value::Decimal(Some(v), ..) => JsonValue::Number(Number::from_f64(v.to_f64()?)?),
57        Value::Char(Some(v), ..) => JsonValue::String(v.to_string()),
58        Value::Varchar(Some(v), ..) => JsonValue::String(v.to_string()),
59        Value::Blob(Some(v), ..) => JsonValue::Array(
60            v.iter()
61                .map(|v| Number::from_u128(*v as _).map(JsonValue::Number))
62                .collect::<Option<_>>()?,
63        ),
64        v @ (Value::Date(Some(..), ..)
65        | Value::Time(Some(..), ..)
66        | Value::Timestamp(Some(..), ..)
67        | Value::TimestampWithTimezone(Some(..), ..)) => JsonValue::String(v.to_string()),
68        Value::Interval(Some(_v), ..) => {
69            return None;
70        }
71        Value::Uuid(Some(v), ..) => JsonValue::String(v.to_string()),
72        Value::Array(Some(v), ..) => {
73            JsonValue::Array(v.iter().map(value_to_json).collect::<Option<_>>()?)
74        }
75        Value::List(Some(v), ..) => {
76            JsonValue::Array(v.iter().map(value_to_json).collect::<Option<_>>()?)
77        }
78        Value::Map(Some(v), ..) => {
79            let mut map = Map::new();
80            for (k, v) in v.iter() {
81                let Ok(k) = String::try_from_value(k.clone()) else {
82                    return None;
83                };
84                let Some(v) = value_to_json(v) else {
85                    return None;
86                };
87                map.insert(k, v)?;
88            }
89            JsonValue::Object(map)
90        }
91        Value::Json(Some(v), ..) => v.clone(),
92        Value::Struct(Some(v), ..) => {
93            let mut map = Map::new();
94            for (k, v) in v.iter() {
95                let Some(v) = value_to_json(v) else {
96                    return None;
97                };
98                map.insert(k.clone(), v)?;
99            }
100            JsonValue::Object(map)
101        }
102        Value::Unknown(Some(v), ..) => JsonValue::String(v.clone()),
103        _ => {
104            return None;
105        }
106    })
107}
108
109/// Quote a `BTreeMap<K, V>` into tokens.
110pub fn quote_btree_map<K: ToTokens, V: ToTokens>(value: &BTreeMap<K, V>) -> TokenStream {
111    let mut tokens = TokenStream::new();
112    for (k, v) in value {
113        let ks = k.to_token_stream();
114        let vs = v.to_token_stream();
115        tokens.append_all(quote! {
116            (#ks, #vs),
117        });
118    }
119    quote! {
120        ::std::collections::BTreeMap::from([
121            #tokens
122        ])
123    }
124}
125
126/// Quote a `Cow<T>` preserving borrowed vs owned status for generated code.
127pub fn quote_cow<T: ToOwned + ToTokens + ?Sized>(value: &Cow<T>) -> TokenStream
128where
129    <T as ToOwned>::Owned: ToTokens,
130{
131    match value {
132        Cow::Borrowed(v) => quote! { ::std::borrow::Cow::Borrowed(#v) },
133        Cow::Owned(v) => quote! { ::std::borrow::Cow::Owned(#v) },
134    }
135}
136
137/// Quote an `Option<T>` into tokens.
138pub fn quote_option<T: ToTokens>(value: &Option<T>) -> TokenStream {
139    match value {
140        None => quote! { None },
141        Some(v) => quote! { Some(#v) },
142    }
143}
144
145/// Determine if the trailing segments of a `syn::Path` match the expected identifiers.
146pub fn matches_path(path: &Path, expect: &[&str]) -> bool {
147    let len = min(path.segments.len(), expect.len());
148    path.segments
149        .iter()
150        .rev()
151        .take(len)
152        .map(|v| &v.ident)
153        .eq(expect.iter().rev().take(len))
154}
155
156/// Write an iterator of items separated by a delimiter into a string.
157pub fn separated_by<T, F>(
158    out: &mut DynQuery,
159    values: impl IntoIterator<Item = T>,
160    mut f: F,
161    separator: &str,
162) where
163    F: FnMut(&mut DynQuery, T),
164{
165    let mut len = out.len();
166    for v in values {
167        if out.len() > len {
168            out.push_str(separator);
169        }
170        len = out.len();
171        f(out, v);
172    }
173}
174
175/// Write, escaping occurrences of `search` char with `replace` while copying into buffer.
176pub fn write_escaped(out: &mut DynQuery, value: &str, search: char, replace: &str) {
177    let mut position = 0;
178    for (i, c) in value.char_indices() {
179        if c == search {
180            out.push_str(&value[position..i]);
181            out.push_str(replace);
182            position = i + c.len_utf8();
183        }
184    }
185    out.push_str(&value[position..]);
186}
187
188/// Convenience wrapper converting into a `CString`.
189pub fn as_c_string(str: impl Into<Vec<u8>>) -> CString {
190    CString::new(
191        str.into()
192            .into_iter()
193            .map(|b| if b == 0 { b'?' } else { b })
194            .collect::<Vec<u8>>(),
195    )
196    .unwrap_or_default()
197}
198
199pub fn error_message_from_ptr<'a>(ptr: &'a *const i8) -> Cow<'a, str> {
200    unsafe {
201        if *ptr != ptr::null() {
202            CStr::from_ptr(*ptr).to_string_lossy()
203        } else {
204            Cow::Borrowed("Unknown error: could not extract the error message")
205        }
206    }
207}
208
209/// Consume a prefix of `input` while the predicate returns true, returning that slice.
210pub fn consume_while<'s>(input: &mut &'s str, predicate: impl FnMut(&char) -> bool) -> &'s str {
211    let len = input
212        .chars()
213        .take_while(predicate)
214        .map(char::len_utf8)
215        .sum();
216    if len == 0 {
217        return "";
218    }
219    let result = &input[..len];
220    *input = &input[len..];
221    result
222}
223
224pub fn extract_number<'s, const SIGNED: bool>(input: &mut &'s str) -> &'s str {
225    let mut end = 0;
226    let mut chars = input.chars().peekable();
227    if SIGNED && matches!(chars.peek(), Some('+') | Some('-')) {
228        chars.next();
229        end += 1;
230    }
231    for _ in chars.take_while(char::is_ascii_digit) {
232        end += 1;
233    }
234    let result = &input[..end];
235    *input = &input[end..];
236    result
237}
238
239pub fn column_def(name: &str, table: &TableRef) -> Option<&'static ColumnDef> {
240    table.columns.iter().find(|v| v.name() == name)
241}
242
243#[macro_export]
244macro_rules! number_to_month {
245    ($month:expr, $throw:expr $(,)?) => {
246        match $month {
247            1 => Month::January,
248            2 => Month::February,
249            3 => Month::March,
250            4 => Month::April,
251            5 => Month::May,
252            6 => Month::June,
253            7 => Month::July,
254            8 => Month::August,
255            9 => Month::September,
256            10 => Month::October,
257            11 => Month::November,
258            12 => Month::December,
259            _ => $throw,
260        }
261    };
262}
263
264#[macro_export]
265macro_rules! month_to_number {
266    ($month:expr $(,)?) => {
267        match $month {
268            Month::January => 1,
269            Month::February => 2,
270            Month::March => 3,
271            Month::April => 4,
272            Month::May => 5,
273            Month::June => 6,
274            Month::July => 7,
275            Month::August => 8,
276            Month::September => 9,
277            Month::October => 10,
278            Month::November => 11,
279            Month::December => 12,
280        }
281    };
282}
283
284#[macro_export]
285/// Conditionally wrap a generated fragment in parentheses.
286macro_rules! possibly_parenthesized {
287    ($out:ident, $cond:expr, $v:expr) => {
288        if $cond {
289            $out.push('(');
290            $v;
291            $out.push(')');
292        } else {
293            $v;
294        }
295    };
296}
297
298pub const TRUNCATE_LONG_LIMIT: usize = {
299    #[cfg(debug_assertions)]
300    {
301        2000
302    }
303
304    #[cfg(not(debug_assertions))]
305    {
306        497
307    }
308};
309
310#[macro_export]
311/// Truncate long strings for logging and error messages purpose.
312///
313/// Returns a `format_args!` that yields at most 497 characters from the start
314/// of the input followed by `...` when truncation occurred. Minimal overhead.
315///
316/// If true is the second argument, it evaluates the first argument just once.
317///
318/// # Examples
319/// ```ignore
320/// use tank_core::truncate_long;
321/// let short = "SELECT 1";
322/// assert_eq!(format!("{}", truncate_long!(short)), "SELECT 1\n");
323/// let long = format!("SELECT {}", "X".repeat(600));
324/// let logged = format!("{}", truncate_long!(long));
325/// assert!(logged.starts_with("SELECT XXXXXX"));
326/// assert!(logged.ends_with("...\n"));
327/// ```
328macro_rules! truncate_long {
329    ($query:expr) => {
330        format_args!(
331            "{}{}",
332            &$query[..::std::cmp::min($query.len(), $crate::TRUNCATE_LONG_LIMIT)].trim(),
333            if $query.len() > $crate::TRUNCATE_LONG_LIMIT {
334                "...\n"
335            } else {
336                ""
337            },
338        )
339    };
340    ($query:expr,true) => {{
341        let query = $query;
342        format!(
343            "{}{}",
344            &query[..::std::cmp::min(query.len(), $crate::TRUNCATE_LONG_LIMIT)].trim(),
345            if query.len() > $crate::TRUNCATE_LONG_LIMIT {
346                "...\n"
347            } else {
348                ""
349            },
350        )
351    }};
352}
353
354/// Sends the value through the channel and logs in case of error.
355///
356/// Parameters:
357/// * `$tx`: sender channel
358/// * `$value`: value to be sent
359///
360/// *Example*:
361/// ```ignore
362/// send_value!(tx, Ok(QueryResult::Row(row)));
363/// ```
364
365#[macro_export]
366macro_rules! send_value {
367    ($tx:ident, $value:expr) => {{
368        if let Err(e) = $tx.send($value) {
369            log::error!("{e:#}");
370        }
371    }};
372}
373
374/// Incrementally accumulates tokens from a speculative parse stream until one
375/// of the supplied parsers succeeds.
376///
377/// Returns `(accumulated_tokens, (parser1_option, parser2_option, ...))` with
378/// exactly one `Some(T)`: the first successful parser.
379#[doc(hidden)]
380#[macro_export]
381macro_rules! take_until {
382    ($original:expr, $($parser:expr),+ $(,)?) => {{
383        let macro_local_input = $original.fork();
384        let mut macro_local_result = (
385            TokenStream::new(),
386            ($({
387                let _ = $parser;
388                None
389            }),+),
390        );
391        loop {
392            if macro_local_input.is_empty() {
393                break;
394            }
395            let mut parsed = false;
396            let produced = ($({
397                let attempt = macro_local_input.fork();
398                if let Ok(content) = ($parser)(&attempt) {
399                    macro_local_input.advance_to(&attempt);
400                    parsed = true;
401                    Some(content)
402                } else {
403                    None
404                }
405            }),+);
406            if parsed {
407                macro_local_result.1 = produced;
408                break;
409            }
410            macro_local_result.0.append(macro_local_input.parse::<TokenTree>()?);
411        }
412        $original.advance_to(&macro_local_input);
413        macro_local_result
414    }};
415}
416
417#[macro_export]
418/// Implement the `Executor` trait for a transaction wrapper type by
419/// delegating each operation to an underlying connection object.
420///
421/// This reduces boilerplate across driver implementations. The macro expands
422/// into an `impl Executor for $transaction<'c>` with forwarding methods for
423/// `prepare`, `run`, `fetch`, `execute`, and `append`.
424///
425/// Parameters:
426/// * `$driver`: concrete driver type.
427/// * `$transaction`: transaction wrapper type (generic over lifetime `'c`).
428/// * `$connection`: field name on the transaction pointing to the connection.
429///
430/// # Examples
431/// ```ignore
432/// use crate::{YourDBConnection, YourDBDriver};
433/// use tank_core::{Error, Result, Transaction, impl_executor_transaction};
434///
435/// pub struct YourDBTransaction<'c> {
436///     connection: &'c mut YourDBConnection,
437/// }
438///
439/// impl_executor_transaction!(YourDBDriver, YourDBTransaction<'c>, connection);
440///
441/// impl<'c> Transaction<'c> for YourDBTransaction<'c> { ... }
442/// ```
443macro_rules! impl_executor_transaction {
444    // Case 1: Lifetime is present (necessary for transactions)
445    ($driver:ty, $transaction:ident $(< $lt:lifetime >)?, $connection:ident) => {
446       impl $(<$lt>)? ::tank_core::Executor for $transaction $(<$lt>)? {
447            type Driver = $driver;
448
449            fn accepts_multiple_statements(&self) -> bool {
450                self.$connection.accepts_multiple_statements()
451            }
452
453            fn do_prepare(
454                &mut self,
455                sql: String,
456            ) -> impl Future<Output = ::tank_core::Result<::tank_core::Query<Self::Driver>>> + Send
457            {
458                self.$connection.do_prepare(sql)
459            }
460
461            fn run<'s>(
462                &'s mut self,
463                query: impl ::tank_core::AsQuery<Self::Driver> + 's,
464            ) -> impl ::tank_core::stream::Stream<
465                Item = ::tank_core::Result<::tank_core::QueryResult>,
466            > + Send {
467                self.$connection.run(query)
468            }
469
470            fn fetch<'s>(
471                &'s mut self,
472                query: impl ::tank_core::AsQuery<Self::Driver> + 's,
473            ) -> impl ::tank_core::stream::Stream<
474                Item = ::tank_core::Result<::tank_core::Row>,
475            > + Send
476            + 's {
477                self.$connection.fetch(query)
478            }
479
480            fn execute<'s>(
481                &'s mut self,
482                query: impl ::tank_core::AsQuery<Self::Driver> + 's,
483            ) -> impl Future<Output = ::tank_core::Result<::tank_core::RowsAffected>> + Send {
484                self.$connection.execute(query)
485            }
486
487            fn append<'a, E, It>(
488                &mut self,
489                entities: It,
490            ) -> impl Future<Output = ::tank_core::Result<::tank_core::RowsAffected>> + Send
491            where
492                E: ::tank_core::Entity + 'a,
493                It: IntoIterator<Item = &'a E> + Send,
494                <It as IntoIterator>::IntoIter: Send,
495            {
496                self.$connection.append(entities)
497            }
498        }
499    }
500}
501
502#[macro_export]
503macro_rules! current_timestamp_ms {
504    () => {{}};
505}