1use crate::time_duration::TimeDuration;
2use crate::timestamp::Timestamp;
3use crate::{
4    algebraic_value::ser::ValueSerializer,
5    ser::{self, Serialize},
6    ProductType, ProductTypeElement,
7};
8use crate::{i256, u256};
9use core::fmt;
10use core::fmt::Write as _;
11use derive_more::{From, Into};
12
13pub trait Satn: ser::Serialize {
15    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17        Writer::with(f, |f| self.serialize(SatnFormatter { f }))?;
18        Ok(())
19    }
20
21    fn fmt_psql(&self, f: &mut fmt::Formatter, ty: &PsqlType<'_>) -> fmt::Result {
23        Writer::with(f, |f| {
24            self.serialize(PsqlFormatter {
25                fmt: SatnFormatter { f },
26                ty,
27            })
28        })?;
29        Ok(())
30    }
31
32    fn to_satn(&self) -> String {
34        Wrapper::from_ref(self).to_string()
35    }
36
37    fn to_satn_pretty(&self) -> String {
39        format!("{:#}", Wrapper::from_ref(self))
40    }
41}
42
43impl<T: ser::Serialize + ?Sized> Satn for T {}
44
45#[repr(transparent)]
49pub struct Wrapper<T: ?Sized>(pub T);
50
51impl<T: ?Sized> Wrapper<T> {
52    pub fn from_ref(t: &T) -> &Self {
54        unsafe { &*(t as *const T as *const Self) }
57    }
58}
59
60impl<T: Satn + ?Sized> fmt::Display for Wrapper<T> {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        self.0.fmt(f)
63    }
64}
65
66impl<T: Satn + ?Sized> fmt::Debug for Wrapper<T> {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        self.0.fmt(f)
69    }
70}
71
72pub struct PsqlWrapper<'a, T: ?Sized> {
76    pub ty: PsqlType<'a>,
77    pub value: T,
78}
79
80impl<T: ?Sized> PsqlWrapper<'_, T> {
81    pub fn from_ref(t: &T) -> &Self {
83        unsafe { &*(t as *const T as *const Self) }
86    }
87}
88
89impl<T: Satn + ?Sized> fmt::Display for PsqlWrapper<'_, T> {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        self.value.fmt_psql(f, &self.ty)
92    }
93}
94
95impl<T: Satn + ?Sized> fmt::Debug for PsqlWrapper<'_, T> {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        self.value.fmt_psql(f, &self.ty)
98    }
99}
100
101struct EntryWrapper<'a, 'f, const SEP: char> {
103    fmt: Writer<'a, 'f>,
105    has_fields: bool,
108}
109
110impl<'a, 'f, const SEP: char> EntryWrapper<'a, 'f, SEP> {
111    fn new(fmt: Writer<'a, 'f>) -> Self {
113        Self { fmt, has_fields: false }
114    }
115
116    fn entry(&mut self, entry: impl FnOnce(Writer) -> fmt::Result) -> fmt::Result {
120        let res = (|| match &mut self.fmt {
121            Writer::Pretty(f) => {
122                if !self.has_fields {
123                    f.write_char('\n')?;
124                }
125                f.state.indent += 1;
126                entry(Writer::Pretty(f.as_mut()))?;
127                f.write_char(SEP)?;
128                f.write_char('\n')?;
129                f.state.indent -= 1;
130                Ok(())
131            }
132            Writer::Normal(f) => {
133                if self.has_fields {
134                    f.write_char(SEP)?;
135                    f.write_char(' ')?;
136                }
137                entry(Writer::Normal(f))
138            }
139        })();
140        self.has_fields = true;
141        res
142    }
143}
144
145enum Writer<'a, 'f> {
147    Normal(&'a mut fmt::Formatter<'f>),
149    Pretty(IndentedWriter<'a, 'f>),
151}
152
153impl<'f> Writer<'_, 'f> {
154    fn with<R>(f: &mut fmt::Formatter<'_>, func: impl FnOnce(Writer<'_, '_>) -> R) -> R {
156        let mut state;
157        let f = if f.alternate() {
159            state = IndentState {
160                indent: 0,
161                on_newline: true,
162            };
163            Writer::Pretty(IndentedWriter { f, state: &mut state })
164        } else {
165            Writer::Normal(f)
166        };
167        func(f)
168    }
169
170    fn as_mut(&mut self) -> Writer<'_, 'f> {
172        match self {
173            Writer::Normal(f) => Writer::Normal(f),
174            Writer::Pretty(f) => Writer::Pretty(f.as_mut()),
175        }
176    }
177}
178
179struct IndentedWriter<'a, 'f> {
181    f: &'a mut fmt::Formatter<'f>,
182    state: &'a mut IndentState,
183}
184
185struct IndentState {
187    indent: u32,
189    on_newline: bool,
191}
192
193impl<'f> IndentedWriter<'_, 'f> {
194    fn as_mut(&mut self) -> IndentedWriter<'_, 'f> {
196        IndentedWriter {
197            f: self.f,
198            state: self.state,
199        }
200    }
201}
202
203impl fmt::Write for IndentedWriter<'_, '_> {
204    fn write_str(&mut self, s: &str) -> fmt::Result {
205        for s in s.split_inclusive('\n') {
206            if self.state.on_newline {
207                for _ in 0..self.state.indent {
209                    self.f.write_str("    ")?;
210                }
211            }
212
213            self.state.on_newline = s.ends_with('\n');
214            self.f.write_str(s)?;
215        }
216        Ok(())
217    }
218}
219
220impl fmt::Write for Writer<'_, '_> {
221    fn write_str(&mut self, s: &str) -> fmt::Result {
222        match self {
223            Writer::Normal(f) => f.write_str(s),
224            Writer::Pretty(f) => f.write_str(s),
225        }
226    }
227}
228
229struct SatnFormatter<'a, 'f> {
231    f: Writer<'a, 'f>,
233}
234
235#[derive(From, Into)]
237struct SatnError(fmt::Error);
238
239impl ser::Error for SatnError {
240    fn custom<T: fmt::Display>(_msg: T) -> Self {
241        Self(fmt::Error)
242    }
243}
244
245impl SatnFormatter<'_, '_> {
246    #[inline(always)]
248    fn write_fmt(&mut self, args: fmt::Arguments) -> Result<(), SatnError> {
249        self.f.write_fmt(args)?;
250        Ok(())
251    }
252}
253
254impl<'a, 'f> ser::Serializer for SatnFormatter<'a, 'f> {
255    type Ok = ();
256    type Error = SatnError;
257    type SerializeArray = ArrayFormatter<'a, 'f>;
258    type SerializeSeqProduct = SeqFormatter<'a, 'f>;
259    type SerializeNamedProduct = NamedFormatter<'a, 'f>;
260
261    fn serialize_bool(mut self, v: bool) -> Result<Self::Ok, Self::Error> {
262        write!(self, "{v}")
263    }
264    fn serialize_u8(mut self, v: u8) -> Result<Self::Ok, Self::Error> {
265        write!(self, "{v}")
266    }
267    fn serialize_u16(mut self, v: u16) -> Result<Self::Ok, Self::Error> {
268        write!(self, "{v}")
269    }
270    fn serialize_u32(mut self, v: u32) -> Result<Self::Ok, Self::Error> {
271        write!(self, "{v}")
272    }
273    fn serialize_u64(mut self, v: u64) -> Result<Self::Ok, Self::Error> {
274        write!(self, "{v}")
275    }
276    fn serialize_u128(mut self, v: u128) -> Result<Self::Ok, Self::Error> {
277        write!(self, "{v}")
278    }
279    fn serialize_u256(mut self, v: u256) -> Result<Self::Ok, Self::Error> {
280        write!(self, "{v}")
281    }
282    fn serialize_i8(mut self, v: i8) -> Result<Self::Ok, Self::Error> {
283        write!(self, "{v}")
284    }
285    fn serialize_i16(mut self, v: i16) -> Result<Self::Ok, Self::Error> {
286        write!(self, "{v}")
287    }
288    fn serialize_i32(mut self, v: i32) -> Result<Self::Ok, Self::Error> {
289        write!(self, "{v}")
290    }
291    fn serialize_i64(mut self, v: i64) -> Result<Self::Ok, Self::Error> {
292        write!(self, "{v}")
293    }
294    fn serialize_i128(mut self, v: i128) -> Result<Self::Ok, Self::Error> {
295        write!(self, "{v}")
296    }
297    fn serialize_i256(mut self, v: i256) -> Result<Self::Ok, Self::Error> {
298        write!(self, "{v}")
299    }
300    fn serialize_f32(mut self, v: f32) -> Result<Self::Ok, Self::Error> {
301        write!(self, "{v}")
302    }
303    fn serialize_f64(mut self, v: f64) -> Result<Self::Ok, Self::Error> {
304        write!(self, "{v}")
305    }
306
307    fn serialize_str(mut self, v: &str) -> Result<Self::Ok, Self::Error> {
308        write!(self, "\"{}\"", v)
309    }
310
311    fn serialize_bytes(mut self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
312        write!(self, "0x{}", hex::encode(v))
313    }
314
315    fn serialize_array(mut self, _len: usize) -> Result<Self::SerializeArray, Self::Error> {
316        write!(self, "[")?; Ok(ArrayFormatter {
318            f: EntryWrapper::new(self.f),
319        })
320    }
321
322    fn serialize_seq_product(self, len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
323        self.serialize_named_product(len).map(|inner| SeqFormatter { inner })
325    }
326
327    fn serialize_named_product(mut self, _len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
328        write!(self, "(")?; Ok(NamedFormatter {
330            f: EntryWrapper::new(self.f),
331            idx: 0,
332        })
333    }
334
335    fn serialize_variant<T: ser::Serialize + ?Sized>(
336        mut self,
337        _tag: u8,
338        name: Option<&str>,
339        value: &T,
340    ) -> Result<Self::Ok, Self::Error> {
341        write!(self, "(")?;
342        EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| {
343            if let Some(name) = name {
344                write!(f, "{}", name)?;
345            }
346            write!(f, " = ")?;
347            value.serialize(SatnFormatter { f })?;
348            Ok(())
349        })?;
350        write!(self, ")")
351    }
352
353    unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
354        let res = unsafe { ValueSerializer.serialize_bsatn(ty, bsatn) };
360        let value = res.unwrap_or_else(|x| match x {});
361
362        value.serialize(self)
364    }
365
366    unsafe fn serialize_bsatn_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
367        self,
368        ty: &crate::AlgebraicType,
369        total_bsatn_len: usize,
370        bsatn: I,
371    ) -> Result<Self::Ok, Self::Error> {
372        let res = unsafe { ValueSerializer.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) };
379        let value = res.unwrap_or_else(|x| match x {});
380
381        value.serialize(self)
383    }
384
385    unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
386        self,
387        total_len: usize,
388        string: I,
389    ) -> Result<Self::Ok, Self::Error> {
390        let res = unsafe { ValueSerializer.serialize_str_in_chunks(total_len, string) };
393        let value = res.unwrap_or_else(|x| match x {});
394
395        value.serialize(self)
398    }
399}
400
401struct ArrayFormatter<'a, 'f> {
403    f: EntryWrapper<'a, 'f, ','>,
405}
406
407impl ser::SerializeArray for ArrayFormatter<'_, '_> {
408    type Ok = ();
409    type Error = SatnError;
410
411    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
412        self.f.entry(|f| elem.serialize(SatnFormatter { f }).map_err(|e| e.0))?;
413        Ok(())
414    }
415
416    fn end(mut self) -> Result<Self::Ok, Self::Error> {
417        write!(self.f.fmt, "]")?;
418        Ok(())
419    }
420}
421
422struct SeqFormatter<'a, 'f> {
424    inner: NamedFormatter<'a, 'f>,
426}
427
428impl ser::SerializeSeqProduct for SeqFormatter<'_, '_> {
429    type Ok = ();
430    type Error = SatnError;
431
432    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
433        ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem)
434    }
435
436    fn end(self) -> Result<Self::Ok, Self::Error> {
437        ser::SerializeNamedProduct::end(self.inner)
438    }
439}
440
441struct NamedFormatter<'a, 'f> {
443    f: EntryWrapper<'a, 'f, ','>,
445    idx: usize,
447}
448
449impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> {
450    type Ok = ();
451    type Error = SatnError;
452
453    fn serialize_element<T: ser::Serialize + ?Sized>(
454        &mut self,
455        name: Option<&str>,
456        elem: &T,
457    ) -> Result<(), Self::Error> {
458        let res = self.f.entry(|mut f| {
459            if let Some(name) = name {
461                write!(f, "{}", name)?;
462            } else {
463                write!(f, "{}", self.idx)?;
464            }
465            write!(f, " = ")?;
466            elem.serialize(SatnFormatter { f })?;
467            Ok(())
468        });
469        self.idx += 1;
470        res?;
471        Ok(())
472    }
473
474    fn end(mut self) -> Result<Self::Ok, Self::Error> {
475        write!(self.f.fmt, ")")?;
476        Ok(())
477    }
478}
479
480struct PsqlEntryWrapper<'a, 'f, const SEP: char> {
481    entry: EntryWrapper<'a, 'f, SEP>,
482    idx: usize,
484    ty: &'a PsqlType<'a>,
485}
486
487struct PsqlNamedFormatter<'a, 'f> {
489    f: PsqlEntryWrapper<'a, 'f, ','>,
491    start: bool,
493    use_fmt: PsqlPrintFmt,
495}
496
497impl<'a, 'f> PsqlNamedFormatter<'a, 'f> {
498    pub fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self {
499        Self {
500            start: true,
501            f: PsqlEntryWrapper {
502                entry: EntryWrapper::new(f),
503                idx: 0,
504                ty,
505            },
506            use_fmt: PsqlPrintFmt::Satn,
508        }
509    }
510}
511
512impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> {
513    type Ok = ();
514    type Error = SatnError;
515
516    fn serialize_element<T: Satn + ser::Serialize + ?Sized>(
517        &mut self,
518        name: Option<&str>,
519        elem: &T,
520    ) -> Result<(), Self::Error> {
521        self.use_fmt = self.f.ty.use_fmt(name);
524        let res = self.f.entry.entry(|mut f| {
525            let PsqlType { tuple, field, idx } = self.f.ty;
526            if !self.use_fmt.is_special() {
527                if self.start {
528                    write!(f, "(")?;
529                    self.start = false;
530                }
531                if let Some(name) = name {
533                    write!(f, "{}", name)?;
534                } else {
535                    write!(f, "{}", idx)?;
536                }
537                write!(f, " = ")?;
538            }
539            let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() {
541                (product, &product.elements[self.f.idx], self.f.idx)
542            } else {
543                (*tuple, *field, *idx)
544            };
545
546            elem.serialize(PsqlFormatter {
547                fmt: SatnFormatter { f },
548                ty: &PsqlType { tuple, field, idx },
549            })?;
550
551            Ok(())
552        });
553
554        if !self.use_fmt.is_special() {
556            self.f.idx += 1;
557        }
558
559        res?;
560
561        Ok(())
562    }
563
564    fn end(mut self) -> Result<Self::Ok, Self::Error> {
565        if !self.use_fmt.is_special() {
566            write!(self.f.entry.fmt, ")")?;
567        }
568        Ok(())
569    }
570}
571
572struct PsqlSeqFormatter<'a, 'f> {
574    inner: PsqlNamedFormatter<'a, 'f>,
576}
577
578impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> {
579    type Ok = ();
580    type Error = SatnError;
581
582    fn serialize_element<T: ser::Serialize + ?Sized>(&mut self, elem: &T) -> Result<(), Self::Error> {
583        ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem)
584    }
585
586    fn end(self) -> Result<Self::Ok, Self::Error> {
587        ser::SerializeNamedProduct::end(self.inner)
588    }
589}
590
591#[derive(PartialEq)]
593pub enum PsqlPrintFmt {
594    Hex,
596    Timestamp,
598    Duration,
600    Satn,
602}
603
604impl PsqlPrintFmt {
605    fn is_special(&self) -> bool {
606        self != &PsqlPrintFmt::Satn
607    }
608}
609
610#[derive(Debug, Clone)]
612pub struct PsqlType<'a> {
613    pub tuple: &'a ProductType,
615    pub field: &'a ProductTypeElement,
617    pub idx: usize,
619}
620
621impl PsqlType<'_> {
622    fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt {
626        if self.tuple.is_identity()
627            || self.tuple.is_connection_id()
628            || self.field.algebraic_type.is_identity()
629            || self.field.algebraic_type.is_connection_id()
630            || name.map(ProductType::is_identity_tag).unwrap_or_default()
631            || name.map(ProductType::is_connection_id_tag).unwrap_or_default()
632        {
633            return PsqlPrintFmt::Hex;
634        };
635
636        if self.tuple.is_timestamp()
637            || self.field.algebraic_type.is_timestamp()
638            || name.map(ProductType::is_timestamp_tag).unwrap_or_default()
639        {
640            return PsqlPrintFmt::Timestamp;
641        };
642
643        if self.tuple.is_time_duration()
644            || self.field.algebraic_type.is_time_duration()
645            || name.map(ProductType::is_time_duration_tag).unwrap_or_default()
646        {
647            return PsqlPrintFmt::Duration;
648        };
649
650        PsqlPrintFmt::Satn
651    }
652}
653
654struct PsqlFormatter<'a, 'f> {
656    fmt: SatnFormatter<'a, 'f>,
657    ty: &'a PsqlType<'a>,
658}
659
660impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> {
661    type Ok = ();
662    type Error = SatnError;
663    type SerializeArray = ArrayFormatter<'a, 'f>;
664    type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>;
665    type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>;
666
667    fn serialize_bool(self, v: bool) -> Result<Self::Ok, Self::Error> {
668        self.fmt.serialize_bool(v)
669    }
670    fn serialize_u8(self, v: u8) -> Result<Self::Ok, Self::Error> {
671        self.fmt.serialize_u8(v)
672    }
673    fn serialize_u16(self, v: u16) -> Result<Self::Ok, Self::Error> {
674        self.fmt.serialize_u16(v)
675    }
676    fn serialize_u32(self, v: u32) -> Result<Self::Ok, Self::Error> {
677        self.fmt.serialize_u32(v)
678    }
679    fn serialize_u64(self, v: u64) -> Result<Self::Ok, Self::Error> {
680        self.fmt.serialize_u64(v)
681    }
682    fn serialize_u128(self, v: u128) -> Result<Self::Ok, Self::Error> {
683        match self.ty.use_fmt(None) {
684            PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()),
685            _ => self.fmt.serialize_u128(v),
686        }
687    }
688    fn serialize_u256(self, v: u256) -> Result<Self::Ok, Self::Error> {
689        match self.ty.use_fmt(None) {
690            PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()),
691            _ => self.fmt.serialize_u256(v),
692        }
693    }
694    fn serialize_i8(self, v: i8) -> Result<Self::Ok, Self::Error> {
695        self.fmt.serialize_i8(v)
696    }
697    fn serialize_i16(self, v: i16) -> Result<Self::Ok, Self::Error> {
698        self.fmt.serialize_i16(v)
699    }
700    fn serialize_i32(self, v: i32) -> Result<Self::Ok, Self::Error> {
701        self.fmt.serialize_i32(v)
702    }
703    fn serialize_i64(mut self, v: i64) -> Result<Self::Ok, Self::Error> {
704        match self.ty.use_fmt(None) {
705            PsqlPrintFmt::Duration => {
706                write!(self.fmt, "{}", TimeDuration::from_micros(v))?;
707                Ok(())
708            }
709            PsqlPrintFmt::Timestamp => {
710                write!(self.fmt, "{}", Timestamp::from_micros_since_unix_epoch(v))?;
711                Ok(())
712            }
713            _ => self.fmt.serialize_i64(v),
714        }
715    }
716    fn serialize_i128(self, v: i128) -> Result<Self::Ok, Self::Error> {
717        self.fmt.serialize_i128(v)
718    }
719    fn serialize_i256(self, v: i256) -> Result<Self::Ok, Self::Error> {
720        self.fmt.serialize_i256(v)
721    }
722    fn serialize_f32(self, v: f32) -> Result<Self::Ok, Self::Error> {
723        self.fmt.serialize_f32(v)
724    }
725    fn serialize_f64(self, v: f64) -> Result<Self::Ok, Self::Error> {
726        self.fmt.serialize_f64(v)
727    }
728
729    fn serialize_str(self, v: &str) -> Result<Self::Ok, Self::Error> {
730        self.fmt.serialize_str(v)
731    }
732
733    fn serialize_bytes(self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
734        self.fmt.serialize_bytes(v)
735    }
736
737    fn serialize_array(self, len: usize) -> Result<Self::SerializeArray, Self::Error> {
738        self.fmt.serialize_array(len)
739    }
740
741    fn serialize_seq_product(self, len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
742        Ok(PsqlSeqFormatter {
743            inner: self.serialize_named_product(len)?,
744        })
745    }
746
747    fn serialize_named_product(self, _len: usize) -> Result<Self::SerializeNamedProduct, Self::Error> {
748        Ok(PsqlNamedFormatter::new(self.ty, self.fmt.f))
749    }
750
751    fn serialize_variant<T: ser::Serialize + ?Sized>(
752        self,
753        tag: u8,
754        name: Option<&str>,
755        value: &T,
756    ) -> Result<Self::Ok, Self::Error> {
757        self.fmt.serialize_variant(tag, name, value)
758    }
759
760    unsafe fn serialize_bsatn(self, ty: &crate::AlgebraicType, bsatn: &[u8]) -> Result<Self::Ok, Self::Error> {
761        unsafe { self.fmt.serialize_bsatn(ty, bsatn) }
763    }
764
765    unsafe fn serialize_bsatn_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
766        self,
767        ty: &crate::AlgebraicType,
768        total_bsatn_len: usize,
769        bsatn: I,
770    ) -> Result<Self::Ok, Self::Error> {
771        unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) }
773    }
774
775    unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator<Item = &'c [u8]>>(
776        self,
777        total_len: usize,
778        string: I,
779    ) -> Result<Self::Ok, Self::Error> {
780        unsafe { self.fmt.serialize_str_in_chunks(total_len, string) }
782    }
783}