1use std::{
2 io::{Error, Write},
3 sync::Arc,
4};
5
6use rustls::{pki_types::ServerName, ClientConfig, RootCertStore};
7use tokio::{
8 io::{AsyncReadExt, AsyncWriteExt},
9 net::TcpStream,
10};
11use tokio_rustls::TlsConnector;
12
13use super::{
14 app::App,
15 dbs::adapter::{NoCertificateVerification, DB},
16 init::{Addr, DBConfig},
17};
18
19#[derive(Debug)]
21pub struct Tool {
22 pub(crate) stop: Option<(Arc<Addr>, i64)>,
23 pub(crate) root: Arc<String>,
24 pub(crate) db: Arc<DB>,
25 pub(crate) install_end: bool,
26}
27
28impl Tool {
29 pub(crate) fn new(db: Arc<DB>, stop: Option<(Arc<Addr>, i64)>, root: Arc<String>) -> Tool {
30 Tool { stop, root, db, install_end: false }
31 }
32
33 pub fn install_end(&mut self) {
35 if !self.db.in_use() {
36 self.install_end = true;
37 }
38 }
39
40 pub(crate) fn stop(&mut self) {
42 if let Some((rpc, stop)) = self.stop.take() {
43 App::stop(rpc, stop);
44 }
45 }
46
47 pub fn get_cpu(&self) -> usize {
49 num_cpus::get()
50 }
51
52 pub fn get_root(&self) -> Arc<String> {
54 Arc::clone(&self.root)
55 }
56
57 pub async fn check_db(&self, config: DBConfig, sql: Option<Vec<String>>) -> Result<String, String> {
59 DB::check_db(&config, sql).await
60 }
61
62 pub fn get_db_type(&self) -> &'static str {
64 #[cfg(feature = "pgsql")]
65 return "PostgreSQL";
66 #[cfg(feature = "mssql")]
67 return "MS Sql Server";
68 #[cfg(not(any(feature = "pgsql", feature = "mssql")))]
69 return "Not defined";
70 }
71
72 pub async fn get_install_sql(&self) -> Result<String, std::io::Error> {
74 let addr = "raw.githubusercontent.com:443";
75 let domain = "raw.githubusercontent.com";
76
77 #[cfg(feature = "mssql")]
78 let url_path = "/tryteex/tiny-web/refs/heads/main/sql/lib-install-mssql.sql";
79 #[cfg(feature = "pgsql")]
80 let url_path = "/tryteex/tiny-web/refs/heads/main/sql/lib-install-pgsql.sql";
81 #[cfg(not(any(feature = "pgsql", feature = "mssql")))]
82 let url_path = "/tryteex/tiny-web/refs/heads/main/sql/lib-install-nosql.sql";
83
84 let stream = TcpStream::connect(addr).await?;
85
86 let mut config = ClientConfig::builder().with_root_certificates(RootCertStore::empty()).with_no_client_auth();
87 config.dangerous().set_certificate_verifier(Arc::new(NoCertificateVerification {}));
88 let tls_connector = TlsConnector::from(Arc::new(config));
89 let server_name = ServerName::try_from(domain).unwrap();
90
91 let mut tls_stream = tls_connector.connect(server_name, stream).await?;
92
93 let request = format!("GET {} HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n", url_path, domain);
94 tls_stream.write_all(request.as_bytes()).await?;
95
96 let mut response = Vec::new();
97 tls_stream.read_to_end(&mut response).await?;
98
99 let response_str = String::from_utf8_lossy(&response);
100 let sql = if let Some(body_start) = response_str.find("\r\n\r\n") {
101 &response_str[body_start + 4..]
102 } else {
103 return Err(std::io::Error::new(std::io::ErrorKind::InvalidData, ""));
104 };
105
106 Ok(sql.to_owned())
107 }
108
109 pub fn save_config_file(&self, data: &str) -> Result<(), Error> {
111 let mut file = std::fs::File::create(format!("{}/tiny.toml", self.root))?;
112 file.write_all(data.as_bytes())?;
113
114 Ok(())
115 }
116}