tame_oauth/gcp/
service_account.rs

1use std::convert::TryInto;
2
3use super::{
4    jwt::{self, Algorithm, Header, Key},
5    TokenResponse,
6};
7use crate::{
8    error::{self, Error},
9    id_token::{
10        AccessTokenRequest, AccessTokenResponse, IdTokenOrRequest, IdTokenProvider, IdTokenRequest,
11        IdTokenResponse,
12    },
13    token::{RequestReason, Token, TokenOrRequest, TokenProvider},
14    token_cache::CachedTokenProvider,
15    IdToken,
16};
17
18const GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:jwt-bearer";
19
20/// Minimal parts needed from a GCP service account key for token acquisition
21#[derive(serde::Deserialize, Debug, Clone)]
22pub struct ServiceAccountInfo {
23    /// The private key we use to sign
24    pub private_key: String,
25    /// The unique id used as the issuer of the JWT claim
26    pub client_email: String,
27    /// The URI we send the token requests to, eg <https://oauth2.googleapis.com/token>
28    pub token_uri: String,
29}
30
31#[derive(serde::Deserialize, Debug)]
32struct IdTokenResponseBody {
33    /// The actual token
34    token: String,
35}
36
37impl ServiceAccountInfo {
38    /// Deserializes service account from a byte slice. This data is typically
39    /// acquired by reading a service account JSON file from disk
40    pub fn deserialize<T>(key_data: T) -> Result<Self, Error>
41    where
42        T: AsRef<[u8]>,
43    {
44        let slice = key_data.as_ref();
45
46        let account_info: Self = serde_json::from_slice(slice)?;
47        Ok(account_info)
48    }
49}
50
51/// A token provider for a GCP service account.
52/// Caches tokens internally.
53pub type ServiceAccountProvider = CachedTokenProvider<ServiceAccountProviderInner>;
54impl ServiceAccountProvider {
55    pub fn new(info: ServiceAccountInfo) -> Result<Self, Error> {
56        Ok(CachedTokenProvider::wrap(ServiceAccountProviderInner::new(
57            info,
58        )?))
59    }
60
61    /// Gets the [`ServiceAccountInfo`] this was created for
62    pub fn get_account_info(&self) -> &ServiceAccountInfo {
63        &self.inner().info
64    }
65}
66
67/// A token provider for a GCP service account. Should not be used directly as it is not cached. Use `ServiceAccountProvider` instead.
68pub struct ServiceAccountProviderInner {
69    info: ServiceAccountInfo,
70    priv_key: Vec<u8>,
71}
72
73impl std::fmt::Debug for ServiceAccountProviderInner {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        f.debug_struct("ServiceAccountProviderInner")
76            .finish_non_exhaustive()
77    }
78}
79
80impl ServiceAccountProviderInner {
81    /// Creates a new `ServiceAccountAccess` given the provided service
82    /// account info. This can fail if the private key is encoded incorrectly.
83    pub fn new(info: ServiceAccountInfo) -> Result<Self, Error> {
84        let key_string = info
85            .private_key
86            .split("-----")
87            .nth(2)
88            .ok_or(Error::InvalidKeyFormat)?;
89
90        // Strip out all of the newlines
91        let key_string = key_string.split_whitespace().fold(
92            String::with_capacity(key_string.len()),
93            |mut s, line| {
94                s.push_str(line);
95                s
96            },
97        );
98
99        let key_bytes = data_encoding::BASE64.decode(key_string.as_bytes())?;
100
101        Ok(Self {
102            info,
103            priv_key: key_bytes,
104        })
105    }
106
107    /// Gets the [`ServiceAccountInfo`] this was created for
108    pub fn get_account_info(&self) -> &ServiceAccountInfo {
109        &self.info
110    }
111
112    fn prepare_access_token_request<'a, S, I, T>(
113        &self,
114        subject: Option<T>,
115        scopes: I,
116    ) -> Result<AccessTokenRequest, Error>
117    where
118        S: AsRef<str> + 'a,
119        I: IntoIterator<Item = &'a S>,
120        T: Into<String>,
121    {
122        let scopes = scopes
123            .into_iter()
124            .map(|s| s.as_ref())
125            .collect::<Vec<_>>()
126            .join(" ");
127
128        let issued_at = std::time::SystemTime::now()
129            .duration_since(std::time::SystemTime::UNIX_EPOCH)?
130            .as_secs() as i64;
131
132        let claims = jwt::Claims {
133            issuer: self.info.client_email.clone(),
134            scope: scopes,
135            audience: self.info.token_uri.clone(),
136            expiration: issued_at + 3600 - 5, // Give us some wiggle room near the hour mark
137            issued_at,
138            subject: subject.map(|s| s.into()),
139        };
140
141        let assertion = jwt::encode(
142            &Header::new(Algorithm::RS256),
143            &claims,
144            Key::Pkcs8(&self.priv_key),
145        )?;
146
147        let body = url::form_urlencoded::Serializer::new(String::new())
148            .append_pair("grant_type", GRANT_TYPE)
149            .append_pair("assertion", &assertion)
150            .finish();
151
152        let body = Vec::from(body);
153
154        let request = http::Request::builder()
155            .method("POST")
156            .uri(&self.info.token_uri)
157            .header(
158                http::header::CONTENT_TYPE,
159                "application/x-www-form-urlencoded",
160            )
161            .header(http::header::CONTENT_LENGTH, body.len())
162            .body(body)?;
163
164        Ok(request)
165    }
166}
167
168impl TokenProvider for ServiceAccountProviderInner {
169    /// Like [`ServiceAccountProviderInner::get_token`], but allows the JWT "subject"
170    /// to be passed in.
171    fn get_token_with_subject<'a, S, I, T>(
172        &self,
173        subject: Option<T>,
174        scopes: I,
175    ) -> Result<TokenOrRequest, Error>
176    where
177        S: AsRef<str> + 'a,
178        I: IntoIterator<Item = &'a S>,
179        T: Into<String>,
180    {
181        let request = self.prepare_access_token_request(subject, scopes)?;
182        Ok(TokenOrRequest::Request {
183            reason: RequestReason::ParametersChanged,
184            request,
185            scope_hash: 0,
186        })
187    }
188
189    /// Handle responses from the token URI request we generated in
190    /// `get_token`. This method deserializes the response and stores
191    /// the token in a local cache, so that future lookups for the
192    /// same scopes don't require new http requests.
193    fn parse_token_response<S>(
194        &self,
195        _hash: u64,
196        response: http::Response<S>,
197    ) -> Result<Token, Error>
198    where
199        S: AsRef<[u8]>,
200    {
201        let (parts, body) = response.into_parts();
202
203        if !parts.status.is_success() {
204            let body_bytes = body.as_ref();
205
206            if parts
207                .headers
208                .get(http::header::CONTENT_TYPE)
209                .and_then(|ct| ct.to_str().ok())
210                == Some("application/json; charset=utf-8")
211            {
212                if let Ok(auth_error) = serde_json::from_slice::<error::AuthError>(body_bytes) {
213                    return Err(Error::Auth(auth_error));
214                }
215            }
216
217            return Err(Error::HttpStatus(parts.status));
218        }
219
220        let token_res: TokenResponse = serde_json::from_slice(body.as_ref())?;
221        let token: Token = token_res.into();
222
223        Ok(token)
224    }
225}
226
227impl IdTokenProvider for ServiceAccountProviderInner {
228    fn get_id_token(&self, _audience: &str) -> Result<IdTokenOrRequest, Error> {
229        let request = self
230            .prepare_access_token_request(None::<&str>, &["https://www.googleapis.com/auth/iam"])?;
231
232        Ok(IdTokenOrRequest::AccessTokenRequest {
233            request,
234            reason: RequestReason::ParametersChanged,
235            audience_hash: 0,
236        })
237    }
238
239    fn get_id_token_with_access_token<S>(
240        &self,
241        audience: &str,
242        response: AccessTokenResponse<S>,
243    ) -> Result<IdTokenRequest, Error>
244    where
245        S: AsRef<[u8]>,
246    {
247        let token = self.parse_token_response(0, response)?;
248
249        let sa_email = self.info.client_email.clone();
250        // See https://cloud.google.com/iam/docs/creating-short-lived-service-account-credentials#sa-credentials-oidc
251        // for details on what it is we're doing
252        let json_body = serde_json::to_vec(&serde_json::json!({
253            "audience": audience,
254            "includeEmail": true,
255        }))?;
256
257        let token_header_value: http::HeaderValue = token.try_into()?;
258
259        let request = http::Request::builder()
260            .method("POST")
261            .uri(format!("https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/{}:generateIdToken", sa_email))
262            .header(
263                http::header::CONTENT_TYPE,
264                "application/json; charset=utf-8",
265            )
266            .header(http::header::CONTENT_LENGTH, json_body.len())
267            .header(http::header::AUTHORIZATION, token_header_value)
268            .body(json_body)?;
269
270        Ok(request)
271    }
272
273    fn parse_id_token_response<S>(
274        &self,
275        _hash: u64,
276        response: IdTokenResponse<S>,
277    ) -> Result<IdToken, Error>
278    where
279        S: AsRef<[u8]>,
280    {
281        let (parts, body) = response.into_parts();
282
283        if !parts.status.is_success() {
284            let body_bytes = body.as_ref();
285
286            if parts
287                .headers
288                .get(http::header::CONTENT_TYPE)
289                .and_then(|ct| ct.to_str().ok())
290                == Some("application/json; charset=utf-8")
291            {
292                if let Ok(auth_error) = serde_json::from_slice::<error::AuthError>(body_bytes) {
293                    return Err(Error::Auth(auth_error));
294                }
295            }
296
297            return Err(Error::HttpStatus(parts.status));
298        }
299
300        let token_res: IdTokenResponseBody = serde_json::from_slice(body.as_ref())?;
301        let token = IdToken::new(token_res.token)?;
302
303        Ok(token)
304    }
305}