Skip to main content

solana_address/
derive.rs

1use {
2    crate::{Address, MAX_SEEDS, PDA_MARKER},
3    core::{mem::MaybeUninit, slice::from_raw_parts},
4    sha2_const_stable::Sha256,
5    solana_sha256_hasher::hashv,
6};
7
8impl Address {
9    /// Derive a [program address][pda] from the given seeds, optional bump and
10    /// program id.
11    ///
12    /// [pda]: https://solana.com/docs/core/pda
13    ///
14    /// In general, the derivation uses an optional bump (byte) value to ensure a
15    /// valid PDA (off-curve) is generated. Even when a program stores a bump to
16    /// derive a program address, it is necessary to use the
17    /// [`Address::create_program_address`] to validate the derivation. In
18    /// most cases, the program has the correct seeds for the derivation, so it would
19    /// be sufficient to just perform the derivation and compare it against the
20    /// expected resulting address.
21    ///
22    /// This function avoids the cost of the `create_program_address` syscall
23    /// (`1500` compute units) by directly computing the derived address
24    /// calculating the hash of the seeds, bump and program id using the
25    /// `sol_sha256` syscall.
26    ///
27    /// # Important
28    ///
29    /// This function differs from [`Address::create_program_address`] in that
30    /// it does not perform a validation to ensure that the derived address is a valid
31    /// (off-curve) program derived address. It is intended for use in cases where the
32    /// seeds, bump, and program id are known to be valid, and the caller wants to derive
33    /// the address without incurring the cost of the `create_program_address` syscall.
34    #[inline]
35    pub fn derive_address<const N: usize>(
36        seeds: &[&[u8]; N],
37        bump: Option<u8>,
38        program_id: &Address,
39    ) -> Address {
40        const {
41            assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
42        }
43
44        let mut data = [const { MaybeUninit::<&[u8]>::uninit() }; MAX_SEEDS + 2];
45        let mut i = 0;
46
47        while i < N {
48            // SAFETY: `data` is guaranteed to have enough space for `N` seeds,
49            // so `i` will always be within bounds.
50            unsafe {
51                data.get_unchecked_mut(i).write(seeds.get_unchecked(i));
52            }
53            i += 1;
54        }
55
56        // SAFETY: `data` is guaranteed to have enough space for `MAX_SEEDS + 2`
57        // elements, and `MAX_SEEDS` is larger than `N`.
58        unsafe {
59            if bump.is_some() {
60                data.get_unchecked_mut(i).write(bump.as_slice());
61                i += 1;
62            }
63            data.get_unchecked_mut(i).write(program_id.as_ref());
64            data.get_unchecked_mut(i + 1).write(PDA_MARKER.as_ref());
65        }
66
67        let hash = hashv(unsafe { from_raw_parts(data.as_ptr() as *const &[u8], i + 2) });
68        Address::from(hash.to_bytes())
69    }
70
71    /// Derive a [program address][pda] from the given seeds, optional bump and
72    /// program id.
73    ///
74    /// [pda]: https://solana.com/docs/core/pda
75    ///
76    /// In general, the derivation uses an optional bump (byte) value to ensure a
77    /// valid PDA (off-curve) is generated.
78    ///
79    /// This function is intended for use in `const` contexts - i.e., the seeds and
80    /// bump are known at compile time and the program id is also a constant. It avoids
81    /// the cost of the `create_program_address` syscall (`1500` compute units) by
82    /// directly computing the derived address using the SHA-256 hash of the seeds,
83    /// bump and program id.
84    ///
85    /// # Important
86    ///
87    /// This function differs from [`Address::create_program_address`] in that
88    /// it does not perform a validation to ensure that the derived address is a valid
89    /// (off-curve) program derived address. It is intended for use in cases where the
90    /// seeds, bump, and program id are known to be valid, and the caller wants to derive
91    /// the address without incurring the cost of the `create_program_address` syscall.
92    ///
93    /// This function is a compile-time constant version of [`Address::derive_address`].
94    /// It has worse performance than `derive_address`, so only use this function in
95    /// `const` contexts, where all parameters are known at compile-time.
96    pub const fn derive_address_const<const N: usize>(
97        seeds: &[&[u8]; N],
98        bump: Option<u8>,
99        program_id: &Address,
100    ) -> Address {
101        const {
102            assert!(N < MAX_SEEDS, "number of seeds must be less than MAX_SEEDS");
103        }
104
105        let mut hasher = Sha256::new();
106        let mut i = 0;
107
108        while i < seeds.len() {
109            hasher = hasher.update(seeds[i]);
110            i += 1;
111        }
112
113        // TODO: replace this with `bump.as_slice()` when the MSRV is
114        // upgraded to `1.84.0+`.
115        Address::new_from_array(if let Some(bump) = bump {
116            hasher
117                .update(&[bump])
118                .update(program_id.as_array())
119                .update(PDA_MARKER)
120                .finalize()
121        } else {
122            hasher
123                .update(program_id.as_array())
124                .update(PDA_MARKER)
125                .finalize()
126        })
127    }
128
129    /// Attempt to derive a valid [program derived address][pda] (PDA) and its corresponding
130    /// bump seed.
131    ///
132    /// [pda]: https://solana.com/docs/core/cpi#program-derived-addresses
133    ///
134    /// The main difference between this method and [`Address::derive_address`]
135    /// is that this method iterates through all possible bump seed values (starting from
136    /// `255` and decrementing) until it finds a valid (off-curve) program derived address.
137    ///
138    /// If a valid PDA is found, it returns the PDA and the bump seed used to derive it;
139    /// otherwise, it returns `None`.
140    #[inline]
141    pub fn derive_program_address<const N: usize>(
142        seeds: &[&[u8]; N],
143        program_id: &Address,
144    ) -> Option<(Address, u8)> {
145        let mut bump = u8::MAX;
146
147        loop {
148            let address = Self::derive_address(seeds, Some(bump), program_id);
149
150            // Check if the derived address is a valid (off-curve)
151            // program derived address.
152            if !address.is_on_curve() {
153                return Some((address, bump));
154            }
155
156            // If the derived address is on-curve, decrement the bump and
157            // try again until all possible bump values are tested.
158            if bump == 0 {
159                return None;
160            }
161
162            bump -= 1;
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::Address;
170
171    #[test]
172    fn test_derive_address() {
173        let program_id = Address::new_from_array([1u8; 32]);
174        let seeds: &[&[u8]; 2] = &[b"seed1", b"seed2"];
175        let (address, bump) = Address::find_program_address(seeds, &program_id);
176
177        let derived_address = Address::derive_address(seeds, Some(bump), &program_id);
178        let derived_address_const = Address::derive_address_const(seeds, Some(bump), &program_id);
179
180        assert_eq!(address, derived_address);
181        assert_eq!(address, derived_address_const);
182
183        let extended_seeds: &[&[u8]; 3] = &[b"seed1", b"seed2", &[bump]];
184
185        let derived_address = Address::derive_address(extended_seeds, None, &program_id);
186        let derived_address_const =
187            Address::derive_address_const(extended_seeds, None, &program_id);
188
189        assert_eq!(address, derived_address);
190        assert_eq!(address, derived_address_const);
191    }
192
193    #[test]
194    fn test_program_derive_address() {
195        let program_id = Address::new_unique();
196        let seeds: &[&[u8]; 3] = &[b"derived", b"programm", b"address"];
197
198        let (address, bump) = Address::find_program_address(seeds, &program_id);
199
200        let (derived_address, derived_bump) =
201            Address::derive_program_address(seeds, &program_id).unwrap();
202
203        assert_eq!(address, derived_address);
204        assert_eq!(bump, derived_bump);
205    }
206}