1
use std::borrow::Borrow;
2
use std::cmp::Ordering;
3
use std::collections::VecDeque;
4
use std::fmt;
5
use std::hash::Hash;
6
use std::hash::Hasher;
7
use std::marker::PhantomData;
8
use std::ops::Deref;
9

            
10
use mcrl2_sys::atermpp::ffi;
11
use mcrl2_sys::cxx::UniquePtr;
12
use utilities::PhantomUnsend;
13

            
14
use crate::aterm::SymbolRef;
15
use crate::aterm::THREAD_TERM_POOL;
16

            
17
use super::global_aterm_pool::GLOBAL_TERM_POOL;
18

            
19
/// This represents a lifetime bound reference to an existing ATerm that is
20
/// protected somewhere statically.
21
///
22
/// Can be 'static if the term is protected in a container or ATerm. That means
23
/// we either return &'a ATermRef<'static> or with a concrete lifetime
24
/// ATermRef<'a>. However, this means that the functions for ATermRef cannot use
25
/// the associated lifetime for the results parameters, as that would allow us
26
/// to acquire the 'static lifetime. This occasionally gives rise to issues
27
/// where we look at the argument of a term and want to return it's name, but
28
/// this is not allowed since the temporary returned by the argument is dropped.
29
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord)]
30
pub struct ATermRef<'a> {
31
    term: *const ffi::_aterm,
32
    marker: PhantomData<&'a ()>,
33
}
34

            
35
/// These are safe because terms are never modified. Garbage collection is
36
/// always performed with exclusive access and uses relaxed atomics to perform
37
/// some interior mutability.
38
unsafe impl Send for ATermRef<'_> {}
39
unsafe impl Sync for ATermRef<'_> {}
40

            
41
impl Default for ATermRef<'_> {
42
110135710
    fn default() -> Self {
43
110135710
        ATermRef {
44
110135710
            term: std::ptr::null(),
45
110135710
            marker: PhantomData,
46
110135710
        }
47
110135710
    }
48
}
49

            
50
impl<'a> ATermRef<'a> {
51
    /// Protects the reference on the thread local protection pool.
52
251641616
    pub fn protect(&self) -> ATerm {
53
251641616
        if self.is_default() {
54
            ATerm::default()
55
        } else {
56
251641616
            THREAD_TERM_POOL.with_borrow_mut(|tp| tp.protect(self.term))
57
        }
58
251641616
    }
59

            
60
    /// Protects the reference on the global protection pool.
61
200
    pub fn protect_global(&self) -> ATermGlobal {
62
200
        if self.is_default() {
63
            ATermGlobal::default()
64
        } else {
65
200
            GLOBAL_TERM_POOL.lock().protect(self.term)
66
        }
67
200
    }
68

            
69
    /// This allows us to extend our borrowed lifetime from 'a to 'b based on
70
    /// existing parent term which has lifetime 'b.
71
    ///
72
    /// The main usecase is to establish transitive lifetimes. For example given
73
    /// a term t from which we borrow `u = t.arg(0)` then we cannot have
74
    /// u.arg(0) live as long as t since the intermediate temporary u is
75
    /// dropped. However, since we know that u.arg(0) is a subterm of `t` we can
76
    /// upgrade its lifetime to the lifetime of `t` using this function.
77
    ///
78
    /// # Safety
79
    ///
80
    /// This function might only be used if witness is a parent term of the
81
    /// current term.
82
421296067
    pub fn upgrade<'b: 'a>(&'a self, parent: &ATermRef<'b>) -> ATermRef<'b> {
83
421296067
        debug_assert!(
84
5865169658
            parent.iter().any(|t| t.copy() == *self),
85
            "Upgrade has been used on a witness that is not a parent term"
86
        );
87

            
88
421296067
        ATermRef::new(self.term)
89
421296067
    }
90

            
91
    /// A private unchecked version of [`ATermRef::upgrade`] to use in iterators.
92
13929153948
    unsafe fn upgrade_unchecked<'b: 'a>(&'a self, _parent: &ATermRef<'b>) -> ATermRef<'b> {
93
13929153948
        ATermRef::new(self.term)
94
13929153948
    }
95

            
96
    /// Obtains the underlying pointer
97
6495663597
    pub(crate) unsafe fn get(&self) -> *const ffi::_aterm {
98
6495663597
        self.term
99
6495663597
    }
100
}
101

            
102
impl<'a> ATermRef<'a> {
103
32226611237
    pub(crate) fn new(term: *const ffi::_aterm) -> ATermRef<'a> {
104
32226611237
        ATermRef {
105
32226611237
            term,
106
32226611237
            marker: PhantomData,
107
32226611237
        }
108
32226611237
    }
109
}
110

            
111
impl ATermRef<'_> {
112
    /// Returns the indexed argument of the term
113
7413503860
    pub fn arg(&self, index: usize) -> ATermRef<'_> {
114
7413503860
        self.require_valid();
115
7413503860
        debug_assert!(
116
7413503860
            index < self.get_head_symbol().arity(),
117
            "arg({index}) is not defined for term {:?}",
118
            self
119
        );
120

            
121
        unsafe {
122
7413503860
            ATermRef {
123
7413503860
                term: ffi::get_term_argument(self.term, index),
124
7413503860
                marker: PhantomData,
125
7413503860
            }
126
7413503860
        }
127
7413503860
    }
128

            
129
    /// Returns the list of arguments as a collection
130
5904464695
    pub fn arguments(&self) -> ATermArgs<'_> {
131
5904464695
        self.require_valid();
132
5904464695

            
133
5904464695
        ATermArgs::new(self.copy())
134
5904464695
    }
135

            
136
    /// Makes a copy of the term with the same lifetime as itself.
137
17573940181
    pub fn copy(&self) -> ATermRef<'_> {
138
17573940181
        ATermRef::new(self.term)
139
17573940181
    }
140

            
141
    /// Returns whether the term is the default term (not initialised)
142
49159599137
    pub fn is_default(&self) -> bool {
143
49159599137
        self.term.is_null()
144
49159599137
    }
145

            
146
    /// Returns true iff this is an aterm_list
147
170651
    pub fn is_list(&self) -> bool {
148
170651
        unsafe { ffi::aterm_is_list(self.term) }
149
170651
    }
150

            
151
    /// Returns true iff this is the empty aterm_list
152
170652
    pub fn is_empty_list(&self) -> bool {
153
170652
        unsafe { ffi::aterm_is_empty_list(self.term) }
154
170652
    }
155

            
156
    /// Returns true iff this is a aterm_int
157
133104
    pub fn is_int(&self) -> bool {
158
133104
        unsafe { ffi::aterm_is_int(self.term) }
159
133104
    }
160

            
161
    /// Returns the head function symbol of the term.
162
14005997586
    pub fn get_head_symbol(&self) -> SymbolRef<'_> {
163
14005997586
        self.require_valid();
164
14005997586
        unsafe { ffi::get_aterm_function_symbol(self.term).into() }
165
14005997586
    }
166

            
167
    /// Returns an iterator over all arguments of the term that runs in pre order traversal of the term trees.
168
421412294
    pub fn iter(&self) -> TermIterator<'_> {
169
421412294
        TermIterator::new(self.copy())
170
421412294
    }
171

            
172
    /// Panics if the term is default
173
34967311887
    pub fn require_valid(&self) {
174
34967311887
        debug_assert!(
175
34967311887
            !self.is_default(),
176
            "This function can only be called on valid terms, i.e., not default terms"
177
        );
178
34967311887
    }
179
}
180

            
181
impl fmt::Display for ATermRef<'_> {
182
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183
        self.require_valid();
184
        write!(f, "{:?}", self)
185
    }
186
}
187

            
188
impl fmt::Debug for ATermRef<'_> {
189
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190
        if self.is_default() {
191
            write!(f, "<default>")?;
192
        } else {
193
            unsafe {
194
                write!(f, "{}", ffi::print_aterm(self.term))?;
195
            }
196
        }
197

            
198
        Ok(())
199
    }
200
}
201

            
202
/// The protected version of [ATermRef], mostly derived from it.
203
#[derive(Default)]
204
pub struct ATerm {
205
    pub(crate) term: ATermRef<'static>,
206
    pub(crate) root: usize,
207

            
208
    // ATerm is not Send because it uses thread-local state for its protection
209
    // mechanism.
210
    _marker: PhantomUnsend,
211
}
212

            
213
impl ATerm {
214
    /// Obtains the underlying pointer
215
    ///
216
    /// # Safety
217
    /// Should not be modified in any way.
218
6696793
    pub(crate) unsafe fn get(&self) -> *const ffi::_aterm {
219
6696793
        self.term.get()
220
6696793
    }
221

            
222
    /// Creates a new term from the given reference and protection set root
223
    /// entry.
224
302220841
    pub(crate) fn new(term: ATermRef<'static>, root: usize) -> ATerm {
225
302220841
        ATerm {
226
302220841
            term,
227
302220841
            root,
228
302220841
            _marker: PhantomData,
229
302220841
        }
230
302220841
    }
231
}
232

            
233
impl Drop for ATerm {
234
349935467
    fn drop(&mut self) {
235
349935467
        if !self.is_default() {
236
302220841
            THREAD_TERM_POOL.with_borrow_mut(|tp| {
237
302220841
                tp.drop(self);
238
302220841
            })
239
47714626
        }
240
349935467
    }
241
}
242

            
243
impl Clone for ATerm {
244
209421216
    fn clone(&self) -> Self {
245
209421216
        self.copy().protect()
246
209421216
    }
247
}
248

            
249
impl Deref for ATerm {
250
    type Target = ATermRef<'static>;
251

            
252
1314072937
    fn deref(&self) -> &Self::Target {
253
1314072937
        &self.term
254
1314072937
    }
255
}
256

            
257
impl<'a> Borrow<ATermRef<'a>> for ATerm {
258
34753360
    fn borrow(&self) -> &ATermRef<'a> {
259
34753360
        &self.term
260
34753360
    }
261
}
262

            
263
impl fmt::Display for ATerm {
264
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265
        write!(f, "{}", self.copy())
266
    }
267
}
268

            
269
impl fmt::Debug for ATerm {
270
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
271
        write!(f, "{:?}", self.copy())
272
    }
273
}
274

            
275
impl Hash for ATerm {
276
133835213
    fn hash<H: Hasher>(&self, state: &mut H) {
277
133835213
        self.term.hash(state)
278
133835213
    }
279
}
280

            
281
impl PartialEq for ATerm {
282
174680860
    fn eq(&self, other: &Self) -> bool {
283
174680860
        self.term.eq(&other.term)
284
174680860
    }
285
}
286

            
287
impl PartialOrd for ATerm {
288
91608246
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
289
91608246
        Some(self.term.cmp(&other.term))
290
91608246
    }
291
}
292

            
293
impl Ord for ATerm {
294
    fn cmp(&self, other: &Self) -> Ordering {
295
        self.term.cmp(&other.term)
296
    }
297
}
298

            
299
impl Eq for ATerm {}
300

            
301
// Some convenient conversions.
302
impl From<UniquePtr<ffi::aterm>> for ATerm {
303
47315
    fn from(value: UniquePtr<ffi::aterm>) -> Self {
304
47315
        THREAD_TERM_POOL.with_borrow_mut(|tp| unsafe { tp.protect(ffi::aterm_address(&value)) })
305
47315
    }
306
}
307

            
308
impl From<&ffi::aterm> for ATerm {
309
41286
    fn from(value: &ffi::aterm) -> Self {
310
41286
        THREAD_TERM_POOL.with_borrow_mut(|tp| unsafe { tp.protect(ffi::aterm_address(value)) })
311
41286
    }
312
}
313

            
314
/// The same as [ATerm] but protected on the global protection set. This allows
315
/// the term to be Send and Sync among threads.
316
#[derive(Default)]
317
pub struct ATermGlobal {
318
    pub(crate) term: ATermRef<'static>,
319
    pub(crate) root: usize,
320
}
321

            
322
impl Drop for ATermGlobal {
323
200
    fn drop(&mut self) {
324
200
        if !self.is_default() {
325
200
            GLOBAL_TERM_POOL.lock().drop_term(self);
326
200
        }
327
200
    }
328
}
329

            
330
impl Clone for ATermGlobal {
331
    fn clone(&self) -> Self {
332
        self.copy().protect_global()
333
    }
334
}
335

            
336
impl Deref for ATermGlobal {
337
    type Target = ATermRef<'static>;
338

            
339
600
    fn deref(&self) -> &Self::Target {
340
600
        &self.term
341
600
    }
342
}
343

            
344
impl Hash for ATermGlobal {
345
    fn hash<H: Hasher>(&self, state: &mut H) {
346
        self.term.hash(state)
347
    }
348
}
349

            
350
impl PartialEq for ATermGlobal {
351
    fn eq(&self, other: &Self) -> bool {
352
        self.term.eq(&other.term)
353
    }
354
}
355

            
356
impl PartialOrd for ATermGlobal {
357
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
358
        Some(self.term.cmp(&other.term))
359
    }
360
}
361

            
362
impl Ord for ATermGlobal {
363
    fn cmp(&self, other: &Self) -> Ordering {
364
        self.term.cmp(&other.term)
365
    }
366
}
367

            
368
impl Eq for ATermGlobal {}
369

            
370
impl fmt::Display for ATermGlobal {
371
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372
        write!(f, "{}", self.copy())
373
    }
374
}
375

            
376
impl fmt::Debug for ATermGlobal {
377
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378
        write!(f, "{:?}", self.copy())
379
    }
380
}
381

            
382
impl From<ATerm> for ATermGlobal {
383
200
    fn from(value: ATerm) -> Self {
384
200
        value.protect_global()
385
200
    }
386
}
387

            
388
pub struct ATermList<T> {
389
    term: ATerm,
390
    _marker: PhantomData<T>,
391
}
392

            
393
impl<T: From<ATerm>> ATermList<T> {
394
    /// Obtain the head, i.e. the first element, of the list.
395
129364
    pub fn head(&self) -> T {
396
129364
        self.term.arg(0).protect().into()
397
129364
    }
398
}
399

            
400
impl<T> ATermList<T> {
401
    /// Returns true iff the list is empty.
402
170652
    pub fn is_empty(&self) -> bool {
403
170652
        self.term.is_empty_list()
404
170652
    }
405

            
406
    /// Obtain the tail, i.e. the remainder, of the list.
407
129364
    pub fn tail(&self) -> ATermList<T> {
408
129364
        self.term.arg(1).into()
409
129364
    }
410

            
411
    /// Returns an iterator over all elements in the list.
412
41287
    pub fn iter(&self) -> ATermListIter<T> {
413
41287
        ATermListIter { current: self.clone() }
414
41287
    }
415
}
416

            
417
impl<T> Clone for ATermList<T> {
418
41287
    fn clone(&self) -> Self {
419
41287
        ATermList {
420
41287
            term: self.term.clone(),
421
41287
            _marker: PhantomData,
422
41287
        }
423
41287
    }
424
}
425

            
426
impl<T> From<ATermList<T>> for ATerm {
427
    fn from(value: ATermList<T>) -> Self {
428
        value.term
429
    }
430
}
431

            
432
impl<T: From<ATerm>> Iterator for ATermListIter<T> {
433
    type Item = T;
434

            
435
170651
    fn next(&mut self) -> Option<Self::Item> {
436
170651
        if self.current.is_empty() {
437
41287
            None
438
        } else {
439
129364
            let head = self.current.head();
440
129364
            self.current = self.current.tail();
441
129364
            Some(head)
442
        }
443
170651
    }
444
}
445

            
446
impl<T> From<ATerm> for ATermList<T> {
447
1
    fn from(value: ATerm) -> Self {
448
1
        debug_assert!(value.term.is_list(), "Can only convert a aterm_list");
449
1
        ATermList::<T> {
450
1
            term: value,
451
1
            _marker: PhantomData,
452
1
        }
453
1
    }
454
}
455

            
456
impl<'a, T> From<ATermRef<'a>> for ATermList<T> {
457
170650
    fn from(value: ATermRef<'a>) -> Self {
458
170650
        debug_assert!(value.is_list(), "Can only convert a aterm_list");
459
170650
        ATermList::<T> {
460
170650
            term: value.protect(),
461
170650
            _marker: PhantomData,
462
170650
        }
463
170650
    }
464
}
465

            
466
impl<T: From<ATerm>> IntoIterator for ATermList<T> {
467
    type IntoIter = ATermListIter<T>;
468
    type Item = T;
469

            
470
    fn into_iter(self) -> Self::IntoIter {
471
        self.iter()
472
    }
473
}
474

            
475
impl<T: From<ATerm>> IntoIterator for &ATermList<T> {
476
    type IntoIter = ATermListIter<T>;
477
    type Item = T;
478

            
479
    fn into_iter(self) -> Self::IntoIter {
480
        self.iter()
481
    }
482
}
483

            
484
/// An iterator over the arguments of a term.
485
#[derive(Default)]
486
pub struct ATermArgs<'a> {
487
    term: ATermRef<'a>,
488
    arity: usize,
489
    index: usize,
490
}
491

            
492
impl<'a> ATermArgs<'a> {
493
5904464695
    fn new(term: ATermRef<'a>) -> ATermArgs<'a> {
494
5904464695
        let arity = term.get_head_symbol().arity();
495
5904464695
        ATermArgs { term, arity, index: 0 }
496
5904464695
    }
497

            
498
1947348
    pub fn is_empty(&self) -> bool {
499
1947348
        self.arity == 0
500
1947348
    }
501
}
502

            
503
impl<'a> Iterator for ATermArgs<'a> {
504
    type Item = ATermRef<'a>;
505

            
506
107495691
    fn next(&mut self) -> Option<Self::Item> {
507
107495691
        if self.index < self.arity {
508
66457460
            let res = unsafe { Some(self.term.arg(self.index).upgrade_unchecked(&self.term)) };
509
66457460

            
510
66457460
            self.index += 1;
511
66457460
            res
512
        } else {
513
41038231
            None
514
        }
515
107495691
    }
516
}
517

            
518
impl DoubleEndedIterator for ATermArgs<'_> {
519
12799612478
    fn next_back(&mut self) -> Option<Self::Item> {
520
12799612478
        if self.index < self.arity {
521
6931348244
            let res = unsafe { Some(self.term.arg(self.arity - 1).upgrade_unchecked(&self.term)) };
522
6931348244

            
523
6931348244
            self.arity -= 1;
524
6931348244
            res
525
        } else {
526
5868264234
            None
527
        }
528
12799612478
    }
529
}
530

            
531
impl ExactSizeIterator for ATermArgs<'_> {
532
61971474
    fn len(&self) -> usize {
533
61971474
        self.arity - self.index
534
61971474
    }
535
}
536

            
537
pub struct ATermListIter<T> {
538
    current: ATermList<T>,
539
}
540

            
541
/// An iterator over all subterms of the given [ATerm] in preorder traversal, i.e.,
542
/// for f(g(a), b) we visit f(g(a), b), g(a), a, b.
543
pub struct TermIterator<'a> {
544
    queue: VecDeque<ATermRef<'a>>,
545
}
546

            
547
impl TermIterator<'_> {
548
421412294
    pub fn new(t: ATermRef) -> TermIterator {
549
421412294
        TermIterator {
550
421412294
            queue: VecDeque::from([t]),
551
421412294
        }
552
421412294
    }
553
}
554

            
555
impl<'a> Iterator for TermIterator<'a> {
556
    type Item = ATermRef<'a>;
557

            
558
5868377040
    fn next(&mut self) -> Option<Self::Item> {
559
5868377040
        match self.queue.pop_back() {
560
5868264234
            Some(term) => {
561
                // Put subterms in the queue
562
6931348842
                for argument in term.arguments().rev() {
563
6931348244
                    unsafe {
564
6931348244
                        self.queue.push_back(argument.upgrade_unchecked(&term));
565
6931348244
                    }
566
                }
567

            
568
5868264234
                Some(term)
569
            }
570
112806
            None => None,
571
        }
572
5868377040
    }
573
}
574

            
575
#[cfg(test)]
576
mod tests {
577
    use std::sync::Mutex;
578
    use std::thread;
579

            
580
    use test_log::test;
581

            
582
    use crate::aterm::random_term;
583
    use crate::aterm::TermPool;
584
    use rand::rngs::StdRng;
585
    use rand::Rng;
586
    use rand::SeedableRng;
587

            
588
    use super::*;
589

            
590
    /// Make sure that the term has the same number of arguments as its arity.
591
400
    fn verify_term(term: &ATermRef<'_>) {
592
5848
        for subterm in term.iter() {
593
5848
            assert_eq!(
594
5848
                subterm.get_head_symbol().arity(),
595
5848
                subterm.arguments().len(),
596
                "The arity matches the number of arguments."
597
            )
598
        }
599
400
    }
600

            
601
1
    #[test]
602
    fn test_term_iterator() {
603
        let mut tp = TermPool::new();
604
        let t = tp.from_string("f(g(a),b)").unwrap();
605

            
606
        let mut result = t.iter();
607
        assert_eq!(result.next().unwrap(), tp.from_string("f(g(a),b)").unwrap().copy());
608
        assert_eq!(result.next().unwrap(), tp.from_string("g(a)").unwrap().copy());
609
        assert_eq!(result.next().unwrap(), tp.from_string("a").unwrap().copy());
610
        assert_eq!(result.next().unwrap(), tp.from_string("b").unwrap().copy());
611
    }
612

            
613
1
    #[test]
614
    fn test_aterm_list() {
615
        let mut tp = TermPool::new();
616
        let list: ATermList<ATerm> = tp.from_string("[f,g,h,i]").unwrap().into();
617

            
618
        assert!(!list.is_empty());
619

            
620
        // Convert into normal vector.
621
        let values: Vec<ATerm> = list.iter().collect();
622

            
623
        assert_eq!(values[0], tp.from_string("f").unwrap());
624
        assert_eq!(values[1], tp.from_string("g").unwrap());
625
        assert_eq!(values[2], tp.from_string("h").unwrap());
626
        assert_eq!(values[3], tp.from_string("i").unwrap());
627
    }
628

            
629
1
    #[test]
630
    fn test_global_aterm_pool_parallel() {
631
        let seed: u64 = rand::rng().random();
632
        println!("seed: {}", seed);
633

            
634
        let terms: Mutex<Vec<ATermGlobal>> = Mutex::new(vec![]);
635

            
636
1
        thread::scope(|s| {
637
3
            for _ in 0..2 {
638
2
                s.spawn(|| {
639
2
                    let mut tp = TermPool::new();
640
2

            
641
2
                    let mut rng = StdRng::seed_from_u64(seed);
642
202
                    for _ in 0..100 {
643
200
                        let t = random_term(
644
200
                            &mut tp,
645
200
                            &mut rng,
646
200
                            &[("f".to_string(), 2)],
647
200
                            &["a".to_string(), "b".to_string()],
648
200
                            10,
649
200
                        );
650
200

            
651
200
                        terms.lock().unwrap().push(t.clone().into());
652
200

            
653
200
                        tp.collect();
654
200

            
655
200
                        verify_term(&t);
656
200
                    }
657
2
                });
658
2
            }
659
1
        });
660

            
661
        for term in &*terms.lock().unwrap() {
662
            verify_term(&term);
663
        }
664
    }
665
}