Skip to main content

use_pg_type/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7/// A PostgreSQL type name label.
8#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub struct PgTypeName(String);
10
11impl PgTypeName {
12    /// Creates a PostgreSQL type name label.
13    ///
14    /// # Errors
15    ///
16    /// Returns [`PgTypeError`] when the label is empty or contains control characters.
17    pub fn new(input: impl AsRef<str>) -> Result<Self, PgTypeError> {
18        validate_type_label(input.as_ref()).map(|value| Self(value.to_owned()))
19    }
20
21    /// Creates the canonical type name for a built-in type.
22    #[must_use]
23    pub fn built_in(ty: PgBuiltInType) -> Self {
24        Self(ty.as_str().to_owned())
25    }
26
27    /// Creates an array-like type label from an element type name.
28    #[must_use]
29    pub fn array_of(element: &Self) -> Self {
30        Self(format!("{}[]", element.as_str()))
31    }
32
33    /// Returns the stored type name.
34    #[must_use]
35    pub fn as_str(&self) -> &str {
36        &self.0
37    }
38
39    /// Returns `true` when the label uses PostgreSQL array suffix syntax.
40    #[must_use]
41    pub fn is_array_label(&self) -> bool {
42        self.0.ends_with("[]")
43    }
44}
45
46impl AsRef<str> for PgTypeName {
47    fn as_ref(&self) -> &str {
48        self.as_str()
49    }
50}
51
52impl fmt::Display for PgTypeName {
53    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
54        formatter.write_str(self.as_str())
55    }
56}
57
58impl FromStr for PgTypeName {
59    type Err = PgTypeError;
60
61    fn from_str(input: &str) -> Result<Self, Self::Err> {
62        Self::new(input)
63    }
64}
65
66impl TryFrom<&str> for PgTypeName {
67    type Error = PgTypeError;
68
69    fn try_from(value: &str) -> Result<Self, Self::Error> {
70        Self::new(value)
71    }
72}
73
74/// Broad PostgreSQL type categories.
75#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
76pub enum PgTypeCategory {
77    #[default]
78    UserDefined,
79    Boolean,
80    Numeric,
81    String,
82    Binary,
83    DateTime,
84    Uuid,
85    Json,
86    Network,
87    Array,
88    Enum,
89    Composite,
90    Domain,
91    Range,
92    Pseudo,
93}
94
95impl PgTypeCategory {
96    /// Returns a stable lowercase category label.
97    #[must_use]
98    pub const fn as_str(self) -> &'static str {
99        match self {
100            Self::UserDefined => "user-defined",
101            Self::Boolean => "boolean",
102            Self::Numeric => "numeric",
103            Self::String => "string",
104            Self::Binary => "binary",
105            Self::DateTime => "date-time",
106            Self::Uuid => "uuid",
107            Self::Json => "json",
108            Self::Network => "network",
109            Self::Array => "array",
110            Self::Enum => "enum",
111            Self::Composite => "composite",
112            Self::Domain => "domain",
113            Self::Range => "range",
114            Self::Pseudo => "pseudo",
115        }
116    }
117}
118
119impl fmt::Display for PgTypeCategory {
120    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
121        formatter.write_str(self.as_str())
122    }
123}
124
125/// Common PostgreSQL built-in type labels.
126#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
127pub enum PgBuiltInType {
128    #[default]
129    Text,
130    Bool,
131    SmallInt,
132    Integer,
133    BigInt,
134    Numeric,
135    Real,
136    DoublePrecision,
137    Varchar,
138    Char,
139    Bytea,
140    Date,
141    Time,
142    Timestamp,
143    TimestampTz,
144    Uuid,
145    Json,
146    Jsonb,
147    Inet,
148    Cidr,
149    Macaddr,
150    Macaddr8,
151    Array,
152}
153
154impl PgBuiltInType {
155    /// Returns the canonical PostgreSQL type spelling.
156    #[must_use]
157    pub const fn as_str(self) -> &'static str {
158        match self {
159            Self::Text => "text",
160            Self::Bool => "boolean",
161            Self::SmallInt => "smallint",
162            Self::Integer => "integer",
163            Self::BigInt => "bigint",
164            Self::Numeric => "numeric",
165            Self::Real => "real",
166            Self::DoublePrecision => "double precision",
167            Self::Varchar => "character varying",
168            Self::Char => "character",
169            Self::Bytea => "bytea",
170            Self::Date => "date",
171            Self::Time => "time",
172            Self::Timestamp => "timestamp",
173            Self::TimestampTz => "timestamp with time zone",
174            Self::Uuid => "uuid",
175            Self::Json => "json",
176            Self::Jsonb => "jsonb",
177            Self::Inet => "inet",
178            Self::Cidr => "cidr",
179            Self::Macaddr => "macaddr",
180            Self::Macaddr8 => "macaddr8",
181            Self::Array => "array",
182        }
183    }
184
185    /// Returns the broad type category.
186    #[must_use]
187    pub const fn category(self) -> PgTypeCategory {
188        match self {
189            Self::Bool => PgTypeCategory::Boolean,
190            Self::SmallInt
191            | Self::Integer
192            | Self::BigInt
193            | Self::Numeric
194            | Self::Real
195            | Self::DoublePrecision => PgTypeCategory::Numeric,
196            Self::Text | Self::Varchar | Self::Char => PgTypeCategory::String,
197            Self::Bytea => PgTypeCategory::Binary,
198            Self::Date | Self::Time | Self::Timestamp | Self::TimestampTz => {
199                PgTypeCategory::DateTime
200            }
201            Self::Uuid => PgTypeCategory::Uuid,
202            Self::Json | Self::Jsonb => PgTypeCategory::Json,
203            Self::Inet | Self::Cidr | Self::Macaddr | Self::Macaddr8 => PgTypeCategory::Network,
204            Self::Array => PgTypeCategory::Array,
205        }
206    }
207}
208
209impl fmt::Display for PgBuiltInType {
210    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
211        formatter.write_str(self.as_str())
212    }
213}
214
215impl FromStr for PgBuiltInType {
216    type Err = PgTypeError;
217
218    fn from_str(input: &str) -> Result<Self, Self::Err> {
219        match normalized_type_label(input)?.as_str() {
220            "bool" | "boolean" => Ok(Self::Bool),
221            "smallint" | "int2" => Ok(Self::SmallInt),
222            "integer" | "int" | "int4" => Ok(Self::Integer),
223            "bigint" | "int8" => Ok(Self::BigInt),
224            "numeric" | "decimal" => Ok(Self::Numeric),
225            "real" | "float4" => Ok(Self::Real),
226            "double precision" | "float8" => Ok(Self::DoublePrecision),
227            "text" => Ok(Self::Text),
228            "varchar" | "character varying" => Ok(Self::Varchar),
229            "char" | "character" => Ok(Self::Char),
230            "bytea" => Ok(Self::Bytea),
231            "date" => Ok(Self::Date),
232            "time" | "time without time zone" => Ok(Self::Time),
233            "timestamp" | "timestamp without time zone" => Ok(Self::Timestamp),
234            "timestamptz" | "timestamp with time zone" => Ok(Self::TimestampTz),
235            "uuid" => Ok(Self::Uuid),
236            "json" => Ok(Self::Json),
237            "jsonb" => Ok(Self::Jsonb),
238            "inet" => Ok(Self::Inet),
239            "cidr" => Ok(Self::Cidr),
240            "macaddr" => Ok(Self::Macaddr),
241            "macaddr8" => Ok(Self::Macaddr8),
242            "array" | "anyarray" => Ok(Self::Array),
243            _ => Err(PgTypeError::UnknownBuiltInType),
244        }
245    }
246}
247
248impl TryFrom<&str> for PgBuiltInType {
249    type Error = PgTypeError;
250
251    fn try_from(value: &str) -> Result<Self, Self::Error> {
252        value.parse()
253    }
254}
255
256/// Optional primitive wrapper for a PostgreSQL type OID.
257#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
258pub struct PgTypeOid(u32);
259
260impl PgTypeOid {
261    /// Creates an OID wrapper.
262    ///
263    /// # Errors
264    ///
265    /// Returns [`PgTypeError::InvalidOid`] when `value` is zero.
266    pub const fn new(value: u32) -> Result<Self, PgTypeError> {
267        if value == 0 {
268            Err(PgTypeError::InvalidOid)
269        } else {
270            Ok(Self(value))
271        }
272    }
273
274    /// Returns the raw OID value.
275    #[must_use]
276    pub const fn get(self) -> u32 {
277        self.0
278    }
279}
280
281impl fmt::Display for PgTypeOid {
282    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
283        write!(formatter, "{}", self.0)
284    }
285}
286
287/// Error returned when PostgreSQL type metadata is invalid.
288#[derive(Clone, Copy, Debug, Eq, PartialEq)]
289pub enum PgTypeError {
290    Empty,
291    ControlCharacter,
292    UnknownBuiltInType,
293    InvalidOid,
294}
295
296impl fmt::Display for PgTypeError {
297    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
298        match self {
299            Self::Empty => formatter.write_str("PostgreSQL type label cannot be empty"),
300            Self::ControlCharacter => {
301                formatter.write_str("PostgreSQL type label cannot contain control characters")
302            }
303            Self::UnknownBuiltInType => {
304                formatter.write_str("unknown PostgreSQL built-in type label")
305            }
306            Self::InvalidOid => {
307                formatter.write_str("PostgreSQL type OID must be greater than zero")
308            }
309        }
310    }
311}
312
313impl Error for PgTypeError {}
314
315fn validate_type_label(input: &str) -> Result<&str, PgTypeError> {
316    let trimmed = input.trim();
317    if trimmed.is_empty() {
318        return Err(PgTypeError::Empty);
319    }
320    if trimmed.chars().any(char::is_control) {
321        return Err(PgTypeError::ControlCharacter);
322    }
323    Ok(trimmed)
324}
325
326fn normalized_type_label(input: &str) -> Result<String, PgTypeError> {
327    let trimmed = validate_type_label(input)?;
328    Ok(trimmed
329        .replace('_', " ")
330        .split_whitespace()
331        .collect::<Vec<_>>()
332        .join(" ")
333        .to_ascii_lowercase())
334}
335
336#[cfg(test)]
337mod tests {
338    use super::{PgBuiltInType, PgTypeCategory, PgTypeError, PgTypeName, PgTypeOid};
339
340    #[test]
341    fn parses_common_built_in_types() -> Result<(), PgTypeError> {
342        assert_eq!("bool".parse::<PgBuiltInType>()?, PgBuiltInType::Bool);
343        assert_eq!("int4".parse::<PgBuiltInType>()?, PgBuiltInType::Integer);
344        assert_eq!(
345            "double precision".parse::<PgBuiltInType>()?,
346            PgBuiltInType::DoublePrecision
347        );
348        assert_eq!(
349            "timestamptz".parse::<PgBuiltInType>()?,
350            PgBuiltInType::TimestampTz
351        );
352        assert_eq!("jsonb".parse::<PgBuiltInType>()?, PgBuiltInType::Jsonb);
353        Ok(())
354    }
355
356    #[test]
357    fn renders_canonical_labels_and_categories() {
358        assert_eq!(PgBuiltInType::Varchar.to_string(), "character varying");
359        assert_eq!(PgBuiltInType::Inet.category(), PgTypeCategory::Network);
360        assert_eq!(PgTypeCategory::Array.to_string(), "array");
361    }
362
363    #[test]
364    fn creates_type_names_and_arrays() {
365        let text = PgTypeName::built_in(PgBuiltInType::Text);
366        let array = PgTypeName::array_of(&text);
367        assert_eq!(text.as_str(), "text");
368        assert_eq!(array.to_string(), "text[]");
369        assert!(array.is_array_label());
370        assert_eq!(PgTypeName::new(""), Err(PgTypeError::Empty));
371    }
372
373    #[test]
374    fn wraps_oids_without_binding_catalog_meaning() -> Result<(), PgTypeError> {
375        let oid = PgTypeOid::new(23)?;
376        assert_eq!(oid.get(), 23);
377        assert_eq!(oid.to_string(), "23");
378        assert_eq!(PgTypeOid::new(0), Err(PgTypeError::InvalidOid));
379        Ok(())
380    }
381}