1use crate::headers::{with_authentication, with_request_headers};
2use crate::{
3 Client, Compression,
4 error::{Error, Result},
5 request_body::{ChunkSender, RequestBody},
6 response::Response,
7 settings,
8};
9use bytes::{Bytes, BytesMut};
10use hyper::{self, Request};
11use std::ops::ControlFlow;
12use std::task::{Context, Poll, ready};
13use std::{cmp, future::Future, io, mem, panic, pin::Pin, time::Duration};
14use tokio::io::AsyncWrite;
15use tokio::{
16 task::JoinHandle,
17 time::{Instant, Sleep},
18};
19use url::Url;
20
21#[cfg(feature = "lz4")]
22pub use compression::CompressedData;
23
24const BUFFER_SIZE: usize = 256 * 1024;
26
27#[must_use]
43pub struct InsertFormatted {
44 state: InsertState,
45 #[cfg(feature = "lz4")]
46 compression: Compression,
47 send_timeout: Option<Timeout>,
48 end_timeout: Option<Timeout>,
49 sleep: Pin<Box<Sleep>>,
52}
53
54struct Timeout {
55 duration: Duration,
56 is_set: bool,
57}
58
59enum InsertState {
60 NotStarted {
61 client: Box<Client>,
62 sql: String,
63 },
64 Active {
65 sender: ChunkSender,
66 handle: JoinHandle<Result<()>>,
67 },
68 Terminated {
69 handle: JoinHandle<Result<()>>,
70 },
71 Completed,
72}
73
74impl InsertState {
75 #[inline(always)]
76 fn is_not_started(&self) -> bool {
77 matches!(self, Self::NotStarted { .. })
78 }
79
80 fn sender(&mut self) -> Option<&mut ChunkSender> {
81 match self {
82 InsertState::Active { sender, .. } => Some(sender),
83 _ => None,
84 }
85 }
86
87 fn handle(&mut self) -> Option<&mut JoinHandle<Result<()>>> {
88 match self {
89 InsertState::Active { handle, .. } | InsertState::Terminated { handle } => Some(handle),
90 _ => None,
91 }
92 }
93
94 fn client_with_sql(&self) -> Option<(&Client, &str)> {
95 match self {
96 InsertState::NotStarted { client, sql } => Some((client, sql)),
97 _ => None,
98 }
99 }
100
101 #[inline]
102 fn expect_client_mut(&mut self) -> &mut Client {
103 let Self::NotStarted { client, .. } = self else {
104 panic!("cannot modify client options while an insert is in-progress")
105 };
106
107 client
108 }
109
110 fn terminated(&mut self) {
111 match mem::replace(self, InsertState::Completed) {
112 InsertState::NotStarted { .. } | InsertState::Completed => (),
113 InsertState::Active { handle, .. } => {
114 *self = InsertState::Terminated { handle };
115 }
116 InsertState::Terminated { handle } => {
117 *self = InsertState::Terminated { handle };
118 }
119 }
120 }
121}
122
123impl InsertFormatted {
124 pub(crate) fn new(client: &Client, sql: String) -> Self {
125 Self {
126 state: InsertState::NotStarted {
127 client: Box::new(client.clone()),
128 sql,
129 },
130 #[cfg(feature = "lz4")]
131 compression: client.compression,
132 send_timeout: None,
133 end_timeout: None,
134 sleep: Box::pin(tokio::time::sleep(Duration::new(0, 0))),
135 }
136 }
137
138 pub fn with_timeouts(
153 mut self,
154 send_timeout: Option<Duration>,
155 end_timeout: Option<Duration>,
156 ) -> Self {
157 self.set_timeouts(send_timeout, end_timeout);
158 self
159 }
160
161 pub fn with_roles(mut self, roles: impl IntoIterator<Item = impl Into<String>>) -> Self {
173 self.state.expect_client_mut().set_roles(roles);
174 self
175 }
176
177 pub fn with_default_roles(mut self) -> Self {
187 self.state.expect_client_mut().clear_roles();
188 self
189 }
190
191 #[track_caller]
197 pub fn with_option(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
198 self.state.expect_client_mut().set_option(name, value);
199 self
200 }
201
202 pub(crate) fn set_timeouts(
203 &mut self,
204 send_timeout: Option<Duration>,
205 end_timeout: Option<Duration>,
206 ) {
207 self.send_timeout = Timeout::new_opt(send_timeout);
208 self.end_timeout = Timeout::new_opt(end_timeout);
209 }
210
211 pub fn buffered(self) -> BufInsertFormatted {
217 self.buffered_with_capacity(BUFFER_SIZE)
218 }
219
220 pub fn buffered_with_capacity(self, capacity: usize) -> BufInsertFormatted {
226 BufInsertFormatted::new(self, capacity)
227 }
228
229 pub async fn send(&mut self, data: Bytes) -> Result<()> {
241 #[cfg(feature = "lz4")]
242 let data = if self.compression.is_lz4() {
243 CompressedData::from_slice(&data).0
244 } else {
245 data
246 };
247
248 self.send_inner(data).await
249 }
250
251 async fn send_inner(&mut self, mut data: Bytes) -> Result<()> {
252 if self.state.is_not_started() {
253 self.init_request()?;
254 }
255
256 std::future::poll_fn(move |cx| {
257 loop {
258 match self.try_send(mem::take(&mut data)) {
260 ControlFlow::Break(Ok(())) => return Poll::Ready(Ok(())),
261 ControlFlow::Break(Err(_)) => {
262 return self.poll_wait_handle(cx);
264 }
265 ControlFlow::Continue(unsent) => {
266 data = unsent;
267 ready!(self.poll_ready(cx))?;
269 }
270 }
271 }
272 })
273 .await
274 }
275
276 #[inline]
277 pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
278 if self.state.is_not_started() {
279 self.init_request()?;
280 }
281
282 let Some(sender) = self.state.sender() else {
283 return Poll::Ready(Err(Error::Network("channel closed".into())));
284 };
285
286 match sender.poll_ready(cx) {
287 Poll::Ready(true) => {
288 Timeout::reset_opt(self.send_timeout.as_mut());
289 Poll::Ready(Ok(()))
290 }
291 Poll::Ready(false) => Poll::Ready(Err(Error::Network("channel closed".into()))),
292 Poll::Pending => {
293 ready!(Timeout::poll_opt(
294 self.send_timeout.as_mut(),
295 self.sleep.as_mut(),
296 cx
297 ));
298 self.abort();
299 Poll::Ready(Err(Error::TimedOut))
300 }
301 }
302 }
303
304 #[inline(always)]
305 pub(crate) fn try_send(&mut self, bytes: Bytes) -> ControlFlow<Result<()>, Bytes> {
306 let Some(sender) = self.state.sender() else {
307 return ControlFlow::Break(Err(Error::Network("channel closed".into())));
308 };
309
310 sender
311 .try_send(bytes)
312 .map_break(|res| res.map_err(|e| Error::Network(e.into())))
313 }
314
315 pub async fn end(mut self) -> Result<()> {
322 std::future::poll_fn(|cx| self.poll_end(cx)).await
323 }
324
325 pub(crate) fn poll_end(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
326 self.state.terminated();
327 self.poll_wait_handle(cx)
328 }
329
330 fn poll_wait_handle(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
331 let Some(handle) = self.state.handle() else {
332 return Poll::Ready(Ok(()));
333 };
334
335 let Poll::Ready(res) = Pin::new(&mut *handle).poll(cx) else {
336 ready!(Timeout::poll_opt(
337 self.end_timeout.as_mut(),
338 self.sleep.as_mut(),
339 cx
340 ));
341
342 handle.abort();
344 return Poll::Ready(Err(Error::TimedOut));
345 };
346
347 let res = match res {
348 Ok(res) => res,
349 Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
350 Err(err) => Err(Error::Custom(format!("unexpected error: {err}"))),
351 };
352
353 self.state = InsertState::Completed;
354
355 Poll::Ready(res)
356 }
357
358 #[cold]
359 #[track_caller]
360 #[inline(never)]
361 fn init_request(&mut self) -> Result<()> {
362 debug_assert!(matches!(self.state, InsertState::NotStarted { .. }));
363 let (client, sql) = self.state.client_with_sql().unwrap(); let mut url = Url::parse(&client.url).map_err(|err| Error::InvalidParams(err.into()))?;
366 let mut pairs = url.query_pairs_mut();
367 pairs.clear();
368
369 if let Some(database) = &client.database {
370 pairs.append_pair(settings::DATABASE, database);
371 }
372
373 pairs.append_pair(settings::QUERY, sql);
374
375 if client.compression.is_lz4() {
376 pairs.append_pair(settings::DECOMPRESS, "1");
377 }
378
379 for (name, value) in &client.options {
380 pairs.append_pair(name, value);
381 }
382
383 drop(pairs);
384
385 let mut builder = Request::post(url.as_str());
386 builder = with_request_headers(builder, &client.headers, &client.products_info);
387 builder = with_authentication(builder, &client.authentication);
388
389 let (sender, body) = RequestBody::chunked();
390
391 let request = builder
392 .body(body)
393 .map_err(|err| Error::InvalidParams(Box::new(err)))?;
394
395 let future = client.http.request(request);
396 let handle =
398 tokio::spawn(async move { Response::new(future, Compression::None).finish().await });
399
400 self.state = InsertState::Active { handle, sender };
401 Ok(())
402 }
403
404 pub(crate) fn abort(&mut self) {
405 if let Some(sender) = self.state.sender() {
406 sender.abort();
407 }
408 }
409}
410
411impl Drop for InsertFormatted {
412 fn drop(&mut self) {
413 self.abort();
414 }
415}
416
417pub struct BufInsertFormatted {
419 insert: InsertFormatted,
420 buffer: BytesMut,
421 capacity: usize,
423}
424
425impl BufInsertFormatted {
426 fn new(insert: InsertFormatted, capacity: usize) -> Self {
427 Self {
428 insert,
429 buffer: BytesMut::with_capacity(capacity),
430 capacity,
431 }
432 }
433
434 #[inline(always)]
436 pub fn buf_len(&self) -> usize {
437 self.buffer.len()
438 }
439
440 #[inline(always)]
448 pub fn capacity(&self) -> usize {
449 self.buffer.capacity()
450 }
451
452 #[inline(always)]
453 pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut {
454 &mut self.buffer
455 }
456
457 pub(crate) fn expect_client_mut(&mut self) -> &mut Client {
458 self.insert.state.expect_client_mut()
459 }
460
461 pub(crate) fn set_timeouts(
462 &mut self,
463 send_timeout: Option<Duration>,
464 end_timeout: Option<Duration>,
465 ) {
466 self.insert.set_timeouts(send_timeout, end_timeout);
467 }
468
469 #[inline(always)]
473 pub fn write_buffered(&mut self, data: &[u8]) {
474 self.buffer.extend_from_slice(data);
475 }
476
477 #[inline(always)]
488 pub async fn write(&mut self, data: &[u8]) -> Result<usize> {
489 std::future::poll_fn(|cx| self.poll_write_inner(data, cx)).await
490 }
491
492 #[inline(always)]
494 fn poll_write_inner(&mut self, data: &[u8], cx: &mut Context<'_>) -> Poll<Result<usize>> {
495 self.init_request_if_required()?;
498
499 if self.buffer.len() >= self.capacity {
502 ready!(self.poll_flush_inner(cx))?;
503 debug_assert!(self.buffer.is_empty());
504 }
505
506 if self.capacity == 0 {
509 self.buffer.extend_from_slice(data);
510 return Poll::Ready(Ok(data.len()));
511 }
512
513 let remaining_capacity = self.capacity - self.buffer.len();
515
516 let write_len = cmp::min(remaining_capacity, data.len());
517
518 self.buffer.extend_from_slice(&data[..write_len]);
519 Poll::Ready(Ok(write_len))
520 }
521
522 #[inline(always)]
526 pub async fn flush(&mut self) -> Result<()> {
527 std::future::poll_fn(|cx| self.poll_flush_inner(cx)).await
528 }
529
530 #[inline(always)]
531 fn poll_flush_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
532 if self.buffer.is_empty() {
533 return Poll::Ready(Ok(()));
534 }
535
536 ready!(self.insert.poll_ready(cx))?;
537
538 let data = self.buffer.split().freeze();
539
540 #[cfg(feature = "lz4")]
541 let data = if self.insert.compression.is_lz4() {
542 CompressedData::from(data).0
543 } else {
544 data
545 };
546
547 let ControlFlow::Break(res) = self.insert.try_send(data) else {
548 unreachable!("BUG: we just checked that `ChunkSender` was ready")
549 };
550
551 Poll::Ready(res)
552 }
553
554 #[inline(always)]
558 pub async fn end(&mut self) -> Result<()> {
559 std::future::poll_fn(|cx| self.poll_end(cx)).await
560 }
561
562 #[inline(always)]
563 fn poll_end(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
564 if !self.buffer.is_empty() {
565 ready!(self.poll_flush_inner(cx))?;
566 debug_assert!(self.buffer.is_empty());
567 }
568
569 self.insert.poll_end(cx)
570 }
571
572 #[inline]
575 pub(crate) fn init_request_if_required(&mut self) -> Result<bool> {
576 if self.insert.state.is_not_started() {
577 self.insert.init_request().map(|_| true)
578 } else {
579 Ok(false)
580 }
581 }
582
583 pub(crate) fn abort(&mut self) {
584 self.insert.abort();
585 }
586}
587
588impl AsyncWrite for BufInsertFormatted {
589 #[inline(always)]
590 fn poll_write(
591 mut self: Pin<&mut Self>,
592 cx: &mut Context<'_>,
593 buf: &[u8],
594 ) -> Poll<std::result::Result<usize, io::Error>> {
595 self.poll_write_inner(buf, cx).map_err(Into::into)
596 }
597
598 #[inline(always)]
599 fn poll_flush(
600 mut self: Pin<&mut Self>,
601 cx: &mut Context<'_>,
602 ) -> Poll<std::result::Result<(), io::Error>> {
603 self.poll_flush_inner(cx).map_err(Into::into)
604 }
605
606 #[inline(always)]
607 fn poll_shutdown(
608 mut self: Pin<&mut Self>,
609 cx: &mut Context<'_>,
610 ) -> Poll<std::result::Result<(), io::Error>> {
611 self.poll_end(cx).map_err(Into::into)
612 }
613}
614
615impl Timeout {
616 fn new_opt(duration: Option<Duration>) -> Option<Self> {
617 duration.map(|duration| Self {
618 duration,
619 is_set: false,
620 })
621 }
622
623 #[inline(always)]
625 fn poll_opt(this: Option<&mut Self>, sleep: Pin<&mut Sleep>, cx: &mut Context<'_>) -> Poll<()> {
626 if let Some(this) = this {
627 this.poll(sleep, cx)
628 } else {
629 Poll::Pending
630 }
631 }
632
633 #[inline]
634 fn poll(&mut self, mut sleep: Pin<&mut Sleep>, cx: &mut Context<'_>) -> Poll<()> {
635 if !self.is_set
636 && let Some(deadline) = Instant::now().checked_add(self.duration)
637 {
638 sleep.as_mut().reset(deadline);
639 self.is_set = true;
640 }
641
642 ready!(sleep.as_mut().poll(cx));
643 self.is_set = false;
644
645 Poll::Ready(())
646 }
647
648 #[inline(always)]
649 fn reset_opt(this: Option<&mut Self>) {
650 if let Some(this) = this {
651 this.is_set = false;
652 }
653 }
654}
655
656#[cfg(feature = "lz4")]
658mod compression {
659 use crate::error::{Error, Result};
660 use crate::insert_formatted::InsertFormatted;
661 use bytes::Bytes;
662
663 #[cfg_attr(docsrs, doc(cfg(feature = "lz4")))]
665 pub struct CompressedData(pub(crate) Bytes);
666
667 impl CompressedData {
668 #[inline(always)]
670 pub fn from_slice(slice: &[u8]) -> Self {
671 Self(
672 crate::compression::lz4::compress(slice)
673 .expect("BUG: `lz4::compress()` should not error"),
674 )
675 }
676 }
677
678 impl<T> From<T> for CompressedData
679 where
680 T: AsRef<[u8]>,
681 {
682 #[inline(always)]
683 fn from(value: T) -> Self {
684 Self::from_slice(value.as_ref())
685 }
686 }
687
688 impl InsertFormatted {
689 pub async fn send_compressed(&mut self, data: CompressedData) -> Result<()> {
695 if !self.compression.is_lz4() {
696 return Err(Error::Compression(
697 "attempting to send compressed data, but compression is not enabled".into(),
698 ));
699 }
700
701 self.send_inner(data.0).await
702 }
703 }
704}