1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;

use async_broadcast::Receiver;
use futures_util::future::try_join_all;
use hyper::header::{HeaderName, HeaderValue};
use proto::consumer_api_server::ConsumerApiServer;
use proto::index_api_server::IndexApiServer;
use proto::reflection_api_server::ReflectionApiServer;
use proto::search_api_server::SearchApiServer;
use summa_core::configs::ConfigProxy;
use summa_core::utils::random::generate_request_id;
use summa_proto::proto;
use tokio_stream::wrappers::TcpListenerStream;
use tonic::codec::CompressionEncoding;
use tonic::transport::Server;
use tonic_web::GrpcWebLayer;
use tower::ServiceBuilder;
use tower_http::classify::GrpcFailureClass;
use tower_http::set_header::SetRequestHeaderLayer;
use tower_http::trace::TraceLayer;
use tracing::{info, info_span, instrument, warn, Instrument, Span};

use crate::apis::consumer::ConsumerApiImpl;
use crate::apis::index::IndexApiImpl;
use crate::apis::reflection::ReflectionApiImpl;
use crate::apis::search::SearchApiImpl;
use crate::errors::SummaServerResult;
use crate::services::Index;
use crate::utils::thread_handler::ControlMessage;

/// GRPC server exposing [API](crate::apis)
pub struct Api {
    server_config_holder: Arc<dyn ConfigProxy<crate::configs::server::Config>>,
    index_service: Index,
}

impl Api {
    fn on_request(request: &hyper::Request<hyper::Body>, _span: &Span) {
        info!(path = ?request.uri().path());
    }
    fn on_response<T>(response: &hyper::Response<T>, duration: Duration, _span: &Span) {
        info!(duration = ?duration, status = ?response.status().as_u16());
    }
    fn on_failure(error: GrpcFailureClass, duration: Duration, _span: &Span) {
        warn!(error = ?error, duration = ?duration);
    }

    fn set_span(request: &hyper::Request<hyper::Body>) -> Span {
        info_span!(
            "request",
            request_id = ?request.headers().get("request-id").expect("request-id must be set"),
            session_id = ?request.headers().get("session-id").expect("session-id must be set"),
        )
    }

    fn set_listener(endpoint: &str) -> SummaServerResult<tokio::net::TcpListener> {
        let std_listener = std::net::TcpListener::bind(endpoint)?;
        std_listener.set_nonblocking(true)?;
        Ok(tokio::net::TcpListener::from_std(std_listener)?)
    }

    /// New GRPC server
    pub fn new(server_config_holder: &Arc<dyn ConfigProxy<crate::configs::server::Config>>, index_service: &Index) -> SummaServerResult<Api> {
        Ok(Api {
            server_config_holder: server_config_holder.clone(),
            index_service: index_service.clone(),
        })
    }

    /// Starts all nested services and start serving requests
    #[instrument("lifecycle", skip_all)]
    pub async fn prepare_serving_future(&self, terminator: Receiver<ControlMessage>) -> SummaServerResult<impl Future<Output = SummaServerResult<()>>> {
        let mut futures: Vec<Box<dyn Future<Output = SummaServerResult<()>> + Send>> = vec![];
        let index_service = self.index_service.clone();
        let consumer_api = ConsumerApiImpl::new(&index_service)?;
        let index_api = IndexApiImpl::new(&self.server_config_holder, &index_service)?;
        let reflection_api = ReflectionApiImpl::new(&index_service)?;
        let search_api = SearchApiImpl::new(&index_service)?;
        let grpc_reflection_service = tonic_reflection::server::Builder::configure()
            .include_reflection_service(false)
            .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
            .build()
            .expect("cannot build grpc server");
        let api_config = self.server_config_holder.read().await.get().api.clone();

        let layer = ServiceBuilder::new()
            .layer(SetRequestHeaderLayer::if_not_present(HeaderName::from_static("request-id"), |_: &_| {
                Some(HeaderValue::from_str(&generate_request_id()).expect("invalid generated request id"))
            }))
            .layer(SetRequestHeaderLayer::if_not_present(HeaderName::from_static("session-id"), |_: &_| {
                Some(HeaderValue::from_str(&generate_request_id()).expect("invalid generated session id"))
            }))
            .concurrency_limit(api_config.concurrency_limit)
            .buffer(api_config.buffer)
            .layer(
                TraceLayer::new_for_grpc()
                    .make_span_with(Api::set_span)
                    .on_request(Api::on_request)
                    .on_response(Api::on_response)
                    .on_failure(Api::on_failure),
            )
            .into_inner();

        let consumer_service = ConsumerApiServer::new(consumer_api);
        let mut index_service = IndexApiServer::new(index_api)
            .accept_compressed(CompressionEncoding::Gzip)
            .send_compressed(CompressionEncoding::Gzip);
        if let Some(max_from_size_bytes) = api_config.max_frame_size_bytes {
            index_service = index_service
                .max_decoding_message_size(max_from_size_bytes as usize)
                .max_encoding_message_size(max_from_size_bytes as usize);
        }
        let reflection_service = ReflectionApiServer::new(reflection_api);
        let mut search_service = SearchApiServer::new(search_api)
            .accept_compressed(CompressionEncoding::Gzip)
            .send_compressed(CompressionEncoding::Gzip);
        if let Some(max_from_size_bytes) = api_config.max_frame_size_bytes {
            search_service = search_service
                .max_decoding_message_size(max_from_size_bytes as usize)
                .max_encoding_message_size(max_from_size_bytes as usize);
        }

        let grpc_router = Server::builder()
            .layer(layer)
            .max_frame_size(api_config.max_frame_size_bytes.map(|x| x / 256))
            .add_service(grpc_reflection_service)
            .add_service(consumer_service.clone())
            .add_service(index_service.clone())
            .add_service(reflection_service.clone())
            .add_service(search_service.clone());
        let grpc_listener = Api::set_listener(&api_config.grpc_endpoint)?;
        let mut grpc_terminator = terminator.clone();
        futures.push(Box::new(async move {
            grpc_router
                .serve_with_incoming_shutdown(TcpListenerStream::new(grpc_listener), async move {
                    info!(action = "binded", endpoint = ?api_config.grpc_endpoint);
                    let signal_result = grpc_terminator.recv().await;
                    info!(action = "sigterm_received", received = ?signal_result);
                })
                .instrument(info_span!(parent: None, "lifecycle", mode = "grpc"))
                .await?;
            info!(action = "terminated");
            Ok(())
        }));

        if let Some(http_endpoint) = api_config.http_endpoint {
            let http_router = Server::builder()
                .accept_http1(true)
                .layer(GrpcWebLayer::new())
                .add_service(consumer_service)
                .add_service(index_service)
                .add_service(reflection_service)
                .add_service(search_service);
            let http_listener = Api::set_listener(&http_endpoint)?;
            let mut http_terminator = terminator.clone();
            futures.push(Box::new(async move {
                http_router
                    .serve_with_incoming_shutdown(TcpListenerStream::new(http_listener), async move {
                        info!(action = "binded", endpoint = ?http_endpoint);
                        let signal_result = http_terminator.recv().await;
                        info!(action = "sigterm_received", received = ?signal_result);
                    })
                    .instrument(info_span!(parent: None, "lifecycle", mode = "http"))
                    .await?;
                info!(action = "terminated");
                Ok(())
            }));
        }
        Ok(async move {
            try_join_all(futures.into_iter().map(Box::into_pin)).await?;
            Ok(())
        })
    }
}