Skip to main content

clickhouse/
insert_formatted.rs

1use crate::headers::{with_authentication, with_request_headers};
2use crate::{
3    Client, Compression,
4    error::{Error, Result},
5    request_body::{ChunkSender, RequestBody},
6    response::Response,
7    settings,
8};
9use bytes::{Bytes, BytesMut};
10use hyper::{self, Request};
11use std::ops::ControlFlow;
12use std::task::{Context, Poll, ready};
13use std::{cmp, future::Future, io, mem, panic, pin::Pin, time::Duration};
14use tokio::io::AsyncWrite;
15use tokio::{
16    task::JoinHandle,
17    time::{Instant, Sleep},
18};
19use url::Url;
20
21#[cfg(any(feature = "lz4", feature = "zstd"))]
22pub use compression::CompressedData;
23
24// The desired max frame size.
25const BUFFER_SIZE: usize = 256 * 1024;
26
27/// Performs one `INSERT`, sending pre-formatted data.
28///
29/// The [`InsertFormatted::end`] method must be called to finalize the `INSERT`.
30/// Otherwise, the whole `INSERT` will be aborted.
31///
32/// Rows are sent progressively to spread network load.
33///
34/// # Note: Not Validated
35/// Unlike [`Insert`][crate::insert::Insert] and [`Inserter`][crate::inserter::Inserter],
36/// this does not perform any validation on the submitted data.
37///
38/// Only the use of self-describing formats (e.g. CSV, TabSeparated, JSON) is recommended.
39///
40/// See the [list of supported formats](https://clickhouse.com/docs/interfaces/formats)
41/// for details.
42#[must_use]
43pub struct InsertFormatted {
44    state: InsertState,
45    #[cfg(any(feature = "lz4", feature = "zstd"))]
46    compression: Compression,
47    send_timeout: Option<Timeout>,
48    end_timeout: Option<Timeout>,
49    // Use boxed `Sleep` to reuse a timer entry, it improves performance.
50    // Also, `tokio::time::timeout()` significantly increases a future's size.
51    sleep: Pin<Box<Sleep>>,
52    span: tracing::Span,
53}
54
55struct Timeout {
56    duration: Duration,
57    is_set: bool,
58}
59
60enum InsertState {
61    NotStarted {
62        client: Box<Client>,
63        sql: String,
64    },
65    Active {
66        sender: ChunkSender,
67        handle: JoinHandle<Result<()>>,
68        sent_bytes: u64,
69        encoded_bytes: u64,
70    },
71    Terminated {
72        handle: JoinHandle<Result<()>>,
73    },
74    Completed,
75}
76
77impl InsertState {
78    #[inline(always)]
79    fn is_not_started(&self) -> bool {
80        matches!(self, Self::NotStarted { .. })
81    }
82
83    fn sender(&mut self) -> Option<&mut ChunkSender> {
84        match self {
85            InsertState::Active { sender, .. } => Some(sender),
86            _ => None,
87        }
88    }
89
90    fn handle(&mut self) -> Option<&mut JoinHandle<Result<()>>> {
91        match self {
92            InsertState::Active { handle, .. } | InsertState::Terminated { handle } => Some(handle),
93            _ => None,
94        }
95    }
96
97    fn client_with_sql(&self) -> Option<(&Client, &str)> {
98        match self {
99            InsertState::NotStarted { client, sql } => Some((client, sql)),
100            _ => None,
101        }
102    }
103
104    #[inline]
105    fn expect_client_mut(&mut self) -> &mut Client {
106        let Self::NotStarted { client, .. } = self else {
107            panic!("cannot modify client settings while an insert is in-progress")
108        };
109
110        client
111    }
112
113    fn terminated(&mut self, span: &tracing::Span) {
114        match mem::replace(self, InsertState::Completed) {
115            InsertState::NotStarted { .. } | InsertState::Completed => (),
116            InsertState::Active {
117                handle,
118                sent_bytes,
119                encoded_bytes,
120                ..
121            } => {
122                *self = InsertState::Terminated { handle };
123
124                tracing::record_all!(
125                    span,
126                    clickhouse.request.sent_bytes = sent_bytes,
127                    clickhouse.request.encoded_bytes = encoded_bytes,
128                );
129            }
130            InsertState::Terminated { handle } => {
131                *self = InsertState::Terminated { handle };
132            }
133        }
134    }
135}
136
137impl InsertFormatted {
138    pub(crate) fn new(client: &Client, sql: String, collection_name: Option<&str>) -> Self {
139        // https://opentelemetry.io/docs/specs/semconv/db/sql/
140        // TODO: write our own Semantic Conventions for ClickHouse
141        Self {
142            span: tracing::info_span!(
143                "clickhouse.insert",
144                // OTel conventional fields
145                // Note that `Empty` or `Option::None` fields are not reported,
146                // so we can avoid adding noise to logs when the `opentelemetry` feature is disabled.
147                otel.status_code = tracing::field::Empty,
148                otel.kind = cfg!(feature = "opentelemetry").then_some("client"),
149                error.type = tracing::field::Empty,
150                db.system.name = cfg!(feature = "opentelemetry").then_some("clickhouse"),
151                // Only log full query text at TRACE level
152                // Important that this is taken before client-side parameters are populated
153                // FIXME: we can't use `enabled!` due to https://github.com/tokio-rs/tracing/issues/2448
154                // but we don't want to log the full query at all verbosity levels.
155                // db.query.text = tracing::enabled!(tracing::Level::TRACE).then_some(&sql),
156                // TODO: generate summary
157                db.query.summary = tracing::field::Empty,
158                db.operation.name = "INSERT",
159                db.collection.name = collection_name,
160                // ClickHouse-specific extension fields
161                clickhouse.request.session_id = client.get_setting(settings::SESSION_ID),
162                clickhouse.request.query_id = client.get_setting(settings::QUERY_ID),
163                clickhouse.request.sent_rows = tracing::field::Empty,
164                clickhouse.request.sent_bytes = tracing::field::Empty,
165                clickhouse.request.encoded_bytes = tracing::field::Empty,
166            ),
167            state: InsertState::NotStarted {
168                client: Box::new(client.clone()),
169                sql,
170            },
171            #[cfg(any(feature = "lz4", feature = "zstd"))]
172            compression: client.compression,
173            send_timeout: None,
174            end_timeout: None,
175            sleep: Box::pin(tokio::time::sleep(Duration::new(0, 0))),
176        }
177    }
178
179    /// Sets timeouts for different operations.
180    ///
181    /// `send_timeout` restricts time on sending a data chunk to a socket.
182    /// `None` disables the timeout, it's a default.
183    /// It's roughly equivalent to `tokio::time::timeout(insert.write(...))`.
184    ///
185    /// `end_timeout` restricts time on waiting for a response from the CH
186    /// server. Thus, it includes all work needed to handle `INSERT` by the
187    /// CH server, e.g. handling all materialized views and so on.
188    /// `None` disables the timeout, it's a default.
189    /// It's roughly equivalent to `tokio::time::timeout(insert.end(...))`.
190    ///
191    /// These timeouts are much more performant (~x10) than wrapping `write()`
192    /// and `end()` calls into `tokio::time::timeout()`.
193    pub fn with_timeouts(
194        mut self,
195        send_timeout: Option<Duration>,
196        end_timeout: Option<Duration>,
197    ) -> Self {
198        self.set_timeouts(send_timeout, end_timeout);
199        self
200    }
201
202    /// Configure the [roles] to use when executing `INSERT` statements.
203    ///
204    /// Overrides any roles previously set by this method, [`InsertFormatted::with_setting`],
205    /// [`Client::with_roles`] or [`Client::with_setting`].
206    ///
207    /// An empty iterator may be passed to clear the set roles.
208    ///
209    /// [roles]: https://clickhouse.com/docs/operations/access-rights#role-management
210    ///
211    /// # Panics
212    /// If called after the request is started, e.g., after [`InsertFormatted::send`].
213    pub fn with_roles(mut self, roles: impl IntoIterator<Item = impl Into<String>>) -> Self {
214        self.state.expect_client_mut().set_roles(roles);
215        self
216    }
217
218    /// Clear any explicit [roles] previously set on this `Insert` or inherited from [`Client`].
219    ///
220    /// Overrides any roles previously set by [`InsertFormatted::with_roles`], [`InsertFormatted::with_setting`],
221    /// [`Client::with_roles`] or [`Client::with_setting`].
222    ///
223    /// [roles]: https://clickhouse.com/docs/operations/access-rights#role-management
224    ///
225    /// # Panics
226    /// If called after the request is started, e.g., after [`InsertFormatted::send`].
227    pub fn with_default_roles(mut self) -> Self {
228        self.state.expect_client_mut().clear_roles();
229        self
230    }
231
232    /// Similar to [`Client::with_option`], but for this particular INSERT
233    /// statement only.
234    ///
235    /// # Panics
236    /// If called after the request is started, e.g., after [`InsertFormatted::send`].
237    #[track_caller]
238    #[deprecated(since = "0.14.3", note = "please use `with_setting` instead")]
239    pub fn with_option(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
240        self.state.expect_client_mut().set_setting(name, value);
241        self
242    }
243
244    /// Similar to [`Client::with_setting`], but for this particular INSERT
245    /// statement only.
246    ///
247    /// # Panics
248    /// If called after the request is started, e.g., after [`InsertFormatted::send`].
249    #[track_caller]
250    pub fn with_setting(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
251        self.state.expect_client_mut().set_setting(name, value);
252        self
253    }
254
255    pub(crate) fn set_timeouts(
256        &mut self,
257        send_timeout: Option<Duration>,
258        end_timeout: Option<Duration>,
259    ) {
260        self.send_timeout = Timeout::new_opt(send_timeout);
261        self.end_timeout = Timeout::new_opt(end_timeout);
262    }
263
264    pub(crate) fn span(&self) -> &tracing::Span {
265        &self.span
266    }
267
268    /// Wrap this `InsertFormatted` with a buffer of a default size.
269    ///
270    /// The returned type also implements [`AsyncWrite`].
271    ///
272    /// To set the capacity, use [`Self::buffered_with_capacity()`].
273    pub fn buffered(self) -> BufInsertFormatted {
274        self.buffered_with_capacity(BUFFER_SIZE)
275    }
276
277    /// Wrap this `InsertFormatted` with a buffer of a given size.
278    ///
279    /// The returned type also implements [`AsyncWrite`].
280    ///
281    /// If `capacity == 0`, the buffer is flushed between every write regardless of size.
282    pub fn buffered_with_capacity(self, capacity: usize) -> BufInsertFormatted {
283        BufInsertFormatted::new(self, capacity)
284    }
285
286    /// Send a chunk of data.
287    ///
288    /// If compression is enabled, the data is compressed first.
289    ///
290    /// To pre-compress the data, use [`Self::send_compressed()`] instead.
291    ///
292    /// # Note: Unbuffered
293    /// This immediately compresses and queues the data to be sent on the connection
294    /// without waiting for more chunks. For best performance, chunks should not be too small.
295    ///
296    /// Use [`Self::buffered()`] for a buffered implementation which also implements [`AsyncWrite`].
297    pub async fn send(&mut self, data: Bytes) -> Result<()> {
298        let original_size = to_u64_saturating(data.len());
299
300        #[cfg(any(feature = "lz4", feature = "zstd"))]
301        let data = if self.compression.is_enabled() {
302            CompressedData::new(&data, self.compression)?.compressed
303        } else {
304            data
305        };
306
307        self.send_inner(data, original_size).await
308    }
309
310    async fn send_inner(&mut self, mut data: Bytes, original_size: u64) -> Result<()> {
311        if self.state.is_not_started() {
312            self.init_request()?;
313        }
314
315        std::future::poll_fn(move |cx| {
316            loop {
317                // Potentially cheaper than cloning `data` which touches the refcount
318                match self.try_send(mem::take(&mut data), original_size) {
319                    ControlFlow::Break(Ok(())) => return Poll::Ready(Ok(())),
320                    ControlFlow::Break(Err(_)) => {
321                        // If the channel is closed, we should return the actual error
322                        return self.poll_wait_handle(cx);
323                    }
324                    ControlFlow::Continue(unsent) => {
325                        data = unsent;
326                        // Shorter code-path if we just try to send the data first
327                        ready!(self.poll_ready(cx))?;
328                    }
329                }
330            }
331        })
332        .await
333    }
334
335    #[inline]
336    pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
337        if self.state.is_not_started() {
338            self.init_request()?;
339        }
340
341        let Some(sender) = self.state.sender() else {
342            return Poll::Ready(Err(Error::Network("channel closed".into())));
343        };
344
345        match sender.poll_ready(cx) {
346            Poll::Ready(true) => {
347                Timeout::reset_opt(self.send_timeout.as_mut());
348                Poll::Ready(Ok(()))
349            }
350            Poll::Ready(false) => Poll::Ready(Err(Error::Network("channel closed".into()))),
351            Poll::Pending => {
352                ready!(Timeout::poll_opt(
353                    self.send_timeout.as_mut(),
354                    self.sleep.as_mut(),
355                    cx
356                ));
357                self.abort();
358                Poll::Ready(Err(Error::TimedOut))
359            }
360        }
361    }
362
363    #[inline(always)]
364    pub(crate) fn try_send(
365        &mut self,
366        bytes: Bytes,
367        original_size: u64,
368    ) -> ControlFlow<Result<()>, Bytes> {
369        let InsertState::Active {
370            sender,
371            sent_bytes,
372            encoded_bytes,
373            ..
374        } = &mut self.state
375        else {
376            return ControlFlow::Break(Err(Error::Network("channel closed".into())));
377        };
378
379        let send_size = bytes.len();
380
381        sender.try_send(bytes).map_break(|res| match res {
382            Ok(()) => {
383                *sent_bytes += to_u64_saturating(send_size);
384                *encoded_bytes += original_size;
385                Ok(())
386            }
387            Err(e) => Err(Error::Network(e.into())),
388        })
389    }
390
391    /// Ends `INSERT`, the server starts processing the data.
392    ///
393    /// Succeeds if the server returns 200, that means the `INSERT` was handled
394    /// successfully, including all materialized views and quorum writes.
395    ///
396    /// NOTE: If this isn't called, the whole `INSERT` is aborted.
397    pub async fn end(mut self) -> Result<()> {
398        std::future::poll_fn(|cx| self.poll_end(cx)).await
399    }
400
401    pub(crate) fn poll_end(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
402        self.state.terminated(&self.span);
403        self.poll_wait_handle(cx)
404    }
405
406    fn poll_wait_handle(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
407        let Some(handle) = self.state.handle() else {
408            return Poll::Ready(Ok(()));
409        };
410
411        let Poll::Ready(res) = Pin::new(&mut *handle).poll(cx) else {
412            ready!(Timeout::poll_opt(
413                self.end_timeout.as_mut(),
414                self.sleep.as_mut(),
415                cx
416            ));
417
418            // We can do nothing useful here, so just shut down the background task.
419            handle.abort();
420            tracing::debug!("insert timed out");
421            return Poll::Ready(Err(Error::TimedOut));
422        };
423
424        let res = match res {
425            Ok(res) => res,
426            Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
427            Err(err) => Err(Error::Custom(format!("unexpected error: {err}"))),
428        };
429
430        self.state = InsertState::Completed;
431
432        tracing::trace!("finished insert");
433
434        Poll::Ready(res.inspect_err(|e| e.record_in_current_span("error from insert query")))
435    }
436
437    #[cold]
438    #[track_caller]
439    #[inline(never)]
440    fn init_request(&mut self) -> Result<()> {
441        debug_assert!(matches!(self.state, InsertState::NotStarted { .. }));
442        let (client, sql) = self.state.client_with_sql().unwrap(); // checked above
443
444        let _span = self.span.enter();
445
446        tracing::trace!("beginning insert");
447
448        let mut url = Url::parse(&client.url).map_err(|err| Error::InvalidParams(err.into()))?;
449        let mut pairs = url.query_pairs_mut();
450        pairs.clear();
451
452        if let Some(database) = &client.database {
453            pairs.append_pair(settings::DATABASE, database);
454        }
455
456        pairs.append_pair(settings::QUERY, sql);
457
458        if client.compression.is_enabled() {
459            pairs.append_pair(settings::DECOMPRESS, "1");
460        }
461
462        for (name, value) in &client.settings {
463            pairs.append_pair(name, value);
464        }
465
466        drop(pairs);
467
468        let mut builder = Request::post(url.as_str());
469        builder = with_request_headers(builder, &client.headers, &client.products_info);
470        builder = with_authentication(builder, &client.authentication);
471
472        let (sender, body) = RequestBody::chunked();
473
474        let request = builder.body(body).map_err(|err| {
475            let err = Error::InvalidParams(Box::new(err));
476            err.record_in_current_span("invalid params in insert request");
477            err
478        })?;
479
480        let future = client.http.request(request);
481
482        // Ensure the span created internally is captured as a child of the current span.
483        let mut response = Response::new(future, Compression::None);
484
485        // TODO: introduce `Executor` to allow bookkeeping of spawned tasks.
486        let handle = tokio::spawn(async move { response.finish().await });
487
488        self.state = InsertState::Active {
489            handle,
490            sender,
491            sent_bytes: 0,
492            encoded_bytes: 0,
493        };
494        Ok(())
495    }
496
497    pub(crate) fn abort(&mut self) {
498        let _span = self.span.enter();
499
500        if let InsertState::Active { sender, .. } = &mut self.state {
501            sender.abort();
502        }
503
504        self.state.terminated(&self.span);
505    }
506}
507
508impl Drop for InsertFormatted {
509    fn drop(&mut self) {
510        self.abort();
511    }
512}
513
514/// A wrapper around [`InsertFormatted`] which buffers writes.
515pub struct BufInsertFormatted {
516    insert: InsertFormatted,
517    buffer: BytesMut,
518    /// Nominal capacity, stored separately because [`Self::write_buffered()`] can grow the buffer.
519    capacity: usize,
520}
521
522impl BufInsertFormatted {
523    fn new(insert: InsertFormatted, capacity: usize) -> Self {
524        Self {
525            insert,
526            buffer: BytesMut::with_capacity(capacity),
527            capacity,
528        }
529    }
530
531    /// Return the number of buffered bytes.
532    #[inline(always)]
533    pub fn buf_len(&self) -> usize {
534        self.buffer.len()
535    }
536
537    /// Return the current capacity of the buffer.
538    ///
539    /// Note: Size is Not Constant
540    /// This may be smaller than the original capacity if part of the buffer
541    /// is still being used by the connection.
542    ///
543    /// This may be larger if a call to [`Self::write_buffered()`] caused the buffer to expand.
544    #[inline(always)]
545    pub fn capacity(&self) -> usize {
546        self.buffer.capacity()
547    }
548
549    #[inline(always)]
550    pub(crate) fn buffer_mut(&mut self) -> &mut BytesMut {
551        &mut self.buffer
552    }
553
554    pub(crate) fn expect_client_mut(&mut self) -> &mut Client {
555        self.insert.state.expect_client_mut()
556    }
557
558    pub(crate) fn set_timeouts(
559        &mut self,
560        send_timeout: Option<Duration>,
561        end_timeout: Option<Duration>,
562    ) {
563        self.insert.set_timeouts(send_timeout, end_timeout);
564    }
565
566    pub(crate) fn span(&self) -> &tracing::Span {
567        self.insert.span()
568    }
569
570    /// Write data to the buffer without waiting for it to be flushed.
571    ///
572    /// May cause the buffer to resize to fit the data.
573    #[inline(always)]
574    pub fn write_buffered(&mut self, data: &[u8]) {
575        self.buffer.extend_from_slice(data);
576    }
577
578    /// Write some data to the buffer, flushing first if it is already full.
579    ///
580    /// Returns the number of bytes written, which may be less than `data.len()` if the remaining
581    /// capacity was smaller.
582    ///
583    /// Cancel-safe. Until this returns `Ok(n)`, the contents of `data` are not yet written to the
584    /// buffer.
585    // `#[inline]` is *supposed* to work on `async fn`
586    // https://doc.rust-lang.org/reference/attributes/codegen.html#r-attributes.codegen.inline.async
587    // but it's apparently not implemented yet: https://github.com/rust-lang/rust/pull/149245
588    #[inline(always)]
589    pub async fn write(&mut self, data: &[u8]) -> Result<usize> {
590        std::future::poll_fn(|cx| self.poll_write_inner(data, cx)).await
591    }
592
593    // `poll_write` but it returns `crate::Result` instead of `io::Result`
594    #[inline(always)]
595    fn poll_write_inner(&mut self, data: &[u8], cx: &mut Context<'_>) -> Poll<Result<usize>> {
596        // We don't want to wait for the buffer to be full before we start the request,
597        // in the event of an error.
598        self.init_request_if_required()?;
599
600        // Capacity calculations change a little bit from those in, e.g., `tokio::io::BufWriter`
601        // since we always need to copy into the buffer to send chunks on the connection.
602        if self.buffer.len() >= self.capacity {
603            ready!(self.poll_flush_inner(cx))?;
604            debug_assert!(self.buffer.is_empty());
605        }
606
607        // Eliminates the need for a special check in `write_all()`;
608        // we need to copy to *some* buffer anyway because of how this type works.
609        if self.capacity == 0 {
610            self.buffer.extend_from_slice(data);
611            return Poll::Ready(Ok(data.len()));
612        }
613
614        // Guaranteed to be >= 1 by the above checks.
615        let remaining_capacity = self.capacity - self.buffer.len();
616
617        let write_len = cmp::min(remaining_capacity, data.len());
618
619        self.buffer.extend_from_slice(&data[..write_len]);
620        Poll::Ready(Ok(write_len))
621    }
622
623    /// Flush the buffer to the server as a single chunk.
624    ///
625    /// If [compression is enabled][Client::with_compression], the full buffer will be compressed.
626    #[inline(always)]
627    pub async fn flush(&mut self) -> Result<()> {
628        std::future::poll_fn(|cx| self.poll_flush_inner(cx)).await
629    }
630
631    #[inline(always)]
632    fn poll_flush_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
633        if self.buffer.is_empty() {
634            return Poll::Ready(Ok(()));
635        }
636
637        ready!(self.insert.poll_ready(cx))?;
638
639        let data = self.buffer.split().freeze();
640
641        let original_size: u64 = data.len().try_into().unwrap_or(u64::MAX);
642
643        #[cfg(any(feature = "lz4", feature = "zstd"))]
644        let data = if self.insert.compression.is_enabled() {
645            CompressedData::new(&data, self.insert.compression)?.compressed
646        } else {
647            data
648        };
649
650        let ControlFlow::Break(res) = self.insert.try_send(data, original_size) else {
651            unreachable!("BUG: we just checked that `ChunkSender` was ready")
652        };
653
654        Poll::Ready(res)
655    }
656
657    /// Flushes the buffer, then calls [`InsertFormatted::end()`].
658    ///
659    /// Cancel-safe.
660    #[inline(always)]
661    pub async fn end(&mut self) -> Result<()> {
662        std::future::poll_fn(|cx| self.poll_end(cx)).await
663    }
664
665    #[inline(always)]
666    fn poll_end(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
667        if !self.buffer.is_empty() {
668            ready!(self.poll_flush_inner(cx))?;
669            debug_assert!(self.buffer.is_empty());
670        }
671
672        self.insert.poll_end(cx)
673    }
674
675    /// Returns `Ok(true)` if the request was freshly started, `Err(...)` on error,
676    /// or `Ok(false)` otherwise.
677    #[inline]
678    pub(crate) fn init_request_if_required(&mut self) -> Result<bool> {
679        if self.insert.state.is_not_started() {
680            self.insert.init_request().map(|_| true)
681        } else {
682            Ok(false)
683        }
684    }
685
686    pub(crate) fn abort(&mut self) {
687        self.insert.abort();
688    }
689}
690
691impl AsyncWrite for BufInsertFormatted {
692    #[inline(always)]
693    fn poll_write(
694        mut self: Pin<&mut Self>,
695        cx: &mut Context<'_>,
696        buf: &[u8],
697    ) -> Poll<std::result::Result<usize, io::Error>> {
698        self.poll_write_inner(buf, cx).map_err(Into::into)
699    }
700
701    #[inline(always)]
702    fn poll_flush(
703        mut self: Pin<&mut Self>,
704        cx: &mut Context<'_>,
705    ) -> Poll<std::result::Result<(), io::Error>> {
706        self.poll_flush_inner(cx).map_err(Into::into)
707    }
708
709    #[inline(always)]
710    fn poll_shutdown(
711        mut self: Pin<&mut Self>,
712        cx: &mut Context<'_>,
713    ) -> Poll<std::result::Result<(), io::Error>> {
714        self.poll_end(cx).map_err(Into::into)
715    }
716}
717
718impl Timeout {
719    fn new_opt(duration: Option<Duration>) -> Option<Self> {
720        duration.map(|duration| Self {
721            duration,
722            is_set: false,
723        })
724    }
725
726    /// Returns `Poll::Pending` if `None`.
727    #[inline(always)]
728    fn poll_opt(this: Option<&mut Self>, sleep: Pin<&mut Sleep>, cx: &mut Context<'_>) -> Poll<()> {
729        if let Some(this) = this {
730            this.poll(sleep, cx)
731        } else {
732            Poll::Pending
733        }
734    }
735
736    #[inline]
737    fn poll(&mut self, mut sleep: Pin<&mut Sleep>, cx: &mut Context<'_>) -> Poll<()> {
738        if !self.is_set
739            && let Some(deadline) = Instant::now().checked_add(self.duration)
740        {
741            sleep.as_mut().reset(deadline);
742            self.is_set = true;
743        }
744
745        ready!(sleep.as_mut().poll(cx));
746        self.is_set = false;
747
748        Poll::Ready(())
749    }
750
751    #[inline(always)]
752    fn reset_opt(this: Option<&mut Self>) {
753        if let Some(this) = this {
754            this.is_set = false;
755        }
756    }
757}
758
759fn to_u64_saturating(n: usize) -> u64 {
760    n.try_into().unwrap_or(u64::MAX)
761}
762
763// Just so I don't have to repeat this feature flag a hundred times.
764#[cfg(any(feature = "lz4", feature = "zstd"))]
765mod compression {
766    use crate::Compression;
767    use crate::error::{Error, Result};
768    use crate::insert_formatted::{InsertFormatted, to_u64_saturating};
769    use bytes::Bytes;
770
771    /// A chunk of pre-compressed data.
772    #[cfg_attr(docsrs, doc(cfg(any(feature = "lz4", feature = "zstd"))))]
773    pub struct CompressedData {
774        pub(crate) compressed: Bytes,
775        pub(crate) original_size: u64,
776    }
777
778    impl CompressedData {
779        /// Compress a slice of bytes using the specified compression method.
780        ///
781        /// # Errors
782        /// Returns [`Error::Compression`] if `compression` is [`Compression::None`].
783        pub fn new(data: &[u8], compression: Compression) -> Result<Self> {
784            let original_size = to_u64_saturating(data.len());
785
786            match compression {
787                Compression::None => Err(Error::Compression(
788                    "cannot pre-compress data when compression is disabled".into(),
789                )),
790                #[cfg(feature = "lz4")]
791                #[allow(deprecated)]
792                Compression::Lz4 | Compression::Lz4Hc(_) => Ok(Self {
793                    compressed: crate::compression::lz4::compress(data)?,
794                    original_size,
795                }),
796                #[cfg(feature = "zstd")]
797                Compression::Zstd(level) => Ok(Self {
798                    compressed: crate::compression::zstd::compress(data, Some(level))?,
799                    original_size,
800                }),
801            }
802        }
803
804        /// Compress a slice of bytes using LZ4.
805        #[cfg(feature = "lz4")]
806        #[deprecated(note = "use `CompressedData::new()` instead")]
807        #[inline(always)]
808        pub fn from_slice(slice: &[u8]) -> Self {
809            Self {
810                original_size: to_u64_saturating(slice.len()),
811                compressed: crate::compression::lz4::compress(slice)
812                    .expect("BUG: `lz4::compress()` should not error"),
813            }
814        }
815    }
816
817    #[cfg(feature = "lz4")]
818    impl<T> From<T> for CompressedData
819    where
820        T: AsRef<[u8]>,
821    {
822        #[inline(always)]
823        #[allow(deprecated)]
824        fn from(value: T) -> Self {
825            Self::from_slice(value.as_ref())
826        }
827    }
828
829    impl InsertFormatted {
830        /// Send a chunk of pre-compressed data.
831        ///
832        /// # Errors
833        /// In addition to network errors, this will return [`Error::Compression`] if the
834        /// [`Client`][crate::Client] does not have compression enabled.
835        pub async fn send_compressed(&mut self, data: CompressedData) -> Result<()> {
836            if !self.compression.is_enabled() {
837                return Err(Error::Compression(
838                    "attempting to send compressed data, but compression is not enabled".into(),
839                ));
840            }
841
842            self.send_inner(data.compressed, data.original_size).await
843        }
844    }
845}