spacegate_shell/
server.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    net::SocketAddr,
4    sync::{Arc, Mutex, OnceLock},
5};
6
7use crate::config::{matches_convert::convert_config_to_kernel, plugin_filter_dto::global_batch_mount_plugin, PluginConfig, SgProtocolConfig, SgTlsMode};
8
9use hyper::Version;
10use spacegate_config::{BackendHost, Config, ConfigItem};
11use spacegate_kernel::{
12    helper_layers::reload::Reloader,
13    listener::SgListen,
14    service::http_gateway::{builder::default_gateway_route_fallback, create_http_router, HttpRouterService},
15    ArcHyperService, BoxError,
16};
17use spacegate_plugin::{mount::MountPointIndex, PluginRepository};
18use std::time::Duration;
19use std::vec::Vec;
20use tokio::time::timeout;
21use tracing::{debug, error, info, instrument, warn};
22
23use tokio_rustls::rustls::{self, pki_types::PrivateKeyDer};
24use tokio_util::sync::CancellationToken;
25
26fn collect_http_route(
27    gateway_name: Arc<str>,
28    http_routes: impl IntoIterator<Item = (String, crate::SgHttpRoute)>,
29) -> Result<HashMap<String, spacegate_kernel::service::http_route::HttpRoute>, BoxError> {
30    http_routes
31        .into_iter()
32        .map(|(name, route)| {
33            let route_name: Arc<str> = name.clone().into();
34            let mount_index = MountPointIndex::HttpRoute {
35                gateway: gateway_name.clone(),
36                route: route_name.clone(),
37            };
38            let plugins = route.plugins;
39            let rules = route.rules;
40            let rules = rules
41                .into_iter()
42                .enumerate()
43                .map(|(rule_index, route_rule)| {
44                    let mount_index = MountPointIndex::HttpRouteRule {
45                        rule: rule_index,
46                        gateway: gateway_name.clone(),
47                        route: route_name.clone(),
48                    };
49                    let mut builder = spacegate_kernel::service::http_route::HttpRouteRule::builder();
50                    builder = if let Some(matches) = route_rule.matches {
51                        builder.matches(matches.into_iter().map(convert_config_to_kernel).collect::<Result<Vec<_>, _>>()?)
52                    } else {
53                        builder.match_all()
54                    };
55                    let backends = route_rule
56                        .backends
57                        .into_iter()
58                        .enumerate()
59                        .map(|(backend_index, backend)| {
60                            let mount_index = MountPointIndex::HttpBackend {
61                                backend: backend_index,
62                                rule: rule_index,
63                                gateway: gateway_name.clone(),
64                                route: route_name.clone(),
65                            };
66                            let host = backend.get_host();
67                            let mut builder = spacegate_kernel::service::http_route::HttpBackend::builder();
68                            let plugins = backend.plugins;
69                            #[cfg(feature = "k8s")]
70                            {
71                                use crate::extension::k8s_service::K8sService;
72                                use spacegate_kernel::helper_layers::map_request::{add_extension::add_extension, MapRequestLayer};
73                                use spacegate_kernel::BoxLayer;
74                                if let BackendHost::K8sService(ref data) = backend.host {
75                                    let namespace_ext = K8sService(data.clone().into());
76                                    // need to add to front
77                                    builder = builder.plugin(BoxLayer::new(MapRequestLayer::new(add_extension(namespace_ext, true))))
78                                }
79                            }
80                            builder = builder.host(host);
81                            if let Some(port) = backend.port {
82                                builder = builder.port(port)
83                            }
84                            if let Some(timeout) = backend.timeout_ms.map(|timeout| Duration::from_millis(timeout as u64)) {
85                                builder = builder.timeout(timeout)
86                            }
87                            let mut layer = if let BackendHost::File { path } = backend.host {
88                                builder.file().path(path).build()
89                            } else if let Some(protocol) = backend.protocol {
90                                builder.schema(protocol.to_string()).build()
91                            } else {
92                                builder.build()
93                            };
94                            if backend.downgrade_http2.is_some() {
95                                if let spacegate_kernel::service::http_route::Backend::Http { version, .. } = &mut layer.backend {
96                                    version.replace(Version::HTTP_11);
97                                }
98                            }
99                            global_batch_mount_plugin(plugins, &mut layer, mount_index);
100                            Result::<_, BoxError>::Ok(layer)
101                        })
102                        .collect::<Result<Vec<_>, _>>()?;
103                    builder = builder.backends(backends);
104                    if let Some(timeout) = route_rule.timeout_ms {
105                        builder = builder.timeout(Duration::from_millis(timeout as u64));
106                    }
107                    let mut layer = builder.build();
108                    global_batch_mount_plugin(route_rule.plugins, &mut layer, mount_index);
109                    Result::<_, BoxError>::Ok(layer)
110                })
111                .collect::<Result<Vec<_>, _>>()?;
112            let mut layer =
113                spacegate_kernel::service::http_route::HttpRoute::builder().hostnames(route.hostnames.unwrap_or_default()).rules(rules).priority(route.priority).build();
114            global_batch_mount_plugin(plugins, &mut layer, mount_index);
115            Ok((name, layer))
116        })
117        .collect::<Result<HashMap<String, _>, _>>()
118}
119
120/// Create a gateway service from plugins and http_routes
121pub(crate) fn create_service(item: ConfigItem, reloader: Reloader<HttpRouterService>) -> Result<ArcHyperService, BoxError> {
122    let gateway_name: Arc<str> = item.gateway.name.into();
123    let http_routes = item.routes;
124    let routes = collect_http_route(gateway_name.clone(), http_routes)?;
125    let plugins = item.gateway.plugins.clone();
126    let mut builder = spacegate_kernel::service::http_gateway::Gateway::builder(gateway_name.clone());
127    if let Some(enable) = item.gateway.parameters.enable_x_request_id {
128        builder = builder.x_request_id(enable);
129    }
130    let mut layer = builder.http_routers(routes).http_route_reloader(reloader).build();
131    global_batch_mount_plugin(plugins, &mut layer, MountPointIndex::Gateway { gateway: gateway_name });
132    let service = layer.as_service();
133    Ok(service)
134}
135
136/// create a new sg gateway route, which can be sent to reloader
137pub(crate) fn create_router_service(gateway_name: Arc<str>, http_routes: BTreeMap<String, crate::SgHttpRoute>) -> Result<HttpRouterService, BoxError> {
138    let routes = collect_http_route(gateway_name, http_routes.clone())?;
139    let service = create_http_router(routes.values(), default_gateway_route_fallback());
140    Ok(service)
141}
142
143/// # Gateway
144/// A running spacegate gateway instance
145///
146/// It's created by calling [start](RunningSgGateway::start).
147///
148/// And you can use [shutdown](RunningSgGateway::shutdown) to shutdown it manually.
149///
150/// Though, after it has been dropped, it will shutdown automatically.
151pub struct RunningSgGateway {
152    pub gateway_name: Arc<str>,
153    token: CancellationToken,
154    handle: tokio::task::JoinHandle<()>,
155    pub reloader: Reloader<HttpRouterService>,
156    shutdown_timeout: Duration,
157}
158impl std::fmt::Debug for RunningSgGateway {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("RunningSgGateway").field("shutdown_timeout", &self.shutdown_timeout).finish()
161    }
162}
163
164pub static GLOBAL_STORE: OnceLock<Arc<Mutex<HashMap<String, RunningSgGateway>>>> = OnceLock::new();
165impl RunningSgGateway {
166    pub async fn global_init(config: Config, signal: CancellationToken) {
167        for (id, spec) in config.plugins.into_inner() {
168            if let Err(err) = PluginRepository::global().create_or_update_instance(PluginConfig { id: id.clone(), spec }) {
169                tracing::error!("[SG.Config] fail to init plugin [{id}]: {err}", id = id.to_string());
170            }
171        }
172        for (name, item) in config.gateways {
173            match RunningSgGateway::create(item, signal.child_token()) {
174                Ok(inst) => RunningSgGateway::global_save(name, inst),
175                Err(e) => {
176                    tracing::error!("[SG.Config] fail to init gateway [{name}]: {e}")
177                }
178            }
179        }
180    }
181    pub async fn global_reset() {
182        let store = Self::global_store();
183        let mut task = tokio::task::JoinSet::new();
184        {
185            let mut g_store = store.lock().expect("poisoned lock");
186            for (_, s) in g_store.drain() {
187                task.spawn(s.shutdown());
188            }
189        }
190        while let Some(res) = task.join_next().await {
191            res.expect("tokio join error")
192        }
193        PluginRepository::global().clear_instances()
194    }
195
196    pub fn global_store() -> Arc<Mutex<HashMap<String, RunningSgGateway>>> {
197        GLOBAL_STORE.get_or_init(Default::default).clone()
198    }
199    pub fn global_save(gateway_name: impl Into<String>, gateway: RunningSgGateway) {
200        let global_store = Self::global_store();
201        let mut global_store = global_store.lock().expect("poisoned lock");
202        global_store.insert(gateway_name.into(), gateway);
203    }
204
205    pub fn global_remove(gateway_name: impl AsRef<str>) -> Option<RunningSgGateway> {
206        let global_store = Self::global_store();
207        let mut global_store = global_store.lock().expect("poisoned lock");
208        global_store.remove(gateway_name.as_ref())
209    }
210
211    pub fn global_update(gateway_name: impl AsRef<str>, http_routes: BTreeMap<String, crate::SgHttpRoute>) -> Result<(), BoxError> {
212        let gateway_name = gateway_name.as_ref();
213        let service = create_router_service(gateway_name.to_string().into(), http_routes)?;
214        let reloader = {
215            let store = Self::global_store();
216            let global_store = store.lock().expect("poisoned lock");
217            if let Some(gw) = global_store.get(gateway_name) {
218                gw.reloader.clone()
219            } else {
220                warn!("no such gateway in global repository: {gateway_name}");
221                return Ok(());
222            }
223        };
224        reloader.reload(service);
225        Ok(())
226    }
227    /// Start a gateway from plugins and http_routes
228    #[instrument(fields(gateway=%config_item.gateway.name), skip_all, err)]
229    pub fn create(config_item: ConfigItem, cancel_token: CancellationToken) -> Result<Self, BoxError> {
230        #[allow(unused_mut)]
231        // let mut builder_ext = hyper::http::Extensions::new();
232        #[cfg(feature = "cache")]
233        {
234            if let Some(url) = &config_item.gateway.parameters.redis_url {
235                let url: Arc<str> = url.clone().into();
236                // builder_ext.insert(crate::extension::redis_url::RedisUrl(url.clone()));
237                // builder_ext.insert(spacegate_kernel::extension::GatewayName(config.gateway.name.clone().into()));
238                // Initialize cache instances
239                tracing::trace!("Initialize cache client...url:{url}");
240                spacegate_ext_redis::RedisClientRepo::global().add(&config_item.gateway.name, url.as_ref());
241            }
242        }
243        tracing::info!("[SG.Server] start gateway");
244        let reloader = <Reloader<HttpRouterService>>::default();
245        let gateway = config_item.gateway.clone();
246        let service = create_service(config_item, reloader.clone())?;
247        if gateway.listeners.is_empty() {
248            error!("[SG.Server] Missing Listeners");
249        }
250
251        let gateway_name: Arc<str> = Arc::from(gateway.name.to_string());
252        let mut listens: Vec<SgListen> = Vec::new();
253        for listener in &gateway.listeners {
254            let ip = listener.ip.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
255            let addr = SocketAddr::new(ip, listener.port);
256            let mut listen = SgListen::new(addr, cancel_token.child_token());
257            if let SgProtocolConfig::Https { ref tls } = listener.protocol {
258                tracing::debug!("[SG.Server] Tls is init...mode:{:?}", tls.mode);
259                if SgTlsMode::Terminate == tls.mode {
260                    {
261                        let certs = rustls_pemfile::certs(&mut tls.cert.as_bytes()).filter_map(Result::ok).collect::<Vec<_>>();
262                        let mut tls_key = tls.key.as_bytes();
263                        let mut keys = rustls_pemfile::read_all(&mut tls_key).filter_map(Result::ok);
264
265                        let key = keys.find_map(|key| {
266                            debug!("key item: {:?}", key);
267                            match key {
268                                rustls_pemfile::Item::Pkcs1Key(k) => Some(PrivateKeyDer::Pkcs1(k)),
269                                rustls_pemfile::Item::Pkcs8Key(k) => Some(PrivateKeyDer::Pkcs8(k)),
270                                rustls_pemfile::Item::Sec1Key(k) => Some(PrivateKeyDer::Sec1(k)),
271                                rest => {
272                                    warn!("Unsupported key type: {:?}", rest);
273                                    None
274                                }
275                            }
276                        });
277                        if let Some(key) = key {
278                            info!("[SG.Server] using cert key {key:?}");
279                            let _ = rustls::crypto::ring::default_provider().install_default();
280                            let builder = rustls::ServerConfig::builder().with_no_client_auth();
281                            let mut tls_server_cfg = if let Some(ref host_name) = listener.hostname {
282                                info!("Using SNI resolver");
283                                let mut resolver = rustls::server::ResolvesServerCertUsingSni::new();
284                                let provider = rustls::crypto::CryptoProvider::get_default().expect("should installed");
285                                let signed_key = provider.key_provider.load_private_key(key)?;
286                                let ck = rustls::sign::CertifiedKey::new(certs, signed_key);
287                                resolver.add(host_name, ck)?;
288                                builder.with_cert_resolver(Arc::new(resolver))
289                            } else {
290                                info!("Using single cert");
291                                builder.with_single_cert(certs, key)?
292                            };
293                            tls_server_cfg.alpn_protocols = vec![b"http/1.1".to_vec(), b"h2".to_vec()];
294                            tls_server_cfg.ignore_client_order = true;
295                            tls_server_cfg.enable_secret_extraction = true;
296                            listen.add_service(service.clone().https(tls_server_cfg))
297                        } else {
298                            error!("[SG.Server] Can not found a valid Tls private key");
299                        }
300                    };
301                }
302            } else {
303                listen.add_service(service.clone().http());
304            }
305            listens.push(listen)
306        }
307
308        // let cancel_guard = cancel_token.clone().drop_guard();
309        let cancel_task = cancel_token.clone().cancelled_owned();
310        let handle = {
311            let gateway_name = gateway_name.clone();
312            tokio::task::spawn(async move {
313                let mut join_set = tokio::task::JoinSet::new();
314                for listen in listens {
315                    join_set.spawn(async move {
316                        let id = listen.listener_id.clone();
317                        if let Err(e) = listen.listen().await {
318                            tracing::error!("[Sg.Server] listen error: {e}")
319                        }
320                        tracing::info!("[Sg.Server] listener[{id}] quit listening")
321                    });
322                }
323                tracing::info!(gateway = gateway_name.as_ref(), "[Sg.Server] start all listeners");
324                cancel_task.await;
325                while let Some(result) = join_set.join_next().await {
326                    if let Err(_e) = result {}
327                }
328                tracing::info!(gateway = gateway_name.as_ref(), "[Sg.Server] cancelled");
329            })
330        };
331        tracing::info!("[SG.Server] start finished");
332        Ok(RunningSgGateway {
333            gateway_name: gateway_name.clone(),
334            token: cancel_token,
335            handle,
336            shutdown_timeout: Duration::from_secs(10),
337            reloader,
338        })
339    }
340
341    /// Shutdown this gateway
342    pub async fn shutdown(self) {
343        self.token.cancel();
344        #[cfg(feature = "cache")]
345        {
346            let name = self.gateway_name.clone();
347            tracing::trace!("[SG.Cache] Remove cache client...");
348            spacegate_ext_redis::global_repo().remove(name.as_ref());
349        }
350        match timeout(self.shutdown_timeout, self.handle).await {
351            Ok(_) => {}
352            Err(e) => {
353                tracing::warn!("[SG.Server] Wait shutdown timeout:{e}");
354            }
355        };
356        tracing::info!("[SG.Server] Gateway shutdown");
357    }
358}