1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
mod extend;
pub mod socket;
mod tls;

#[cfg(feature = "etl")]
use std::net::IpAddr;
use std::{borrow::Cow, fmt::Debug, net::SocketAddr};

use extend::WebSocketExt;
use futures_util::{io::BufReader, Future};
use lru::LruCache;
use serde::de::{DeserializeOwned, IgnoredAny};
use socket::ExaSocket;
use sqlx_core::Error as SqlxError;
pub use tls::WithMaybeTlsExaSocket;

use self::extend::PlainWebSocket;
use super::stream::QueryResultStream;
#[cfg(feature = "etl")]
use crate::responses::Hosts;
use crate::{
    command::{Command, ExaCommand},
    error::{ExaProtocolError, ExaResultExt},
    options::ExaConnectOptionsRef,
    responses::{
        DataChunk, DescribeStatement, ExaAttributes, PreparedStatement, QueryResult, Results,
        SessionInfo,
    },
};

/// A websocket connected to the Exasol, providing lower level
/// interaction with the database.
#[derive(Debug)]
pub struct ExaWebSocket {
    pub ws: WebSocketExt,
    pub attributes: ExaAttributes,
    pub pending_rollback: bool,
}

impl ExaWebSocket {
    const WS_SCHEME: &'static str = "ws";
    const WSS_SCHEME: &'static str = "wss";

    pub(crate) async fn new(
        host: &str,
        port: u16,
        socket: ExaSocket,
        options: ExaConnectOptionsRef<'_>,
        with_tls: bool,
    ) -> Result<(Self, SessionInfo), SqlxError> {
        let scheme = match with_tls {
            true => Self::WSS_SCHEME,
            false => Self::WS_SCHEME,
        };

        let host = format!("{scheme}://{host}:{port}");

        let (ws, _) = async_tungstenite::client_async(host, BufReader::new(socket))
            .await
            .to_sqlx_err()?;

        let attributes = ExaAttributes {
            compression_enabled: options.compression,
            fetch_size: options.fetch_size,
            encryption_enabled: with_tls,
            statement_cache_capacity: options.statement_cache_capacity,
            ..Default::default()
        };

        // Login is always uncompressed
        let mut plain_ws = PlainWebSocket(ws);
        let session_info = plain_ws.login(options).await?;
        let ws = WebSocketExt::new(plain_ws.0, attributes.compression_enabled);

        let mut this = Self {
            ws,
            attributes,
            pending_rollback: false,
        };

        this.get_attributes().await?;

        Ok((this, session_info))
    }

    /// Executes a [`Command`] and returns a [`QueryResultStream`].
    pub async fn get_result_stream<'a, C, F>(
        &'a mut self,
        cmd: Command,
        rs_handle: &mut Option<u16>,
        future_maker: C,
    ) -> Result<QueryResultStream<'_, C, F>, SqlxError>
    where
        C: Fn(&'a mut ExaWebSocket, u16, usize) -> Result<F, SqlxError>,
        F: Future<Output = Result<(DataChunk, &'a mut ExaWebSocket), SqlxError>> + 'a,
    {
        // Close the previously opened result set, if any.
        if let Some(handle) = rs_handle.take() {
            self.close_result_set(handle).await?;
        }

        let query_result = self.get_query_result(cmd).await?;
        *rs_handle = query_result.handle();

        QueryResultStream::new(self, query_result, future_maker)
    }

    pub async fn get_query_result(&mut self, cmd: Command) -> Result<QueryResult, SqlxError> {
        self.send_and_recv::<Results>(cmd).await.map(From::from)
    }

    pub async fn close_result_set(&mut self, handle: u16) -> Result<(), SqlxError> {
        let cmd = ExaCommand::new_close_result(handle).try_into()?;
        self.send_cmd_ignore_response(cmd).await?;
        Ok(())
    }

    pub async fn create_prepared(&mut self, cmd: Command) -> Result<PreparedStatement, SqlxError> {
        self.send_and_recv(cmd).await
    }

    pub async fn describe(&mut self, cmd: Command) -> Result<DescribeStatement, SqlxError> {
        self.send_and_recv(cmd).await
    }

    pub async fn close_prepared(&mut self, handle: u16) -> Result<(), SqlxError> {
        let cmd = ExaCommand::new_close_prepared(handle).try_into()?;
        self.send_cmd_ignore_response(cmd).await
    }

    pub async fn fetch_chunk(&mut self, cmd: Command) -> Result<DataChunk, SqlxError> {
        self.send_and_recv(cmd).await
    }

    pub async fn set_attributes(&mut self) -> Result<(), SqlxError> {
        let cmd = ExaCommand::new_set_attributes(&self.attributes).try_into()?;
        self.send_cmd_ignore_response(cmd).await
    }

    #[cfg(feature = "etl")]
    pub async fn get_hosts(&mut self) -> Result<Vec<IpAddr>, SqlxError> {
        let host_ip = self.socket_addr().ip();
        let cmd = ExaCommand::new_get_hosts(host_ip).try_into()?;
        self.send_and_recv::<Hosts>(cmd).await.map(From::from)
    }

    pub async fn get_attributes(&mut self) -> Result<(), SqlxError> {
        let cmd = ExaCommand::GetAttributes.try_into()?;
        self.send_cmd_ignore_response(cmd).await
    }

    pub fn begin(&mut self) -> Result<(), SqlxError> {
        // Exasol does not have nested transactions.
        if self.attributes.open_transaction {
            return Err(ExaProtocolError::TransactionAlreadyOpen)?;
        }

        // The next time a query is executed, the transaction will be started.
        // We could eagerly start it as well, but that implies one more
        // round-trip to the server and back with no benefit.
        self.attributes.autocommit = false;
        self.attributes.open_transaction = true;
        Ok(())
    }

    pub async fn commit(&mut self) -> Result<(), SqlxError> {
        // It's fine to set this before executing the command
        // since we want to commit anyway.
        self.attributes.autocommit = true;

        // Just changing `autocommit` attribute implies a COMMIT,
        // but we would still have to send a command to the server
        // to update it, so we might as well be explicit.
        let cmd = ExaCommand::new_execute("COMMIT;", &self.attributes).try_into()?;
        self.send_cmd_ignore_response(cmd).await?;

        self.attributes.open_transaction = false;
        Ok(())
    }

    /// Sends a rollback to the database and awaits the response.
    pub async fn rollback(&mut self) -> Result<(), SqlxError> {
        let cmd = ExaCommand::new_execute("ROLLBACK;", &self.attributes).try_into()?;
        self.raw_send(cmd).await?;
        self.recv::<Option<IgnoredAny>>().await?;

        // We explicitly set the attribute after
        // we know the rollback was successful.
        self.attributes.autocommit = true;
        self.attributes.open_transaction = false;
        self.pending_rollback = false;
        Ok(())
    }

    pub async fn ping(&mut self) -> Result<(), SqlxError> {
        self.ws.ping().await
    }

    pub async fn disconnect(&mut self) -> Result<(), SqlxError> {
        let cmd = ExaCommand::Disconnect.try_into()?;
        self.send_cmd_ignore_response(cmd).await
    }

    pub async fn close(&mut self) -> Result<(), SqlxError> {
        self.ws.close().await
    }

    /// Returns the metadata related to a [`PreparedStatement`].
    /// If the metadata is not yet cached, a request will be first
    /// made to the database to prepare the statement and cache its metadata.
    pub async fn get_or_prepare<'a>(
        &mut self,
        cache: &'a mut LruCache<String, PreparedStatement>,
        sql: &str,
        persist: bool,
    ) -> Result<Cow<'a, PreparedStatement>, SqlxError> {
        // The double look-up is required to avoid a borrow checker limitation.
        //
        // See: https://github.com/rust-lang/rust/issues/54663
        if cache.contains(sql) {
            return Ok(Cow::Borrowed(cache.get(sql).unwrap()));
        }

        let cmd = ExaCommand::new_create_prepared(sql).try_into()?;
        let prepared = self.create_prepared(cmd).await?;

        if persist {
            if let Some((_, old)) = cache.push(sql.to_owned(), prepared) {
                self.close_prepared(old.statement_handle).await?;
            }

            return Ok(Cow::Borrowed(cache.get(sql).unwrap()));
        }

        Ok(Cow::Owned(prepared))
    }

    /// Tries to take advance of the batch SQL execution command provided by Exasol.
    /// The command however implies splitting the SQL into an array of statements.
    ///
    /// Since this is finicky, we'll only use it for migration purposes, where
    /// batch SQL is actually expected/encouraged.
    ///
    /// If batch execution fails, we will try to separate statements and execute them
    /// one by one.
    #[cfg(feature = "migrate")]
    pub async fn execute_batch(&mut self, sql: &str) -> Result<(), SqlxError> {
        // Trim the query end so we don't have an empty statement at the end of the array.
        let sql = sql.trim_end();
        let sql = sql.strip_suffix(';').unwrap_or(sql);

        // Do a naive yet valiant attempt at splitting the query.
        let sql_batch = sql.split(';').collect();
        let cmd = ExaCommand::new_execute_batch(sql_batch, &self.attributes).try_into()?;

        // Run batch SQL command
        match self.send_cmd_ignore_response(cmd).await {
            Ok(()) => return Ok(()),
            Err(e) => tracing::warn!(
                "failed to execute batch SQL: {e}; will attempt sequential execution"
            ),
        };

        // If we got here then batch execution failed.
        // This will pretty much be happening if  there are ';' literals in the query.
        //
        // So.. we'll do an incremental identification of queries, still splitting by ';'
        // but this time, if a statement execution fails, we concatenate it with the next string
        // up until the next ';'.
        let mut result = Ok(());
        let mut position = 0;
        let mut sql_start = 0;

        let handle_err_fn = |err: SqlxError, result: &mut Result<(), SqlxError>| {
            // Exasol doesn't seem to have a dedicated code for malformed queries.
            // There's `42000` but it does not look to only be related to syntax errors.
            //
            // So we at least check if this is a database error and continue if so.
            // Otherwise something else is wrong and we can fail early.
            let db_err = match &err {
                SqlxError::Database(_) => err,
                _ => return Err(err),
            };

            tracing::warn!("error running statement: {db_err}; perhaps it's incomplete?");

            // We wanna store the first error occurred.
            // If at some point, after further concatenations, the query
            // succeeds, then we'll set the result to `Ok`.
            if result.is_ok() {
                *result = Err(db_err);
            }

            Ok(())
        };

        while let Some(sql_end) = sql[position..].find(';') {
            // Found a separator, split the SQL.
            let sql = sql[sql_start..position + sql_end].trim();

            let cmd = ExaCommand::new_execute(sql, &self.attributes).try_into()?;
            // Next lookup will be after the just encountered separator.
            position += sql_end + 1;

            if let Err(err) = self.send_cmd_ignore_response(cmd).await {
                handle_err_fn(err, &mut result)?;
            } else {
                // Yay!!!
                sql_start = position;
                result = Ok(());
            }
        }

        // We need to run the remaining statement, if any.
        let sql = sql[sql_start..].trim();

        if !sql.is_empty() {
            let cmd = ExaCommand::new_execute(sql, &self.attributes).try_into()?;
            if let Err(err) = self.send_cmd_ignore_response(cmd).await {
                handle_err_fn(err, &mut result)?;
            } else {
                result = Ok(());
            }
        }

        result
    }

    pub fn socket_addr(&self) -> SocketAddr {
        self.ws.socket_addr()
    }

    /// Utility method for when the `response_data` field of the [`Response`] is not of any
    /// interest. Note that attributes will still get updated.
    async fn send_cmd_ignore_response(&mut self, cmd: Command) -> Result<(), SqlxError> {
        self.send_and_recv::<Option<IgnoredAny>>(cmd).await?;
        Ok(())
    }

    /// Sends a [`Command`] to the database and returns the `response_data` field of the
    /// [`Response`]. We will also await on the response from a previously started command, if
    /// needed.
    async fn send_and_recv<T>(&mut self, cmd: Command) -> Result<T, SqlxError>
    where
        T: DeserializeOwned + Debug,
    {
        self.send(cmd).await?;
        self.recv().await
    }

    /// Sends a [`Command`] to the database, also ensuring a rollback is issued if pending.
    pub(crate) async fn send(&mut self, cmd: Command) -> Result<(), SqlxError> {
        // If a rollback is pending, which really only happens when a transaction is dropped
        // then we need to send that first before the actual command.
        if self.pending_rollback {
            self.rollback().await?;
            self.pending_rollback = false;
        }

        self.raw_send(cmd).await
    }

    /// Sends a [`Command`] to the database.
    async fn raw_send(&mut self, cmd: Command) -> Result<(), SqlxError> {
        let cmd = cmd.into_inner();
        tracing::debug!("sending command to database: {cmd}");

        self.ws.send(cmd).await
    }

    /// Receives a database response containing attributes,
    /// processes the attributes and returns the inner data.
    pub(crate) async fn recv<T>(&mut self) -> Result<T, SqlxError>
    where
        T: DeserializeOwned + Debug,
    {
        let (response_data, attributes) = Result::from(self.ws.recv().await?)?;

        if let Some(attributes) = attributes {
            tracing::debug!("updating connection attributes using:\n{attributes:#?}");
            self.attributes.update(attributes);
        }

        tracing::trace!("database response:\n{response_data:#?}");

        Ok(response_data)
    }
}