1#![recursion_limit = "256"]
2use std::fmt::{Debug, Display};
3
4use log::warn;
5use maplit::hashmap;
6use once_cell::sync::OnceCell;
7
8use taos_query::prelude::Code;
9use taos_query::util::Edition;
10use taos_query::{DsnError, IntoDsn, RawResult};
11
12pub mod stmt;
13pub use stmt::Stmt;
14
15pub mod consumer;
17pub use consumer::{Consumer, Offset, TmqBuilder};
18
19pub mod query;
20pub use query::ResultSet;
21pub use query::Taos;
22
23use query::Error as QueryError;
24use query::WsConnReq;
25
26pub mod schemaless;
27
28pub(crate) use taos_query::block_in_place_or_global;
29use tokio::net::TcpStream;
30use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
31use tokio_tungstenite::MaybeTlsStream;
32use tokio_tungstenite::{connect_async_with_config, WebSocketStream};
33
34#[allow(unused_imports)]
35use ws_tool::codec::{AsyncDeflateCodec, WindowBit};
36
37pub mod client;
38pub use client::ClientConfig;
39
40#[derive(Debug, Clone)]
41pub enum WsAuth {
42 Token(String),
43 Plain(String, String),
44}
45
46#[derive(Debug, Clone)]
47pub struct TaosBuilder {
48 scheme: &'static str, addr: String,
50 auth: WsAuth,
51 database: Option<String>,
52 server_version: OnceCell<String>,
53 conn_mode: Option<u32>,
55}
56
57#[derive(Debug, thiserror::Error)]
58pub struct Error {
59 code: Code,
60 source: anyhow::Error,
61}
62
63impl Error {
64 pub const fn errno(&self) -> Code {
65 self.code
66 }
67 pub fn errstr(&self) -> String {
68 self.source.to_string()
69 }
70}
71
72impl Display for Error {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.write_str(&self.source.to_string())
75 }
76}
77
78impl From<DsnError> for Error {
79 fn from(err: DsnError) -> Self {
80 Error {
81 code: Code::FAILED,
82 source: err.into(),
83 }
84 }
85}
86impl From<query::asyn::Error> for Error {
87 fn from(err: query::asyn::Error) -> Self {
88 Error {
89 code: Code::FAILED,
90 source: err.into(),
91 }
92 }
93}
94
95impl taos_query::TBuilder for TaosBuilder {
96 type Target = Taos;
97
98 fn available_params() -> &'static [&'static str] {
99 &["token"]
100 }
101
102 fn from_dsn<D: IntoDsn>(dsn: D) -> RawResult<Self> {
103 Self::from_dsn(dsn.into_dsn()?)
104 }
105
106 fn client_version() -> &'static str {
107 "0"
108 }
109 fn ping(&self, taos: &mut Self::Target) -> RawResult<()> {
110 taos_query::Queryable::exec(taos, "select server_status()").map(|_| ())
111 }
112
113 fn ready(&self) -> bool {
114 true
115 }
116
117 fn build(&self) -> RawResult<Self::Target> {
118 Ok(Taos {
119 dsn: self.clone(),
120 async_client: OnceCell::new(),
121 async_sml: OnceCell::new(),
122 })
123 }
124
125 fn server_version(&self) -> RawResult<&str> {
126 if let Some(v) = self.server_version.get() {
127 Ok(v.as_str())
128 } else {
129 let conn = self.build()?;
130 use taos_query::prelude::sync::Queryable;
131 let v: String = Queryable::query_one(&conn, "select server_version()")?.unwrap();
132 Ok(match self.server_version.try_insert(v) {
133 Ok(v) => v.as_str(),
134 Err((v, _)) => v.as_str(),
135 })
136 }
137 }
138 fn is_enterprise_edition(&self) -> RawResult<bool> {
139 if self.addr.matches(".cloud.tdengine.com").next().is_some()
140 || self.addr.matches(".cloud.taosdata.com").next().is_some()
141 {
142 return Ok(true);
143 }
144
145 let taos = self.build()?;
146
147 use taos_query::prelude::sync::Queryable;
148 let grant: RawResult<Option<(String, bool)>> = Queryable::query_one(
149 &taos,
150 "select version, (expire_time < now) from information_schema.ins_cluster",
151 );
152
153 let edition = if let Ok(Some((edition, expired))) = grant {
154 Edition::new(edition, expired)
155 } else {
156 let grant: RawResult<Option<(String, (), String)>> =
157 Queryable::query_one(&taos, "show grants");
158
159 if let Ok(Some((edition, _, expired))) = grant {
160 Edition::new(
161 edition.trim(),
162 expired.trim() == "false" || expired.trim() == "unlimited",
163 )
164 } else {
165 warn!("Can't check enterprise edition with either \"show cluster\" or \"show grants\"");
166 Edition::new("unknown", true)
167 }
168 };
169 Ok(edition.is_enterprise_edition())
170 }
171
172 fn get_edition(&self) -> RawResult<Edition> {
173 if self.addr.matches(".cloud.tdengine.com").next().is_some()
174 || self.addr.matches(".cloud.taosdata.com").next().is_some()
175 {
176 let edition = Edition::new("cloud", false);
177 return Ok(edition);
178 }
179
180 let taos = self.build()?;
181
182 use taos_query::prelude::sync::Queryable;
183 let grant: RawResult<Option<(String, bool)>> = Queryable::query_one(
184 &taos,
185 "select version, (expire_time < now) from information_schema.ins_cluster",
186 );
187
188 let edition = if let Ok(Some((edition, expired))) = grant {
189 Edition::new(edition, expired)
190 } else {
191 let grant: RawResult<Option<(String, (), String)>> =
192 Queryable::query_one(&taos, "show grants");
193
194 if let Ok(Some((edition, _, expired))) = grant {
195 Edition::new(
196 edition.trim(),
197 expired.trim() == "false" || expired.trim() == "unlimited",
198 )
199 } else {
200 warn!("Can't check enterprise edition with either \"show cluster\" or \"show grants\"");
201 Edition::new("unknown", true)
202 }
203 };
204 Ok(edition)
205 }
206}
207
208#[async_trait::async_trait]
209impl taos_query::AsyncTBuilder for TaosBuilder {
210 type Target = Taos;
211
212 fn from_dsn<D: IntoDsn>(dsn: D) -> RawResult<Self> {
213 Self::from_dsn(dsn.into_dsn()?)
214 }
215
216 fn client_version() -> &'static str {
217 "0"
218 }
219 async fn ping(&self, taos: &mut Self::Target) -> RawResult<()> {
220 taos_query::AsyncQueryable::exec(taos, "select server_status()")
221 .await
222 .map(|_| ())
223 }
224
225 async fn ready(&self) -> bool {
226 true
227 }
228
229 async fn build(&self) -> RawResult<Self::Target> {
230 Ok(Taos {
231 dsn: self.clone(),
232 async_client: OnceCell::new(),
233 async_sml: OnceCell::new(),
234 })
235 }
236
237 async fn server_version(&self) -> RawResult<&str> {
238 if let Some(v) = self.server_version.get() {
239 Ok(v.as_str())
240 } else {
241 let conn = <Self as taos_query::AsyncTBuilder>::build(self).await?;
242 use taos_query::prelude::AsyncQueryable;
243 let v: String = AsyncQueryable::query_one(&conn, "select server_version()")
244 .await?
245 .unwrap();
246 Ok(match self.server_version.try_insert(v) {
247 Ok(v) => v.as_str(),
248 Err((v, _)) => v.as_str(),
249 })
250 }
251 }
252 async fn is_enterprise_edition(&self) -> RawResult<bool> {
253 use taos_query::prelude::AsyncQueryable;
254
255 let taos = self.build().await?;
256 taos.exec("select server_status()").await?;
258
259 match self.addr.matches(".cloud.tdengine.com").next().is_some()
260 || self.addr.matches(".cloud.taosdata.com").next().is_some()
261 {
262 true => return Ok(true),
263 false => (),
264 }
265
266 let grant: RawResult<Option<(String, bool)>> = AsyncQueryable::query_one(
267 &taos,
268 "select version, (expire_time < now) from information_schema.ins_cluster",
269 )
270 .await;
271
272 let edition = if let Ok(Some((edition, expired))) = grant {
273 Edition::new(edition, expired)
274 } else {
275 let grant: RawResult<Option<(String, (), String)>> =
276 AsyncQueryable::query_one(&taos, "show grants").await;
277
278 if let Ok(Some((edition, _, expired))) = grant {
279 Edition::new(
280 edition.trim(),
281 expired.trim() == "false" || expired.trim() == "unlimited",
282 )
283 } else {
284 warn!("Can't check enterprise edition with either \"show cluster\" or \"show grants\"");
285 Edition::new("unknown", true)
286 }
287 };
288 Ok(edition.is_enterprise_edition())
289 }
290
291 async fn get_edition(&self) -> RawResult<Edition> {
292 use taos_query::prelude::AsyncQueryable;
293
294 let taos = self.build().await?;
295 taos.exec("select server_status()").await?;
297
298 match self.addr.matches(".cloud.tdengine.com").next().is_some()
299 || self.addr.matches(".cloud.taosdata.com").next().is_some()
300 {
301 true => {
302 let edition = Edition::new("cloud", false);
303 return Ok(edition);
304 }
305 false => (),
306 }
307
308 let grant: RawResult<Option<(String, bool)>> = AsyncQueryable::query_one(
309 &taos,
310 "select version, (expire_time < now) from information_schema.ins_cluster",
311 )
312 .await;
313
314 let edition = if let Ok(Some((edition, expired))) = grant {
315 Edition::new(edition, expired)
316 } else {
317 let grant: RawResult<Option<(String, (), String)>> =
318 AsyncQueryable::query_one(&taos, "show grants").await;
319
320 if let Ok(Some((edition, _, expired))) = grant {
321 Edition::new(
322 edition.trim(),
323 expired.trim() == "false" || expired.trim() == "unlimited",
324 )
325 } else {
326 warn!("Can't check enterprise edition with either \"show cluster\" or \"show grants\"");
327 Edition::new("unknown", true)
328 }
329 };
330 Ok(edition)
331 }
332}
333
334impl TaosBuilder {
335 pub fn from_dsn(dsn: impl IntoDsn) -> RawResult<Self> {
336 let mut dsn = dsn.into_dsn()?;
337 let scheme = match (dsn.driver.as_str(), dsn.protocol.as_deref()) {
338 ("ws" | "http", _) => "ws",
339 ("wss" | "https", _) => "wss",
340 ("taos" | "taosws" | "tmq", Some("ws" | "http") | None) => "ws",
341 ("taos" | "taosws" | "tmq", Some("wss" | "https")) => "wss",
342 _ => Err(DsnError::InvalidDriver(dsn.to_string()))?,
343 };
344
345 let conn_mode = match dsn.params.get("conn_mode") {
346 Some(s) => match s.parse::<u32>() {
347 Ok(num) => Some(num),
348 Err(_) => Err(DsnError::InvalidDriver(dsn.to_string()))?,
349 },
350 None => None,
351 };
352
353 let token = dsn.params.remove("token");
354
355 let addr = match dsn.addresses.first() {
356 Some(addr) => {
357 if addr.port.is_none() && addr.host.as_deref() == Some("localhost") {
358 "localhost:6041".to_string()
359 } else {
360 addr.to_string()
361 }
362 }
363 None => "localhost:6041".to_string(),
364 };
365
366 if let Some(token) = token {
373 Ok(TaosBuilder {
374 scheme,
375 addr,
376 auth: WsAuth::Token(token),
377 database: dsn.subject,
378 server_version: OnceCell::new(),
379 conn_mode,
381 })
382 } else {
383 let username = dsn.username.unwrap_or_else(|| "root".to_string());
384 let password = dsn.password.unwrap_or_else(|| "taosdata".to_string());
385 Ok(TaosBuilder {
386 scheme,
387 addr,
388 auth: WsAuth::Plain(username, password),
389 database: dsn.subject,
390 server_version: OnceCell::new(),
391 conn_mode,
393 })
394 }
395 }
396 pub(crate) fn to_query_url(&self) -> String {
397 match &self.auth {
398 WsAuth::Token(token) => {
399 format!("{}://{}/rest/ws?token={}", self.scheme, self.addr, token)
400 }
401 WsAuth::Plain(_, _) => format!("{}://{}/rest/ws", self.scheme, self.addr),
402 }
403 }
404
405 pub(crate) fn to_stmt_url(&self) -> String {
406 match &self.auth {
407 WsAuth::Token(token) => {
408 format!("{}://{}/rest/stmt?token={}", self.scheme, self.addr, token)
409 }
410 WsAuth::Plain(_, _) => format!("{}://{}/rest/stmt", self.scheme, self.addr),
411 }
412 }
413
414 pub(crate) fn to_tmq_url(&self) -> String {
415 match &self.auth {
416 WsAuth::Token(token) => {
417 format!("{}://{}/rest/tmq?token={}", self.scheme, self.addr, token)
418 }
419 WsAuth::Plain(_, _) => format!("{}://{}/rest/tmq", self.scheme, self.addr),
420 }
421 }
422
423 pub(crate) fn to_schemaless_url(&self) -> String {
424 match &self.auth {
425 WsAuth::Token(token) => {
426 format!(
427 "{}://{}/rest/schemaless?token={}",
428 self.scheme, self.addr, token
429 )
430 }
431 WsAuth::Plain(_, _) => format!("{}://{}/rest/schemaless", self.scheme, self.addr),
432 }
433 }
434
435 pub(crate) fn to_ws_url(&self) -> String {
436 match &self.auth {
437 WsAuth::Token(token) => {
438 format!("{}://{}/ws?token={}", self.scheme, self.addr, token)
439 }
440 WsAuth::Plain(_, _) => format!("{}://{}/ws", self.scheme, self.addr),
441 }
442 }
443
444 pub(crate) fn to_conn_request(&self) -> WsConnReq {
445 let mode = match self.conn_mode {
446 Some(1) => Some(0), _ => None,
448 };
449
450 match &self.auth {
451 WsAuth::Token(_token) => WsConnReq {
452 user: Some("root".to_string()),
453 password: Some("taosdata".to_string()),
454 db: self.database.as_ref().map(Clone::clone),
455 mode,
456 },
457 WsAuth::Plain(user, pass) => WsConnReq {
458 user: Some(user.to_string()),
459 password: Some(pass.to_string()),
460 db: self.database.as_ref().map(Clone::clone),
461 mode,
462 },
463 }
464 }
465
466 pub(crate) async fn build_stream(
467 &self,
468 url: String,
469 ) -> RawResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
470 let mut config = WebSocketConfig::default();
471 config.max_frame_size = None;
472
473 let res = connect_async_with_config(self.to_ws_url(), Some(config), false)
474 .await
475 .map_err(|err| {
476 let err_string = err.to_string();
477 if err_string.contains("401 Unauthorized") {
478 QueryError::Unauthorized(self.to_ws_url())
479 } else {
480 err.into()
481 }
482 });
483
484 let (ws, _) = match res {
485 Ok(res) => res,
486 Err(err) => {
487 if err.to_string().contains("404 Not Found") || err.to_string().contains("400") {
488 connect_async_with_config(&url, Some(config), false)
489 .await
490 .map_err(|err| {
491 let err_string = err.to_string();
492 if err_string.contains("401 Unauthorized") {
493 QueryError::Unauthorized(url)
494 } else {
495 err.into()
496 }
497 })?
498 } else {
499 return Err(err.into());
500 }
501 }
502 };
503 Ok(ws)
504 }
505
506 pub(crate) async fn ws_tool_build_stream(
507 &self,
508 url: String,
509 ) -> RawResult<AsyncDeflateCodec<tokio::io::BufStream<ws_tool::stream::AsyncStream>>> {
510 let mut config = ClientConfig::default();
511
512 #[cfg(feature = "deflate")]
513 {
514 config.window = Some(WindowBit::Fifteen);
515 config.extra_headers = hashmap! {
516 "Accept-Encoding".to_string() => "gzip, deflate".to_string(),
517 };
518 }
519 #[cfg(not(feature = "deflate"))]
520 {
521 config.window = None;
522 config.extra_headers = hashmap! {
523 "Accept-Encoding".to_string() => "gzip".to_string(),
524 };
525 }
526
527 log::trace!(
528 "ws_tool config window: {:?}, headers: {:?}",
529 &config.window,
530 &config.extra_headers
531 );
532
533 let res: Result<
534 AsyncDeflateCodec<tokio::io::BufStream<ws_tool::stream::AsyncStream>>,
535 QueryError,
536 > = config
537 .async_connect_with(self.to_ws_url(), AsyncDeflateCodec::check_fn)
538 .await
539 .map_err(|err| {
540 let err_string = err.to_string();
541 if err_string.contains("401 Unauthorized") {
542 QueryError::Unauthorized(self.to_ws_url())
543 } else {
544 err.into()
545 }
546 });
547
548 let ws: AsyncDeflateCodec<tokio::io::BufStream<ws_tool::stream::AsyncStream>> = match res {
549 Ok(res) => res,
550 Err(err) => {
551 let uri = url.clone();
552 if err.to_string().contains("404 Not Found") || err.to_string().contains("400") {
553 config
554 .async_connect_with(uri, AsyncDeflateCodec::check_fn)
555 .await
556 .map_err(|err| {
557 let err_string = err.to_string();
558 if err_string.contains("401 Unauthorized") {
559 QueryError::Unauthorized(url)
560 } else {
561 err.into()
562 }
563 })?
564 } else {
565 return Err(err.into());
566 }
567 }
568 };
569
570 Ok(ws)
571 }
572
573 pub(crate) async fn build_tmq_stream(
574 &self,
575 url: String,
576 ) -> RawResult<AsyncDeflateCodec<tokio::io::BufStream<ws_tool::stream::AsyncStream>>> {
577 let mut config = ClientConfig::default();
578
579 #[cfg(feature = "deflate")]
580 {
581 config.window = Some(WindowBit::Fifteen);
582 config.extra_headers = hashmap! {
583 "Accept-Encoding".to_string() => "gzip, deflate".to_string(),
584 };
585 }
586 #[cfg(not(feature = "deflate"))]
587 {
588 config.window = None;
589 config.extra_headers = hashmap! {
590 "Accept-Encoding".to_string() => "gzip".to_string(),
591 };
592 }
593
594 log::trace!(
595 "ws_tool config window: {:?}, headers: {:?}",
596 &config.window,
597 &config.extra_headers
598 );
599
600 let ws: Result<
601 AsyncDeflateCodec<tokio::io::BufStream<ws_tool::stream::AsyncStream>>,
602 QueryError,
603 > = config
604 .async_connect_with(url.clone(), AsyncDeflateCodec::check_fn)
605 .await
606 .map_err(|err| {
607 let err_string = err.to_string();
608 if err_string.contains("401 Unauthorized") {
609 QueryError::Unauthorized(url)
610 } else {
611 err.into()
612 }
613 });
614
615 Ok(ws?)
616 }
617}
618
619#[cfg(feature = "rustls")]
620#[cfg(test)]
621mod lib_tests {
622
623 use crate::{
624 query::infra::{ToMessage, WsRecv, WsSend},
625 *,
626 };
627 use futures::{SinkExt, StreamExt};
628 use std::time::Duration;
629 use tracing::*;
630 use tracing_subscriber::util::SubscriberInitExt;
631 use ws_tool::frame::OpCode;
632
633 #[cfg(feature = "rustls")]
634 #[tokio::test]
635 async fn test_build_stream() -> Result<(), anyhow::Error> {
636 let _subscriber = tracing_subscriber::fmt::fmt()
637 .with_max_level(Level::INFO)
638 .with_file(true)
639 .with_line_number(true)
640 .finish();
641 let _ = _subscriber.try_init();
642
643 let dsn = std::env::var("TEST_CLOUD_DSN").unwrap_or("http://localhost:6041".to_string());
644 let builder = TaosBuilder::from_dsn(dsn).unwrap();
645 let url = builder.to_query_url();
646 info!("url: {}", url);
647 let ws = builder.build_stream(url).await.unwrap();
648 trace!("ws: {:?}", ws);
649
650 let (mut sender, mut reader) = ws.split();
651
652 let version = WsSend::Version;
653 sender.send(version.to_tungstenite_msg()).await?;
654
655 let _handle = tokio::spawn(async move {
656 loop {
657 if let Some(Ok(msg)) = reader.next().await {
658 let text = msg.to_text().unwrap();
659 let recv: WsRecv = serde_json::from_str(text).unwrap();
660 info!("recv: {:?}", recv);
661 assert_eq!(recv.code, 0);
662 }
663 }
664 });
665
666 tokio::time::sleep(Duration::from_millis(1000)).await;
667
668 Ok(())
669 }
670
671 #[cfg(feature = "rustls")]
672 #[tokio::test]
673 async fn test_ws_tool_build_stream() -> Result<(), anyhow::Error> {
674 let _subscriber = tracing_subscriber::fmt::fmt()
675 .with_max_level(Level::DEBUG)
676 .with_file(true)
677 .with_line_number(true)
678 .finish();
679 let _ = _subscriber.try_init();
680
681 let dsn = std::env::var("TEST_CLOUD_DSN").unwrap_or("http://localhost:6041".to_string());
682
683 let builder = TaosBuilder::from_dsn(dsn).unwrap();
684 let url = builder.to_query_url();
685 let ws = builder.ws_tool_build_stream(url).await.unwrap();
686
687 let (mut sink, mut source) = ws.split();
688
689 let version = WsSend::Version;
690 source
691 .send(OpCode::Text, &serde_json::to_vec(&version)?)
692 .await?;
693
694 let _handle = tokio::spawn(async move {
695 loop {
696 let frame = sink.receive().await.unwrap();
697 let (header, payload) = frame;
698 trace!("header.code: {:?}, payload: {:?}", &header.code, &payload);
699 let code = header.code;
700
701 match code {
702 OpCode::Binary => {
703 println!("{:?}", payload);
704 }
705 OpCode::Text => {
706 let recv: crate::query::infra::WsRecv =
707 serde_json::from_slice(&payload).unwrap();
708 info!("recv: {:?}", recv);
709 assert_eq!(recv.code, 0);
710 }
711 _ => (),
712 }
713 }
714 });
715
716 tokio::time::sleep(Duration::from_millis(1000)).await;
717
718 Ok(())
719 }
720}
721
722#[cfg(feature = "deflate")]
723#[cfg(test)]
724mod lib_deflate_tests {
725
726 use crate::{
727 query::infra::{ToMessage, WsRecv, WsSend},
728 *,
729 };
730 use futures::{SinkExt, StreamExt};
731 use std::time::Duration;
732 use tracing::*;
733 use tracing_subscriber::util::SubscriberInitExt;
734 use ws_tool::frame::OpCode;
735
736 #[cfg(feature = "deflate")]
737 #[tokio::test]
738 async fn test_ws_tool_build_stream_with_deflate() -> Result<(), anyhow::Error> {
739 let _subscriber = tracing_subscriber::fmt::fmt()
740 .with_max_level(Level::DEBUG)
741 .with_file(true)
742 .with_line_number(true)
743 .finish();
744 let _ = _subscriber.try_init();
745
746 let dsn = std::env::var("TEST_CLOUD_DSN").unwrap_or("http://localhost:6041".to_string());
747
748 let builder = TaosBuilder::from_dsn(dsn).unwrap();
749 let url = builder.to_query_url();
750 let ws = builder.ws_tool_build_stream(url).await.unwrap();
751
752 let (mut sink, mut source) = ws.split();
753
754 let version = WsSend::Version;
755 source
756 .send(OpCode::Text, &serde_json::to_vec(&version)?)
757 .await?;
758
759 let _handle = tokio::spawn(async move {
760 loop {
761 let frame = sink.receive().await.unwrap();
762 let (header, payload) = frame;
763 trace!("header.code: {:?}, payload: {:?}", &header.code, &payload);
764 let code = header.code;
765
766 match code {
767 OpCode::Binary => {
768 println!("{:?}", payload);
769 }
770 OpCode::Text => {
771 let recv: crate::query::infra::WsRecv =
772 serde_json::from_slice(&payload).unwrap();
773 info!("recv: {:?}", recv);
774 assert_eq!(recv.code, 0);
775 }
776 _ => (),
777 }
778 }
779 });
780
781 tokio::time::sleep(Duration::from_millis(1000)).await;
782
783 Ok(())
784 }
785}