typed_clickhouse/
insert.rs

1use std::{future::Future, marker::PhantomData, mem, panic};
2
3use bytes::BytesMut;
4use hyper::{self, body, Body, Request};
5use serde::Serialize;
6use tokio::task::JoinHandle;
7use url::Url;
8
9use crate::{
10    error::{Error, Result},
11    introspection::{self, Reflection},
12    response::Response,
13    rowbinary, Client,
14};
15
16const BUFFER_SIZE: usize = 128 * 1024;
17const MIN_CHUNK_SIZE: usize = BUFFER_SIZE - 1024;
18
19pub struct Insert<T> {
20    buffer: BytesMut,
21    sender: Option<body::Sender>,
22    handle: JoinHandle<Result<()>>,
23    _marker: PhantomData<fn() -> T>, // TODO: test contravariance.
24}
25
26impl<T> Insert<T> {
27    pub(crate) fn new(client: &Client, table: &str) -> Result<Self>
28    where
29        T: Reflection,
30    {
31        let mut url = Url::parse(&client.url).expect("TODO");
32        let mut pairs = url.query_pairs_mut();
33        pairs.clear();
34
35        if let Some(database) = &client.database {
36            pairs.append_pair("database", database);
37        }
38
39        let fields = introspection::join_field_names::<T>()
40            .expect("the row type must be a struct or a wrapper around it");
41
42        // TODO: what about escaping a table name?
43        // https://clickhouse.yandex/docs/en/query_language/syntax/#syntax-identifiers
44        let query = format!("INSERT INTO {}({}) FORMAT RowBinary", table, fields);
45        pairs.append_pair("query", &query);
46        drop(pairs);
47
48        let mut builder = Request::post(url.as_str());
49
50        if let Some(user) = &client.user {
51            builder = builder.header("X-ClickHouse-User", user);
52        }
53
54        if let Some(password) = &client.password {
55            builder = builder.header("X-ClickHouse-Key", password);
56        }
57
58        let (sender, body) = Body::channel();
59
60        let request = builder
61            .body(body)
62            .map_err(|err| Error::InvalidParams(Box::new(err)))?;
63
64        let future = client.client.request(request);
65        let handle = tokio::spawn(async move {
66            // TODO: should we read the body?
67            let _ = Response::from(future).resolve().await?;
68            Ok(())
69        });
70
71        Ok(Insert {
72            buffer: BytesMut::with_capacity(BUFFER_SIZE),
73            sender: Some(sender),
74            handle,
75            _marker: PhantomData,
76        })
77    }
78
79    pub fn write<'a>(&'a mut self, row: &T) -> impl Future<Output = Result<()>> + 'a + Send
80    where
81        T: Serialize,
82    {
83        let result = rowbinary::serialize_into(&mut self.buffer, row);
84
85        async move {
86            result?;
87            self.send_chunk_if_exceeds(MIN_CHUNK_SIZE).await?;
88            Ok(())
89        }
90    }
91
92    pub async fn end(mut self) -> Result<()> {
93        self.send_chunk_if_exceeds(1).await?;
94        drop(self.sender.take());
95
96        match (&mut self.handle).await {
97            Ok(res) => res,
98            Err(err) if err.is_panic() => panic::resume_unwind(err.into_panic()),
99            Err(err) => {
100                // TODO
101                Err(Error::Custom(format!("unexpected error: {}", err)))
102            }
103        }
104    }
105
106    async fn send_chunk_if_exceeds(&mut self, threshold: usize) -> Result<()> {
107        if self.buffer.len() >= threshold {
108            // Hyper uses non-trivial and inefficient (see benches) schema of buffering chunks.
109            // It's difficult to determine when allocations occur.
110            // So, instead we control it manually here and rely on the system allocator.
111            let chunk = mem::replace(&mut self.buffer, BytesMut::with_capacity(BUFFER_SIZE));
112
113            if let Some(sender) = &mut self.sender {
114                sender.send_data(chunk.freeze()).await?;
115            }
116        }
117
118        Ok(())
119    }
120}
121
122impl<T> Drop for Insert<T> {
123    fn drop(&mut self) {
124        if let Some(sender) = self.sender.take() {
125            sender.abort();
126        }
127    }
128}