spongefish/
domain_separator.rs1use core::{fmt, fmt::Arguments};
55
56use rand::rngs::StdRng;
57
58#[cfg(feature = "sha3")]
59use crate::VerifierState;
60use crate::{DuplexSpongeInterface, Encoding, ProverState, StdHash, Unit};
61
62#[derive(Debug, Default, Copy, Clone)]
79pub struct WithoutInstance;
80
81pub struct WithInstance<I>(I);
92
93#[derive(Debug, Clone, Copy, Default)]
95pub struct WithoutSession;
96
97pub struct WithSession<S>(pub(crate) S);
99
100impl<S: fmt::Debug> fmt::Debug for WithSession<S> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_tuple("WithSession").field(&self.0).finish()
103 }
104}
105
106#[derive(Debug, Clone, Copy, Default)]
111pub struct NoSession;
112
113impl<T: Unit> Encoding<[T]> for NoSession {
114 fn encode(&self) -> impl AsRef<[T]> {
115 let empty: [T; 0] = [];
116 empty
117 }
118}
119
120pub struct DomainSeparator<I, S = WithoutSession> {
124 pub protocol: [u8; 64],
126 pub session: S,
128 instance: I,
130}
131
132impl DomainSeparator<WithoutInstance, WithoutSession> {
133 #[must_use]
134 pub const fn new(protocol: [u8; 64]) -> Self {
135 Self {
136 protocol,
137 session: WithoutSession,
138 instance: WithoutInstance,
139 }
140 }
141}
142
143impl<I> DomainSeparator<I, WithoutSession> {
144 #[must_use]
149 pub fn session<S>(self, value: S) -> DomainSeparator<I, WithSession<S>> {
150 DomainSeparator {
151 protocol: self.protocol,
152 session: WithSession(value),
153 instance: self.instance,
154 }
155 }
156
157 #[must_use]
159 pub fn without_session(self) -> DomainSeparator<I, WithSession<NoSession>> {
160 self.session(NoSession)
161 }
162}
163
164impl<S> DomainSeparator<WithoutInstance, WithSession<S>> {
165 pub fn instance<I>(self, value: I) -> DomainSeparator<WithInstance<I>, WithSession<S>> {
166 DomainSeparator {
167 protocol: self.protocol,
168 session: self.session,
169 instance: WithInstance(value),
170 }
171 }
172}
173
174impl<I, S> DomainSeparator<WithInstance<I>, WithSession<S>>
175where
176 I: Encoding,
177 S: Encoding,
178{
179 #[cfg(feature = "sha3")]
180 #[must_use]
181 pub fn std_prover(&self) -> ProverState {
182 let mut prover_state = ProverState::from(StdHash::from_protocol_id(self.protocol));
183 prover_state.public_message(&self.session.0);
184 prover_state.public_message(&self.instance.0);
185 prover_state
186 }
187
188 #[cfg(feature = "sha3")]
189 #[must_use]
190 pub fn std_verifier<'ver>(&self, narg_string: &'ver [u8]) -> VerifierState<'ver, StdHash> {
191 let mut verifier_state =
192 VerifierState::from_parts(StdHash::from_protocol_id(self.protocol), narg_string);
193 verifier_state.public_message(&self.session.0);
194 verifier_state.public_message(&self.instance.0);
195 verifier_state
196 }
197}
198
199impl<I, S> DomainSeparator<WithInstance<I>, WithSession<S>> {
200 pub fn to_prover<H>(&self, h: H) -> ProverState<H, StdRng>
201 where
202 H: DuplexSpongeInterface,
203 [u8; 64]: Encoding<[H::U]>,
204 S: Encoding<[H::U]>,
205 I: Encoding<[H::U]>,
206 {
207 let mut prover_state = ProverState::from(h);
208 prover_state.public_message(&self.protocol);
209 prover_state.public_message(&self.session.0);
210 prover_state.public_message(&self.instance.0);
211 prover_state
212 }
213
214 pub fn to_verifier<'ver, H>(&self, h: H, narg_string: &'ver [u8]) -> VerifierState<'ver, H>
215 where
216 H: DuplexSpongeInterface,
217 [u8; 64]: Encoding<[H::U]>,
218 S: Encoding<[H::U]>,
219 I: Encoding<[H::U]>,
220 {
221 let mut verifier_state = VerifierState::from_parts(h, narg_string);
222 verifier_state.public_message(&self.protocol);
223 verifier_state.public_message(&self.session.0);
224 verifier_state.public_message(&self.instance.0);
225 verifier_state
226 }
227}
228
229#[inline]
230#[must_use]
231pub fn protocol_id(args: Arguments) -> [u8; 64] {
232 if let Some(message) = args.as_str() {
233 return pad_identifier(message.as_bytes());
234 }
235
236 let formatted = alloc::fmt::format(args);
237 pad_identifier(formatted.as_bytes())
238}
239
240#[inline]
241#[must_use]
242pub fn session_id(args: Arguments) -> [u8; 64] {
243 if let Some(message) = args.as_str() {
244 return derive_session_id(message.as_bytes());
245 }
246
247 let formatted = alloc::fmt::format(args);
248 derive_session_id(formatted.as_bytes())
249}
250
251#[inline]
252#[doc(hidden)]
253#[must_use]
254pub fn session_id_from_str<S>(value: &S) -> [u8; 64]
255where
256 S: AsRef<str> + ?Sized,
257{
258 derive_session_id(value.as_ref().as_bytes())
259}
260
261fn pad_identifier(identifier: &[u8]) -> [u8; 64] {
262 assert!(
263 identifier.len() <= 64,
264 "protocol identifier must fit in 64 bytes"
265 );
266
267 let mut protocol_id = [0u8; 64];
268 protocol_id[..identifier.len()].copy_from_slice(identifier);
269 protocol_id
270}
271
272fn derive_session_id(session: &[u8]) -> [u8; 64] {
273 let mut sponge = StdHash::from_protocol_id(pad_identifier(b"fiat-shamir/session-id"));
274 sponge.absorb(session);
275
276 let mut session_id = [0u8; 64];
277 sponge.squeeze(&mut session_id[32..]);
278 session_id
279}