1use std::{
2 collections::{BTreeMap, HashMap},
3 net::SocketAddr,
4 sync::{Arc, Mutex, OnceLock},
5};
6
7use crate::config::{matches_convert::convert_config_to_kernel, plugin_filter_dto::global_batch_mount_plugin, PluginConfig, SgProtocolConfig, SgTlsMode};
8
9use hyper::Version;
10use spacegate_config::{BackendHost, Config, ConfigItem};
11use spacegate_kernel::{
12 helper_layers::reload::Reloader,
13 listener::SgListen,
14 service::http_gateway::{builder::default_gateway_route_fallback, create_http_router, HttpRouterService},
15 ArcHyperService, BoxError,
16};
17use spacegate_plugin::{mount::MountPointIndex, PluginRepository};
18use std::time::Duration;
19use std::vec::Vec;
20use tokio::time::timeout;
21use tracing::{debug, error, info, instrument, warn};
22
23use tokio_rustls::rustls::{self, pki_types::PrivateKeyDer};
24use tokio_util::sync::CancellationToken;
25
26fn collect_http_route(
27 gateway_name: Arc<str>,
28 http_routes: impl IntoIterator<Item = (String, crate::SgHttpRoute)>,
29) -> Result<HashMap<String, spacegate_kernel::service::http_route::HttpRoute>, BoxError> {
30 http_routes
31 .into_iter()
32 .map(|(name, route)| {
33 let route_name: Arc<str> = name.clone().into();
34 let mount_index = MountPointIndex::HttpRoute {
35 gateway: gateway_name.clone(),
36 route: route_name.clone(),
37 };
38 let plugins = route.plugins;
39 let rules = route.rules;
40 let rules = rules
41 .into_iter()
42 .enumerate()
43 .map(|(rule_index, route_rule)| {
44 let mount_index = MountPointIndex::HttpRouteRule {
45 rule: rule_index,
46 gateway: gateway_name.clone(),
47 route: route_name.clone(),
48 };
49 let mut builder = spacegate_kernel::service::http_route::HttpRouteRule::builder();
50 builder = if let Some(matches) = route_rule.matches {
51 builder.matches(matches.into_iter().map(convert_config_to_kernel).collect::<Result<Vec<_>, _>>()?)
52 } else {
53 builder.match_all()
54 };
55 let backends = route_rule
56 .backends
57 .into_iter()
58 .enumerate()
59 .map(|(backend_index, backend)| {
60 let mount_index = MountPointIndex::HttpBackend {
61 backend: backend_index,
62 rule: rule_index,
63 gateway: gateway_name.clone(),
64 route: route_name.clone(),
65 };
66 let host = backend.get_host();
67 let mut builder = spacegate_kernel::service::http_route::HttpBackend::builder();
68 let plugins = backend.plugins;
69 #[cfg(feature = "k8s")]
70 {
71 use crate::extension::k8s_service::K8sService;
72 use spacegate_kernel::helper_layers::map_request::{add_extension::add_extension, MapRequestLayer};
73 use spacegate_kernel::BoxLayer;
74 if let BackendHost::K8sService(ref data) = backend.host {
75 let namespace_ext = K8sService(data.clone().into());
76 builder = builder.plugin(BoxLayer::new(MapRequestLayer::new(add_extension(namespace_ext, true))))
78 }
79 }
80 builder = builder.host(host);
81 if let Some(port) = backend.port {
82 builder = builder.port(port)
83 }
84 if let Some(timeout) = backend.timeout_ms.map(|timeout| Duration::from_millis(timeout as u64)) {
85 builder = builder.timeout(timeout)
86 }
87 let mut layer = if let BackendHost::File { path } = backend.host {
88 builder.file().path(path).build()
89 } else if let Some(protocol) = backend.protocol {
90 builder.schema(protocol.to_string()).build()
91 } else {
92 builder.build()
93 };
94 if backend.downgrade_http2.is_some() {
95 if let spacegate_kernel::service::http_route::Backend::Http { version, .. } = &mut layer.backend {
96 version.replace(Version::HTTP_11);
97 }
98 }
99 global_batch_mount_plugin(plugins, &mut layer, mount_index);
100 Result::<_, BoxError>::Ok(layer)
101 })
102 .collect::<Result<Vec<_>, _>>()?;
103 builder = builder.backends(backends);
104 if let Some(timeout) = route_rule.timeout_ms {
105 builder = builder.timeout(Duration::from_millis(timeout as u64));
106 }
107 let mut layer = builder.build();
108 global_batch_mount_plugin(route_rule.plugins, &mut layer, mount_index);
109 Result::<_, BoxError>::Ok(layer)
110 })
111 .collect::<Result<Vec<_>, _>>()?;
112 let mut layer =
113 spacegate_kernel::service::http_route::HttpRoute::builder().hostnames(route.hostnames.unwrap_or_default()).rules(rules).priority(route.priority).build();
114 global_batch_mount_plugin(plugins, &mut layer, mount_index);
115 Ok((name, layer))
116 })
117 .collect::<Result<HashMap<String, _>, _>>()
118}
119
120pub(crate) fn create_service(item: ConfigItem, reloader: Reloader<HttpRouterService>) -> Result<ArcHyperService, BoxError> {
122 let gateway_name: Arc<str> = item.gateway.name.into();
123 let http_routes = item.routes;
124 let routes = collect_http_route(gateway_name.clone(), http_routes)?;
125 let plugins = item.gateway.plugins.clone();
126 let mut builder = spacegate_kernel::service::http_gateway::Gateway::builder(gateway_name.clone());
127 if let Some(enable) = item.gateway.parameters.enable_x_request_id {
128 builder = builder.x_request_id(enable);
129 }
130 let mut layer = builder.http_routers(routes).http_route_reloader(reloader).build();
131 global_batch_mount_plugin(plugins, &mut layer, MountPointIndex::Gateway { gateway: gateway_name });
132 let service = layer.as_service();
133 Ok(service)
134}
135
136pub(crate) fn create_router_service(gateway_name: Arc<str>, http_routes: BTreeMap<String, crate::SgHttpRoute>) -> Result<HttpRouterService, BoxError> {
138 let routes = collect_http_route(gateway_name, http_routes.clone())?;
139 let service = create_http_router(routes.values(), default_gateway_route_fallback());
140 Ok(service)
141}
142
143pub struct RunningSgGateway {
152 pub gateway_name: Arc<str>,
153 token: CancellationToken,
154 handle: tokio::task::JoinHandle<()>,
155 pub reloader: Reloader<HttpRouterService>,
156 shutdown_timeout: Duration,
157}
158impl std::fmt::Debug for RunningSgGateway {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("RunningSgGateway").field("shutdown_timeout", &self.shutdown_timeout).finish()
161 }
162}
163
164pub static GLOBAL_STORE: OnceLock<Arc<Mutex<HashMap<String, RunningSgGateway>>>> = OnceLock::new();
165impl RunningSgGateway {
166 pub async fn global_init(config: Config, signal: CancellationToken) {
167 for (id, spec) in config.plugins.into_inner() {
168 if let Err(err) = PluginRepository::global().create_or_update_instance(PluginConfig { id: id.clone(), spec }) {
169 tracing::error!("[SG.Config] fail to init plugin [{id}]: {err}", id = id.to_string());
170 }
171 }
172 for (name, item) in config.gateways {
173 match RunningSgGateway::create(item, signal.child_token()) {
174 Ok(inst) => RunningSgGateway::global_save(name, inst),
175 Err(e) => {
176 tracing::error!("[SG.Config] fail to init gateway [{name}]: {e}")
177 }
178 }
179 }
180 }
181 pub async fn global_reset() {
182 let store = Self::global_store();
183 let mut task = tokio::task::JoinSet::new();
184 {
185 let mut g_store = store.lock().expect("poisoned lock");
186 for (_, s) in g_store.drain() {
187 task.spawn(s.shutdown());
188 }
189 }
190 while let Some(res) = task.join_next().await {
191 res.expect("tokio join error")
192 }
193 PluginRepository::global().clear_instances()
194 }
195
196 pub fn global_store() -> Arc<Mutex<HashMap<String, RunningSgGateway>>> {
197 GLOBAL_STORE.get_or_init(Default::default).clone()
198 }
199 pub fn global_save(gateway_name: impl Into<String>, gateway: RunningSgGateway) {
200 let global_store = Self::global_store();
201 let mut global_store = global_store.lock().expect("poisoned lock");
202 global_store.insert(gateway_name.into(), gateway);
203 }
204
205 pub fn global_remove(gateway_name: impl AsRef<str>) -> Option<RunningSgGateway> {
206 let global_store = Self::global_store();
207 let mut global_store = global_store.lock().expect("poisoned lock");
208 global_store.remove(gateway_name.as_ref())
209 }
210
211 pub fn global_update(gateway_name: impl AsRef<str>, http_routes: BTreeMap<String, crate::SgHttpRoute>) -> Result<(), BoxError> {
212 let gateway_name = gateway_name.as_ref();
213 let service = create_router_service(gateway_name.to_string().into(), http_routes)?;
214 let reloader = {
215 let store = Self::global_store();
216 let global_store = store.lock().expect("poisoned lock");
217 if let Some(gw) = global_store.get(gateway_name) {
218 gw.reloader.clone()
219 } else {
220 warn!("no such gateway in global repository: {gateway_name}");
221 return Ok(());
222 }
223 };
224 reloader.reload(service);
225 Ok(())
226 }
227 #[instrument(fields(gateway=%config_item.gateway.name), skip_all, err)]
229 pub fn create(config_item: ConfigItem, cancel_token: CancellationToken) -> Result<Self, BoxError> {
230 #[allow(unused_mut)]
231 #[cfg(feature = "cache")]
233 {
234 if let Some(url) = &config_item.gateway.parameters.redis_url {
235 let url: Arc<str> = url.clone().into();
236 tracing::trace!("Initialize cache client...url:{url}");
240 spacegate_ext_redis::RedisClientRepo::global().add(&config_item.gateway.name, url.as_ref());
241 }
242 }
243 tracing::info!("[SG.Server] start gateway");
244 let reloader = <Reloader<HttpRouterService>>::default();
245 let gateway = config_item.gateway.clone();
246 let service = create_service(config_item, reloader.clone())?;
247 if gateway.listeners.is_empty() {
248 error!("[SG.Server] Missing Listeners");
249 }
250
251 let gateway_name: Arc<str> = Arc::from(gateway.name.to_string());
252 let mut listens: Vec<SgListen> = Vec::new();
253 for listener in &gateway.listeners {
254 let ip = listener.ip.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
255 let addr = SocketAddr::new(ip, listener.port);
256 let mut listen = SgListen::new(addr, cancel_token.child_token());
257 if let SgProtocolConfig::Https { ref tls } = listener.protocol {
258 tracing::debug!("[SG.Server] Tls is init...mode:{:?}", tls.mode);
259 if SgTlsMode::Terminate == tls.mode {
260 {
261 let certs = rustls_pemfile::certs(&mut tls.cert.as_bytes()).filter_map(Result::ok).collect::<Vec<_>>();
262 let mut tls_key = tls.key.as_bytes();
263 let mut keys = rustls_pemfile::read_all(&mut tls_key).filter_map(Result::ok);
264
265 let key = keys.find_map(|key| {
266 debug!("key item: {:?}", key);
267 match key {
268 rustls_pemfile::Item::Pkcs1Key(k) => Some(PrivateKeyDer::Pkcs1(k)),
269 rustls_pemfile::Item::Pkcs8Key(k) => Some(PrivateKeyDer::Pkcs8(k)),
270 rustls_pemfile::Item::Sec1Key(k) => Some(PrivateKeyDer::Sec1(k)),
271 rest => {
272 warn!("Unsupported key type: {:?}", rest);
273 None
274 }
275 }
276 });
277 if let Some(key) = key {
278 info!("[SG.Server] using cert key {key:?}");
279 let _ = rustls::crypto::ring::default_provider().install_default();
280 let builder = rustls::ServerConfig::builder().with_no_client_auth();
281 let mut tls_server_cfg = if let Some(ref host_name) = listener.hostname {
282 info!("Using SNI resolver");
283 let mut resolver = rustls::server::ResolvesServerCertUsingSni::new();
284 let provider = rustls::crypto::CryptoProvider::get_default().expect("should installed");
285 let signed_key = provider.key_provider.load_private_key(key)?;
286 let ck = rustls::sign::CertifiedKey::new(certs, signed_key);
287 resolver.add(host_name, ck)?;
288 builder.with_cert_resolver(Arc::new(resolver))
289 } else {
290 info!("Using single cert");
291 builder.with_single_cert(certs, key)?
292 };
293 tls_server_cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()];
294 tls_server_cfg.ignore_client_order = true;
295 tls_server_cfg.enable_secret_extraction = true;
296 listen.add_service(service.clone().https(tls_server_cfg))
297 } else {
298 error!("[SG.Server] Can not found a valid Tls private key");
299 }
300 };
301 }
302 } else {
303 listen.add_service(service.clone().http());
304 }
305 listens.push(listen)
306 }
307
308 let cancel_task = cancel_token.clone().cancelled_owned();
310 let handle = {
311 let gateway_name = gateway_name.clone();
312 tokio::task::spawn(async move {
313 let mut join_set = tokio::task::JoinSet::new();
314 for listen in listens {
315 join_set.spawn(async move {
316 let id = listen.listener_id.clone();
317 if let Err(e) = listen.listen().await {
318 tracing::error!("[Sg.Server] listen error: {e}")
319 }
320 tracing::info!("[Sg.Server] listener[{id}] quit listening")
321 });
322 }
323 tracing::info!(gateway = gateway_name.as_ref(), "[Sg.Server] start all listeners");
324 cancel_task.await;
325 while let Some(result) = join_set.join_next().await {
326 if let Err(_e) = result {}
327 }
328 tracing::info!(gateway = gateway_name.as_ref(), "[Sg.Server] cancelled");
329 })
330 };
331 tracing::info!("[SG.Server] start finished");
332 Ok(RunningSgGateway {
333 gateway_name: gateway_name.clone(),
334 token: cancel_token,
335 handle,
336 shutdown_timeout: Duration::from_secs(10),
337 reloader,
338 })
339 }
340
341 pub async fn shutdown(self) {
343 self.token.cancel();
344 #[cfg(feature = "cache")]
345 {
346 let name = self.gateway_name.clone();
347 tracing::trace!("[SG.Cache] Remove cache client...");
348 spacegate_ext_redis::global_repo().remove(name.as_ref());
349 }
350 match timeout(self.shutdown_timeout, self.handle).await {
351 Ok(_) => {}
352 Err(e) => {
353 tracing::warn!("[SG.Server] Wait shutdown timeout:{e}");
354 }
355 };
356 tracing::info!("[SG.Server] Gateway shutdown");
357 }
358}