serde_odbc/
col_binding.rs1use std::mem::size_of;
18use std::ptr::null;
19
20use odbc_sys::{
21 SQLSetStmtAttr, SQLHSTMT, SQLLEN, SQLPOINTER, SQL_ATTR_ROWS_FETCHED_PTR,
22 SQL_ATTR_ROW_ARRAY_SIZE, SQL_ATTR_ROW_BIND_TYPE,
23};
24use serde::ser::Serialize;
25
26use super::col_binder::bind_cols;
27use super::error::{OdbcResult, Result};
28
29pub trait ColBinding {
30 fn new() -> Self;
31
32 type Cols;
33 fn cols(&self) -> &Self::Cols;
34
35 unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()>;
36
37 fn fetch(&mut self) -> bool;
38}
39
40pub struct Cols<C: Copy + Default + Serialize> {
41 data: C,
42 last_data: *const C,
43}
44
45pub struct NoCols {
46 data: (),
47}
48
49pub struct RowSet<C: Copy + Default + Serialize> {
50 data: Vec<C>,
51 last_data: *const C,
52 last_size: usize,
53 rows_fetched: SQLLEN,
54}
55
56impl<C: Copy + Default + Serialize> ColBinding for Cols<C> {
57 fn new() -> Self {
58 Cols {
59 data: Default::default(),
60 last_data: null(),
61 }
62 }
63
64 type Cols = C;
65 fn cols(&self) -> &Self::Cols {
66 &self.data
67 }
68
69 unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
70 let data = &self.data as *const C;
71
72 if self.last_data != data {
73 bind_cols(stmt, &*data)?;
74 self.last_data = data;
75 }
76
77 Ok(())
78 }
79
80 fn fetch(&mut self) -> bool {
81 true
82 }
83}
84
85impl ColBinding for NoCols {
86 fn new() -> Self {
87 NoCols { data: () }
88 }
89
90 type Cols = ();
91 fn cols(&self) -> &Self::Cols {
92 &self.data
93 }
94
95 unsafe fn bind(&mut self, _stmt: SQLHSTMT) -> Result<()> {
96 Ok(())
97 }
98
99 fn fetch(&mut self) -> bool {
100 true
101 }
102}
103
104impl<C: Copy + Default + Serialize> ColBinding for RowSet<C> {
105 fn new() -> Self {
106 RowSet {
107 data: Vec::new(),
108 last_data: null(),
109 last_size: 0,
110 rows_fetched: 0,
111 }
112 }
113
114 type Cols = Vec<C>;
115 fn cols(&self) -> &Self::Cols {
116 &self.data
117 }
118
119 unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
120 let capacity = self.data.capacity();
121 self.data.resize(capacity, Default::default());
122
123 let data = self.data.first().unwrap() as *const C;
124 let size = self.data.len();
125
126 if self.last_data != data {
127 bind_cols(stmt, &*data)?;
128 self.last_data = data;
129 }
130
131 if self.last_size != size {
132 Self::bind_row_set(stmt, size, &mut self.rows_fetched)?;
133 self.last_size = size;
134 }
135
136 Ok(())
137 }
138
139 fn fetch(&mut self) -> bool {
140 self.data.truncate(self.rows_fetched as usize);
141 self.rows_fetched != 0
142 }
143}
144
145impl<C: Copy + Default + Serialize> RowSet<C> {
146 pub fn fetch_size(&self) -> usize {
147 self.data.capacity()
148 }
149
150 pub fn set_fetch_size(&mut self, size: usize) {
151 let capacity = self.data.capacity();
152 if size > capacity {
153 self.data.reserve(size - capacity);
154 }
155 }
156
157 unsafe fn bind_row_set(stmt: SQLHSTMT, size: usize, rows_fetched: &mut SQLLEN) -> Result<()> {
158 SQLSetStmtAttr(
159 stmt,
160 SQL_ATTR_ROW_BIND_TYPE,
161 size_of::<C>() as SQLPOINTER,
162 0,
163 )
164 .check()?;
165
166 SQLSetStmtAttr(stmt, SQL_ATTR_ROW_ARRAY_SIZE, size as SQLPOINTER, 0).check()?;
167
168 SQLSetStmtAttr(
169 stmt,
170 SQL_ATTR_ROWS_FETCHED_PTR,
171 (rows_fetched as *mut SQLLEN) as SQLPOINTER,
172 0,
173 )
174 .check()
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 use crate::{
183 connection::{Connection, Environment},
184 param_binding::{NoParams, Params},
185 statement::Statement,
186 tests::CONN_STR,
187 };
188
189 #[test]
190 fn bind_row_set() {
191 let env = Environment::new().unwrap();
192 let conn = Connection::new(&env, CONN_STR).unwrap();
193
194 {
195 let mut stmt: Statement<NoParams, NoCols> =
196 Statement::new(&conn, "CREATE TEMPORARY TABLE tbl (col INTEGER NOT NULL)").unwrap();
197 stmt.exec().unwrap();
198 }
199
200 {
201 let mut stmt: Statement<Params<i32>, NoCols> =
202 Statement::new(&conn, "INSERT INTO tbl (col) VALUES (?)").unwrap();
203 for i in 0..128 {
204 *stmt.params() = i;
205 stmt.exec().unwrap();
206 }
207 }
208
209 {
210 let mut stmt: Statement<NoParams, RowSet<i32>> =
211 Statement::new(&conn, "SELECT col FROM tbl ORDER BY col").unwrap();
212 stmt.set_fetch_size(32);
213 assert!(32 == stmt.fetch_size());
214 stmt.exec().unwrap();
215 for i in 0..4 {
216 assert!(stmt.fetch().unwrap());
217 assert_eq!(32, stmt.cols().len());
218 stmt.cols().iter().enumerate().for_each(|(j, cols)| {
219 assert_eq!(32 * i + j, *cols as usize);
220 });
221 }
222 assert!(!stmt.fetch().unwrap());
223 }
224
225 {
226 let mut stmt: Statement<NoParams, RowSet<i32>> =
227 Statement::new(&conn, "SELECT col FROM tbl ORDER BY col").unwrap();
228 stmt.set_fetch_size(256);
229 assert!(256 == stmt.fetch_size());
230 stmt.exec().unwrap();
231 assert!(stmt.fetch().unwrap());
232 assert_eq!(128, stmt.cols().len());
233 stmt.cols().iter().enumerate().for_each(|(i, cols)| {
234 assert_eq!(i, *cols as usize);
235 });
236 assert!(!stmt.fetch().unwrap());
237 }
238 }
239}