viceroy_lib/wiggle_abi/
shielding.rs

1use crate::config::Backend;
2use crate::error::Error;
3use crate::session::Session;
4use crate::wiggle_abi::{fastly_shielding, types};
5use http::Uri;
6use std::str::FromStr;
7
8impl fastly_shielding::FastlyShielding for Session {
9    fn shield_info(
10        &mut self,
11        memory: &mut wiggle::GuestMemory<'_>,
12        name: wiggle::GuestPtr<str>,
13        out_buffer: wiggle::GuestPtr<u8>,
14        out_buffer_max_len: u32,
15    ) -> Result<u32, Error> {
16        // Validate the input name and then return the unsupported error.
17        let Some(name) = memory.as_str(name)?.map(str::to_string) else {
18            return Err(Error::ValueAbsent);
19        };
20
21        let running_on = self.shielding_sites().is_local(&name);
22        let unencrypted = self
23            .shielding_sites()
24            .get_unencrypted(&name)
25            .map(|x| x.to_string())
26            .unwrap_or_default();
27        let encrypted = self
28            .shielding_sites()
29            .get_encrypted(&name)
30            .map(|x| x.to_string())
31            .unwrap_or_default();
32
33        if !running_on && unencrypted.is_empty() {
34            return Err(Error::InvalidArgument);
35        }
36
37        let mut output_bytes = Vec::new();
38
39        output_bytes.push(if running_on { 1u8 } else { 0 });
40        output_bytes.extend_from_slice(unencrypted.as_bytes());
41        output_bytes.push(0);
42        output_bytes.extend_from_slice(encrypted.as_bytes());
43        output_bytes.push(0);
44
45        let target_len = output_bytes.len() as u32;
46
47        if target_len > out_buffer_max_len {
48            return Err(Error::BufferLengthError {
49                buf: "shielding_info",
50                len: "info.len()",
51            });
52        }
53
54        memory.copy_from_slice(&output_bytes, out_buffer.as_array(target_len))?;
55        Ok(target_len)
56    }
57
58    fn backend_for_shield(
59        &mut self,
60        memory: &mut wiggle::GuestMemory<'_>,
61        shield_name: wiggle::GuestPtr<str>,
62        shield_backend_options: types::ShieldBackendOptions,
63        shield_backend_config: wiggle::GuestPtr<types::ShieldBackendConfig>,
64        out_buffer: wiggle::GuestPtr<u8>,
65        out_buffer_max_len: u32,
66    ) -> Result<u32, Error> {
67        // Validate our inputs and then return the unsupported error.
68        let Some(shield_uri) = memory.as_str(shield_name)?.map(str::to_string) else {
69            return Err(Error::ValueAbsent);
70        };
71
72        if shield_backend_options.contains(types::ShieldBackendOptions::RESERVED) {
73            return Err(Error::InvalidArgument);
74        }
75
76        let config = memory.read(shield_backend_config)?;
77
78        if shield_backend_options.contains(types::ShieldBackendOptions::USE_CACHE_KEY) {
79            let field_string = config.cache_key.as_array(config.cache_key_len).cast();
80            if memory.as_str(field_string)?.is_none() {
81                return Err(Error::InvalidArgument);
82            }
83        }
84
85        let Ok(uri) = Uri::from_str(&shield_uri) else {
86            return Err(Error::InvalidArgument);
87        };
88
89        let new_name = format!("******{uri}*****");
90        let new_backend = Backend {
91            uri,
92            override_host: None,
93            cert_host: None,
94            use_sni: false,
95            grpc: false,
96            client_cert: None,
97            ca_certs: Vec::new(),
98        };
99
100        if !self.add_backend(&new_name, new_backend) {
101            return Err(Error::BackendNameRegistryError(new_name));
102        }
103
104        let new_name_bytes = new_name.as_bytes().to_vec();
105
106        let target_len = new_name_bytes.len() as u32;
107
108        if target_len > out_buffer_max_len {
109            return Err(Error::BufferLengthError {
110                buf: "shielding_backend",
111                len: "name.len()",
112            });
113        }
114
115        memory.copy_from_slice(&new_name_bytes, out_buffer.as_array(target_len))?;
116
117        Ok(target_len)
118    }
119}