1
use std::fmt;
2

            
3
use itertools::Itertools;
4
use mcrl2::aterm::ATermRef;
5
use mcrl2::aterm::Markable;
6
use mcrl2::aterm::Protected;
7
use mcrl2::aterm::Protector;
8
use mcrl2::aterm::TermPool;
9
use mcrl2::aterm::Todo;
10
use mcrl2::data::is_data_expression;
11
use mcrl2::data::is_data_machine_number;
12
use mcrl2::data::is_data_variable;
13
use mcrl2::data::DataApplication;
14
use mcrl2::data::DataExpression;
15
use mcrl2::data::DataExpressionRef;
16
use mcrl2::data::DataFunctionSymbolRef;
17

            
18
use crate::utilities::PositionIndexed;
19
use crate::Rule;
20

            
21
use super::create_var_map;
22
use super::ExplicitPosition;
23
use super::PositionIterator;
24

            
25
use log::trace;
26

            
27
/// This stack is used to avoid recursion and also to keep track of terms in
28
/// normal forms by explicitly representing the rewrites of a right hand
29
/// side.
30
#[derive(Default)]
31
pub struct InnermostStack {
32
    pub configs: Protected<Vec<Config>>,
33
    pub terms: Protected<Vec<DataExpressionRef<'static>>>,
34
}
35

            
36
impl InnermostStack {
37
    /// Updates the InnermostStack to integrate the rhs_stack instructions.
38
11317516
    pub fn integrate(
39
11317516
        write_configs: &mut Protector<Vec<Config>>,
40
11317516
        write_terms: &mut Protector<Vec<DataExpressionRef<'static>>>,
41
11317516
        rhs_stack: &RHSStack,
42
11317516
        term: &DataExpression,
43
11317516
        result_index: usize,
44
11317516
    ) {
45
11317516
        // TODO: This ignores the first element of the stack, but that is kind of difficult to deal with.
46
11317516
        let top_of_stack = write_terms.len();
47
11317516
        write_terms.reserve(rhs_stack.stack_size - 1); // We already reserved space for the result.
48
20282209
        for _ in 0..rhs_stack.stack_size - 1 {
49
20282209
            write_terms.push(Default::default());
50
20282209
        }
51

            
52
11317516
        let mut first = true;
53
16056438
        for config in rhs_stack.innermost_stack.read().iter() {
54
16056438
            match config {
55
16056438
                Config::Construct(symbol, arity, offset) => {
56
16056438
                    if first {
57
10777086
                        // The first result must be placed on the original result index.
58
10777086
                        InnermostStack::add_result(write_configs, symbol.copy(), *arity, result_index);
59
10777087
                    } else {
60
5279352
                        // Otherwise, we put it on the end of the stack.
61
5279352
                        InnermostStack::add_result(write_configs, symbol.copy(), *arity, top_of_stack + offset - 1);
62
5279352
                    }
63
                }
64
                Config::Rewrite(_) => {
65
                    unreachable!("This case should not happen");
66
                }
67
                Config::Return() => {
68
                    unreachable!("This case should not happen");
69
                }
70
            }
71
16056438
            first = false;
72
        }
73
11317516
        trace!(
74
            "\t applied stack size: {}, substitution: {}, stack: [{}]",
75
            rhs_stack.stack_size,
76
            rhs_stack.variables.iter().format_with(", ", |element, f| {
77
                f(&format_args!("{} -> {}", element.0, element.1))
78
            }),
79
            rhs_stack.innermost_stack.read().iter().format("\n")
80
        );
81

            
82
11317516
        debug_assert!(
83
11317516
            rhs_stack.stack_size != 1 || rhs_stack.variables.len() <= 1,
84
            "There can only be a single variable in the right hand side"
85
        );
86
11317516
        if rhs_stack.stack_size == 1 && rhs_stack.variables.len() == 1 {
87
540430
            // This is a special case where we place the result on the correct position immediately.
88
540430
            // The right hand side is only a variable
89
540430
            let t: ATermRef<'_> = write_terms.protect(&term.get_position(&rhs_stack.variables[0].0));
90
540430
            write_terms[result_index] = t.into();
91
540430
        } else {
92
25779943
            for (position, index) in &rhs_stack.variables {
93
15002857
                // Add the positions to the stack.
94
15002857
                let t = write_terms.protect(&term.get_position(position));
95
15002857
                write_terms[top_of_stack + index - 1] = t.into();
96
15002857
            }
97
        }
98
11317516
    }
99

            
100
    /// Indicate that the given symbol with arity can be constructed at the given index.
101
41794097
    pub fn add_result(
102
41794097
        write_configs: &mut Protector<Vec<Config>>,
103
41794097
        symbol: DataFunctionSymbolRef<'_>,
104
41794097
        arity: usize,
105
41794097
        index: usize,
106
41794097
    ) {
107
41794097
        let symbol = write_configs.protect(&symbol.into());
108
41794097
        write_configs.push(Config::Construct(symbol.into(), arity, index));
109
41794097
    }
110

            
111
    /// Indicate that the term must be rewritten and its result must be placed at the given index.
112
25737659
    pub fn add_rewrite(
113
25737659
        write_configs: &mut Protector<Vec<Config>>,
114
25737659
        write_terms: &mut Protector<Vec<DataExpressionRef<'static>>>,
115
25737659
        term: DataExpressionRef<'_>,
116
25737659
        index: usize,
117
25737659
    ) {
118
25737659
        let term = write_terms.protect(&term);
119
25737659
        write_configs.push(Config::Rewrite(index));
120
25737659
        write_terms.push(term.into());
121
25737659
    }
122
}
123

            
124
#[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Debug)]
125
pub enum Config {
126
    /// Rewrite the top of the stack and put result at the given index.
127
    Rewrite(usize),
128
    /// Constructs function symbol with given arity at the given index.
129
    Construct(DataFunctionSymbolRef<'static>, usize, usize),
130
    /// Yields the given index as returned term.
131
    Return(),
132
}
133

            
134
impl Markable for Config {
135
    fn mark(&self, todo: Todo<'_>) {
136
        if let Config::Construct(t, _, _) = self {
137
            let t: ATermRef<'_> = t.copy().into();
138
            t.mark(todo);
139
        }
140
    }
141

            
142
3669759001
    fn contains_term(&self, term: &ATermRef<'_>) -> bool {
143
3669759001
        if let Config::Construct(t, _, _) = self {
144
3469125012
            term == &<DataFunctionSymbolRef as Into<ATermRef>>::into(t.copy())
145
        } else {
146
200633989
            false
147
        }
148
3669759001
    }
149

            
150
    fn len(&self) -> usize {
151
        if let Config::Construct(_, _, _) = self {
152
            1
153
        } else {
154
            0
155
        }
156
    }
157
}
158

            
159
impl fmt::Display for InnermostStack {
160
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161
        writeln!(f, "Terms: [")?;
162
        for (i, term) in self.terms.read().iter().enumerate() {
163
            if !term.is_default() {
164
                writeln!(f, "{}\t{}", i, term)?;
165
            } else {
166
                writeln!(f, "{}\t<default>", i)?;
167
            }
168
        }
169
        writeln!(f, "]")?;
170

            
171
        writeln!(f, "Configs: [")?;
172
        for config in self.configs.read().iter() {
173
            writeln!(f, "\t{}", config)?;
174
        }
175
        write!(f, "]")
176
    }
177
}
178

            
179
impl fmt::Display for Config {
180
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181
        match self {
182
            Config::Rewrite(result) => write!(f, "Rewrite({})", result),
183
            Config::Construct(symbol, arity, result) => {
184
                write!(f, "Construct({}, {}, {})", symbol, arity, result)
185
            }
186
            Config::Return() => write!(f, "Return()"),
187
        }
188
    }
189
}
190

            
191
/// A stack for the right-hand side.
192
pub struct RHSStack {
193
    /// The innermost rewrite stack for the right hand side and the positions that must be added to the stack.
194
    innermost_stack: Protected<Vec<Config>>,
195
    variables: Vec<(ExplicitPosition, usize)>,
196
    stack_size: usize,
197
}
198

            
199
impl RHSStack {
200
    /// Construct a new right-hand stack for a given equation/rewrite rule.
201
36082
    pub fn new(rule: &Rule) -> RHSStack {
202
36082
        let var_map = create_var_map(&rule.lhs.clone().into());
203
36082

            
204
36082
        // Compute the extra information for the InnermostRewriter.
205
36082
        let mut innermost_stack: Protected<Vec<Config>> = Protected::new(vec![]);
206
36082
        let mut variables = vec![];
207
36082
        let mut stack_size = 0;
208

            
209
808877
        for (term, position) in PositionIterator::new(rule.rhs.copy().into()) {
210
808877
            if let Some(index) = position.indices.last() {
211
772795
                if *index == 1 {
212
427592
                    continue; // Skip the function symbol.
213
345203
                }
214
36082
            }
215

            
216
381285
            if is_data_variable(&term) {
217
42818
                variables.push((
218
42818
                    var_map
219
42818
                        .get(&term.protect())
220
42818
                        .expect("All variables in the right hand side must occur in the left hand side")
221
42818
                        .clone(),
222
42818
                    stack_size,
223
42818
                ));
224
42818
                stack_size += 1;
225
338467
            } else if is_data_machine_number(&term) {
226
58223
                // Skip SortId(@NoValue) and OpId
227
280244
            } else if is_data_expression(&term) {
228
58223
                let t: DataExpressionRef = term.into();
229
58223
                let arity = t.data_arguments().len();
230
58223
                let mut write = innermost_stack.write();
231
58223
                let symbol = write.protect(&t.data_function_symbol().into());
232
58223
                write.push(Config::Construct(symbol.into(), arity, stack_size));
233
58223
                stack_size += 1;
234
222021
            } else {
235
222021
                // Skip intermediate terms such as UntypeSortUnknown.
236
222021
            }
237
        }
238

            
239
36082
        RHSStack {
240
36082
            innermost_stack,
241
36082
            stack_size,
242
36082
            variables,
243
36082
        }
244
36082
    }
245

            
246
    /// Evaluate the rhs stack for the given term and returns the result.
247
1
    pub fn evaluate(&self, tp: &mut TermPool, term: &DataExpression) -> DataExpression {
248
1
        let mut stack = InnermostStack::default();
249
1
        stack.terms.write().push(DataExpressionRef::default());
250
1

            
251
1
        InnermostStack::integrate(&mut stack.configs.write(), &mut stack.terms.write(), self, term, 0);
252
        loop {
253
4
            trace!("{}", stack);
254

            
255
4
            let mut write_configs = stack.configs.write();
256
4
            if let Some(config) = write_configs.pop() {
257
3
                match config {
258
3
                    Config::Construct(symbol, arity, index) => {
259
3
                        // Take the last arity arguments.
260
3
                        let mut write_terms = stack.terms.write();
261
3
                        let length = write_terms.len();
262
3

            
263
3
                        let arguments = &write_terms[length - arity..];
264

            
265
3
                        let term: DataExpression = if arguments.is_empty() {
266
                            symbol.protect().into()
267
                        } else {
268
3
                            DataApplication::new(tp, &symbol.copy(), arguments).into()
269
                        };
270

            
271
                        // Add the term on the stack.
272
3
                        write_terms.drain(length - arity..);
273
3
                        let t = write_terms.protect(&term);
274
3
                        write_terms[index] = t.into();
275
                    }
276
                    Config::Rewrite(_) => {
277
                        unreachable!("This case should not happen");
278
                    }
279
                    Config::Return() => {
280
                        unreachable!("This case should not happen");
281
                    }
282
                }
283
            } else {
284
1
                break;
285
1
            }
286
1
        }
287
1

            
288
1
        debug_assert!(
289
1
            stack.terms.read().len() == 1,
290
            "Expect exactly one term on the result stack"
291
        );
292

            
293
1
        let mut write_terms = stack.terms.write();
294
1

            
295
1
        write_terms
296
1
            .pop()
297
1
            .expect("The result should be the last element on the stack")
298
1
            .protect()
299
1
    }
300
}
301

            
302
impl Clone for RHSStack {
303
    fn clone(&self) -> Self {
304
        // TODO: It would make sense if Protected could implement Clone.
305
        let mut innermost_stack: Protected<Vec<Config>> = Protected::new(vec![]);
306

            
307
        let mut write = innermost_stack.write();
308
        for t in self.innermost_stack.read().iter() {
309
            match t {
310
                Config::Rewrite(x) => write.push(Config::Rewrite(*x)),
311
                Config::Construct(f, x, y) => {
312
                    let f = write.protect(&f.copy().into());
313
                    write.push(Config::Construct(f.into(), *x, *y));
314
                }
315
                Config::Return() => write.push(Config::Return()),
316
            }
317
        }
318
        drop(write);
319

            
320
        Self {
321
            variables: self.variables.clone(),
322
            stack_size: self.stack_size,
323
            innermost_stack,
324
        }
325
    }
326
}
327

            
328
#[cfg(test)]
329
mod tests {
330
    use super::*;
331
    use ahash::AHashSet;
332
    use mcrl2::aterm::TermPool;
333
    use mcrl2::data::DataFunctionSymbol;
334

            
335
    use crate::test_utility::create_rewrite_rule;
336
    use crate::utilities::to_untyped_data_expression;
337

            
338
    use test_log::test;
339

            
340
1
    #[test]
341
    fn test_rhs_stack() {
342
        let mut tp = TermPool::new();
343

            
344
        let rhs_stack =
345
            RHSStack::new(&create_rewrite_rule(&mut tp, "fact(s(N))", "times(s(N), fact(N))", &["N"]).unwrap());
346
        let mut expected = Protected::new(vec![]);
347

            
348
        let mut write = expected.write();
349
        let t = write.protect(&DataFunctionSymbol::new(&mut tp, "times").copy().into());
350
        write.push(Config::Construct(t.into(), 2, 0));
351

            
352
        let t = write.protect(&DataFunctionSymbol::new(&mut tp, "s").copy().into());
353
        write.push(Config::Construct(t.into(), 1, 1));
354

            
355
        let t = write.protect(&DataFunctionSymbol::new(&mut tp, "fact").copy().into());
356
        write.push(Config::Construct(t.into(), 1, 2));
357
        drop(write);
358

            
359
        // Check if the resulting construction succeeded.
360
        assert_eq!(
361
            rhs_stack.innermost_stack, expected,
362
            "The resulting config stack is not as expected"
363
        );
364

            
365
        assert_eq!(rhs_stack.stack_size, 5, "The stack size does not match");
366

            
367
        // Test the evaluation
368
        let lhs = tp.from_string("fact(s(a))").unwrap();
369
        let lhs_expression = to_untyped_data_expression(&mut tp, &lhs, &AHashSet::new());
370

            
371
        let rhs = tp.from_string("times(s(a), fact(a))").unwrap();
372
        let rhs_expression = to_untyped_data_expression(&mut tp, &rhs, &AHashSet::new());
373

            
374
        assert_eq!(
375
            rhs_stack.evaluate(&mut tp, &lhs_expression),
376
            rhs_expression,
377
            "The rhs stack does not evaluate to the expected term"
378
        );
379
    }
380

            
381
1
    #[test]
382
    fn test_rhs_stack_variable() {
383
        let mut tp = TermPool::new();
384

            
385
        let rhs = RHSStack::new(&create_rewrite_rule(&mut tp, "f(x)", "x", &["x"]).unwrap());
386

            
387
        // Check if the resulting construction succeeded.
388
        assert!(
389
            rhs.innermost_stack.read().is_empty(),
390
            "The resulting config stack is not as expected"
391
        );
392

            
393
        assert_eq!(rhs.stack_size, 1, "The stack size does not match");
394
    }
395
}