1use crate::client::SendRequest;
15use crate::error::Error;
16use crate::types::TokenInfo;
17
18use std::{io, path::PathBuf};
19
20use base64::Engine as _;
21
22use http::header;
23use http_body_util::BodyExt;
24#[cfg(all(feature = "aws-lc-rs", not(feature = "ring")))]
25use rustls::crypto::aws_lc_rs as crypto_provider;
26#[cfg(feature = "ring")]
27use rustls::crypto::ring as crypto_provider;
28use rustls::{self, pki_types::PrivateKeyDer, sign::SigningKey};
29use serde::{Deserialize, Serialize};
30use time::OffsetDateTime;
31use url::form_urlencoded;
32
33const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
34const GOOGLE_RS256_HEAD: &str = r#"{"alg":"RS256","typ":"JWT"}"#;
35
36fn append_base64<T: AsRef<[u8]> + ?Sized>(s: &T, out: &mut String) {
38 base64::engine::general_purpose::URL_SAFE.encode_string(s, out)
39}
40
41fn decode_rsa_key(pem_pkcs8: &str) -> Result<PrivateKeyDer, io::Error> {
43 let private_key = rustls_pemfile::pkcs8_private_keys(&mut pem_pkcs8.as_bytes()).next();
44
45 match private_key {
46 Some(Ok(key)) => Ok(PrivateKeyDer::Pkcs8(key)),
47 None => Err(io::Error::new(
48 io::ErrorKind::InvalidInput,
49 "Not enough private keys in PEM",
50 )),
51 Some(Err(_)) => Err(io::Error::new(
52 io::ErrorKind::InvalidInput,
53 "Error reading key from PEM",
54 )),
55 }
56}
57
58#[derive(Serialize, Deserialize, Debug, Clone)]
65pub struct ServiceAccountKey {
66 #[serde(rename = "type")]
67 pub key_type: Option<String>,
69 pub project_id: Option<String>,
71 pub private_key_id: Option<String>,
73 pub private_key: String,
75 pub client_email: String,
77 pub client_id: Option<String>,
79 pub auth_uri: Option<String>,
81 pub token_uri: String,
83 pub auth_provider_x509_cert_url: Option<String>,
85 pub client_x509_cert_url: Option<String>,
87}
88
89#[derive(Serialize, Debug)]
92struct Claims<'a> {
93 iss: &'a str,
94 aud: &'a str,
95 exp: i64,
96 iat: i64,
97 #[serde(rename = "sub")]
98 subject: Option<&'a str>,
99 scope: String,
100}
101
102impl<'a> Claims<'a> {
103 fn new<T>(key: &'a ServiceAccountKey, scopes: &[T], subject: Option<&'a str>) -> Self
104 where
105 T: AsRef<str>,
106 {
107 let iat = OffsetDateTime::now_utc().unix_timestamp();
108 let expiry = iat + 3600 - 5; let scope = crate::helper::join(scopes, " ");
111 Claims {
112 iss: &key.client_email,
113 aud: &key.token_uri,
114 exp: expiry,
115 iat,
116 subject,
117 scope,
118 }
119 }
120}
121
122pub(crate) struct JWTSigner {
124 signer: Box<dyn rustls::sign::Signer>,
125}
126
127impl JWTSigner {
128 fn new(private_key: &str) -> Result<Self, io::Error> {
129 let key = decode_rsa_key(private_key)?;
130 let signing_key = crypto_provider::sign::RsaSigningKey::new(&key)
131 .map_err(|_| io::Error::new(io::ErrorKind::Other, "Couldn't initialize signer"))?;
132 let signer = signing_key
133 .choose_scheme(&[rustls::SignatureScheme::RSA_PKCS1_SHA256])
134 .ok_or_else(|| {
135 io::Error::new(io::ErrorKind::Other, "Couldn't choose signing scheme")
136 })?;
137 Ok(JWTSigner { signer })
138 }
139
140 fn sign_claims(&self, claims: &Claims) -> Result<String, rustls::Error> {
141 let mut jwt_head = Self::encode_claims(claims);
142 let signature = self.signer.sign(jwt_head.as_bytes())?;
143 jwt_head.push('.');
144 append_base64(&signature, &mut jwt_head);
145 Ok(jwt_head)
146 }
147
148 fn encode_claims(claims: &Claims) -> String {
151 let mut head = String::new();
152 append_base64(GOOGLE_RS256_HEAD, &mut head);
153 head.push('.');
154 append_base64(&serde_json::to_string(&claims).unwrap(), &mut head);
155 head
156 }
157}
158
159pub struct ServiceAccountFlowOpts {
160 pub(crate) key: FlowOptsKey,
161 pub(crate) subject: Option<String>,
162}
163
164pub(crate) enum FlowOptsKey {
166 Path(PathBuf),
168 Key(Box<ServiceAccountKey>),
170}
171
172pub struct ServiceAccountFlow {
174 key: ServiceAccountKey,
175 subject: Option<String>,
176 signer: JWTSigner,
177}
178
179impl ServiceAccountFlow {
180 pub(crate) async fn new(opts: ServiceAccountFlowOpts) -> Result<Self, io::Error> {
181 let key = match opts.key {
182 FlowOptsKey::Path(path) => crate::read_service_account_key(path).await?,
183 FlowOptsKey::Key(key) => *key,
184 };
185
186 let signer = JWTSigner::new(&key.private_key)?;
187 Ok(ServiceAccountFlow {
188 key,
189 subject: opts.subject,
190 signer,
191 })
192 }
193
194 pub(crate) async fn token<T>(
196 &self,
197 hyper_client: &impl SendRequest,
198 scopes: &[T],
199 ) -> Result<TokenInfo, Error>
200 where
201 T: AsRef<str>,
202 {
203 let claims = Claims::new(&self.key, scopes, self.subject.as_deref());
204 let signed = self.signer.sign_claims(&claims).map_err(|_| {
205 Error::LowLevelError(io::Error::new(
206 io::ErrorKind::Other,
207 "unable to sign claims",
208 ))
209 })?;
210 let rqbody = form_urlencoded::Serializer::new(String::new())
211 .extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", signed.as_str())])
212 .finish();
213 let request = http::Request::post(&self.key.token_uri)
214 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
215 .body(rqbody)
216 .unwrap();
217 log::debug!("requesting token from service account: {:?}", request);
218 let (head, body) = hyper_client.request(request).await?.into_parts();
219 let body = body.collect().await?.to_bytes();
220 log::debug!("received response; head: {:?}, body: {:?}", head, body);
221 TokenInfo::from_json(&body)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::helper::read_service_account_key;
229
230 const TEST_PRIVATE_KEY_PATH: &str = "examples/Sanguine-69411a0c0eea.json";
232
233 #[cfg(feature = "hyper-rustls")]
235 #[allow(dead_code)]
237 async fn test_service_account_e2e() {
238 let acc = ServiceAccountFlow::new(ServiceAccountFlowOpts {
239 key: FlowOptsKey::Path(TEST_PRIVATE_KEY_PATH.into()),
240 subject: None,
241 })
242 .await
243 .unwrap();
244 let client = crate::client::HttpClient::new(
245 hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
246 .build(
247 hyper_rustls::HttpsConnectorBuilder::new()
248 .with_provider_and_native_roots(crypto_provider::default_provider())
249 .unwrap()
250 .https_only()
251 .enable_http1()
252 .enable_http2()
253 .build(),
254 ),
255 None,
256 );
257 println!(
258 "{:?}",
259 acc.token(&client, &["https://www.googleapis.com/auth/pubsub"])
260 .await
261 );
262 println!(
263 "{:?}",
264 acc.token(
265 &client,
266 &["https://some.scope/likely-to-hand-out-id-tokens"]
267 )
268 .await
269 );
270 }
271
272 #[tokio::test]
273 async fn test_jwt_initialize_claims() {
274 let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
275 .await
276 .unwrap();
277 let scopes = vec!["scope1", "scope2", "scope3"];
278 let claims = Claims::new(&key, &scopes, None);
279
280 assert_eq!(
281 claims.iss,
282 "oauth2-public-test@sanguine-rhythm-105020.iam.gserviceaccount.com".to_string()
283 );
284 assert_eq!(claims.scope, "scope1 scope2 scope3".to_string());
285 assert_eq!(
286 claims.aud,
287 "https://accounts.google.com/o/oauth2/token".to_string()
288 );
289 assert!(claims.exp > 1000000000);
290 assert!(claims.iat < claims.exp);
291 assert_eq!(claims.exp - claims.iat, 3595);
292 }
293
294 #[tokio::test]
295 async fn test_jwt_sign() {
296 let key = read_service_account_key(TEST_PRIVATE_KEY_PATH)
297 .await
298 .unwrap();
299 let scopes = vec!["scope1", "scope2", "scope3"];
300 let signer = JWTSigner::new(&key.private_key).unwrap();
301 let claims = Claims::new(&key, &scopes, None);
302 let signature = signer.sign_claims(&claims);
303
304 assert!(signature.is_ok());
305
306 let signature = signature.unwrap();
307 assert_eq!(
308 signature.split('.').next().unwrap(),
309 "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
310 );
311 }
312}