use super::switchboard;
use ethers::core::{
abi::ParamType,
types::{
transaction::eip712::{make_type_hash, EIP712Domain},
Address, Bytes, U256,
},
};
use inflector::Inflector;
use std::str::FromStr;
use switchboard_common::EvmTransaction;
use syn::{
parse::Error, spanned::Spanned, Data, Expr, Fields, GenericArgument, Lit, PathArguments,
Result as SynResult, Type,
};
#[derive(
Clone,
::ethers::contract::EthAbiType,
::ethers::contract::EthAbiCodec,
Default,
Debug,
PartialEq,
Eq,
Hash,
)]
pub struct Transaction {
pub expiration_time_seconds: ::ethers::core::types::U256,
pub gas_limit: ::ethers::core::types::U256,
pub value: ::ethers::core::types::U256,
pub to: ::ethers::core::types::Address,
pub from: ::ethers::core::types::Address,
pub data: ::ethers::core::types::Bytes,
}
type Eip712Error = ethers::core::types::transaction::eip712::Eip712Error;
impl Transaction {
fn type_hash(&self) -> ::core::result::Result<[u8; 32], Eip712Error> {
let input: ::syn::DeriveInput = ::syn::parse_quote! { struct Transaction {
pub expiration_time_seconds: ::ethers::core::types::U256,
pub gas_limit: ::ethers::core::types::U256,
pub value: ::ethers::core::types::U256,
pub to: ::ethers::core::types::Address,
pub from: ::ethers::core::types::Address,
pub data: ::ethers::core::types::Bytes,
} };
let primary_type = input.clone().ident;
let parsed_fields = parse_fields(&input).unwrap();
let type_hash = make_type_hash(primary_type.to_string(), &parsed_fields);
Ok(type_hash)
}
#[inline]
fn domain_separator(
&self,
domain: ethers::core::types::transaction::eip712::EIP712Domain,
) -> ::core::result::Result<[u8; 32], Eip712Error> {
let domain_separator = domain.separator();
let _domain_str = serde_json::to_string(&domain).unwrap();
Ok(domain_separator)
}
fn struct_hash(&self) -> ::core::result::Result<[u8; 32], Eip712Error> {
let mut items = vec![ethers::core::abi::Token::Uint(
ethers::core::types::U256::from(&Self::type_hash(&self)?[..]),
)];
if let ethers::core::abi::Token::Tuple(tokens) =
ethers::core::abi::Tokenizable::into_token(::core::clone::Clone::clone(self))
{
items.reserve(tokens.len());
for token in tokens {
match &token {
ethers::core::abi::Token::Tuple(_t) => {
return Err(Eip712Error::NestedEip712StructNotImplemented);
}
_ => {
items.push(
ethers::core::types::transaction::eip712::encode_eip712_type(token),
);
}
}
}
}
let struct_hash = ethers::core::utils::keccak256(ethers::core::abi::encode(&items));
Ok(struct_hash)
}
pub fn encode_eip712(
&self,
domain: ethers::core::types::transaction::eip712::EIP712Domain,
) -> std::result::Result<[u8; 32], Eip712Error> {
let domain_separator = self.domain_separator(domain)?;
let struct_hash = self.struct_hash()?;
let digest_input = [&[0x19, 0x01], &domain_separator[..], &struct_hash[..]].concat();
Ok(ethers::core::utils::keccak256(digest_input))
}
}
pub fn parse_fields(
input: &syn::DeriveInput,
) -> SynResult<Vec<(String, ethers::core::abi::ParamType)>> {
let data = match &input.data {
Data::Struct(s) => s,
Data::Enum(e) => {
return Err(Error::new(
e.enum_token.span,
"Eip712 is not derivable for enums",
))
}
Data::Union(u) => {
return Err(Error::new(
u.union_token.span,
"Eip712 is not derivable for unions",
))
}
};
let named_fields = match &data.fields {
Fields::Named(fields) => fields,
_ => return Err(Error::new(input.span(), "unnamed fields are not supported")),
};
let mut fields = Vec::with_capacity(named_fields.named.len());
for f in named_fields.named.iter() {
let name = f.ident.as_ref().unwrap().to_string();
let s = name.strip_prefix("r#").unwrap_or(&name);
let name = s.to_camel_case();
let ty = match f
.attrs
.iter()
.find(|a| a.path().segments.iter().any(|s| s.ident == "eip712"))
{
Some(a) => {
return Err(Error::new(
a.span(),
"nested Eip712 struct are not yet supported",
))
}
None => find_parameter_type(&f.ty)?,
};
fields.push((name, ty));
}
Ok(fields)
}
pub fn find_parameter_type(ty: &Type) -> core::result::Result<ParamType, Error> {
const ERROR: &str = "Failed to derive proper ABI from array field";
match ty {
Type::Array(arr) => {
let ty = find_parameter_type(&arr.elem)?;
if let Expr::Lit(ref expr) = arr.len {
if let Lit::Int(ref len) = expr.lit {
if let Ok(len) = len.base10_parse::<usize>() {
return match (ty, len) {
(ParamType::Uint(8), 32) => Ok(ParamType::FixedBytes(32)),
(ty, len) => Ok(ParamType::FixedArray(Box::new(ty), len)),
};
}
}
}
Err(Error::new(arr.span(), ERROR))
}
Type::Path(ty) => {
if let Some(segment) = ty.path.segments.iter().find(|s| s.ident == "Vec") {
if let PathArguments::AngleBracketed(ref args) = segment.arguments {
debug_assert!(matches!(args.args.len(), 1 | 2));
let ty = args.args.iter().next().unwrap();
if let GenericArgument::Type(ref ty) = ty {
return find_parameter_type(ty)
.map(|kind| ParamType::Array(Box::new(kind)));
}
}
}
ty.path
.get_ident()
.or_else(|| ty.path.segments.last().map(|s| &s.ident))
.and_then(|ident| {
match ident.to_string().as_str() {
"Address" => Some(ParamType::Address),
"Bytes" => Some(ParamType::Bytes),
"Uint8" => Some(ParamType::Uint(8)),
"String" => Some(ParamType::String),
"bool" => Some(ParamType::Bool),
"usize" => Some(ParamType::Uint(64)),
"isize" => Some(ParamType::Int(64)),
s => parse_param_type(s),
}
})
.ok_or_else(|| Error::new(ty.span(), ERROR))
}
Type::Tuple(ty) => ty
.elems
.iter()
.map(find_parameter_type)
.collect::<core::result::Result<Vec<_>, _>>()
.map(ParamType::Tuple),
_ => Err(Error::new(ty.span(), ERROR)),
}
}
pub fn parse_param_type(s: &str) -> Option<ParamType> {
match s.chars().next() {
Some('H' | 'h') => {
let size = s[1..].parse::<usize>().ok()? / 8;
Some(ParamType::FixedBytes(size))
}
Some(c @ 'U' | c @ 'I' | c @ 'u' | c @ 'i') => {
let size = s[1..].parse::<usize>().ok()?;
if matches!(c, 'U' | 'u') {
Some(ParamType::Uint(size))
} else {
Some(ParamType::Int(size))
}
}
_ => None,
}
}
impl From<&EvmTransaction> for Transaction {
fn from(tx: &EvmTransaction) -> Self {
Transaction {
expiration_time_seconds: U256::from(tx.expiration_time_seconds),
gas_limit: U256::from_str_radix(&tx.gas_limit, 10).unwrap(),
value: U256::from_str_radix(&tx.value, 10).unwrap(),
to: Address::from_str(&tx.to).unwrap(),
from: Address::from_str(&tx.from).unwrap(),
data: Bytes::from_str(&tx.data).unwrap(),
}
}
}
impl From<&switchboard::Transaction> for Transaction {
fn from(tx: &switchboard::Transaction) -> Self {
Transaction {
expiration_time_seconds: tx.expiration_time_seconds,
gas_limit: tx.gas_limit,
value: tx.value,
to: tx.to,
from: tx.from,
data: tx.data.clone(),
}
}
}
pub fn get_transaction_hash(
name: String,
version: String,
chain_id: u64,
verifying_contract: Address,
transaction: switchboard::Transaction,
) -> std::result::Result<[u8; 32], Eip712Error> {
let tx = Transaction::from(&transaction);
let domain = EIP712Domain {
name: Some(name.into()),
version: Some(version.into()),
chain_id: Some(chain_id.into()),
verifying_contract: Some(verifying_contract),
salt: None,
};
tx.encode_eip712(domain.clone())
}