use crate::{
Opcode,
Operand,
RegistersCaller,
RegistersCallerCircuit,
RegistersLoad,
RegistersLoadCircuit,
RegistersStore,
RegistersStoreCircuit,
StackMatches,
StackProgram,
};
use console::{
network::prelude::*,
program::{
Entry,
EntryType,
Identifier,
Literal,
LiteralType,
Owner,
Plaintext,
PlaintextType,
Record,
Register,
RegisterType,
Value,
ValueType,
},
types::Field,
};
use indexmap::IndexMap;
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Cast<N: Network> {
operands: Vec<Operand<N>>,
destination: Register<N>,
register_type: RegisterType<N>,
}
impl<N: Network> Cast<N> {
#[inline]
pub const fn opcode() -> Opcode {
Opcode::Cast
}
#[inline]
pub fn operands(&self) -> &[Operand<N>] {
&self.operands
}
#[inline]
pub fn destinations(&self) -> Vec<Register<N>> {
vec![self.destination.clone()]
}
#[inline]
pub const fn register_type(&self) -> &RegisterType<N> {
&self.register_type
}
}
impl<N: Network> Cast<N> {
#[inline]
pub fn evaluate(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut (impl RegistersCaller<N> + RegistersLoad<N> + RegistersStore<N>),
) -> Result<()> {
let inputs: Vec<_> = self.operands.iter().map(|operand| registers.load(stack, operand)).try_collect()?;
match self.register_type {
RegisterType::Plaintext(PlaintextType::Literal(..)) => bail!("Casting to literal is currently unsupported"),
RegisterType::Plaintext(PlaintextType::Struct(struct_name)) => {
self.cast_to_struct(stack, registers, struct_name, inputs)
}
RegisterType::Record(record_name) => {
if inputs.len() < N::MIN_RECORD_ENTRIES {
bail!("Casting to a record requires at least {} operand", N::MIN_RECORD_ENTRIES)
}
let record_type = stack.program().get_record(&record_name)?;
if inputs.len() != record_type.entries().len() + 1 {
bail!(
"Casting to the record {} requires {} operands, but {} were provided",
record_type.name(),
record_type.entries().len() + 1,
inputs.len()
)
}
let owner: Owner<N, Plaintext<N>> = match &inputs[0] {
Value::Plaintext(Plaintext::Literal(Literal::Address(owner), ..)) => {
match record_type.owner().is_public() {
true => Owner::Public(*owner),
false => Owner::Private(Plaintext::Literal(Literal::Address(*owner), Default::default())),
}
}
_ => bail!("Invalid record 'owner'"),
};
let mut entries = IndexMap::new();
for (entry, (entry_name, entry_type)) in
inputs.iter().skip(N::MIN_RECORD_ENTRIES).zip_eq(record_type.entries())
{
let register_type = RegisterType::from(ValueType::from(*entry_type));
let plaintext = match entry {
Value::Plaintext(plaintext) => {
stack.matches_register_type(&Value::Plaintext(plaintext.clone()), ®ister_type)?;
plaintext.clone()
}
Value::Record(..) => bail!("Casting a record into a record entry is illegal"),
};
match entry_type {
EntryType::Constant(..) => entries.insert(*entry_name, Entry::Constant(plaintext)),
EntryType::Public(..) => entries.insert(*entry_name, Entry::Public(plaintext)),
EntryType::Private(..) => entries.insert(*entry_name, Entry::Private(plaintext)),
};
}
let index = Field::from_u64(self.destination.locator());
let randomizer = N::hash_to_scalar_psd2(&[registers.tvk()?, index])?;
let nonce = N::g_scalar_multiply(&randomizer);
let record = Record::<N, Plaintext<N>>::from_plaintext(owner, entries, nonce)?;
registers.store(stack, &self.destination, Value::Record(record))
}
RegisterType::ExternalRecord(_locator) => {
bail!("Illegal operation: Cannot cast to an external record.")
}
}
}
#[inline]
pub fn execute<A: circuit::Aleo<Network = N>>(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut (impl RegistersCallerCircuit<N, A> + RegistersLoadCircuit<N, A> + RegistersStoreCircuit<N, A>),
) -> Result<()> {
use circuit::{Eject, Inject};
let inputs: Vec<_> =
self.operands.iter().map(|operand| registers.load_circuit(stack, operand)).try_collect()?;
match self.register_type {
RegisterType::Plaintext(PlaintextType::Literal(..)) => bail!("Casting to literal is currently unsupported"),
RegisterType::Plaintext(PlaintextType::Struct(struct_)) => {
if inputs.len() < N::MIN_STRUCT_ENTRIES {
bail!("Casting to a struct requires at least {} operand", N::MIN_STRUCT_ENTRIES)
}
let struct_ = stack.program().get_struct(&struct_)?;
if inputs.len() != struct_.members().len() {
bail!(
"Casting to the struct {} requires {} operands, but {} were provided",
struct_.name(),
struct_.members().len(),
inputs.len()
)
}
let mut members = IndexMap::new();
for (member, (member_name, member_type)) in inputs.iter().zip_eq(struct_.members()) {
let register_type = RegisterType::Plaintext(*member_type);
let plaintext = match member {
circuit::Value::Plaintext(plaintext) => {
stack.matches_register_type(
&circuit::Value::Plaintext(plaintext.clone()).eject_value(),
®ister_type,
)?;
plaintext.clone()
}
circuit::Value::Record(..) => {
bail!("Casting a record into a struct member is illegal")
}
};
members.insert(circuit::Identifier::constant(*member_name), plaintext);
}
let struct_ = circuit::Plaintext::Struct(members, Default::default());
registers.store_circuit(stack, &self.destination, circuit::Value::Plaintext(struct_))
}
RegisterType::Record(record_name) => {
if inputs.len() < N::MIN_RECORD_ENTRIES {
bail!("Casting to a record requires at least {} operand", N::MIN_RECORD_ENTRIES)
}
let record_type = stack.program().get_record(&record_name)?;
if inputs.len() != record_type.entries().len() + 1 {
bail!(
"Casting to the record {} requires {} operands, but {} were provided",
record_type.name(),
record_type.entries().len() + 1,
inputs.len()
)
}
let owner: circuit::Owner<A, circuit::Plaintext<A>> = match &inputs[0] {
circuit::Value::Plaintext(circuit::Plaintext::Literal(circuit::Literal::Address(owner), ..)) => {
match record_type.owner().is_public() {
true => circuit::Owner::Public(owner.clone()),
false => circuit::Owner::Private(circuit::Plaintext::Literal(
circuit::Literal::Address(owner.clone()),
Default::default(),
)),
}
}
_ => bail!("Invalid record 'owner'"),
};
let mut entries = IndexMap::new();
for (entry, (entry_name, entry_type)) in
inputs.iter().skip(N::MIN_RECORD_ENTRIES).zip_eq(record_type.entries())
{
let register_type = RegisterType::from(ValueType::from(*entry_type));
let plaintext = match entry {
circuit::Value::Plaintext(plaintext) => {
stack.matches_register_type(
&circuit::Value::Plaintext(plaintext.clone()).eject_value(),
®ister_type,
)?;
plaintext.clone()
}
circuit::Value::Record(..) => bail!("Casting a record into a record entry is illegal"),
};
let entry_name = circuit::Identifier::constant(*entry_name);
match entry_type {
EntryType::Constant(..) => entries.insert(entry_name, circuit::Entry::Constant(plaintext)),
EntryType::Public(..) => entries.insert(entry_name, circuit::Entry::Public(plaintext)),
EntryType::Private(..) => entries.insert(entry_name, circuit::Entry::Private(plaintext)),
};
}
let index = circuit::Field::constant(Field::from_u64(self.destination.locator()));
let randomizer = A::hash_to_scalar_psd2(&[registers.tvk_circuit()?, index]);
let nonce = A::g_scalar_multiply(&randomizer);
let record = circuit::Record::<A, circuit::Plaintext<A>>::from_plaintext(owner, entries, nonce)?;
registers.store_circuit(stack, &self.destination, circuit::Value::Record(record))
}
RegisterType::ExternalRecord(_locator) => {
bail!("Illegal operation: Cannot cast to an external record.")
}
}
}
#[inline]
pub fn finalize(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut (impl RegistersLoad<N> + RegistersStore<N>),
) -> Result<()> {
let inputs: Vec<_> = self.operands.iter().map(|operand| registers.load(stack, operand)).try_collect()?;
match self.register_type {
RegisterType::Plaintext(PlaintextType::Literal(..)) => bail!("Casting to literal is currently unsupported"),
RegisterType::Plaintext(PlaintextType::Struct(struct_name)) => {
self.cast_to_struct(stack, registers, struct_name, inputs)
}
RegisterType::Record(_record_name) => {
bail!("Illegal operation: Cannot cast to a record in a finalize block.")
}
RegisterType::ExternalRecord(_locator) => {
bail!("Illegal operation: Cannot cast to an external record.")
}
}
}
#[inline]
pub fn output_types(
&self,
stack: &impl StackProgram<N>,
input_types: &[RegisterType<N>],
) -> Result<Vec<RegisterType<N>>> {
ensure!(
input_types.len() == self.operands.len(),
"Instruction '{}' expects {} operands, found {} operands",
Self::opcode(),
input_types.len(),
self.operands.len(),
);
match self.register_type {
RegisterType::Plaintext(PlaintextType::Literal(..)) => bail!("Casting to literal is currently unsupported"),
RegisterType::Plaintext(PlaintextType::Struct(struct_name)) => {
let struct_ = stack.program().get_struct(&struct_name)?;
ensure!(
input_types.len() >= N::MIN_STRUCT_ENTRIES,
"Casting to a struct requires at least {} operand",
N::MIN_STRUCT_ENTRIES
);
ensure!(
input_types.len() == struct_.members().len(),
"Casting to the struct {} requires {} operands, but {} were provided",
struct_.name(),
struct_.members().len(),
input_types.len()
);
for ((_, member_type), input_type) in struct_.members().iter().zip_eq(input_types) {
match input_type {
RegisterType::Plaintext(plaintext_type) => {
ensure!(
member_type == plaintext_type,
"Struct '{struct_name}' member type mismatch: expected '{member_type}', found '{plaintext_type}'"
)
}
RegisterType::Record(record_name) => bail!(
"Struct '{struct_name}' member type mismatch: expected '{member_type}', found record '{record_name}'"
),
RegisterType::ExternalRecord(locator) => bail!(
"Struct '{struct_name}' member type mismatch: expected '{member_type}', found external record '{locator}'"
),
}
}
}
RegisterType::Record(record_name) => {
let record = stack.program().get_record(&record_name)?;
ensure!(
input_types.len() >= N::MIN_RECORD_ENTRIES,
"Casting to a record requires at least {} operands",
N::MIN_RECORD_ENTRIES
);
ensure!(
input_types.len() == record.entries().len() + 1,
"Casting to the record {} requires {} operands, but {} were provided",
record.name(),
record.entries().len() + 1,
input_types.len()
);
ensure!(
input_types[0] == RegisterType::Plaintext(PlaintextType::Literal(LiteralType::Address)),
"Casting to a record requires the first operand to be an address"
);
for (input_type, (_, entry_type)) in
input_types.iter().skip(N::MIN_RECORD_ENTRIES).zip_eq(record.entries())
{
match input_type {
RegisterType::Plaintext(plaintext_type) => match entry_type {
EntryType::Constant(entry_type)
| EntryType::Public(entry_type)
| EntryType::Private(entry_type) => {
ensure!(
entry_type == plaintext_type,
"Record '{record_name}' entry type mismatch: expected '{entry_type}', found '{plaintext_type}'"
)
}
},
RegisterType::Record(record_name) => bail!(
"Record '{record_name}' entry type mismatch: expected '{entry_type}', found record '{record_name}'"
),
RegisterType::ExternalRecord(locator) => bail!(
"Record '{record_name}' entry type mismatch: expected '{entry_type}', found external record '{locator}'"
),
}
}
}
RegisterType::ExternalRecord(_locator) => {
bail!("Illegal operation: Cannot cast to an external record.")
}
}
Ok(vec![self.register_type])
}
}
impl<N: Network> Cast<N> {
fn cast_to_struct(
&self,
stack: &(impl StackMatches<N> + StackProgram<N>),
registers: &mut impl RegistersStore<N>,
struct_name: Identifier<N>,
inputs: Vec<Value<N>>,
) -> Result<()> {
if inputs.len() < N::MIN_STRUCT_ENTRIES {
bail!("Casting to a struct requires at least {} operand", N::MIN_STRUCT_ENTRIES)
}
let struct_ = stack.program().get_struct(&struct_name)?;
if inputs.len() != struct_.members().len() {
bail!(
"Casting to the struct {} requires {} operands, but {} were provided",
struct_.name(),
struct_.members().len(),
inputs.len()
)
}
let mut members = IndexMap::new();
for (member, (member_name, member_type)) in inputs.iter().zip_eq(struct_.members()) {
let plaintext = match member {
Value::Plaintext(plaintext) => {
stack.matches_plaintext(plaintext, member_type)?;
plaintext.clone()
}
Value::Record(..) => bail!("Casting a record into a struct member is illegal"),
};
members.insert(*member_name, plaintext);
}
let struct_ = Plaintext::Struct(members, Default::default());
registers.store(stack, &self.destination, Value::Plaintext(struct_))
}
}
impl<N: Network> Parser for Cast<N> {
#[inline]
fn parse(string: &str) -> ParserResult<Self> {
fn parse_operand<N: Network>(string: &str) -> ParserResult<Operand<N>> {
let (string, _) = Sanitizer::parse_whitespaces(string)?;
Operand::parse(string)
}
let (string, _) = tag(*Self::opcode())(string)?;
let (string, operands) = many1(parse_operand)(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, _) = tag("into")(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, destination) = Register::parse(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, _) = tag("as")(string)?;
let (string, _) = Sanitizer::parse_whitespaces(string)?;
let (string, register_type) = RegisterType::parse(string)?;
let max_operands = match register_type {
RegisterType::Plaintext(_) => N::MAX_STRUCT_ENTRIES,
RegisterType::Record(_) | RegisterType::ExternalRecord(_) => N::MAX_RECORD_ENTRIES,
};
match operands.len() <= max_operands {
true => Ok((string, Self { operands, destination, register_type })),
false => {
map_res(fail, |_: ParserResult<Self>| Err(error("Failed to parse 'cast' opcode: too many operands")))(
string,
)
}
}
}
}
impl<N: Network> FromStr for Cast<N> {
type Err = Error;
#[inline]
fn from_str(string: &str) -> Result<Self> {
match Self::parse(string) {
Ok((remainder, object)) => {
ensure!(remainder.is_empty(), "Failed to parse string. Found invalid character in: \"{remainder}\"");
Ok(object)
}
Err(error) => bail!("Failed to parse string. {error}"),
}
}
}
impl<N: Network> Debug for Cast<N> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
Display::fmt(self, f)
}
}
impl<N: Network> Display for Cast<N> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
let max_operands = match self.register_type() {
RegisterType::Plaintext(_) => N::MAX_STRUCT_ENTRIES,
RegisterType::Record(_) | RegisterType::ExternalRecord(_) => N::MAX_RECORD_ENTRIES,
};
if self.operands.len().is_zero() || self.operands.len() > max_operands {
eprintln!("The number of operands must be nonzero and <= {max_operands}");
return Err(fmt::Error);
}
write!(f, "{} ", Self::opcode())?;
self.operands.iter().try_for_each(|operand| write!(f, "{operand} "))?;
write!(f, "into {} as {}", self.destination, self.register_type)
}
}
impl<N: Network> FromBytes for Cast<N> {
fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
let num_operands = u8::read_le(&mut reader)? as usize;
if num_operands.is_zero() || num_operands > N::MAX_RECORD_ENTRIES {
return Err(error(format!("The number of operands must be nonzero and <= {}", N::MAX_RECORD_ENTRIES)));
}
let mut operands = Vec::with_capacity(num_operands);
for _ in 0..num_operands {
operands.push(Operand::read_le(&mut reader)?);
}
let destination = Register::read_le(&mut reader)?;
let register_type = RegisterType::read_le(&mut reader)?;
let max_operands = match register_type {
RegisterType::Plaintext(_) => N::MAX_STRUCT_ENTRIES,
RegisterType::Record(_) | RegisterType::ExternalRecord(_) => N::MAX_RECORD_ENTRIES,
};
if num_operands.is_zero() || num_operands > max_operands {
return Err(error(format!("The number of operands must be nonzero and <= {max_operands}")));
}
Ok(Self { operands, destination, register_type })
}
}
impl<N: Network> ToBytes for Cast<N> {
fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
let max_operands = match self.register_type() {
RegisterType::Plaintext(_) => N::MAX_STRUCT_ENTRIES,
RegisterType::Record(_) | RegisterType::ExternalRecord(_) => N::MAX_RECORD_ENTRIES,
};
if self.operands.len().is_zero() || self.operands.len() > max_operands {
return Err(error(format!("The number of operands must be nonzero and <= {max_operands}")));
}
(self.operands.len() as u8).write_le(&mut writer)?;
self.operands.iter().try_for_each(|operand| operand.write_le(&mut writer))?;
self.destination.write_le(&mut writer)?;
self.register_type.write_le(&mut writer)
}
}
#[cfg(test)]
mod tests {
use super::*;
use console::{network::Testnet3, program::Identifier};
type CurrentNetwork = Testnet3;
#[test]
fn test_parse() {
let (string, cast) =
Cast::<CurrentNetwork>::parse("cast r0.owner r0.token_amount into r1 as token.record").unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(cast.operands.len(), 2, "The number of operands is incorrect");
assert_eq!(
cast.operands[0],
Operand::Register(Register::Member(0, vec![Identifier::from_str("owner").unwrap()])),
"The first operand is incorrect"
);
assert_eq!(
cast.operands[1],
Operand::Register(Register::Member(0, vec![Identifier::from_str("token_amount").unwrap()])),
"The second operand is incorrect"
);
assert_eq!(cast.destination, Register::Locator(1), "The destination register is incorrect");
assert_eq!(
cast.register_type,
RegisterType::Record(Identifier::from_str("token").unwrap()),
"The value type is incorrect"
);
}
#[test]
fn test_parse_cast_into_plaintext_max_operands() {
let mut string = "cast ".to_string();
let mut operands = Vec::with_capacity(CurrentNetwork::MAX_STRUCT_ENTRIES);
for i in 0..CurrentNetwork::MAX_STRUCT_ENTRIES {
string.push_str(&format!("r{i} "));
operands.push(Operand::Register(Register::Locator(i as u64)));
}
string.push_str(&format!("into r{} as foo", CurrentNetwork::MAX_STRUCT_ENTRIES));
let (string, cast) = Cast::<CurrentNetwork>::parse(&string).unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(cast.operands.len(), CurrentNetwork::MAX_STRUCT_ENTRIES, "The number of operands is incorrect");
assert_eq!(cast.operands, operands, "The operands are incorrect");
assert_eq!(
cast.destination,
Register::Locator(CurrentNetwork::MAX_STRUCT_ENTRIES as u64),
"The destination register is incorrect"
);
assert_eq!(
cast.register_type,
RegisterType::Plaintext(PlaintextType::Struct(Identifier::from_str("foo").unwrap())),
"The value type is incorrect"
);
}
#[test]
fn test_parse_cast_into_record_max_operands() {
let mut string = "cast ".to_string();
let mut operands = Vec::with_capacity(CurrentNetwork::MAX_RECORD_ENTRIES);
for i in 0..CurrentNetwork::MAX_RECORD_ENTRIES {
string.push_str(&format!("r{i} "));
operands.push(Operand::Register(Register::Locator(i as u64)));
}
string.push_str(&format!("into r{} as token.record", CurrentNetwork::MAX_RECORD_ENTRIES));
let (string, cast) = Cast::<CurrentNetwork>::parse(&string).unwrap();
assert!(string.is_empty(), "Parser did not consume all of the string: '{string}'");
assert_eq!(cast.operands.len(), CurrentNetwork::MAX_RECORD_ENTRIES, "The number of operands is incorrect");
assert_eq!(cast.operands, operands, "The operands are incorrect");
assert_eq!(
cast.destination,
Register::Locator((CurrentNetwork::MAX_RECORD_ENTRIES) as u64),
"The destination register is incorrect"
);
assert_eq!(
cast.register_type,
RegisterType::Record(Identifier::from_str("token").unwrap()),
"The value type is incorrect"
);
}
#[test]
fn test_parse_cast_into_record_too_many_operands() {
let mut string = "cast ".to_string();
for i in 0..=CurrentNetwork::MAX_RECORD_ENTRIES {
string.push_str(&format!("r{i} "));
}
string.push_str(&format!("into r{} as token.record", CurrentNetwork::MAX_RECORD_ENTRIES + 1));
assert!(Cast::<CurrentNetwork>::parse(&string).is_err(), "Parser did not error");
}
#[test]
fn test_parse_cast_into_plaintext_too_many_operands() {
let mut string = "cast ".to_string();
for i in 0..=CurrentNetwork::MAX_STRUCT_ENTRIES {
string.push_str(&format!("r{i} "));
}
string.push_str(&format!("into r{} as foo", CurrentNetwork::MAX_STRUCT_ENTRIES + 1));
assert!(Cast::<CurrentNetwork>::parse(&string).is_err(), "Parser did not error");
}
}