1
use proc_macro2::TokenStream;
2

            
3
use quote::format_ident;
4
use quote::quote;
5
use quote::ToTokens;
6
use syn::parse_quote;
7
use syn::Item;
8
use syn::ItemMod;
9

            
10
8
pub(crate) fn mcrl2_derive_terms_impl(_attributes: TokenStream, input: TokenStream) -> TokenStream {
11
8
    // Parse the input tokens into a syntax tree
12
8
    let mut ast: ItemMod = syn::parse2(input.clone()).expect("mcrl2_term can only be applied to a module");
13

            
14
8
    if let Some((_, content)) = &mut ast.content {
15
        // Generated code blocks are added to this list.
16
8
        let mut added = vec![];
17

            
18
151
        for item in content.iter_mut() {
19
151
            match item {
20
29
                Item::Struct(object) => {
21
                    // If the struct is annotated with term we process it as a term.
22
68
                    if let Some(attr) = object.attrs.iter().find(|attr| attr.meta.path().is_ident("mcrl2_term")) {
23
                        // The #term(assertion) annotation must contain an assertion
24
29
                        let assertion = match attr.parse_args::<syn::Ident>() {
25
28
                            Ok(assertion) => {
26
28
                                let assertion_msg = format!("{assertion}");
27
28
                                quote!(
28
28
                                    debug_assert!(#assertion(&term), "Term {:?} does not satisfy {}", term, #assertion_msg)
29
28
                                )
30
                            }
31
1
                            Err(_x) => {
32
1
                                quote!()
33
                            }
34
                        };
35

            
36
                        // Add the expected derive macros to the input struct.
37
29
                        object
38
29
                            .attrs
39
29
                            .push(parse_quote!(#[derive(Clone, Default, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]));
40
29

            
41
29
                        // ALL structs in this module must contain the term.
42
29
                        assert!(
43
29
                            object.fields.iter().any(|field| {
44
29
                                if let Some(name) = &field.ident {
45
29
                                    name == "term"
46
                                } else {
47
                                    false
48
                                }
49
29
                            }),
50
                            "The struct {} in mod {} has no field 'term: ATerm'",
51
                            object.ident,
52
                            ast.ident
53
                        );
54

            
55
29
                        let name = format_ident!("{}", object.ident);
56
29

            
57
29
                        // Add a <name>Ref struct that contains the ATermRef<'a> and
58
29
                        // the implementation and both protect and borrow. Also add
59
29
                        // the conversion from and to an ATerm.
60
29
                        let name_ref = format_ident!("{}Ref", object.ident);
61
29
                        let generated: TokenStream = quote!(
62
29
                            impl #name {
63
29
                                pub fn copy<'a>(&'a self) -> #name_ref<'a> {
64
29
                                    self.term.copy().into()
65
29
                                }
66
29
                            }
67
29

            
68
29
                            impl From<ATerm> for #name {
69
29
                                fn from(term: ATerm) -> #name {
70
29
                                    #assertion;
71
29
                                    #name {
72
29
                                        term
73
29
                                    }
74
29
                                }
75
29
                            }
76
29

            
77
29
                            impl Into<ATerm> for #name {
78
29
                                fn into(self) -> ATerm {
79
29
                                    self.term
80
29
                                }
81
29
                            }
82
29

            
83
29
                            impl Deref for #name {
84
29
                                type Target = ATerm;
85
29

            
86
29
                                fn deref(&self) -> &Self::Target {
87
29
                                    &self.term
88
29
                                }
89
29
                            }
90
29

            
91
29
                            impl Borrow<ATerm> for #name {
92
29
                                fn borrow(&self) -> &ATerm {
93
29
                                    &self.term
94
29
                                }
95
29
                            }
96
29

            
97
29
                            impl Borrow<ATermRef<'static>> for #name {
98
29
                                fn borrow(&self) -> &ATermRef<'static> {
99
29
                                    &self.term
100
29
                                }
101
29
                            }
102
29

            
103
29
                            impl Markable for #name {
104
29
                                fn mark(&self, todo: Todo) {
105
29
                                    self.term.mark(todo);
106
29
                                }
107
29

            
108
29
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
109
29
                                    &self.term.copy() == term
110
29
                                }
111
29

            
112
29
                                fn len(&self) -> usize {
113
29
                                    1
114
29
                                }
115
29
                            }
116
29

            
117
29
                            #[derive(Default, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
118
29
                            pub struct #name_ref<'a> {
119
29
                                pub(crate) term: ATermRef<'a>
120
29
                            }
121
29

            
122
29
                            impl<'a> #name_ref<'a> {
123
29
                                pub fn copy<'b>(&'b self) -> #name_ref<'b> {
124
29
                                    self.term.copy().into()
125
29
                                }
126
29

            
127
29
                                pub fn protect(&self) -> #name {
128
29
                                    self.term.protect().into()
129
29
                                }
130
29
                            }
131
29

            
132
29
                            impl<'a> From<ATermRef<'a>> for #name_ref<'a> {
133
29
                                fn from(term: ATermRef<'a>) -> #name_ref<'a> {
134
29
                                    #assertion;
135
29
                                    #name_ref {
136
29
                                        term
137
29
                                    }
138
29
                                }
139
29
                            }
140
29

            
141
29
                            impl<'a> Into<ATermRef<'a>> for #name_ref<'a> {
142
29
                                fn into(self) -> ATermRef<'a> {
143
29
                                    self.term
144
29
                                }
145
29
                            }
146
29

            
147
29
                            impl<'a> Deref for #name_ref<'a> {
148
29
                                type Target = ATermRef<'a>;
149
29

            
150
29
                                fn deref(&self) -> &Self::Target {
151
29
                                    &self.term
152
29
                                }
153
29
                            }
154
29

            
155
29
                            impl<'a> Borrow<ATermRef<'a>> for #name_ref<'a> {
156
29
                                fn borrow(&self) -> &ATermRef<'a> {
157
29
                                    &self.term
158
29
                                }
159
29
                            }
160
29

            
161
29
                            impl<'a> Markable for #name_ref<'a> {
162
29
                                fn mark(&self, todo: Todo) {
163
29
                                    self.term.mark(todo);
164
29
                                }
165
29

            
166
29
                                fn contains_term(&self, term: &ATermRef<'_>) -> bool {
167
29
                                    &self.term == term
168
29
                                }
169
29

            
170
29
                                fn len(&self) -> usize {
171
29
                                    1
172
29
                                }
173
29
                            }
174
29
                        );
175
29

            
176
29
                        added.push(Item::Verbatim(generated));
177
                    }
178
                }
179
62
                Item::Impl(implementation) => {
180
62
                    if !implementation
181
62
                        .attrs
182
62
                        .iter()
183
62
                        .any(|attr| attr.meta.path().is_ident("mcrl2_ignore"))
184
                    {
185
                        // Duplicate the implementation for the ATermRef struct that is generated above.
186
47
                        let mut ref_implementation = implementation.clone();
187
47

            
188
47
                        // Remove ignore functions
189
92
                        ref_implementation.items.retain(|item| match item {
190
92
                            syn::ImplItem::Fn(func) => {
191
95
                                !func.attrs.iter().any(|attr| attr.meta.path().is_ident("mcrl2_ignore"))
192
                            }
193
                            _ => true,
194
92
                        });
195

            
196
47
                        if let syn::Type::Path(path) = ref_implementation.self_ty.as_ref() {
197
47
                            // Build an identifier TestRef<'_>
198
47
                            let name_ref = format_ident!("{}Ref", path.path.get_ident().unwrap());
199
47
                            let path = parse_quote!(#name_ref <'_>);
200
47

            
201
47
                            ref_implementation.self_ty = Box::new(syn::Type::Path(syn::TypePath { qself: None, path }));
202
47

            
203
47
                            added.push(Item::Verbatim(ref_implementation.into_token_stream()));
204
47
                        }
205
15
                    }
206
                }
207
60
                _ => {
208
60
                    // Ignore the rest.
209
60
                }
210
            }
211
        }
212

            
213
8
        content.append(&mut added);
214
    }
215

            
216
    // Hand the output tokens back to the compiler
217
8
    ast.into_token_stream()
218
8
}
219

            
220
#[cfg(test)]
221
mod tests {
222
    use std::str::FromStr;
223

            
224
    use super::*;
225

            
226
    #[test]
227
1
    fn test_macro() {
228
1
        let input = "
229
1
            mod anything {
230
1

            
231
1
                #[mcrl2_term(test)]
232
1
                #[derive(Debug)]
233
1
                struct Test {
234
1
                    term: ATerm,
235
1
                }
236
1

            
237
1
                impl Test {
238
1
                    fn a_function() {
239
1

            
240
1
                    }
241
1
                }
242
1
            }
243
1
        ";
244
1

            
245
1
        let tokens = TokenStream::from_str(input).unwrap();
246
1
        let result = mcrl2_derive_terms_impl(TokenStream::default(), tokens);
247
1

            
248
1
        println!("{result}");
249
1
    }
250
}