struqture_py_macros/
lib.rs

1// Copyright © 2021-2023 HQS Quantum Simulations GmbH. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
4// in compliance with the License. You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software distributed under the
9// License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
10// express or implied. See the License for the specific language governing permissions and
11// limitations under the License.
12
13use quote::format_ident;
14use std::collections::HashSet;
15use syn::parse::{Parse, ParseStream};
16use syn::punctuated::Punctuated;
17use syn::{Ident, Token, Type, TypePath};
18
19mod product_wrapper;
20use product_wrapper::productwrapper;
21
22/// Attribute macro for constructing the pyo3 implementation for mixed indices.
23#[proc_macro_attribute]
24pub fn product_wrapper(
25    metadata: proc_macro::TokenStream,
26    input: proc_macro::TokenStream,
27) -> proc_macro::TokenStream {
28    productwrapper(metadata, input)
29}
30
31mod noiseless_system_wrapper;
32use noiseless_system_wrapper::noiselesswrapper;
33
34/// Attribute macro for constructing the pyo3 implementation for noiseless systems.
35#[proc_macro_attribute]
36pub fn noiseless_system_wrapper(
37    metadata: proc_macro::TokenStream,
38    input: proc_macro::TokenStream,
39) -> proc_macro::TokenStream {
40    noiselesswrapper(metadata, input)
41}
42
43mod noisy_system_wrapper;
44use noisy_system_wrapper::noisywrapper;
45
46/// Attribute macro for constructing the pyo3 implementation for noisy systems.
47#[proc_macro_attribute]
48pub fn noisy_system_wrapper(
49    metadata: proc_macro::TokenStream,
50    input: proc_macro::TokenStream,
51) -> proc_macro::TokenStream {
52    noisywrapper(metadata, input)
53}
54
55mod mappings;
56use mappings::mappings_macro;
57
58/// Attribute macro for constructing the pyo3 implementation for mappings.
59#[proc_macro_attribute]
60pub fn mappings(
61    metadata: proc_macro::TokenStream,
62    input: proc_macro::TokenStream,
63) -> proc_macro::TokenStream {
64    mappings_macro(metadata, input)
65}
66
67// Helper functions
68// Struct for parsed derive macro arguments. Used to identify structs belonging to enums
69#[derive(Debug)]
70struct AttributeMacroArguments(HashSet<String>);
71
72impl AttributeMacroArguments {
73    pub fn contains(&self, st: &str) -> bool {
74        self.0.contains(st)
75    }
76    pub fn _ids(&self) -> Vec<Ident> {
77        self.0
78            .clone()
79            .into_iter()
80            .map(|s| format_ident!("{}", s))
81            .collect()
82    }
83}
84
85fn strip_python_wrapper_name(ident: &Type) -> (String, proc_macro2::Ident) {
86    // get name of the interal struct (not the wrapper)
87    let type_path = match ident.clone() {
88        Type::Path(TypePath { path: p, .. }) => p,
89        _ => panic!("Trait only supports newtype variants with normal types of form path"),
90    };
91    let type_string = match type_path.get_ident() {
92        Some(ident_path) => ident_path.to_string(),
93        _ => match type_path.segments.last() {
94            Some(segment) => segment.ident.to_string(),
95            None => panic!("Can't extract string."),
96        },
97    };
98    // Cut off "Wrapper" at the end of the Impl name
99    let struct_name = type_string
100        .as_str()
101        .strip_suffix("Wrapper")
102        .expect("Not conform to Wrapper naming scheme.");
103    let struct_ident = format_ident!("{}", struct_name);
104    (struct_name.to_string(), struct_ident)
105}
106
107impl Parse for AttributeMacroArguments {
108    fn parse(input: ParseStream) -> syn::parse::Result<Self> {
109        // Parse arguments as comma separated list of idents
110        let arguments = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
111        Ok(Self(
112            arguments.into_iter().map(|id| id.to_string()).collect(),
113        ))
114    }
115}