1use crate::client::InnerClient;
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::error::SqlState;
5use crate::types::{Field, Kind, Oid, Type};
6use crate::{query, slice_iter};
7use crate::{Column, Error, Statement};
8use bytes::Bytes;
9use fallible_iterator::FallibleIterator;
10use futures_util::{pin_mut, TryStreamExt};
11use log::debug;
12use postgres_protocol::message::backend::Message;
13use postgres_protocol::message::frontend;
14use std::future::Future;
15use std::pin::Pin;
16use std::sync::atomic::{AtomicUsize, Ordering};
17use std::sync::Arc;
18
19const TYPEINFO_QUERY: &str = "\
20SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
21FROM pg_catalog.pg_type t
22LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
23INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
24WHERE t.oid = $1
25";
26
27const TYPEINFO_FALLBACK_QUERY: &str = "\
29SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid
30FROM pg_catalog.pg_type t
31INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
32WHERE t.oid = $1
33";
34
35const TYPEINFO_ENUM_QUERY: &str = "\
36SELECT enumlabel
37FROM pg_catalog.pg_enum
38WHERE enumtypid = $1
39ORDER BY enumsortorder
40";
41
42const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\
44SELECT enumlabel
45FROM pg_catalog.pg_enum
46WHERE enumtypid = $1
47ORDER BY oid
48";
49
50const TYPEINFO_COMPOSITE_QUERY: &str = "\
51SELECT attname, atttypid
52FROM pg_catalog.pg_attribute
53WHERE attrelid = $1
54AND NOT attisdropped
55AND attnum > 0
56ORDER BY attnum
57";
58
59static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
60
61pub async fn prepare(
62 client: &Arc<InnerClient>,
63 query: &str,
64 types: &[Type],
65) -> Result<Statement, Error> {
66 let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst));
67 let buf = encode(client, &name, query, types)?;
68 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
69
70 match responses.next().await? {
71 Message::ParseComplete => {}
72 _ => return Err(Error::unexpected_message()),
73 }
74
75 let parameter_description = match responses.next().await? {
76 Message::ParameterDescription(body) => body,
77 _ => return Err(Error::unexpected_message()),
78 };
79
80 let row_description = match responses.next().await? {
81 Message::RowDescription(body) => Some(body),
82 Message::NoData => None,
83 _ => return Err(Error::unexpected_message()),
84 };
85
86 let mut parameters = vec![];
87 let mut it = parameter_description.parameters();
88 while let Some(oid) = it.next().map_err(Error::parse)? {
89 let type_ = get_type(client, oid).await?;
90 parameters.push(type_);
91 }
92
93 let mut columns = vec![];
94 if let Some(row_description) = row_description {
95 let mut it = row_description.fields();
96 while let Some(field) = it.next().map_err(Error::parse)? {
97 let type_ = get_type(client, field.type_oid()).await?;
98 let column = Column::new(field.name().to_string(), type_);
99 columns.push(column);
100 }
101 }
102
103 Ok(Statement::new(client, name, parameters, columns))
104}
105
106fn prepare_rec<'a>(
107 client: &'a Arc<InnerClient>,
108 query: &'a str,
109 types: &'a [Type],
110) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
111 Box::pin(prepare(client, query, types))
112}
113
114fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
115 if types.is_empty() {
116 debug!("preparing query {}: {}", name, query);
117 } else {
118 debug!("preparing query {} with types {:?}: {}", name, types, query);
119 }
120
121 client.with_buf(|buf| {
122 frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?;
123 frontend::describe(b'S', name, buf).map_err(Error::encode)?;
124 frontend::sync(buf);
125 Ok(buf.split().freeze())
126 })
127}
128
129async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
130 if let Some(type_) = Type::from_oid(oid) {
131 return Ok(type_);
132 }
133
134 if let Some(type_) = client.type_(oid) {
135 return Ok(type_);
136 }
137
138 let stmt = typeinfo_statement(client).await?;
139
140 let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
141 pin_mut!(rows);
142
143 let row = match rows.try_next().await? {
144 Some(row) => row,
145 None => return Err(Error::unexpected_message()),
146 };
147
148 let name: String = row.try_get(0)?;
149 let type_: i8 = row.try_get(1)?;
150 let elem_oid: Oid = row.try_get(2)?;
151 let rngsubtype: Option<Oid> = row.try_get(3)?;
152 let basetype: Oid = row.try_get(4)?;
153 let schema: String = row.try_get(5)?;
154 let relid: Oid = row.try_get(6)?;
155
156 let kind = if type_ == b'e' as i8 {
157 let variants = get_enum_variants(client, oid).await?;
158 Kind::Enum(variants)
159 } else if type_ == b'p' as i8 {
160 Kind::Pseudo
161 } else if basetype != 0 {
162 let type_ = get_type_rec(client, basetype).await?;
163 Kind::Domain(type_)
164 } else if elem_oid != 0 {
165 let type_ = get_type_rec(client, elem_oid).await?;
166 Kind::Array(type_)
167 } else if relid != 0 {
168 let fields = get_composite_fields(client, relid).await?;
169 Kind::Composite(fields)
170 } else if let Some(rngsubtype) = rngsubtype {
171 let type_ = get_type_rec(client, rngsubtype).await?;
172 Kind::Range(type_)
173 } else {
174 Kind::Simple
175 };
176
177 let type_ = Type::new(name, oid, kind, schema);
178 client.set_type(oid, &type_);
179
180 Ok(type_)
181}
182
183fn get_type_rec<'a>(
184 client: &'a Arc<InnerClient>,
185 oid: Oid,
186) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
187 Box::pin(get_type(client, oid))
188}
189
190async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
191 if let Some(stmt) = client.typeinfo() {
192 return Ok(stmt);
193 }
194
195 let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
196 Ok(stmt) => stmt,
197 Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
198 prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
199 }
200 Err(e) => return Err(e),
201 };
202
203 client.set_typeinfo(&stmt);
204 Ok(stmt)
205}
206
207async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
208 let stmt = typeinfo_enum_statement(client).await?;
209
210 query::query(client, stmt, slice_iter(&[&oid]))
211 .await?
212 .and_then(|row| async move { row.try_get(0) })
213 .try_collect()
214 .await
215}
216
217async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
218 if let Some(stmt) = client.typeinfo_enum() {
219 return Ok(stmt);
220 }
221
222 let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
223 Ok(stmt) => stmt,
224 Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
225 prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
226 }
227 Err(e) => return Err(e),
228 };
229
230 client.set_typeinfo_enum(&stmt);
231 Ok(stmt)
232}
233
234async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
235 let stmt = typeinfo_composite_statement(client).await?;
236
237 let rows = query::query(client, stmt, slice_iter(&[&oid]))
238 .await?
239 .try_collect::<Vec<_>>()
240 .await?;
241
242 let mut fields = vec![];
243 for row in rows {
244 let name = row.try_get(0)?;
245 let oid = row.try_get(1)?;
246 let type_ = get_type_rec(client, oid).await?;
247 fields.push(Field::new(name, type_));
248 }
249
250 Ok(fields)
251}
252
253async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
254 if let Some(stmt) = client.typeinfo_composite() {
255 return Ok(stmt);
256 }
257
258 let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
259
260 client.set_typeinfo_composite(&stmt);
261 Ok(stmt)
262}