Skip to main content

scion_sdk_reqwest_connect_rpc/
client.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Connect-RPC client library using reqwest.
15
16use std::{borrow::Cow, sync::Arc, time::Duration};
17
18use anyhow::Context as _;
19use bytes::Bytes;
20use reqwest::header::{self, HeaderMap, HeaderValue};
21use thiserror::Error;
22use tracing::Instrument;
23
24use crate::{
25    error::CrpcError,
26    token_source::{TokenSource, TokenSourceError},
27};
28
29/// Connect RPC client error.
30#[derive(Debug, Error)]
31pub enum CrpcClientError {
32    /// Error that occurs when there is a connection issue.
33    #[error("connection error {context}: {source:#?}")]
34    ConnectionError {
35        /// Additional context about the connection error.
36        context: Cow<'static, str>,
37        /// The underlying source error.
38        source: Box<dyn std::error::Error + Send + Sync + 'static>,
39    },
40    /// Error returned by the server.
41    #[error("server returned an error: {0:#?}")]
42    CrpcError(CrpcError),
43    /// Error decoding the response body.
44    #[error("failed to decode response body: {context}: {source:#?}")]
45    DecodeError {
46        /// Additional context about the decoding error.
47        context: Cow<'static, str>,
48        /// The underlying source error.
49        source: Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
50        /// The response body, if available.
51        body: Option<Bytes>,
52    },
53    /// Error retrieving a token from the token source.
54    #[error("failed to retrieve token: {0}")]
55    TokenSourceError(#[from] TokenSourceError),
56}
57
58const APPLICATION_PROTO: &str = "application/proto";
59
60/// A Connect-RPC client.
61pub struct CrpcClient {
62    http_client: reqwest::Client,
63    base_url: url::Url,
64    token_source: Option<Arc<dyn TokenSource>>,
65    user_agent: HeaderValue,
66}
67
68impl CrpcClient {
69    /// Creates a new [`CrpcClient`] for the given base URL.
70    pub fn new(base_url: &url::Url) -> anyhow::Result<Self> {
71        let http_client = reqwest::ClientBuilder::new()
72            .timeout(Duration::from_secs(30))
73            .build()
74            .context("error creating HTTP client")?;
75
76        Self::new_with_client(base_url, http_client)
77    }
78
79    /// Creates a new [`CrpcClient`] for the given base URL and explicit [`reqwest::Client`].
80    pub fn new_with_client(
81        base_url: &url::Url,
82        http_client: reqwest::Client,
83    ) -> anyhow::Result<Self> {
84        let user_agent =
85            HeaderValue::from_str(&format!("reqwest-crpc {}", env!("CARGO_PKG_VERSION")))
86                .context("error creating user agent header")?;
87
88        Ok(CrpcClient {
89            http_client,
90            base_url: base_url.clone(),
91            token_source: None,
92            user_agent,
93        })
94    }
95
96    /// Uses given token source for authentication of all following requests.
97    pub fn use_token_source(&mut self, token_source: Arc<dyn TokenSource>) -> &mut Self {
98        self.token_source = Some(token_source);
99        self
100    }
101
102    /// Sets the user agent header for all following requests.
103    pub fn use_user_agent(&mut self, user_agent: &str) -> anyhow::Result<&mut Self> {
104        self.user_agent = HeaderValue::from_str(user_agent)
105            .with_context(|| format!("error creating user agent header from {user_agent}"))?;
106        Ok(self)
107    }
108
109    /// Unary RPC request.
110    pub async fn unary_request<Req, Res>(
111        &self,
112        path: &str,
113        req: Req,
114    ) -> Result<Res, CrpcClientError>
115    where
116        Req: prost::Message + Default,
117        Res: prost::Message + Default,
118    {
119        self.do_unary_request(path, req)
120            .instrument(tracing::info_span!("request", %path, id = rand::random::<u16>()))
121            .await
122    }
123
124    /// Sends a unary request to the endhost API.
125    async fn do_unary_request<Req, Res>(&self, path: &str, req: Req) -> Result<Res, CrpcClientError>
126    where
127        Req: prost::Message + Default,
128        Res: prost::Message + Default,
129    {
130        let url = self.base_url.join(path).map_err(|e| {
131            CrpcClientError::ConnectionError {
132                context: "error joining base URL and path".into(),
133                source: e.into(),
134            }
135        })?;
136
137        let mut headers = HeaderMap::with_capacity(3);
138        headers.insert(
139            header::CONTENT_TYPE,
140            header::HeaderValue::from_static(APPLICATION_PROTO),
141        );
142        headers.insert(header::USER_AGENT, self.user_agent.clone());
143
144        tracing::trace!(?url, ?headers, "Sending crpc unary request");
145
146        if let Some(token_source) = &self.token_source {
147            let token = token_source.get_token().await?;
148            let token_header = header::HeaderValue::from_str(&token_source.format_header(token))
149                .map_err(|e| {
150                    CrpcClientError::TokenSourceError(
151                        format!("error formatting token as header value: {e:?}").into(),
152                    )
153                })?;
154
155            headers.insert(header::AUTHORIZATION, token_header);
156        }
157
158        let body = req.encode_to_vec();
159        let response = self
160            .http_client
161            .post(url)
162            .body(reqwest::Body::from(body))
163            .headers(headers)
164            .send()
165            .await
166            .map_err(|e| {
167                CrpcClientError::ConnectionError {
168                    context: "error sending request".into(),
169                    source: e.into(),
170                }
171            })?;
172
173        tracing::trace!(status=%response.status(), body_len=%response.content_length().unwrap_or(0), "Received crpc unary response");
174
175        let status = response.status();
176        if !status.is_success() {
177            let response_raw = response
178                .text()
179                .await
180                .unwrap_or_else(|_| "<failed to read body>".to_string());
181
182            // Try to parse the body as a CrpcError, otherwise create a generic one.
183            match serde_json::from_str::<CrpcError>(&response_raw) {
184                Ok(crpc_err) => {
185                    return Err(CrpcClientError::CrpcError(crpc_err));
186                }
187                Err(_) => {
188                    return Err(CrpcClientError::CrpcError(CrpcError::new(
189                        status.into(),
190                        response_raw,
191                    )));
192                }
193            }
194        }
195
196        let body = response.bytes().await.map_err(|e| {
197            CrpcClientError::DecodeError {
198                context: "error reading response body".into(),
199                source: Some(e.into()),
200                body: None,
201            }
202        })?;
203
204        Res::decode(&body[..]).map_err(|e| {
205            CrpcClientError::DecodeError {
206                context: "error decoding response body".into(),
207                source: Some(e.into()),
208                body: Some(body.clone()),
209            }
210        })
211    }
212}