1
// Author(s): Mark Bouwman
2

            
3
use std::collections::HashMap;
4
use std::collections::HashSet;
5

            
6
use crate::utilities::ExplicitPosition;
7
use mcrl2::aterm::ATerm;
8
use mcrl2::aterm::ATermRef;
9
use mcrl2::aterm::Symbol;
10
use mcrl2::aterm::TermBuilder;
11
use mcrl2::aterm::TermPool;
12
use mcrl2::aterm::Yield;
13
use mcrl2::data::is_data_variable;
14
use mcrl2::data::DataVariable;
15

            
16
/// A SemiCompressedTermTree (SCTT) is a mix between a [ATerm] and a syntax tree and is used
17
/// to represent the rhs of rewrite rules and the lhs and rhs of conditions.
18
///
19
/// It stores as much as possible in the term pool. Due to variables it cannot be fully compressed.
20
/// For variables it stores the position in the lhs of a rewrite rule where the concrete term can be
21
/// found that will replace the variable.
22
///
23
/// # Examples
24
/// For the rewrite rule and(true, true) = true, the SCTT of the rhs will be of type Compressed, with
25
/// a pointer to the term true.
26
///
27
/// For the rewrite rule minus(x, 0) = x, the SCTT of the rhs will be of type Variable, storing position
28
/// 1, the position of x in the lhs.
29
///
30
/// For the rewrite rule minus(s(x), s(y)) = minus(x, y), the SCTT of the rhs will be of type
31
/// Explicit, which will stored the head symbol 'minus' and two child SCTTs of type Variable.
32
#[derive(Debug, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
33
pub enum SemiCompressedTermTree {
34
    Explicit(ExplicitNode),
35
    Compressed(ATerm),
36
    Variable(ExplicitPosition),
37
}
38

            
39
/// Stores the head symbol and a SCTT for every argument explicitly.
40
#[derive(Debug, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)]
41
pub struct ExplicitNode {
42
    pub head: Symbol,
43
    pub children: Vec<SemiCompressedTermTree>,
44
}
45

            
46
use SemiCompressedTermTree::*;
47

            
48
use super::PositionIndexed;
49
use super::PositionIterator;
50

            
51
pub type SCCTBuilder = TermBuilder<&'static SemiCompressedTermTree, &'static Symbol>;
52

            
53
impl SemiCompressedTermTree {
54
    /// Given an [ATerm] and a term pool this function instantiates the SCTT and computes a [ATerm].
55
    ///
56
    /// # Example
57
    /// For the SCTT belonging to the rewrite rule minus(s(x), s(y)) = minus(x, y)
58
    /// and the concrete lhs minus(s(0), s(0)) the evaluation will go as follows.
59
    /// evaluate will encounter an ExplicitNode and make two recursive calls to get the subterms.
60
    /// Both these recursive calls will return the term '0'.
61
    /// The term pool will be used to construct the term minus(0, 0).
62
10873421
    pub fn evaluate_with<'a>(&'a self, builder: &mut SCCTBuilder, t: &ATermRef<'_>, tp: &mut TermPool) -> ATerm {
63
10873421
        // TODO: Figure out if this can be done properly. This is safe because evaluate will always leave the
64
10873421
        // underlying vectors empty.
65
10873421
        let builder: &mut TermBuilder<&'a SemiCompressedTermTree, &'a Symbol> = unsafe { std::mem::transmute(builder) };
66
10873421

            
67
10873421
        builder
68
10873421
            .evaluate(
69
10873421
                tp,
70
10873421
                self,
71
39612248
                |_tp, args, node| {
72
39612248
                    match node {
73
10102856
                        Explicit(node) => {
74
                            // Create an ATerm with as arguments all the evaluated semi compressed term trees.
75
28738827
                            for i in 0..node.children.len() {
76
28738827
                                args.push(&node.children[i]);
77
28738827
                            }
78

            
79
10102856
                            Ok(Yield::Construct(&node.head))
80
                        }
81
12999491
                        Compressed(ct) => Ok(Yield::Term(ct.clone())),
82
16509901
                        Variable(p) => Ok(Yield::Term(t.get_position(p).protect())),
83
                    }
84
39612248
                },
85
10873421
                |tp, symbol, args| Ok(tp.create(symbol, args)),
86
10873421
            )
87
10873421
            .unwrap()
88
10873421
    }
89

            
90
    /// The same as [evaluate_with], but allocates a [SCCTBuilder] internally.
91
8077631
    pub fn evaluate(&self, t: &ATermRef<'_>, tp: &mut TermPool) -> ATerm {
92
8077631
        let mut builder = TermBuilder::<&SemiCompressedTermTree, &Symbol>::new();
93
8077631

            
94
8077631
        self.evaluate_with(&mut builder, t, tp)
95
8077631
    }
96

            
97
    /// Creates a SCTT from a term. The var_map parameter should specify where the variable can be
98
    /// found in the lhs of the rewrite rule.
99
1729945
    pub(crate) fn from_term(
100
1729945
        t: &ATermRef<'_>,
101
1729945
        var_map: &HashMap<DataVariable, ExplicitPosition>,
102
1729945
    ) -> SemiCompressedTermTree {
103
1729945
        if is_data_variable(t) {
104
107211
            Variable(
105
107211
                var_map
106
107211
                    .get(&t.protect())
107
107211
                    .unwrap_or_else(|| panic!("{t} not contained in variable mapping var_map"))
108
107211
                    .clone(),
109
107211
            )
110
1622734
        } else if t.arguments().is_empty() {
111
1065687
            Compressed(t.protect())
112
        } else {
113
557047
            let children = t
114
557047
                .arguments()
115
1599559
                .map(|c| SemiCompressedTermTree::from_term(&c, var_map))
116
557047
                .collect();
117
557047
            let node = ExplicitNode {
118
557047
                head: t.get_head_symbol().protect(),
119
557047
                children,
120
557047
            };
121
557047

            
122
1543662
            if node.children.iter().all(|c| c.is_compressed()) {
123
484663
                Compressed(t.protect())
124
            } else {
125
72384
                Explicit(node)
126
            }
127
        }
128
1729945
    }
129

            
130
    /// Used to check if a subterm is duplicated, for example "times(s(x), y) =
131
    /// plus(y, times(x,y))" is duplicating.
132
71421
    pub(crate) fn contains_duplicate_var_references(&self) -> bool {
133
71421
        let references = self.get_all_var_references();
134
71421
        let mut seen = HashSet::new();
135

            
136
111272
        for r in references {
137
50732
            if seen.contains(&r) {
138
10881
                return true;
139
39851
            }
140
39851
            seen.insert(r);
141
        }
142

            
143
60540
        false
144
71421
    }
145

            
146
    /// Get all positions to variables in the left hand side.
147
229163
    fn get_all_var_references(&self) -> Vec<ExplicitPosition> {
148
229163
        let mut result = vec![];
149
229163
        match self {
150
53771
            Explicit(en) => {
151
211513
                for n in &en.children {
152
157742
                    result.extend_from_slice(&n.get_all_var_references());
153
157742
                }
154
            }
155
111810
            Compressed(_) => {}
156
63582
            Variable(ep) => {
157
63582
                result.push(ep.clone());
158
63582
            }
159
        }
160

            
161
229163
        result
162
229163
    }
163

            
164
    /// Returns true iff this tree is compressed.
165
1543662
    fn is_compressed(&self) -> bool {
166
1543662
        matches!(self, Compressed(_))
167
1543662
    }
168
}
169

            
170
/// Create a mapping of variables to their position in the given term
171
215003
pub fn create_var_map(t: &ATerm) -> HashMap<DataVariable, ExplicitPosition> {
172
215003
    let mut result = HashMap::new();
173

            
174
3772382
    for (term, position) in PositionIterator::new(t.copy()) {
175
3772382
        if is_data_variable(&term) {
176
259864
            result.insert(term.protect().into(), position.clone());
177
3512518
        }
178
    }
179
215003
    result
180
215003
}
181

            
182
#[cfg(test)]
183
mod tests {
184
    use super::*;
185
    use ahash::AHashSet;
186
    use mcrl2::aterm::apply;
187
    use mcrl2::aterm::TermPool;
188

            
189
    /// Converts a slice of static strings into a set of owned strings
190
    ///
191
    /// example:
192
    ///     make_var_map(["x"])
193
3
    fn var_map(vars: &[&str]) -> AHashSet<String> {
194
3
        AHashSet::from_iter(vars.iter().map(|x| String::from(*x)))
195
3
    }
196

            
197
    /// Convert terms in variables to a [DataVariable].
198
5
    pub fn convert_variables(tp: &mut TermPool, t: &ATerm, variables: &AHashSet<String>) -> ATerm {
199
19
        apply(tp, t, &|tp, arg| {
200
19
            if variables.contains(arg.get_head_symbol().name()) {
201
                // Convert a constant variable, for example 'x', into an untyped variable.
202
8
                Some(DataVariable::new(tp, &arg.get_head_symbol().name()).into())
203
            } else {
204
11
                None
205
            }
206
19
        })
207
5
    }
208

            
209
    #[test]
210
1
    fn test_constant() {
211
1
        let mut tp = TermPool::new();
212
1
        let t = tp.from_string("a").unwrap();
213
1

            
214
1
        let map = HashMap::new();
215
1
        let sctt = SemiCompressedTermTree::from_term(&t, &map);
216
1
        assert_eq!(sctt, Compressed(t));
217
1
    }
218

            
219
    #[test]
220
1
    fn test_compressible() {
221
1
        let mut tp = TermPool::new();
222
1
        let t = tp.from_string("f(a,a)").unwrap();
223
1

            
224
1
        let map = HashMap::new();
225
1
        let sctt = SemiCompressedTermTree::from_term(&t, &&map);
226
1
        assert_eq!(sctt, Compressed(t));
227
1
    }
228

            
229
    #[test]
230
1
    fn test_not_compressible() {
231
1
        let mut tp = TermPool::new();
232
1
        let t = {
233
1
            let tmp = tp.from_string("f(x,x)").unwrap();
234
1
            convert_variables(&mut tp, &tmp, &var_map(&["x"]))
235
1
        };
236
1

            
237
1
        let mut map = HashMap::new();
238
1
        map.insert(DataVariable::new(&mut tp, "x"), ExplicitPosition::new(&[2]));
239
1

            
240
1
        let sctt = SemiCompressedTermTree::from_term(&t, &map);
241
1

            
242
1
        let en = Explicit(ExplicitNode {
243
1
            head: tp.create_symbol("f", 2),
244
1
            children: vec![
245
1
                Variable(ExplicitPosition::new(&[2])), // Note that both point to the second occurence of x.
246
1
                Variable(ExplicitPosition::new(&[2])),
247
1
            ],
248
1
        });
249
1

            
250
1
        assert_eq!(sctt, en);
251
1
    }
252

            
253
    #[test]
254
1
    fn test_partly_compressible() {
255
1
        let mut tp = TermPool::new();
256
1
        let t = {
257
1
            let tmp = tp.from_string("f(f(a,a),x)").unwrap();
258
1
            convert_variables(&mut tp, &tmp, &var_map(&["x"]))
259
1
        };
260
1
        let compressible = tp.from_string("f(a,a)").unwrap();
261
1

            
262
1
        // Make a variable map with only x@2.
263
1
        let mut map = HashMap::new();
264
1
        map.insert(DataVariable::new(&mut tp, "x"), ExplicitPosition::new(&[2]));
265
1

            
266
1
        let sctt = SemiCompressedTermTree::from_term(&t, &map);
267
1
        let en = Explicit(ExplicitNode {
268
1
            head: tp.create_symbol("f", 2),
269
1
            children: vec![Compressed(compressible), Variable(ExplicitPosition::new(&[2]))],
270
1
        });
271
1
        assert_eq!(sctt, en);
272
1
    }
273

            
274
    #[test]
275
1
    fn test_evaluation() {
276
1
        let mut tp = TermPool::new();
277
1
        let t_rhs = {
278
1
            let tmp = tp.from_string("f(f(a,a),x)").unwrap();
279
1
            convert_variables(&mut tp, &tmp, &var_map(&["x"]))
280
1
        };
281
1
        let t_lhs = tp.from_string("g(b)").unwrap();
282
1

            
283
1
        // Make a variable map with only x@1.
284
1
        let mut map = HashMap::new();
285
1
        map.insert(DataVariable::new(&mut tp, "x"), ExplicitPosition::new(&[1]));
286
1

            
287
1
        let sctt = SemiCompressedTermTree::from_term(&t_rhs, &map);
288
1

            
289
1
        let t_expected = tp.from_string("f(f(a,a),b)").unwrap();
290
1
        assert_eq!(sctt.evaluate(&t_lhs, &mut tp), t_expected);
291
1
    }
292

            
293
    #[test]
294
1
    fn test_create_varmap() {
295
1
        let mut tp = TermPool::new();
296
1
        let t = {
297
1
            let tmp = tp.from_string("f(x,x)").unwrap();
298
1
            convert_variables(&mut tp, &tmp, &AHashSet::from([String::from("x")]))
299
1
        };
300
1
        let x = DataVariable::new(&mut tp, "x");
301
1

            
302
1
        let map = create_var_map(&t);
303
1
        assert!(map.contains_key(&x));
304
1
    }
305

            
306
    #[test]
307
1
    fn test_is_duplicating() {
308
1
        let mut tp = TermPool::new();
309
1
        let t_rhs = {
310
1
            let tmp = tp.from_string("f(x,x)").unwrap();
311
1
            convert_variables(&mut tp, &tmp, &AHashSet::from([String::from("x")]))
312
1
        };
313
1

            
314
1
        // Make a variable map with only x@1.
315
1
        let mut map = HashMap::new();
316
1
        map.insert(DataVariable::new(&mut tp, "x"), ExplicitPosition::new(&[1]));
317
1

            
318
1
        let sctt = SemiCompressedTermTree::from_term(&t_rhs, &map);
319
1
        assert!(sctt.contains_duplicate_var_references(), "This sctt is duplicating");
320
1
    }
321
}