1
use std::fmt;
2

            
3
use itertools::Itertools;
4
use mcrl2::data::is_data_variable;
5
use mcrl2::data::DataVariable;
6
use mcrl2::data::DataVariableRef;
7

            
8
use crate::utilities::ExplicitPosition;
9
use crate::utilities::PositionIndexed;
10
use crate::utilities::PositionIterator;
11
use crate::Rule;
12

            
13
/// An equivalence class is a variable with (multiple) positions. This is
14
/// necessary for non-linear patterns.
15
///
16
/// # Example
17
/// Suppose we have a pattern f(x,x), where x is a variable. Then it will have
18
/// one equivalence class storing "x" and the positions 1 and 2. The function
19
/// equivalences_hold checks whether the term has the same term on those
20
/// positions. For example, it will returns false on the term f(a, b) and true
21
/// on the term f(a, a).
22
#[derive(Hash, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
23
pub struct EquivalenceClass {
24
    pub(crate) variable: DataVariable,
25
    pub(crate) positions: Vec<ExplicitPosition>,
26
}
27

            
28
/// Derives the positions in a pattern with same variable (for non-linear patters)
29
107501
pub fn derive_equivalence_classes(rule: &Rule) -> Vec<EquivalenceClass> {
30
107501
    let mut var_equivalences = vec![];
31

            
32
1886193
    for (term, pos) in PositionIterator::new(rule.lhs.copy().into()) {
33
1886193
        if is_data_variable(&term) {
34
129932
            // Register the position of the variable
35
129932
            update_equivalences(&mut var_equivalences, &DataVariableRef::from(term), pos);
36
1756261
        }
37
    }
38

            
39
    // Discard variables that only occur once
40
123401
    var_equivalences.retain(|x| x.positions.len() > 1);
41
107501
    var_equivalences
42
107501
}
43

            
44
/// Checks if the equivalence classes hold for the given term.
45
13664731
pub fn check_equivalence_classes<'a, T, P>(term: &'a P, eqs: &[EquivalenceClass]) -> bool
46
13664731
where
47
13664731
    P: PositionIndexed<Target<'a> = T> + 'a,
48
13664731
    T: PartialEq,
49
13664731
{
50
13664731
    eqs.iter().all(|ec| {
51
1510376
        debug_assert!(
52
1510376
            ec.positions.len() >= 2,
53
            "An equivalence class must contain at least two positions"
54
        );
55

            
56
        // The term at the first position must be equivalent to all other positions.
57
1510376
        let mut iter_pos = ec.positions.iter();
58
1510376
        let first = iter_pos.next().unwrap();
59
1510376
        iter_pos.all(|other_pos| term.get_position(first) == term.get_position(other_pos))
60
13664731
    })
61
13664731
}
62

            
63
/// Adds the position of a variable to the equivalence classes
64
129932
fn update_equivalences(ve: &mut Vec<EquivalenceClass>, variable: &DataVariableRef<'_>, pos: ExplicitPosition) {
65
129932
    // Check if the variable was seen before
66
129932
    if ve.iter().any(|ec| ec.variable.copy() == *variable) {
67
8691
        for ec in ve.iter_mut() {
68
            // Find the equivalence class and add the position
69
8691
            if ec.variable.copy() == *variable && !ec.positions.iter().any(|x| x == &pos) {
70
6531
                ec.positions.push(pos);
71
6531
                break;
72
2160
            }
73
        }
74
123401
    } else {
75
123401
        // If the variable was not found at another position add a new equivalence class
76
123401
        ve.push(EquivalenceClass {
77
123401
            variable: variable.protect(),
78
123401
            positions: vec![pos],
79
123401
        });
80
123401
    }
81
129932
}
82

            
83
impl fmt::Display for EquivalenceClass {
84
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85
        write!(f, "{}{{ {} }}", self.variable, self.positions.iter().format(", "))
86
    }
87
}
88

            
89
#[cfg(test)]
90
mod tests {
91
    use ahash::AHashSet;
92
    use mcrl2::aterm::ATermRef;
93
    use mcrl2::aterm::TermPool;
94
    use mcrl2::data::DataVariable;
95

            
96
    use crate::test_utility::create_rewrite_rule;
97
    use crate::utilities::to_untyped_data_expression;
98

            
99
    use super::*;
100

            
101
    #[test]
102
1
    fn test_derive_equivalence_classes() {
103
1
        let mut tp = TermPool::new();
104
1
        let eq: Vec<EquivalenceClass> =
105
1
            derive_equivalence_classes(&create_rewrite_rule(&mut tp, "f(x, h(x))", "result", &["x"]).unwrap());
106
1

            
107
1
        assert_eq!(
108
1
            eq,
109
1
            vec![EquivalenceClass {
110
1
                variable: DataVariable::new(&mut tp, "x").into(),
111
1
                positions: vec![ExplicitPosition::new(&[2]), ExplicitPosition::new(&[3, 2])]
112
1
            },],
113
            "The resulting config stack is not as expected"
114
        );
115

            
116
        // Check the equivalence class for an example
117
1
        let term = tp.from_string("f(a(b), h(a(b)))").unwrap();
118
1
        let expression = to_untyped_data_expression(&mut tp, &term, &AHashSet::new());
119
1

            
120
1
        let t: &ATermRef<'_> = &expression;
121
1
        assert!(
122
1
            check_equivalence_classes(t, &eq),
123
            "The equivalence classes are not checked correctly, equivalences: {:?} and term {}",
124
            &eq,
125
            &expression
126
        );
127
1
    }
128
}