witchcraft_server/
witchcraft.rs

1// Copyright 2022 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use crate::blocking::conjure::ConjureBlockingEndpoint;
15use crate::blocking::pool::ThreadPool;
16use crate::debug::DiagnosticRegistry;
17use crate::endpoint::conjure::ConjureEndpoint;
18use crate::endpoint::extended_path::ExtendedPathEndpoint;
19use crate::endpoint::WitchcraftEndpoint;
20use crate::health::HealthCheckRegistry;
21use crate::readiness::ReadinessCheckRegistry;
22use crate::shutdown_hooks::ShutdownHooks;
23use crate::{blocking, RequestBody, ResponseWriter};
24use conjure_http::server::{AsyncService, BoxAsyncEndpoint, ConjureRuntime, Endpoint, Service};
25use conjure_runtime::ClientFactory;
26use futures_util::Future;
27use std::sync::Arc;
28use tokio::runtime::Handle;
29use witchcraft_metrics::MetricRegistry;
30use witchcraft_server_config::install::InstallConfig;
31
32/// The Witchcraft server context.
33pub struct Witchcraft {
34    pub(crate) metrics: Arc<MetricRegistry>,
35    pub(crate) health_checks: Arc<HealthCheckRegistry>,
36    pub(crate) readiness_checks: Arc<ReadinessCheckRegistry>,
37    pub(crate) diagnostics: Arc<DiagnosticRegistry>,
38    pub(crate) client_factory: ClientFactory,
39    pub(crate) handle: Handle,
40    pub(crate) install_config: InstallConfig,
41    pub(crate) thread_pool: Option<Arc<ThreadPool>>,
42    pub(crate) endpoints: Vec<Box<dyn WitchcraftEndpoint + Sync + Send>>,
43    pub(crate) shutdown_hooks: ShutdownHooks,
44    pub(crate) conjure_runtime: Arc<ConjureRuntime>,
45}
46
47impl Witchcraft {
48    /// Returns a reference to the server's metric registry.
49    #[inline]
50    pub fn metrics(&self) -> &Arc<MetricRegistry> {
51        &self.metrics
52    }
53
54    /// Returns a reference to the server's health check registry.
55    #[inline]
56    pub fn health_checks(&self) -> &Arc<HealthCheckRegistry> {
57        &self.health_checks
58    }
59
60    /// Returns a reference to the server's readiness check registry.
61    #[inline]
62    pub fn readiness_checks(&self) -> &Arc<ReadinessCheckRegistry> {
63        &self.readiness_checks
64    }
65
66    /// Returns a reference to the server's HTTP client factory.
67    #[inline]
68    pub fn client_factory(&self) -> &ClientFactory {
69        &self.client_factory
70    }
71
72    /// Returns a reference to the server's diagnostics registry.
73    #[inline]
74    pub fn diagnostics(&self) -> &Arc<DiagnosticRegistry> {
75        &self.diagnostics
76    }
77
78    /// Returns a reference to a handle to the server's Tokio runtime.
79    #[inline]
80    pub fn handle(&self) -> &Handle {
81        &self.handle
82    }
83
84    /// Installs an async service at the server's root.
85    pub fn app<T>(&mut self, service: T)
86    where
87        T: AsyncService<RequestBody, ResponseWriter>,
88    {
89        self.endpoints(None, service.endpoints(&self.conjure_runtime), true)
90    }
91
92    /// Installs an async service under the server's `/api` prefix.
93    pub fn api<T>(&mut self, service: T)
94    where
95        T: AsyncService<RequestBody, ResponseWriter>,
96    {
97        self.endpoints(Some("/api"), service.endpoints(&self.conjure_runtime), true)
98    }
99
100    pub(crate) fn endpoints(
101        &mut self,
102        prefix: Option<&str>,
103        endpoints: Vec<BoxAsyncEndpoint<'static, RequestBody, ResponseWriter>>,
104        track_metrics: bool,
105    ) {
106        let metrics = if track_metrics {
107            Some(&*self.metrics)
108        } else {
109            None
110        };
111
112        self.endpoints.extend(
113            endpoints
114                .into_iter()
115                .map(|e| Box::new(ConjureEndpoint::new(metrics, e)))
116                .map(|e| extend_path(e, self.install_config.context_path(), prefix)),
117        )
118    }
119
120    /// Installs a blocking service at the server's root.
121    pub fn blocking_app<T>(&mut self, service: T)
122    where
123        T: Service<blocking::RequestBody, blocking::ResponseWriter>,
124    {
125        self.blocking_endpoints(None, service.endpoints(&self.conjure_runtime))
126    }
127
128    /// Installs a blocking service under the server's `/api` prefix.
129    pub fn blocking_api<T>(&mut self, service: T)
130    where
131        T: Service<blocking::RequestBody, blocking::ResponseWriter>,
132    {
133        self.blocking_endpoints(Some("/api"), service.endpoints(&self.conjure_runtime))
134    }
135
136    fn blocking_endpoints(
137        &mut self,
138        prefix: Option<&str>,
139        endpoints: Vec<
140            Box<dyn Endpoint<blocking::RequestBody, blocking::ResponseWriter> + Sync + Send>,
141        >,
142    ) {
143        let thread_pool = self
144            .thread_pool
145            .get_or_insert_with(|| Arc::new(ThreadPool::new(&self.install_config, &self.metrics)));
146
147        self.endpoints.extend(
148            endpoints
149                .into_iter()
150                .map(|e| Box::new(ConjureBlockingEndpoint::new(&self.metrics, thread_pool, e)))
151                .map(|e| extend_path(e, self.install_config.context_path(), prefix)),
152        )
153    }
154
155    /// Adds a future that will be run when the server begins its shutdown process.
156    ///
157    /// The server will not shut down until the future completes or the configured shutdown timeout elapses.
158    pub fn on_shutdown<F>(&mut self, future: F)
159    where
160        F: Future<Output = ()> + 'static + Send,
161    {
162        self.shutdown_hooks.push(future)
163    }
164}
165
166fn extend_path(
167    endpoint: Box<dyn WitchcraftEndpoint + Sync + Send>,
168    context_path: &str,
169    prefix: Option<&str>,
170) -> Box<dyn WitchcraftEndpoint + Sync + Send> {
171    let context_path = if context_path == "/" {
172        ""
173    } else {
174        context_path
175    };
176    let prefix = format!("{context_path}{}", prefix.unwrap_or(""));
177
178    if prefix.is_empty() {
179        endpoint
180    } else {
181        Box::new(ExtendedPathEndpoint::new(endpoint, &prefix))
182    }
183}