Skip to main content

solana_address/
derive.rs

1use {
2    crate::{Address, MAX_SEEDS, PDA_MARKER},
3    core::{mem::MaybeUninit, slice::from_raw_parts},
4    sha2_const_stable::Sha256,
5    solana_sha256_hasher::hashv,
6};
7
8impl Address {
9    /// Derive a [program address][pda] from the given seeds, optional bump and
10    /// program id.
11    ///
12    /// [pda]: https://solana.com/docs/core/pda
13    ///
14    /// In general, the derivation uses an optional bump (byte) value to ensure a
15    /// valid PDA (off-curve) is generated. Even when a program stores a bump to
16    /// derive a program address, it is necessary to use the
17    /// [`Address::create_program_address`] to validate the derivation. In
18    /// most cases, the program has the correct seeds for the derivation, so it would
19    /// be sufficient to just perform the derivation and compare it against the
20    /// expected resulting address.
21    ///
22    /// This function avoids the cost of the `create_program_address` syscall
23    /// (`1500` compute units) by directly computing the derived address
24    /// calculating the hash of the seeds, bump and program id using the
25    /// `sol_sha256` syscall.
26    ///
27    /// # Important
28    ///
29    /// This function differs from [`Address::create_program_address`] in that
30    /// it does not perform a validation to ensure that the derived address is a valid
31    /// (off-curve) program derived address. It is intended for use in cases where the
32    /// seeds, bump, and program id are known to be valid, and the caller wants to derive
33    /// the address without incurring the cost of the `create_program_address` syscall.
34    pub fn derive_address<const N: usize>(
35        seeds: &[&[u8]; N],
36        bump: Option<u8>,
37        program_id: &Address,
38    ) -> Address {
39        const {
40            assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
41        }
42
43        let mut data = [const { MaybeUninit::<&[u8]>::uninit() }; MAX_SEEDS + 2];
44        let mut i = 0;
45
46        while i < N {
47            // SAFETY: `data` is guaranteed to have enough space for `N` seeds,
48            // so `i` will always be within bounds.
49            unsafe {
50                data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
51            }
52            i += 1;
53        }
54
55        // SAFETY: `data` is guaranteed to have enough space for `MAX_SEEDS + 2`
56        // elements, and `MAX_SEEDS` is larger than `N`.
57        unsafe {
58            if bump.is_some() {
59                data.get_unchecked_mut(i).write(bump.as_slice());
60                i += 1;
61            }
62            data.get_unchecked_mut(i).write(program_id.as_ref());
63            data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
64        }
65
66        let hash = hashv(unsafe { from_raw_parts(data.as_ptr() as *const &[u8], i + 2) });
67        Address::from(hash.to_bytes())
68    }
69
70    /// Derive a [program address][pda] from the given seeds, optional bump and
71    /// program id.
72    ///
73    /// [pda]: https://solana.com/docs/core/pda
74    ///
75    /// In general, the derivation uses an optional bump (byte) value to ensure a
76    /// valid PDA (off-curve) is generated.
77    ///
78    /// This function is intended for use in `const` contexts - i.e., the seeds and
79    /// bump are known at compile time and the program id is also a constant. It avoids
80    /// the cost of the `create_program_address` syscall (`1500` compute units) by
81    /// directly computing the derived address using the SHA-256 hash of the seeds,
82    /// bump and program id.
83    ///
84    /// # Important
85    ///
86    /// This function differs from [`Address::create_program_address`] in that
87    /// it does not perform a validation to ensure that the derived address is a valid
88    /// (off-curve) program derived address. It is intended for use in cases where the
89    /// seeds, bump, and program id are known to be valid, and the caller wants to derive
90    /// the address without incurring the cost of the `create_program_address` syscall.
91    ///
92    /// This function is a compile-time constant version of [`Address::derive_address`].
93    /// It has worse performance than `derive_address`, so only use this function in
94    /// `const` contexts, where all parameters are known at compile-time.
95    pub const fn derive_address_const<const N: usize>(
96        seeds: &[&[u8]; N],
97        bump: Option<u8>,
98        program_id: &Address,
99    ) -> Address {
100        const {
101            assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
102        }
103
104        let mut hasher = Sha256::new();
105        let mut i = 0;
106
107        while i < seeds.len() {
108            hasher = hasher.update(seeds[i]);
109            i += 1;
110        }
111
112        // TODO: replace this with `bump.as_slice()` when the MSRV is
113        // upgraded to `1.84.0+`.
114        Address::new_from_array(if let Some(bump) = bump {
115            hasher
116                .update(&[bump])
117                .update(program_id.as_array())
118                .update(PDA_MARKER)
119                .finalize()
120        } else {
121            hasher
122                .update(program_id.as_array())
123                .update(PDA_MARKER)
124                .finalize()
125        })
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use crate::Address;
132
133    #[test]
134    fn test_derive_address() {
135        let program_id = Address::new_from_array([1u8; 32]);
136        let seeds: &[&[u8]; 2] = &[b"seed1", b"seed2"];
137        let (address, bump) = Address::find_program_address(seeds, &program_id);
138
139        let derived_address = Address::derive_address(seeds, Some(bump), &program_id);
140        let derived_address_const = Address::derive_address_const(seeds, Some(bump), &program_id);
141
142        assert_eq!(address, derived_address);
143        assert_eq!(address, derived_address_const);
144
145        let extended_seeds: &[&[u8]; 3] = &[b"seed1", b"seed2", &[bump]];
146
147        let derived_address = Address::derive_address(extended_seeds, None, &program_id);
148        let derived_address_const =
149            Address::derive_address_const(extended_seeds, None, &program_id);
150
151        assert_eq!(address, derived_address);
152        assert_eq!(address, derived_address_const);
153    }
154}