Requirement

Today I came across a requirement to generate a random instance of an enumeration type.

Unlike Python, this is not as convenient and requires a specific Trait implementation with Rust. The simplest idea is to number the different members of the enum type, generate a random number, instantiate the corresponding member, and if the member has data, recursively generate that data at random.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
impl Distribution<Instruction> for Standard {
    fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> Instruction {
        match rng.gen_range(0..459) {
            0 => Instruction::Unreachable,
            1 => Instruction::Nop,
            2 => Instruction::Block(BlockType::FunctionType(rng.gen())),
            3 => Instruction::Catch(rng.gen()),
            // ... 预估超过2千行
            _ => unreachable!(),
        }
    }
}

The requirement itself is simple, but the problem is that this enum type has so many members, 459 of them, that it would take at least half a day to write and be tedious according to traditional thinking. As you can see in the figure, it would take thousands of lines to implement a simple function for this enum type.

rust

I hate this simple but heavy work so much that I came up with the Rust procedure macro.

I hate this simple but heavy work so much that I came up with the Rust procedure macro.

Procedural Macros

When I first learned Rust, I learned about macros, which is a technique I’ve practiced in other projects because it’s a template that generates code, so it doesn’t meet my needs this time. Procedural macros, on the other hand, are a great tool for code generation, as they can parse and process the code itself by writing functions that operate on the basis of an abstract syntax tree, so they can implement very complex logic.

Writing procedural macros is a bit more brain-intensive, and writing a procedural macro that automatically generates code can make me lose a few hairs. But I’m more willing to let go of a few hairs than to waste my life writing a few thousand lines of boring code. And I was surprised to find that the rand library used to implement a similar procedure macro for arbitrary structures, tuples and enumerations in 0.5, which is no longer maintained, but I can learn from it.

Define the #[derive] macro

My requirement is to automatically implement impl Distribution<Instruction> for Standard based on the membership information of Instruction, and here I need to write a #[derive] macro to act on Instruction.

1
2
#[derive(Debug, Rand)]
pub enum Instruction {...}

First we define the #[derive] procedure macro named Rand. In this function we can get the token sequence of Instruction, parse it into an abstract syntax tree (AST), and finally generate a new token sequence from the AST and our logic, i.e. the final generated code.

1
2
3
4
5
6
#[proc_macro_derive(Rand)]
pub fn rand_derive(input: TokenStream) -> TokenStream {
    let ast = parse_macro_input!(input as DeriveInput);
    let tokens = impl_rand_derive(&ast);
    TokenStream::from(tokens)
}

For parsing token sequences into ASTs, the community generally uses the syn library, while reducing the data structure of ASTs into token sequences generally uses the quote library, and I was surprised to find that both libraries were developed by David Tolnay when I searched today. After looking at his library published in crates.io, it’s really strong, so I suggest you check it out and worship it like crazy!

Parsing and generation

Once we have the abstract syntax tree, the top level is Instruction and we should iterate through all its members, analyze their types and generate code based on the relevant information.

There are three possible types of members:

  • Named: with name, similar to Named { x: u8, y: i32}.
  • Unnamed: without name, similar to Unamed(u8, i32).
  • Unit: () type.

For both Named and Unamed types, it is necessary to iterate through all their elements, recursively generate code, and initialize the data with __rng.gen().

Finally, the number of members of the enumerated type is determined and the match statement is generated.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
let rand = if let syn::Data::Enum(ref data) = ast.data {
    let ref virants = data.variants;
    let len = virants.len();

    let mut arms = virants.iter().map(|variant| {
        let ref ident = variant.ident;
        match &variant.fields {
            syn::Fields::Named(field) => {
                let fields = field
                    .named
                    .iter()
                    .filter_map(|field| field.ident.as_ref())
                    .map(|ident| quote! { #ident: __rng.gen() })
                    .collect::<Vec<_>>();
                quote! { #name::#ident { #(#fields,)* } }
            }
            syn::Fields::Unnamed(field) => {
                let fields = field
                    .unnamed
                    .iter()
                    .map(|field| quote! { __rng.gen() })
                    .collect::<Vec<_>>();
                quote! { #name::#ident (#(#fields),*) }
            }
            syn::Fields::Unit => quote! { #name::#ident },
        }
    });

    match len {
        1 => quote! { #(#arms)* },
        2 => {
            let (a, b) = (arms.next(), arms.next());
            quote! { if __rng.gen() { #a } else { #b } }
        }
        _ => {
            let mut variants = arms
                .enumerate()
                .map(|(index, arm)| quote! { #index => #arm })
                .collect::<Vec<_>>();
            variants.push(quote! { _ => unreachable!() });
            quote! { match __rng.gen_range(0..#len) { #(#variants,)* } }
        }
    }
} else {
    unimplemented!()
};

I hate recursion

Immediately afterwards it becomes clear that recursion __rng.gen() in the Named and Unamed sections above requires that the types they use also implement the corresponding trait. excluding the existing implementation of the basic types, the remaining types need to be implemented by us manually, which requires that our procedure macros also apply to other structures.

So our function needs to be modified to handle other non-enumeration types: structures and tuples (tuples are not used in my requirements, so I won’t implement them).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
let rand = match ast.data {
    syn::Data::Struct(ref data) => {
        let fields = data
            .fields
            .iter()
            .filter_map(|field| field.ident.as_ref())
            .map(|ident| quote! { #ident: __rng.gen() })
            .collect::<Vec<_>>();

        quote! { #name { #(#fields,)* } }
    }
    syn::Data::Enum(ref data) => {
        // 刚刚的方法拿进来
    }
    _ => unimplemented!(),
};

Tested and found that 458 of the 459 members passed, the remaining one being of type Cow. It was really annoying that there was no way to implement this trait for Cow, and even theoretically there was no way to generate a random Cow because it doesn’t have data at all, it only has pointers.

I immediately came up with a solution, sacrificing a bit of performance and replacing Cow with Vec. Although we still can’t implement this trait for Vec (because Vec is externally defined), I can determine the type during parsing and manually generate random data of random length if it’s Vec, I’m a little resourceful.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
let fields = field
    .unnamed
    .iter()
    .map(|field| {
        if inner_type_is_vec(&field.ty) {
            quote! {{
                    let i = __rng.gen_range(0..100);
                    __rng.sample_iter(::rand::distributions::Standard)
                        .take(i)
                        .collect()
            }}
        } else {
            quote! { __rng.gen() }
        }
    })
    .collect::<Vec<_>>();


fn inner_type_is_vec(ty: &syn::Type) -> bool {
    if let syn::Type::Path(syn::TypePath { ref path, .. }) = ty {
        if let Some(seg) = path.segments.last() {
            return seg.ident == "Vec"
        }
    }
    false
}

Test, all passed! Happy!

Summary

Learning procedural Macros, writing procedural Macros, writing test cases, and finally passing the tests took a lot of effort. It was quite an accomplishment until just now, when I found out that although rand no longer maintains this derive macro, there is a third-party maintained version, and after testing it, except for a few test cases that didn’t pass, it was completely usable for my current needs. What a pain, if only I had found it earlier, it would have been another afternoon of wheel building. But fortunately the end result was good. By writing procedural Macros, I was able to complete a task that required 2k+ lines of code with 100 lines of code, and most importantly, it wasn’t boring.

Rust’s macro mechanism is really powerful and can be used to do a lot of interesting things. For example, the current variable-length parameter functions and serialization deserialization are implemented in Rust through procedural macros. With procedural macros you can bring forward a lot of work that needs to be done at runtime in other languages to compile time, significantly improving the performance and flexibility of Rust programs and providing us with powerful expression and implementation capabilities.

It occurred to me that macros could be used for code obfuscation and literal volume encryption, and I’ll try them out later.