Skip to main content

sentinel_driver/types/
range.rs

1use bytes::{BufMut, BytesMut};
2
3use crate::error::{Error, Result};
4use crate::types::{FromSql, Oid, ToSql};
5
6const RANGE_EMPTY: u8 = 0x01;
7const RANGE_LB_INC: u8 = 0x02;
8const RANGE_UB_INC: u8 = 0x04;
9const RANGE_LB_INF: u8 = 0x08;
10const RANGE_UB_INF: u8 = 0x10;
11
12/// A bound of a PostgreSQL range.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub enum RangeBound<T> {
15    Inclusive(T),
16    Exclusive(T),
17    Unbounded,
18}
19
20/// PostgreSQL range type.
21///
22/// Generic over the element type `T`. The `range_oid` and `element_oid` must
23/// be provided since Rust generics cannot map to PG range OIDs automatically.
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct PgRange<T> {
26    pub lower: RangeBound<T>,
27    pub upper: RangeBound<T>,
28    pub is_empty: bool,
29    pub range_oid: Oid,
30    pub element_oid: Oid,
31}
32
33impl<T> PgRange<T> {
34    /// Create an empty range.
35    pub fn empty(range_oid: Oid, element_oid: Oid) -> Self {
36        PgRange {
37            lower: RangeBound::Unbounded,
38            upper: RangeBound::Unbounded,
39            is_empty: true,
40            range_oid,
41            element_oid,
42        }
43    }
44}
45
46impl<T: ToSql> ToSql for PgRange<T> {
47    fn oid(&self) -> Oid {
48        self.range_oid
49    }
50
51    fn to_sql(&self, buf: &mut BytesMut) -> Result<()> {
52        if self.is_empty {
53            buf.put_u8(RANGE_EMPTY);
54            return Ok(());
55        }
56
57        let mut flags: u8 = 0;
58
59        match &self.lower {
60            RangeBound::Inclusive(_) => flags |= RANGE_LB_INC,
61            RangeBound::Exclusive(_) => {}
62            RangeBound::Unbounded => flags |= RANGE_LB_INF,
63        }
64
65        match &self.upper {
66            RangeBound::Inclusive(_) => flags |= RANGE_UB_INC,
67            RangeBound::Exclusive(_) => {}
68            RangeBound::Unbounded => flags |= RANGE_UB_INF,
69        }
70
71        buf.put_u8(flags);
72
73        // Encode lower bound
74        match &self.lower {
75            RangeBound::Inclusive(v) | RangeBound::Exclusive(v) => {
76                let len_pos = buf.len();
77                buf.put_i32(0); // placeholder
78                let data_start = buf.len();
79                v.to_sql(buf)?;
80                let data_len = (buf.len() - data_start) as i32;
81                buf[len_pos..len_pos + 4].copy_from_slice(&data_len.to_be_bytes());
82            }
83            RangeBound::Unbounded => {}
84        }
85
86        // Encode upper bound
87        match &self.upper {
88            RangeBound::Inclusive(v) | RangeBound::Exclusive(v) => {
89                let len_pos = buf.len();
90                buf.put_i32(0); // placeholder
91                let data_start = buf.len();
92                v.to_sql(buf)?;
93                let data_len = (buf.len() - data_start) as i32;
94                buf[len_pos..len_pos + 4].copy_from_slice(&data_len.to_be_bytes());
95            }
96            RangeBound::Unbounded => {}
97        }
98
99        Ok(())
100    }
101}
102
103impl<T: FromSql> PgRange<T> {
104    /// Decode a range from binary format. Requires OIDs since generic types
105    /// cannot determine them.
106    pub fn from_sql_with_oids(buf: &[u8], range_oid: Oid, element_oid: Oid) -> Result<Self> {
107        if buf.is_empty() {
108            return Err(Error::Decode("range: empty buffer".into()));
109        }
110
111        let flags = buf[0];
112
113        if flags & RANGE_EMPTY != 0 {
114            return Ok(PgRange::empty(range_oid, element_oid));
115        }
116
117        let mut offset = 1;
118
119        let lower = if flags & RANGE_LB_INF != 0 {
120            RangeBound::Unbounded
121        } else {
122            if offset + 4 > buf.len() {
123                return Err(Error::Decode("range: lower bound truncated".into()));
124            }
125            let len = i32::from_be_bytes([
126                buf[offset],
127                buf[offset + 1],
128                buf[offset + 2],
129                buf[offset + 3],
130            ]) as usize;
131            offset += 4;
132            if offset + len > buf.len() {
133                return Err(Error::Decode("range: lower bound data truncated".into()));
134            }
135            let val = T::from_sql(&buf[offset..offset + len])?;
136            offset += len;
137            if flags & RANGE_LB_INC != 0 {
138                RangeBound::Inclusive(val)
139            } else {
140                RangeBound::Exclusive(val)
141            }
142        };
143
144        let upper = if flags & RANGE_UB_INF != 0 {
145            RangeBound::Unbounded
146        } else {
147            if offset + 4 > buf.len() {
148                return Err(Error::Decode("range: upper bound truncated".into()));
149            }
150            let len = i32::from_be_bytes([
151                buf[offset],
152                buf[offset + 1],
153                buf[offset + 2],
154                buf[offset + 3],
155            ]) as usize;
156            offset += 4;
157            if offset + len > buf.len() {
158                return Err(Error::Decode("range: upper bound data truncated".into()));
159            }
160            let val = T::from_sql(&buf[offset..offset + len])?;
161            if flags & RANGE_UB_INC != 0 {
162                RangeBound::Inclusive(val)
163            } else {
164                RangeBound::Exclusive(val)
165            }
166        };
167
168        Ok(PgRange {
169            lower,
170            upper,
171            is_empty: false,
172            range_oid,
173            element_oid,
174        })
175    }
176}