viceroy_lib/wiggle_abi/
req_impl.rs

1//! fastly_req` hostcall implementations.
2use std::net::IpAddr;
3
4use super::types::SendErrorDetail;
5use super::SecretStoreError;
6use crate::cache::CacheOverride;
7use crate::config::ClientCertInfo;
8use crate::secret_store::SecretLookup;
9
10use {
11    crate::{
12        config::Backend,
13        error::Error,
14        pushpin::{PushpinRedirectInfo, PushpinRedirectRequestInfo},
15        session::{AsyncItem, PeekableTask, Session, ViceroyRequestMetadata},
16        upstream,
17        wiggle_abi::{
18            fastly_http_downstream::FastlyHttpDownstream,
19            fastly_http_req::FastlyHttpReq,
20            headers::HttpHeaders,
21            types::{
22                BackendConfigOptions, BodyHandle, CacheOverrideTag, ClientCertVerifyResult,
23                ContentEncodings, DynamicBackendConfig, FramingHeadersMode, HttpVersion,
24                InspectInfo, InspectInfoMask, MultiValueCursor, MultiValueCursorResult,
25                PendingRequestHandle, RequestHandle, ResponseHandle,
26            },
27        },
28    },
29    fastly_shared::{INVALID_BODY_HANDLE, INVALID_REQUEST_HANDLE, INVALID_RESPONSE_HANDLE},
30    http::{HeaderValue, Method, Uri},
31    hyper::http::request::Request,
32    wiggle::{GuestMemory, GuestPtr},
33};
34
35#[wiggle::async_trait]
36impl FastlyHttpReq for Session {
37    fn body_downstream_get(
38        &mut self,
39        _memory: &mut GuestMemory<'_>,
40    ) -> Result<(RequestHandle, BodyHandle), Error> {
41        let req_handle = self.downstream_request();
42        let body_handle = self.downstream_request_body();
43        Ok((req_handle, body_handle))
44    }
45
46    fn cache_override_set(
47        &mut self,
48        _memory: &mut GuestMemory<'_>,
49        req_handle: RequestHandle,
50        tag: CacheOverrideTag,
51        ttl: u32,
52        stale_while_revalidate: u32,
53    ) -> Result<(), Error> {
54        let overrides = CacheOverride::from_abi(u32::from(tag), ttl, stale_while_revalidate, None)
55            .ok_or(Error::InvalidArgument)?;
56
57        self.request_parts_mut(req_handle)?
58            .extensions
59            .insert(overrides);
60
61        Ok(())
62    }
63
64    fn cache_override_v2_set(
65        &mut self,
66        memory: &mut GuestMemory<'_>,
67        req_handle: RequestHandle,
68        tag: CacheOverrideTag,
69        ttl: u32,
70        stale_while_revalidate: u32,
71        sk: GuestPtr<[u8]>,
72    ) -> Result<(), Error> {
73        let sk = if sk.len() > 0 {
74            let sk = memory.as_slice(sk)?.ok_or(Error::SharedMemory)?;
75            let sk = HeaderValue::from_bytes(&sk).map_err(|_| Error::InvalidArgument)?;
76            Some(sk)
77        } else {
78            None
79        };
80
81        let overrides = CacheOverride::from_abi(u32::from(tag), ttl, stale_while_revalidate, sk)
82            .ok_or(Error::InvalidArgument)?;
83
84        self.request_parts_mut(req_handle)?
85            .extensions
86            .insert(overrides);
87
88        Ok(())
89    }
90
91    fn downstream_server_ip_addr(
92        &mut self,
93        memory: &mut GuestMemory<'_>,
94        // Must be a 16-byte array:
95        addr_octets_ptr: GuestPtr<u8>,
96    ) -> Result<u32, Error> {
97        FastlyHttpDownstream::downstream_server_ip_addr(
98            self,
99            memory,
100            self.downstream_request(),
101            addr_octets_ptr,
102        )
103    }
104
105    fn downstream_client_ip_addr(
106        &mut self,
107        memory: &mut GuestMemory<'_>,
108        // Must be a 16-byte array:
109        addr_octets_ptr: GuestPtr<u8>,
110    ) -> Result<u32, Error> {
111        FastlyHttpDownstream::downstream_client_ip_addr(
112            self,
113            memory,
114            self.downstream_request(),
115            addr_octets_ptr,
116        )
117    }
118
119    fn downstream_client_h2_fingerprint(
120        &mut self,
121        memory: &mut GuestMemory<'_>,
122        h2fp_out: GuestPtr<u8>,
123        h2fp_max_len: u32,
124        nwritten_out: GuestPtr<u32>,
125    ) -> Result<(), Error> {
126        FastlyHttpDownstream::downstream_client_h2_fingerprint(
127            self,
128            memory,
129            self.downstream_request(),
130            h2fp_out,
131            h2fp_max_len,
132            nwritten_out,
133        )
134    }
135
136    fn downstream_client_request_id(
137        &mut self,
138        memory: &mut GuestMemory<'_>,
139        reqid_out: GuestPtr<u8>,
140        reqid_max_len: u32,
141        nwritten_out: GuestPtr<u32>,
142    ) -> Result<(), Error> {
143        FastlyHttpDownstream::downstream_client_request_id(
144            self,
145            memory,
146            self.downstream_request(),
147            reqid_out,
148            reqid_max_len,
149            nwritten_out,
150        )
151    }
152
153    fn downstream_client_oh_fingerprint(
154        &mut self,
155        memory: &mut GuestMemory<'_>,
156        ohfp_out: GuestPtr<u8>,
157        ohfp_max_len: u32,
158        nwritten_out: GuestPtr<u32>,
159    ) -> Result<(), Error> {
160        FastlyHttpDownstream::downstream_client_oh_fingerprint(
161            self,
162            memory,
163            self.downstream_request(),
164            ohfp_out,
165            ohfp_max_len,
166            nwritten_out,
167        )
168    }
169
170    fn downstream_client_ddos_detected(
171        &mut self,
172        memory: &mut GuestMemory<'_>,
173    ) -> Result<u32, Error> {
174        FastlyHttpDownstream::downstream_client_ddos_detected(
175            self,
176            memory,
177            self.downstream_request(),
178        )
179    }
180
181    fn downstream_tls_cipher_openssl_name(
182        &mut self,
183        memory: &mut GuestMemory<'_>,
184        cipher_out: GuestPtr<u8>,
185        cipher_max_len: u32,
186        nwritten_out: GuestPtr<u32>,
187    ) -> Result<(), Error> {
188        FastlyHttpDownstream::downstream_client_oh_fingerprint(
189            self,
190            memory,
191            self.downstream_request(),
192            cipher_out,
193            cipher_max_len,
194            nwritten_out,
195        )
196    }
197
198    #[allow(unused_variables)] // FIXME ACF 2022-05-03: Remove this directive once implemented.
199    fn upgrade_websocket(
200        &mut self,
201        memory: &mut GuestMemory<'_>,
202        backend_name: GuestPtr<str>,
203    ) -> Result<(), Error> {
204        Err(Error::NotAvailable("WebSocket upgrade"))
205    }
206
207    #[allow(unused_variables)] // FIXME ACF 2022-10-03: Remove this directive once implemented.
208    fn redirect_to_websocket_proxy(
209        &mut self,
210        memory: &mut GuestMemory<'_>,
211        backend_name: GuestPtr<str>,
212    ) -> Result<(), Error> {
213        Err(Error::NotAvailable("Redirect to WebSocket proxy"))
214    }
215
216    #[allow(unused_variables)] // FIXME ACF 2022-10-03: Remove this directive once implemented.
217    fn redirect_to_grip_proxy(
218        &mut self,
219        memory: &mut GuestMemory<'_>,
220        backend_name: GuestPtr<str>,
221    ) -> Result<(), Error> {
222        let backend_name = memory
223            .as_str(backend_name)?
224            .ok_or(Error::SharedMemory)?
225            .to_string();
226        let redirect_info = PushpinRedirectInfo {
227            backend_name,
228            request_info: None,
229        };
230
231        self.redirect_downstream_to_pushpin(redirect_info)?;
232        Ok(())
233    }
234
235    fn redirect_to_websocket_proxy_v2(
236        &mut self,
237        _memory: &mut GuestMemory<'_>,
238        _req_handle: RequestHandle,
239        _backend: GuestPtr<str>,
240    ) -> Result<(), Error> {
241        Err(Error::NotAvailable("Redirect to WebSocket proxy"))
242    }
243
244    fn redirect_to_grip_proxy_v2(
245        &mut self,
246        memory: &mut GuestMemory<'_>,
247        req_handle: RequestHandle,
248        backend_name: GuestPtr<str>,
249    ) -> Result<(), Error> {
250        let backend_name = memory
251            .as_str(backend_name)?
252            .ok_or(Error::SharedMemory)?
253            .to_string();
254        let req = self.request_parts(req_handle)?;
255        let redirect_info = PushpinRedirectInfo {
256            backend_name,
257            request_info: Some(PushpinRedirectRequestInfo::from_parts(req)),
258        };
259
260        self.redirect_downstream_to_pushpin(redirect_info)?;
261        Ok(())
262    }
263
264    fn downstream_tls_protocol(
265        &mut self,
266        memory: &mut GuestMemory<'_>,
267        protocol_out: GuestPtr<u8>,
268        protocol_max_len: u32,
269        nwritten_out: GuestPtr<u32>,
270    ) -> Result<(), Error> {
271        FastlyHttpDownstream::downstream_tls_protocol(
272            self,
273            memory,
274            self.downstream_request(),
275            protocol_out,
276            protocol_max_len,
277            nwritten_out,
278        )
279    }
280
281    fn downstream_tls_client_hello(
282        &mut self,
283        memory: &mut GuestMemory<'_>,
284        chello_out: GuestPtr<u8>,
285        chello_max_len: u32,
286        nwritten_out: GuestPtr<u32>,
287    ) -> Result<(), Error> {
288        FastlyHttpDownstream::downstream_tls_client_hello(
289            self,
290            memory,
291            self.downstream_request(),
292            chello_out,
293            chello_max_len,
294            nwritten_out,
295        )
296    }
297
298    fn downstream_tls_raw_client_certificate(
299        &mut self,
300        memory: &mut GuestMemory<'_>,
301        cert_out: GuestPtr<u8>,
302        cert_max_len: u32,
303        nwritten_out: GuestPtr<u32>,
304    ) -> Result<(), Error> {
305        FastlyHttpDownstream::downstream_tls_raw_client_certificate(
306            self,
307            memory,
308            self.downstream_request(),
309            cert_out,
310            cert_max_len,
311            nwritten_out,
312        )
313    }
314
315    fn downstream_tls_client_cert_verify_result(
316        &mut self,
317        memory: &mut GuestMemory<'_>,
318    ) -> Result<ClientCertVerifyResult, Error> {
319        FastlyHttpDownstream::downstream_tls_client_cert_verify_result(
320            self,
321            memory,
322            self.downstream_request(),
323        )
324    }
325
326    fn downstream_tls_ja3_md5(
327        &mut self,
328        memory: &mut GuestMemory<'_>,
329        ja3_md5_out: GuestPtr<u8>,
330    ) -> Result<u32, Error> {
331        FastlyHttpDownstream::downstream_tls_ja3_md5(
332            self,
333            memory,
334            self.downstream_request(),
335            ja3_md5_out,
336        )
337    }
338
339    fn downstream_tls_ja4(
340        &mut self,
341        memory: &mut GuestMemory<'_>,
342        ja4_out: GuestPtr<u8>,
343        ja4_max_len: u32,
344        nwritten_out: GuestPtr<u32>,
345    ) -> Result<(), Error> {
346        FastlyHttpDownstream::downstream_tls_ja4(
347            self,
348            memory,
349            self.downstream_request(),
350            ja4_out,
351            ja4_max_len,
352            nwritten_out,
353        )
354    }
355
356    fn downstream_compliance_region(
357        &mut self,
358        memory: &mut GuestMemory<'_>,
359        // Must be a 16-byte array:
360        region_out: GuestPtr<u8>,
361        region_max_len: u32,
362        nwritten_out: GuestPtr<u32>,
363    ) -> Result<(), Error> {
364        FastlyHttpDownstream::downstream_compliance_region(
365            self,
366            memory,
367            self.downstream_request(),
368            region_out,
369            region_max_len,
370            nwritten_out,
371        )
372    }
373
374    fn framing_headers_mode_set(
375        &mut self,
376        _memory: &mut GuestMemory<'_>,
377        req_handle: RequestHandle,
378        mode: FramingHeadersMode,
379    ) -> Result<(), Error> {
380        let extensions = &mut self.request_parts_mut(req_handle)?.extensions;
381
382        match extensions.get_mut::<ViceroyRequestMetadata>() {
383            None => {
384                extensions.insert(ViceroyRequestMetadata {
385                    framing_headers_mode: mode,
386                    ..Default::default()
387                });
388            }
389            Some(vrm) => {
390                vrm.framing_headers_mode = mode;
391            }
392        }
393
394        Ok(())
395    }
396
397    fn register_dynamic_backend(
398        &mut self,
399        memory: &mut GuestMemory<'_>,
400        name: GuestPtr<str>,
401        upstream_dynamic: GuestPtr<str>,
402        backend_info_mask: BackendConfigOptions,
403        backend_info: GuestPtr<DynamicBackendConfig>,
404    ) -> Result<(), Error> {
405        let name = {
406            let name_slice = memory.to_vec(name.as_bytes())?;
407            String::from_utf8(name_slice).map_err(|_| Error::InvalidArgument)?
408        };
409        let origin_name = {
410            let origin_name_slice = memory.to_vec(upstream_dynamic.as_bytes())?;
411            String::from_utf8(origin_name_slice).map_err(|_| Error::InvalidArgument)?
412        };
413        let config = memory.read(backend_info)?;
414
415        // If someone set our reserved bit, error. We might need it, and we don't
416        // want anyone it early.
417        if backend_info_mask.contains(BackendConfigOptions::RESERVED) {
418            return Err(Error::InvalidArgument);
419        }
420
421        // If someone has set any bits we don't know about, let's also return false,
422        // as there's either bad data or an API compatibility problem.
423        if backend_info_mask != BackendConfigOptions::from_bits_truncate(backend_info_mask.bits()) {
424            return Err(Error::InvalidArgument);
425        }
426
427        let override_host = if backend_info_mask.contains(BackendConfigOptions::HOST_OVERRIDE) {
428            if config.host_override_len == 0 {
429                return Err(Error::InvalidArgument);
430            }
431
432            if config.host_override_len > 1024 {
433                return Err(Error::InvalidArgument);
434            }
435
436            let byte_slice =
437                memory.to_vec(config.host_override.as_array(config.host_override_len))?;
438
439            let string = String::from_utf8(byte_slice).map_err(|_| Error::InvalidArgument)?;
440
441            Some(HeaderValue::from_str(&string)?)
442        } else {
443            None
444        };
445
446        let scheme = if backend_info_mask.contains(BackendConfigOptions::USE_SSL) {
447            "https"
448        } else {
449            "http"
450        };
451
452        let ca_certs =
453            if (scheme == "https") && backend_info_mask.contains(BackendConfigOptions::CA_CERT) {
454                if config.ca_cert_len == 0 {
455                    return Err(Error::InvalidArgument);
456                }
457
458                if config.ca_cert_len > (64 * 1024) {
459                    return Err(Error::InvalidArgument);
460                }
461
462                let byte_slice = memory
463                    .as_slice(config.ca_cert.as_array(config.ca_cert_len))?
464                    .ok_or(Error::SharedMemory)?;
465                let mut byte_cursor = std::io::Cursor::new(&byte_slice[..]);
466                rustls_pemfile::certs(&mut byte_cursor)?
467                    .drain(..)
468                    .map(rustls::Certificate)
469                    .collect()
470            } else {
471                vec![]
472            };
473
474        let mut cert_host = if backend_info_mask.contains(BackendConfigOptions::CERT_HOSTNAME) {
475            if config.cert_hostname_len == 0 {
476                return Err(Error::InvalidArgument);
477            }
478
479            if config.cert_hostname_len > 1024 {
480                return Err(Error::InvalidArgument);
481            }
482
483            let byte_slice = memory
484                .as_slice(config.cert_hostname.as_array(config.cert_hostname_len))?
485                .ok_or(Error::SharedMemory)?;
486
487            Some(std::str::from_utf8(&byte_slice)?.to_owned())
488        } else {
489            None
490        };
491
492        let use_sni = if backend_info_mask.contains(BackendConfigOptions::SNI_HOSTNAME) {
493            if config.sni_hostname_len == 0 {
494                false
495            } else if config.sni_hostname_len > 1024 {
496                return Err(Error::InvalidArgument);
497            } else {
498                let byte_slice = memory
499                    .as_slice(config.sni_hostname.as_array(config.sni_hostname_len))?
500                    .ok_or(Error::SharedMemory)?;
501                let sni_hostname = std::str::from_utf8(&byte_slice)?;
502                if let Some(cert_host) = &cert_host {
503                    if cert_host != sni_hostname {
504                        // because we're using rustls, we cannot support distinct SNI and cert hostnames
505                        return Err(Error::InvalidArgument);
506                    }
507                } else {
508                    cert_host = Some(sni_hostname.to_owned())
509                }
510
511                true
512            }
513        } else {
514            true
515        };
516
517        let client_cert = if backend_info_mask.contains(BackendConfigOptions::CLIENT_CERT) {
518            let cert_slice = memory
519                .as_slice(
520                    config
521                        .client_certificate
522                        .as_array(config.client_certificate_len),
523                )?
524                .ok_or(Error::SharedMemory)?;
525            let key_lookup =
526                self.secret_lookup(config.client_key)
527                    .ok_or(Error::SecretStoreError(
528                        SecretStoreError::InvalidSecretHandle(config.client_key),
529                    ))?;
530            let key = match &key_lookup {
531                SecretLookup::Standard {
532                    store_name,
533                    secret_name,
534                } => self
535                    .secret_stores()
536                    .get_store(store_name)
537                    .ok_or(Error::SecretStoreError(
538                        SecretStoreError::InvalidSecretHandle(config.client_key),
539                    ))?
540                    .get_secret(secret_name)
541                    .ok_or(Error::SecretStoreError(
542                        SecretStoreError::InvalidSecretHandle(config.client_key),
543                    ))?
544                    .plaintext(),
545
546                SecretLookup::Injected { plaintext } => plaintext,
547            };
548
549            Some(ClientCertInfo::new(&cert_slice, key)?)
550        } else {
551            None
552        };
553
554        let grpc = backend_info_mask.contains(BackendConfigOptions::GRPC);
555
556        let new_backend = Backend {
557            uri: Uri::builder()
558                .scheme(scheme)
559                .authority(origin_name)
560                .path_and_query("/")
561                .build()?,
562            override_host,
563            cert_host,
564            use_sni,
565            grpc,
566            client_cert,
567            ca_certs,
568        };
569
570        if !self.add_backend(&name, new_backend) {
571            return Err(Error::BackendNameRegistryError(name));
572        }
573
574        Ok(())
575    }
576
577    fn new(&mut self, _memory: &mut GuestMemory<'_>) -> Result<RequestHandle, Error> {
578        let (parts, _) = Request::new(()).into_parts();
579        Ok(self.insert_request_parts(parts))
580    }
581
582    fn header_names_get(
583        &mut self,
584        memory: &mut GuestMemory<'_>,
585        req_handle: RequestHandle,
586        buf: GuestPtr<u8>,
587        buf_len: u32,
588        cursor: MultiValueCursor,
589        ending_cursor_out: GuestPtr<MultiValueCursorResult>,
590        nwritten_out: GuestPtr<u32>,
591    ) -> Result<(), Error> {
592        let headers = &self.request_parts(req_handle)?.headers;
593        multi_value_result!(
594            memory,
595            headers.names_get(memory, buf, buf_len, cursor, nwritten_out),
596            ending_cursor_out
597        )
598    }
599
600    fn original_header_names_get(
601        &mut self,
602        memory: &mut GuestMemory<'_>,
603        buf: GuestPtr<u8>,
604        buf_len: u32,
605        cursor: MultiValueCursor,
606        ending_cursor_out: GuestPtr<MultiValueCursorResult>,
607        nwritten_out: GuestPtr<u32>,
608    ) -> Result<(), Error> {
609        FastlyHttpDownstream::downstream_original_header_names(
610            self,
611            memory,
612            self.downstream_request(),
613            buf,
614            buf_len,
615            cursor,
616            ending_cursor_out,
617            nwritten_out,
618        )
619    }
620
621    fn original_header_count(&mut self, memory: &mut GuestMemory<'_>) -> Result<u32, Error> {
622        FastlyHttpDownstream::downstream_original_header_count(
623            self,
624            memory,
625            self.downstream_request(),
626        )
627    }
628
629    fn header_value_get(
630        &mut self,
631        memory: &mut GuestMemory<'_>,
632        req_handle: RequestHandle,
633        name: GuestPtr<[u8]>,
634        value: GuestPtr<u8>,
635        value_max_len: u32,
636        nwritten_out: GuestPtr<u32>,
637    ) -> Result<(), Error> {
638        let headers = &self.request_parts(req_handle)?.headers;
639        headers.value_get(memory, name, value, value_max_len, nwritten_out)
640    }
641
642    fn header_values_get(
643        &mut self,
644        memory: &mut GuestMemory<'_>,
645        req_handle: RequestHandle,
646        name: GuestPtr<[u8]>,
647        buf: GuestPtr<u8>,
648        buf_len: u32,
649        cursor: MultiValueCursor,
650        ending_cursor_out: GuestPtr<MultiValueCursorResult>,
651        nwritten_out: GuestPtr<u32>,
652    ) -> Result<(), Error> {
653        let headers = &self.request_parts(req_handle)?.headers;
654        multi_value_result!(
655            memory,
656            headers.values_get(memory, name, buf, buf_len, cursor, nwritten_out),
657            ending_cursor_out
658        )
659    }
660
661    fn header_values_set(
662        &mut self,
663        memory: &mut GuestMemory<'_>,
664        req_handle: RequestHandle,
665        name: GuestPtr<[u8]>,
666        values: GuestPtr<[u8]>,
667    ) -> Result<(), Error> {
668        let headers = &mut self.request_parts_mut(req_handle)?.headers;
669        headers.values_set(memory, name, values)
670    }
671
672    fn header_insert(
673        &mut self,
674        memory: &mut GuestMemory<'_>,
675        req_handle: RequestHandle,
676        name: GuestPtr<[u8]>,
677        value: GuestPtr<[u8]>,
678    ) -> Result<(), Error> {
679        let headers = &mut self.request_parts_mut(req_handle)?.headers;
680        HttpHeaders::insert(headers, memory, name, value)
681    }
682
683    fn header_append(
684        &mut self,
685        memory: &mut GuestMemory<'_>,
686        req_handle: RequestHandle,
687        name: GuestPtr<[u8]>,
688        value: GuestPtr<[u8]>,
689    ) -> Result<(), Error> {
690        let headers = &mut self.request_parts_mut(req_handle)?.headers;
691        HttpHeaders::append(headers, memory, name, value)
692    }
693
694    fn header_remove(
695        &mut self,
696        memory: &mut GuestMemory<'_>,
697        req_handle: RequestHandle,
698        name: GuestPtr<[u8]>,
699    ) -> Result<(), Error> {
700        let headers = &mut self.request_parts_mut(req_handle)?.headers;
701        HttpHeaders::remove(headers, memory, name)
702    }
703
704    fn method_get(
705        &mut self,
706        memory: &mut GuestMemory<'_>,
707        req_handle: RequestHandle,
708        buf: GuestPtr<u8>,
709        buf_len: u32,
710        nwritten_out: GuestPtr<u32>,
711    ) -> Result<(), Error> {
712        let req = self.request_parts(req_handle)?;
713        let req_method = &req.method;
714        let req_method_bytes = req_method.to_string().into_bytes();
715
716        if req_method_bytes.len() > buf_len as usize {
717            // Write out the number of bytes necessary to fit this method, or zero on overflow to
718            // signal an error condition.
719            memory.write(nwritten_out, req_method_bytes.len().try_into().unwrap_or(0))?;
720            return Err(Error::BufferLengthError {
721                buf: "method",
722                len: "method_max_len",
723            });
724        }
725
726        let req_method_len = u32::try_from(req_method_bytes.len())
727            .expect("smaller than method_max_len means it must fit");
728
729        memory.copy_from_slice(&req_method_bytes, buf.as_array(req_method_len))?;
730        memory.write(nwritten_out, req_method_len)?;
731
732        Ok(())
733    }
734
735    fn method_set(
736        &mut self,
737        memory: &mut GuestMemory<'_>,
738        req_handle: RequestHandle,
739        method: GuestPtr<str>,
740    ) -> Result<(), Error> {
741        let method_ref = &mut self.request_parts_mut(req_handle)?.method;
742        let method_slice = memory
743            .as_slice(method.as_bytes())?
744            .ok_or(Error::SharedMemory)?;
745        *method_ref = Method::from_bytes(method_slice)?;
746
747        Ok(())
748    }
749
750    fn uri_get(
751        &mut self,
752        memory: &mut GuestMemory<'_>,
753        req_handle: RequestHandle,
754        buf: GuestPtr<u8>,
755        buf_len: u32,
756        nwritten_out: GuestPtr<u32>,
757    ) -> Result<(), Error> {
758        let req = self.request_parts(req_handle)?;
759        let req_uri_bytes = req.uri.to_string().into_bytes();
760
761        if req_uri_bytes.len() > buf_len as usize {
762            // Write out the number of bytes necessary to fit this method, or zero on overflow to
763            // signal an error condition.
764            memory.write(nwritten_out, req_uri_bytes.len().try_into().unwrap_or(0))?;
765            return Err(Error::BufferLengthError {
766                buf: "uri",
767                len: "uri_max_len",
768            });
769        }
770        let req_uri_len =
771            u32::try_from(req_uri_bytes.len()).expect("smaller than uri_max_len means it must fit");
772
773        memory.copy_from_slice(&req_uri_bytes, buf.as_array(req_uri_len))?;
774        memory.write(nwritten_out, req_uri_len)?;
775
776        Ok(())
777    }
778
779    fn uri_set(
780        &mut self,
781        memory: &mut GuestMemory<'_>,
782        req_handle: RequestHandle,
783        uri: GuestPtr<str>,
784    ) -> Result<(), Error> {
785        let uri_ref = &mut self.request_parts_mut(req_handle)?.uri;
786        let req_uri_bytes = memory
787            .as_slice(uri.as_bytes())?
788            .ok_or(Error::SharedMemory)?;
789
790        *uri_ref = Uri::try_from(req_uri_bytes)?;
791        Ok(())
792    }
793
794    fn version_get(
795        &mut self,
796        _memory: &mut GuestMemory<'_>,
797        req_handle: RequestHandle,
798    ) -> Result<HttpVersion, Error> {
799        let req = self.request_parts(req_handle)?;
800        HttpVersion::try_from(req.version).map_err(|msg| Error::Unsupported { msg })
801    }
802
803    fn version_set(
804        &mut self,
805        _memory: &mut GuestMemory<'_>,
806        req_handle: RequestHandle,
807        version: HttpVersion,
808    ) -> Result<(), Error> {
809        let req = self.request_parts_mut(req_handle)?;
810
811        let version = hyper::Version::try_from(version)?;
812        req.version = version;
813        Ok(())
814    }
815
816    async fn send(
817        &mut self,
818        memory: &mut GuestMemory<'_>,
819        req_handle: RequestHandle,
820        body_handle: BodyHandle,
821        backend_bytes: GuestPtr<str>,
822    ) -> Result<(ResponseHandle, BodyHandle), Error> {
823        let backend_bytes_slice = memory
824            .as_slice(backend_bytes.as_bytes())?
825            .ok_or(Error::SharedMemory)?;
826        let backend_name = std::str::from_utf8(&backend_bytes_slice)?;
827
828        // prepare the request
829        let req_parts = self.take_request_parts(req_handle)?;
830        let req_body = self.take_body(body_handle)?;
831        let req = Request::from_parts(req_parts, req_body);
832        let backend = self
833            .backend(backend_name)
834            .ok_or_else(|| Error::UnknownBackend(backend_name.to_owned()))?;
835
836        // synchronously send the request
837        let resp = upstream::send_request(req, backend, self.tls_config()).await?;
838        Ok(self.insert_response(resp))
839    }
840
841    async fn send_v2(
842        &mut self,
843        memory: &mut GuestMemory<'_>,
844        req_handle: RequestHandle,
845        body_handle: BodyHandle,
846        backend_bytes: GuestPtr<str>,
847        _error_detail: GuestPtr<SendErrorDetail>,
848    ) -> Result<(ResponseHandle, BodyHandle), Error> {
849        // This initial implementation ignores the error detail field
850        self.send(memory, req_handle, body_handle, backend_bytes)
851            .await
852    }
853
854    async fn send_v3(
855        &mut self,
856        memory: &mut GuestMemory<'_>,
857        req_handle: RequestHandle,
858        body_handle: BodyHandle,
859        backend_bytes: GuestPtr<str>,
860        error_detail: GuestPtr<SendErrorDetail>,
861    ) -> Result<(ResponseHandle, BodyHandle), Error> {
862        self.send_v2(memory, req_handle, body_handle, backend_bytes, error_detail)
863            .await
864    }
865
866    async fn send_async(
867        &mut self,
868        memory: &mut GuestMemory<'_>,
869        req_handle: RequestHandle,
870        body_handle: BodyHandle,
871        backend_bytes: GuestPtr<str>,
872    ) -> Result<PendingRequestHandle, Error> {
873        let backend_bytes_slice = memory
874            .as_slice(backend_bytes.as_bytes())?
875            .ok_or(Error::SharedMemory)?;
876        let backend_name = std::str::from_utf8(&backend_bytes_slice)?;
877
878        // prepare the request
879        let req_parts = self.take_request_parts(req_handle)?;
880        let req_body = self.take_body(body_handle)?;
881        let req = Request::from_parts(req_parts, req_body);
882        let backend = self
883            .backend(backend_name)
884            .ok_or_else(|| Error::UnknownBackend(backend_name.to_owned()))?;
885
886        // asynchronously send the request
887        let task =
888            PeekableTask::spawn(upstream::send_request(req, backend, self.tls_config())).await;
889
890        // return a handle to the pending task
891        Ok(self.insert_pending_request(task))
892    }
893
894    async fn send_async_v2(
895        &mut self,
896        memory: &mut GuestMemory<'_>,
897        req_handle: RequestHandle,
898        body_handle: BodyHandle,
899        backend_bytes: GuestPtr<str>,
900        streaming: u32,
901    ) -> Result<PendingRequestHandle, Error> {
902        if streaming == 1 {
903            self.send_async_streaming(memory, req_handle, body_handle, backend_bytes)
904                .await
905        } else {
906            self.send_async(memory, req_handle, body_handle, backend_bytes)
907                .await
908        }
909    }
910
911    async fn send_async_streaming(
912        &mut self,
913        memory: &mut GuestMemory<'_>,
914        req_handle: RequestHandle,
915        body_handle: BodyHandle,
916        backend_bytes: GuestPtr<str>,
917    ) -> Result<PendingRequestHandle, Error> {
918        let backend_bytes_slice = memory
919            .as_slice(backend_bytes.as_bytes())?
920            .ok_or(Error::SharedMemory)?;
921        let backend_name = std::str::from_utf8(backend_bytes_slice)?;
922
923        // prepare the request
924        let req_parts = self.take_request_parts(req_handle)?;
925        let req_body = self.begin_streaming(body_handle)?;
926        let req = Request::from_parts(req_parts, req_body);
927        let backend = self
928            .backend(backend_name)
929            .ok_or_else(|| Error::UnknownBackend(backend_name.to_owned()))?;
930
931        // asynchronously send the request
932        let task =
933            PeekableTask::spawn(upstream::send_request(req, backend, self.tls_config())).await;
934
935        // return a handle to the pending task
936        Ok(self.insert_pending_request(task))
937    }
938
939    // note: The first value in the return tuple represents whether the request is done: 0 when not
940    // done, 1 when done.
941    async fn pending_req_poll(
942        &mut self,
943        _memory: &mut GuestMemory<'_>,
944        pending_req_handle: PendingRequestHandle,
945    ) -> Result<(u32, ResponseHandle, BodyHandle), Error> {
946        if self.async_item_mut(pending_req_handle.into())?.is_ready() {
947            let resp = self
948                .take_pending_request(pending_req_handle)?
949                .recv()
950                .await?;
951            let (resp_handle, resp_body_handle) = self.insert_response(resp);
952            Ok((1, resp_handle, resp_body_handle))
953        } else {
954            Ok((0, INVALID_REQUEST_HANDLE.into(), INVALID_BODY_HANDLE.into()))
955        }
956    }
957
958    async fn pending_req_poll_v2(
959        &mut self,
960        memory: &mut GuestMemory<'_>,
961        pending_req_handle: PendingRequestHandle,
962        _error_detail: GuestPtr<SendErrorDetail>,
963    ) -> Result<(u32, ResponseHandle, BodyHandle), Error> {
964        // This initial implementation ignores the error detail field
965        self.pending_req_poll(memory, pending_req_handle).await
966    }
967
968    async fn pending_req_wait(
969        &mut self,
970        _memory: &mut GuestMemory<'_>,
971        pending_req_handle: PendingRequestHandle,
972    ) -> Result<(ResponseHandle, BodyHandle), Error> {
973        let pending_req = self
974            .take_pending_request(pending_req_handle)?
975            .recv()
976            .await?;
977        Ok(self.insert_response(pending_req))
978    }
979
980    async fn pending_req_wait_v2(
981        &mut self,
982        memory: &mut GuestMemory<'_>,
983        pending_req_handle: PendingRequestHandle,
984        _error_detail: GuestPtr<SendErrorDetail>,
985    ) -> Result<(ResponseHandle, BodyHandle), Error> {
986        // This initial implementation ignores the error detail field
987        self.pending_req_wait(memory, pending_req_handle).await
988    }
989
990    // First element of return tuple is the "done index"
991    async fn pending_req_select(
992        &mut self,
993        memory: &mut GuestMemory<'_>,
994        pending_req_handles: GuestPtr<[PendingRequestHandle]>,
995    ) -> Result<(u32, ResponseHandle, BodyHandle), Error> {
996        if pending_req_handles.len() == 0 {
997            return Err(Error::InvalidArgument);
998        }
999        let pending_req_handles = pending_req_handles.cast::<[u32]>();
1000
1001        // perform the select operation
1002        let done_index = self
1003            .select_impl(
1004                memory
1005                    // TODO: `GuestMemory::as_slice` only supports guest pointers to u8 slices in
1006                    // wiggle 22.0.0, but `GuestMemory::to_vec` supports guest pointers to slices
1007                    // of arbitrary types. As `GuestMemory::to_vec` will copy the contents of the
1008                    // slice out of guest memory, we should switch this to `GuestMemory::as_slice`
1009                    // once it is polymorphic in the element type of the slice.
1010                    .to_vec(pending_req_handles)?
1011                    .into_iter()
1012                    .map(|handle| PendingRequestHandle::from(handle).into()),
1013            )
1014            .await? as u32;
1015
1016        let item = self.take_async_item(
1017            PendingRequestHandle::from(memory.read(pending_req_handles.get(done_index).unwrap())?)
1018                .into(),
1019        )?;
1020
1021        let outcome = match item {
1022            AsyncItem::PendingReq(task) => match task {
1023                PeekableTask::Complete(res) => match res {
1024                    Ok(res) => {
1025                        let (resp_handle, body_handle) = self.insert_response(res);
1026                        (done_index, resp_handle, body_handle)
1027                    }
1028                    Err(_) => (
1029                        done_index,
1030                        INVALID_RESPONSE_HANDLE.into(),
1031                        INVALID_BODY_HANDLE.into(),
1032                    ),
1033                },
1034                _ => panic!("Pending request was not completed"),
1035            },
1036            _ => panic!("AsyncItem was not a pending request"),
1037        };
1038
1039        Ok(outcome)
1040    }
1041
1042    async fn pending_req_select_v2(
1043        &mut self,
1044        memory: &mut GuestMemory<'_>,
1045        pending_req_handles: GuestPtr<[PendingRequestHandle]>,
1046        _error_detail: GuestPtr<SendErrorDetail>,
1047    ) -> Result<(u32, ResponseHandle, BodyHandle), Error> {
1048        // This initial implementation ignores the error detail field
1049        self.pending_req_select(memory, pending_req_handles).await
1050    }
1051
1052    fn fastly_key_is_valid(&mut self, memory: &mut GuestMemory<'_>) -> Result<u32, Error> {
1053        FastlyHttpDownstream::fastly_key_is_valid(self, memory, self.downstream_request())
1054    }
1055
1056    fn close(
1057        &mut self,
1058        _memory: &mut GuestMemory<'_>,
1059        req_handle: RequestHandle,
1060    ) -> Result<(), Error> {
1061        // We don't do anything with the parts, but we do pass the error up if
1062        // the handle given doesn't exist
1063        self.take_request_parts(req_handle)?;
1064        Ok(())
1065    }
1066
1067    fn auto_decompress_response_set(
1068        &mut self,
1069        _memory: &mut GuestMemory<'_>,
1070        req_handle: RequestHandle,
1071        encodings: ContentEncodings,
1072    ) -> Result<(), Error> {
1073        // NOTE: We're going to hide this flag in the extensions of the request in order to decrease
1074        // the book-keeping burden inside Session. The flag will get picked up later, in `send_request`.
1075        let extensions = &mut self.request_parts_mut(req_handle)?.extensions;
1076
1077        match extensions.get_mut::<ViceroyRequestMetadata>() {
1078            None => {
1079                extensions.insert(ViceroyRequestMetadata {
1080                    auto_decompress_encodings: encodings,
1081                    ..Default::default()
1082                });
1083            }
1084            Some(vrm) => {
1085                vrm.auto_decompress_encodings = encodings;
1086            }
1087        }
1088
1089        Ok(())
1090    }
1091
1092    fn inspect(
1093        &mut self,
1094        memory: &mut GuestMemory<'_>,
1095        ds_req: RequestHandle,
1096        ds_body: BodyHandle,
1097        info_mask: InspectInfoMask,
1098        info: GuestPtr<InspectInfo>,
1099        buf: GuestPtr<u8>,
1100        buf_len: u32,
1101    ) -> Result<u32, Error> {
1102        // Make sure we're given valid handles, even though we won't use them.
1103        let _ = self.request_parts(ds_req)?;
1104        let _ = self.body(ds_body)?;
1105
1106        // Make sure the InspectInfo looks good, even though we won't use it.
1107        let info = memory.read(info)?;
1108        let info_string_or_err = |flag, str_field: GuestPtr<u8>, len_field| {
1109            if info_mask.contains(flag) {
1110                if len_field == 0 {
1111                    return Err(Error::InvalidArgument);
1112                }
1113
1114                let byte_vec = memory.to_vec(str_field.as_array(len_field))?;
1115                let s = String::from_utf8(byte_vec).map_err(|_| Error::InvalidArgument)?;
1116
1117                Ok(s)
1118            } else {
1119                // For now, corp and workspace arguments are required to actually generate the hostname,
1120                // but in the future the lookaside service will be generated using the customer ID, and
1121                // it will be okay for them to be unspecified or empty.
1122                Err(Error::InvalidArgument)
1123            }
1124        };
1125
1126        let _ = info_string_or_err(InspectInfoMask::CORP, info.corp, info.corp_len)?;
1127        let _ = info_string_or_err(
1128            InspectInfoMask::WORKSPACE,
1129            info.workspace,
1130            info.workspace_len,
1131        )?;
1132
1133        if info_mask.contains(InspectInfoMask::OVERRIDE_CLIENT_IP) {
1134            let _ = read_guest_ip(
1135                memory,
1136                &info.override_client_ip_ptr,
1137                info.override_client_ip_len,
1138            )?;
1139        }
1140
1141        // Return the mock NGWAF response.
1142        let ngwaf_resp = self.ngwaf_response();
1143        let ngwaf_resp_len = ngwaf_resp.len();
1144
1145        match u32::try_from(ngwaf_resp_len) {
1146            Ok(ngwaf_resp_len) if ngwaf_resp_len <= buf_len => {
1147                memory.copy_from_slice(ngwaf_resp.as_bytes(), buf.as_array(ngwaf_resp_len))?;
1148
1149                Ok(ngwaf_resp_len)
1150            }
1151            _ => Err(Error::BufferLengthError {
1152                buf: "buf",
1153                len: "buf_len",
1154            }),
1155        }
1156    }
1157
1158    fn on_behalf_of(
1159        &mut self,
1160        _memory: &mut GuestMemory<'_>,
1161        _ds_req: RequestHandle,
1162        _service: GuestPtr<str>,
1163    ) -> Result<(), Error> {
1164        Err(Error::Unsupported {
1165            msg: "on_behalf_of is not supported in Viceroy",
1166        })
1167    }
1168}
1169
1170fn try_ip_from_bytes<const N: usize>(bytes: Option<&[u8]>) -> Result<IpAddr, Error>
1171where
1172    IpAddr: From<[u8; N]>,
1173{
1174    bytes
1175        .and_then(|bs| <[u8; N]>::try_from(bs).ok())
1176        .map(IpAddr::from)
1177        .ok_or(Error::InvalidArgument)
1178}
1179
1180fn read_guest_ip(
1181    memory: &mut GuestMemory<'_>,
1182    bytes: &GuestPtr<u8>,
1183    len: u32,
1184) -> Result<Option<IpAddr>, Error> {
1185    let bytes = memory.as_slice(bytes.as_array(len as u32))?;
1186
1187    match len {
1188        0 => Ok(None),
1189        4 => try_ip_from_bytes::<4>(bytes).map(Some),
1190        16 => try_ip_from_bytes::<16>(bytes).map(Some),
1191        _ => Err(Error::InvalidArgument),
1192    }
1193}