#ifndef LISTEX_H
#define LISTEX_H

#include "list.h"

/*@

fixpoint t reverseHead<t>(t head, list<t> tail) {
    switch (tail) {
        case nil: return head;
        case cons(h, t): return reverseHead(h, t);
    }
}

fixpoint list<t> reverseTail<t>(t head, list<t> tail) {
    switch (tail) {
        case nil: return nil;
        case cons(h, t): return append(reverseTail(h, t), cons(head, nil));
    }
}

lemma void reverse_head_tail_lemma<t>(t h, list<t> t)
    requires true;
    ensures reverse(cons(h, t)) == cons(reverseHead(h, t), reverseTail(h, t));
{
    switch (t) {
        case nil:
        case cons(th, tt):
            reverse_head_tail_lemma(th, tt);
    }
}

fixpoint list<t> before<t>(t x, list<t> xs) {
    switch (xs) {
        case nil: return nil;
        case cons(x0, xs0): return x == x0 ? nil : cons(x0, before(x, xs0));
    }
}

fixpoint list<t> after<t>(t x, list<t> xs) {
    switch (xs) {
        case nil: return nil;
        case cons(x0, xs0): return x == x0 ? xs0 : after(x, xs0);
    }
}

lemma void drop_k_plus_one_alt<t>(int k, list<t> xs);
    requires 0 <= k &*& k <= length(xs) &*& drop(k, xs) != nil;
    ensures drop(k, xs) == cons(nth(k, xs), drop(k + 1, xs)) &*& length(xs) > k;

lemma void foreach_take_plus_one_unseparate<a>(int n, list<a> xs);
    requires 0 <= n &*& n < length(xs) &*& foreach(take(n, xs), ?p) &*& p(nth(n, xs));
    ensures foreach(take(n + 1, xs), p);

lemma void foreach_separate<t>(list<t> xs, t x);
    requires foreach(xs, ?p) &*& mem(x, xs) == true;
    ensures foreach(before(x, xs), p) &*& p(x) &*& foreach(after(x, xs), p);

lemma void foreach_unseparate<t>(list<t> xs, t x);
    requires mem(x, xs) == true &*& foreach(before(x, xs), ?p) &*& p(x) &*& foreach(after(x, xs), p);
    ensures foreach(xs, p);

lemma void foreach_separate_ith<t>(int i, list<t> es);
    requires foreach(es, ?p) &*& 0 <= i &*& i < length(es);
    ensures foreach(take(i, es), p) &*& p(nth(i, es)) &*& foreach(drop(i + 1, es), p);

lemma void foreach_unseparate_ith_nochange<t>(int i, list<t> es);
    requires foreach(take(i, es), ?p) &*& p(nth(i, es)) &*& foreach(drop(i + 1, es), p) &*& 0 <= i &*& i < length(es);
    ensures foreach(es, p);

predicate foreach3(list<void *> xs, list<void *> ys, list<void *>zs, predicate(void *, void *, void *) p) =
    xs == nil && ys == nil && zs == nil ?
        true
    :
        p(?x, ?y, ?z) &*& foreach3(?xs0, ?ys0, ?zs0, p) &*&
        xs == cons(x, xs0) &*&
        ys == cons(y, ys0) &*&
        zs == cons(z, zs0);

@*/

#endif