serde_odbc/
col_binding.rs

1/*
2This file is part of serde-odbc.
3
4serde-odbc is free software: you can redistribute it and/or modify
5it under the terms of the GNU Lesser General Public License as published by
6the Free Software Foundation, either version 3 of the License, or
7(at your option) any later version.
8
9serde-odbc is distributed in the hope that it will be useful,
10but WITHOUT ANY WARRANTY; without even the implied warranty of
11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12GNU Lesser General Public License for more details.
13
14You should have received a copy of the GNU Lesser General Public License
15along with serde-odbc.  If not, see <http://www.gnu.org/licenses/>.
16*/
17use std::mem::size_of;
18use std::ptr::null;
19
20use odbc_sys::{
21    SQLSetStmtAttr, SQLHSTMT, SQLLEN, SQLPOINTER, SQL_ATTR_ROWS_FETCHED_PTR,
22    SQL_ATTR_ROW_ARRAY_SIZE, SQL_ATTR_ROW_BIND_TYPE,
23};
24use serde::ser::Serialize;
25
26use super::col_binder::bind_cols;
27use super::error::{OdbcResult, Result};
28
29pub trait ColBinding {
30    fn new() -> Self;
31
32    type Cols;
33    fn cols(&self) -> &Self::Cols;
34
35    unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()>;
36
37    fn fetch(&mut self) -> bool;
38}
39
40pub struct Cols<C: Copy + Default + Serialize> {
41    data: C,
42    last_data: *const C,
43}
44
45pub struct NoCols {
46    data: (),
47}
48
49pub struct RowSet<C: Copy + Default + Serialize> {
50    data: Vec<C>,
51    last_data: *const C,
52    last_size: usize,
53    rows_fetched: SQLLEN,
54}
55
56impl<C: Copy + Default + Serialize> ColBinding for Cols<C> {
57    fn new() -> Self {
58        Cols {
59            data: Default::default(),
60            last_data: null(),
61        }
62    }
63
64    type Cols = C;
65    fn cols(&self) -> &Self::Cols {
66        &self.data
67    }
68
69    unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
70        let data = &self.data as *const C;
71
72        if self.last_data != data {
73            bind_cols(stmt, &*data)?;
74            self.last_data = data;
75        }
76
77        Ok(())
78    }
79
80    fn fetch(&mut self) -> bool {
81        true
82    }
83}
84
85impl ColBinding for NoCols {
86    fn new() -> Self {
87        NoCols { data: () }
88    }
89
90    type Cols = ();
91    fn cols(&self) -> &Self::Cols {
92        &self.data
93    }
94
95    unsafe fn bind(&mut self, _stmt: SQLHSTMT) -> Result<()> {
96        Ok(())
97    }
98
99    fn fetch(&mut self) -> bool {
100        true
101    }
102}
103
104impl<C: Copy + Default + Serialize> ColBinding for RowSet<C> {
105    fn new() -> Self {
106        RowSet {
107            data: Vec::new(),
108            last_data: null(),
109            last_size: 0,
110            rows_fetched: 0,
111        }
112    }
113
114    type Cols = Vec<C>;
115    fn cols(&self) -> &Self::Cols {
116        &self.data
117    }
118
119    unsafe fn bind(&mut self, stmt: SQLHSTMT) -> Result<()> {
120        let capacity = self.data.capacity();
121        self.data.resize(capacity, Default::default());
122
123        let data = self.data.first().unwrap() as *const C;
124        let size = self.data.len();
125
126        if self.last_data != data {
127            bind_cols(stmt, &*data)?;
128            self.last_data = data;
129        }
130
131        if self.last_size != size {
132            Self::bind_row_set(stmt, size, &mut self.rows_fetched)?;
133            self.last_size = size;
134        }
135
136        Ok(())
137    }
138
139    fn fetch(&mut self) -> bool {
140        self.data.truncate(self.rows_fetched as usize);
141        self.rows_fetched != 0
142    }
143}
144
145impl<C: Copy + Default + Serialize> RowSet<C> {
146    pub fn fetch_size(&self) -> usize {
147        self.data.capacity()
148    }
149
150    pub fn set_fetch_size(&mut self, size: usize) {
151        let capacity = self.data.capacity();
152        if size > capacity {
153            self.data.reserve(size - capacity);
154        }
155    }
156
157    unsafe fn bind_row_set(stmt: SQLHSTMT, size: usize, rows_fetched: &mut SQLLEN) -> Result<()> {
158        SQLSetStmtAttr(
159            stmt,
160            SQL_ATTR_ROW_BIND_TYPE,
161            size_of::<C>() as SQLPOINTER,
162            0,
163        )
164        .check()?;
165
166        SQLSetStmtAttr(stmt, SQL_ATTR_ROW_ARRAY_SIZE, size as SQLPOINTER, 0).check()?;
167
168        SQLSetStmtAttr(
169            stmt,
170            SQL_ATTR_ROWS_FETCHED_PTR,
171            (rows_fetched as *mut SQLLEN) as SQLPOINTER,
172            0,
173        )
174        .check()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    use crate::{
183        connection::{Connection, Environment},
184        param_binding::{NoParams, Params},
185        statement::Statement,
186        tests::CONN_STR,
187    };
188
189    #[test]
190    fn bind_row_set() {
191        let env = Environment::new().unwrap();
192        let conn = Connection::new(&env, CONN_STR).unwrap();
193
194        {
195            let mut stmt: Statement<NoParams, NoCols> =
196                Statement::new(&conn, "CREATE TEMPORARY TABLE tbl (col INTEGER NOT NULL)").unwrap();
197            stmt.exec().unwrap();
198        }
199
200        {
201            let mut stmt: Statement<Params<i32>, NoCols> =
202                Statement::new(&conn, "INSERT INTO tbl (col) VALUES (?)").unwrap();
203            for i in 0..128 {
204                *stmt.params() = i;
205                stmt.exec().unwrap();
206            }
207        }
208
209        {
210            let mut stmt: Statement<NoParams, RowSet<i32>> =
211                Statement::new(&conn, "SELECT col FROM tbl ORDER BY col").unwrap();
212            stmt.set_fetch_size(32);
213            assert!(32 == stmt.fetch_size());
214            stmt.exec().unwrap();
215            for i in 0..4 {
216                assert!(stmt.fetch().unwrap());
217                assert_eq!(32, stmt.cols().len());
218                stmt.cols().iter().enumerate().for_each(|(j, cols)| {
219                    assert_eq!(32 * i + j, *cols as usize);
220                });
221            }
222            assert!(!stmt.fetch().unwrap());
223        }
224
225        {
226            let mut stmt: Statement<NoParams, RowSet<i32>> =
227                Statement::new(&conn, "SELECT col FROM tbl ORDER BY col").unwrap();
228            stmt.set_fetch_size(256);
229            assert!(256 == stmt.fetch_size());
230            stmt.exec().unwrap();
231            assert!(stmt.fetch().unwrap());
232            assert_eq!(128, stmt.cols().len());
233            stmt.cols().iter().enumerate().for_each(|(i, cols)| {
234                assert_eq!(i, *cols as usize);
235            });
236            assert!(!stmt.fetch().unwrap());
237        }
238    }
239}