1use crate::decode::Decode;
2use crate::encode::{Encode, IsNull};
3use crate::error::BoxDynError;
4use crate::types::{PgPoint, Type};
5use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
6use sqlx_core::bytes::Buf;
7use sqlx_core::Error;
8use std::mem;
9use std::str::FromStr;
10
11const BYTE_WIDTH: usize = mem::size_of::<f64>();
12
13#[derive(Debug, Clone, PartialEq)]
35pub struct PgPath {
36 pub closed: bool,
37 pub points: Vec<PgPoint>,
38}
39
40#[derive(Copy, Clone, Debug, PartialEq, Eq)]
41struct Header {
42 is_closed: bool,
43 length: usize,
44}
45
46impl Type<Postgres> for PgPath {
47 fn type_info() -> PgTypeInfo {
48 PgTypeInfo::with_name("path")
49 }
50}
51
52impl PgHasArrayType for PgPath {
53 fn array_type_info() -> PgTypeInfo {
54 PgTypeInfo::with_name("_path")
55 }
56}
57
58impl<'r> Decode<'r, Postgres> for PgPath {
59 fn decode(value: PgValueRef<'r>) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
60 match value.format() {
61 PgValueFormat::Text => Ok(PgPath::from_str(value.as_str()?)?),
62 PgValueFormat::Binary => Ok(PgPath::from_bytes(value.as_bytes()?)?),
63 }
64 }
65}
66
67impl<'q> Encode<'q, Postgres> for PgPath {
68 fn produces(&self) -> Option<PgTypeInfo> {
69 Some(PgTypeInfo::with_name("path"))
70 }
71
72 fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
73 self.serialize(buf)?;
74 Ok(IsNull::No)
75 }
76}
77
78impl FromStr for PgPath {
79 type Err = Error;
80
81 fn from_str(s: &str) -> Result<Self, Self::Err> {
82 let closed = !s.contains('[');
83 let sanitised = s.replace(['(', ')', '[', ']', ' '], "");
84 let parts = sanitised.split(',').collect::<Vec<_>>();
85
86 let mut points = vec![];
87
88 if parts.len() % 2 != 0 {
89 return Err(Error::Decode(
90 format!("Unmatched pair in PATH: {}", s).into(),
91 ));
92 }
93
94 for chunk in parts.chunks_exact(2) {
95 if let [x_str, y_str] = chunk {
96 let x = parse_float_from_str(x_str, "could not get x")?;
97 let y = parse_float_from_str(y_str, "could not get y")?;
98
99 let point = PgPoint { x, y };
100 points.push(point);
101 }
102 }
103
104 if !points.is_empty() {
105 return Ok(PgPath { points, closed });
106 }
107
108 Err(Error::Decode(
109 format!("could not get path from {}", s).into(),
110 ))
111 }
112}
113
114impl PgPath {
115 fn header(&self) -> Header {
116 Header {
117 is_closed: self.closed,
118 length: self.points.len(),
119 }
120 }
121
122 fn from_bytes(mut bytes: &[u8]) -> Result<Self, BoxDynError> {
123 let header = Header::try_read(&mut bytes)?;
124
125 if bytes.len() != header.data_size() {
126 return Err(format!(
127 "expected {} bytes after header, got {}",
128 header.data_size(),
129 bytes.len()
130 )
131 .into());
132 }
133
134 if bytes.len() % BYTE_WIDTH * 2 != 0 {
135 return Err(format!(
136 "data length not divisible by pairs of {BYTE_WIDTH}: {}",
137 bytes.len()
138 )
139 .into());
140 }
141
142 let mut out_points = Vec::with_capacity(bytes.len() / (BYTE_WIDTH * 2));
143
144 while bytes.has_remaining() {
145 let point = PgPoint {
146 x: bytes.get_f64(),
147 y: bytes.get_f64(),
148 };
149 out_points.push(point)
150 }
151 Ok(PgPath {
152 closed: header.is_closed,
153 points: out_points,
154 })
155 }
156
157 fn serialize(&self, buff: &mut PgArgumentBuffer) -> Result<(), BoxDynError> {
158 let header = self.header();
159 buff.reserve(header.data_size());
160 header.try_write(buff)?;
161
162 for point in &self.points {
163 buff.extend_from_slice(&point.x.to_be_bytes());
164 buff.extend_from_slice(&point.y.to_be_bytes());
165 }
166 Ok(())
167 }
168
169 #[cfg(test)]
170 fn serialize_to_vec(&self) -> Vec<u8> {
171 let mut buff = PgArgumentBuffer::default();
172 self.serialize(&mut buff).unwrap();
173 buff.to_vec()
174 }
175}
176
177impl Header {
178 const HEADER_WIDTH: usize = mem::size_of::<i8>() + mem::size_of::<i32>();
179
180 fn data_size(&self) -> usize {
181 self.length * BYTE_WIDTH * 2
182 }
183
184 fn try_read(buf: &mut &[u8]) -> Result<Self, String> {
185 if buf.len() < Self::HEADER_WIDTH {
186 return Err(format!(
187 "expected PATH data to contain at least {} bytes, got {}",
188 Self::HEADER_WIDTH,
189 buf.len()
190 ));
191 }
192
193 let is_closed = buf.get_i8();
194 let length = buf.get_i32();
195
196 let length = usize::try_from(length).ok().ok_or_else(|| {
197 format!(
198 "received PATH data length: {length}. Expected length between 0 and {}",
199 usize::MAX
200 )
201 })?;
202
203 Ok(Self {
204 is_closed: is_closed != 0,
205 length,
206 })
207 }
208
209 fn try_write(&self, buff: &mut PgArgumentBuffer) -> Result<(), String> {
210 let is_closed = self.is_closed as i8;
211
212 let length = i32::try_from(self.length).map_err(|_| {
213 format!(
214 "PATH length exceeds allowed maximum ({} > {})",
215 self.length,
216 i32::MAX
217 )
218 })?;
219
220 buff.extend(is_closed.to_be_bytes());
221 buff.extend(length.to_be_bytes());
222
223 Ok(())
224 }
225}
226
227fn parse_float_from_str(s: &str, error_msg: &str) -> Result<f64, Error> {
228 s.parse().map_err(|_| Error::Decode(error_msg.into()))
229}
230
231#[cfg(test)]
232mod path_tests {
233
234 use std::str::FromStr;
235
236 use crate::types::PgPoint;
237
238 use super::PgPath;
239
240 const PATH_CLOSED_BYTES: &[u8] = &[
241 1, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
242 64, 16, 0, 0, 0, 0, 0, 0,
243 ];
244
245 const PATH_OPEN_BYTES: &[u8] = &[
246 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
247 64, 16, 0, 0, 0, 0, 0, 0,
248 ];
249
250 const PATH_UNEVEN_POINTS: &[u8] = &[
251 0, 0, 0, 0, 2, 63, 240, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 0, 0, 0, 0, 64, 8, 0, 0, 0, 0, 0, 0,
252 64, 16, 0, 0,
253 ];
254
255 #[test]
256 fn can_deserialise_path_type_bytes_closed() {
257 let path = PgPath::from_bytes(PATH_CLOSED_BYTES).unwrap();
258 assert_eq!(
259 path,
260 PgPath {
261 closed: true,
262 points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }]
263 }
264 )
265 }
266
267 #[test]
268 fn cannot_deserialise_path_type_uneven_point_bytes() {
269 let path = PgPath::from_bytes(PATH_UNEVEN_POINTS);
270 assert!(path.is_err());
271
272 if let Err(err) = path {
273 assert_eq!(
274 err.to_string(),
275 format!("expected 32 bytes after header, got 28")
276 )
277 }
278 }
279
280 #[test]
281 fn can_deserialise_path_type_bytes_open() {
282 let path = PgPath::from_bytes(PATH_OPEN_BYTES).unwrap();
283 assert_eq!(
284 path,
285 PgPath {
286 closed: false,
287 points: vec![PgPoint { x: 1.0, y: 2.0 }, PgPoint { x: 3.0, y: 4.0 }]
288 }
289 )
290 }
291
292 #[test]
293 fn can_deserialise_path_type_str_first_syntax() {
294 let path = PgPath::from_str("[( 1, 2), (3, 4 )]").unwrap();
295 assert_eq!(
296 path,
297 PgPath {
298 closed: false,
299 points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
300 }
301 );
302 }
303
304 #[test]
305 fn cannot_deserialise_path_type_str_uneven_points_first_syntax() {
306 let input_str = "[( 1, 2), (3)]";
307 let path = PgPath::from_str(input_str);
308
309 assert!(path.is_err());
310
311 if let Err(err) = path {
312 assert_eq!(
313 err.to_string(),
314 format!("error occurred while decoding: Unmatched pair in PATH: {input_str}")
315 )
316 }
317 }
318
319 #[test]
320 fn can_deserialise_path_type_str_second_syntax() {
321 let path = PgPath::from_str("(( 1, 2), (3, 4 ))").unwrap();
322 assert_eq!(
323 path,
324 PgPath {
325 closed: true,
326 points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
327 }
328 );
329 }
330
331 #[test]
332 fn can_deserialise_path_type_str_third_syntax() {
333 let path = PgPath::from_str("(1, 2), (3, 4 )").unwrap();
334 assert_eq!(
335 path,
336 PgPath {
337 closed: true,
338 points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
339 }
340 );
341 }
342
343 #[test]
344 fn can_deserialise_path_type_str_fourth_syntax() {
345 let path = PgPath::from_str("1, 2, 3, 4").unwrap();
346 assert_eq!(
347 path,
348 PgPath {
349 closed: true,
350 points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }]
351 }
352 );
353 }
354
355 #[test]
356 fn can_deserialise_path_type_str_float() {
357 let path = PgPath::from_str("(1.1, 2.2), (3.3, 4.4)").unwrap();
358 assert_eq!(
359 path,
360 PgPath {
361 closed: true,
362 points: vec![PgPoint { x: 1.1, y: 2.2 }, PgPoint { x: 3.3, y: 4.4 }]
363 }
364 );
365 }
366
367 #[test]
368 fn can_serialise_path_type() {
369 let path = PgPath {
370 closed: true,
371 points: vec![PgPoint { x: 1., y: 2. }, PgPoint { x: 3., y: 4. }],
372 };
373 assert_eq!(path.serialize_to_vec(), PATH_CLOSED_BYTES,)
374 }
375}