tikv_client/common/
security.rs1use std::fs::File;
4use std::io::Read;
5use std::path::Path;
6use std::path::PathBuf;
7use std::time::Duration;
8
9use log::info;
10use regex::Regex;
11use tonic::transport::Channel;
12use tonic::transport::ClientTlsConfig;
13use tonic::transport::Identity;
14use tonic::transport::{Certificate, Endpoint};
15
16use crate::internal_err;
17use crate::Result;
18
19lazy_static::lazy_static! {
20 static ref SCHEME_REG: Regex = Regex::new(r"^\s*(https?://)").unwrap();
21}
22
23fn check_pem_file(tag: &str, path: &Path) -> Result<File> {
24 File::open(path)
25 .map_err(|e| internal_err!("failed to open {} to load {}: {:?}", path.display(), tag, e))
26}
27
28fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
29 let mut file = check_pem_file(tag, path)?;
30 let mut key = vec![];
31 file.read_to_end(&mut key)
32 .map_err(|e| {
33 internal_err!(
34 "failed to load {} from path {}: {:?}",
35 tag,
36 path.display(),
37 e
38 )
39 })
40 .map(|_| key)
41}
42
43#[derive(Default)]
45pub struct SecurityManager {
46 ca: Vec<u8>,
48 cert: Vec<u8>,
50 key: PathBuf,
52}
53
54impl SecurityManager {
55 pub fn load(
57 ca_path: impl AsRef<Path>,
58 cert_path: impl AsRef<Path>,
59 key_path: impl Into<PathBuf>,
60 ) -> Result<SecurityManager> {
61 let key_path = key_path.into();
62 check_pem_file("private key", &key_path)?;
63 Ok(SecurityManager {
64 ca: load_pem_file("ca", ca_path.as_ref())?,
65 cert: load_pem_file("certificate", cert_path.as_ref())?,
66 key: key_path,
67 })
68 }
69
70 pub async fn connect<Factory, Client>(
72 &self,
73 addr: &str,
75 factory: Factory,
76 ) -> Result<Client>
77 where
78 Factory: FnOnce(Channel) -> Client,
79 {
80 info!("connect to rpc server at endpoint: {:?}", addr);
81 let channel = if !self.ca.is_empty() {
82 self.tls_channel(addr).await?
83 } else {
84 self.default_channel(addr).await?
85 };
86 let ch = channel.connect().await?;
87
88 Ok(factory(ch))
89 }
90
91 async fn tls_channel(&self, addr: &str) -> Result<Endpoint> {
92 let addr = "https://".to_string() + &SCHEME_REG.replace(addr, "");
93 let builder = self.endpoint(addr.to_string())?;
94 let tls = ClientTlsConfig::new()
95 .ca_certificate(Certificate::from_pem(&self.ca))
96 .identity(Identity::from_pem(
97 &self.cert,
98 load_pem_file("private key", &self.key)?,
99 ));
100 let builder = builder.tls_config(tls)?;
101 Ok(builder)
102 }
103
104 async fn default_channel(&self, addr: &str) -> Result<Endpoint> {
105 let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
106 self.endpoint(addr)
107 }
108
109 fn endpoint(&self, addr: String) -> Result<Endpoint> {
110 let endpoint = Channel::from_shared(addr)?
111 .tcp_keepalive(Some(Duration::from_secs(10)))
112 .keep_alive_timeout(Duration::from_secs(3));
113 Ok(endpoint)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use std::fs::File;
120 use std::io::Write;
121 use std::path::PathBuf;
122
123 use tempfile;
124
125 use super::*;
126
127 #[test]
128 fn test_security() {
129 let temp = tempfile::tempdir().unwrap();
130 let example_ca = temp.path().join("ca");
131 let example_cert = temp.path().join("cert");
132 let example_pem = temp.path().join("key");
133 for (id, f) in [&example_ca, &example_cert, &example_pem]
134 .iter()
135 .enumerate()
136 {
137 File::create(f).unwrap().write_all(&[id as u8]).unwrap();
138 }
139 let cert_path: PathBuf = format!("{}", example_cert.display()).into();
140 let key_path: PathBuf = format!("{}", example_pem.display()).into();
141 let ca_path: PathBuf = format!("{}", example_ca.display()).into();
142 let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
143 assert_eq!(mgr.ca, vec![0]);
144 assert_eq!(mgr.cert, vec![1]);
145 let key = load_pem_file("private key", &key_path).unwrap();
146 assert_eq!(key, vec![2]);
147 }
148}