Capucine Example: Merge Sort

An implementation of the merge sort algorithm.

Back to Capucine

Discussion

This example is an implementation of the merge sort algorithm. There is no invariant, but this shows that Capucine is able to handle real-life programs. It also shows how Capucine is capable of separating the regions for the two sub-arrays in each recursive calls, thanks to region polymorphism. To compile it, go into the capucine directory and type:

ocamlbuild bench/merge.gwhy

Code

logic type array (a)
logic function store (array (a), int, a): array (a)
logic function select (array (a), int): a

axiom select_eq:
  forall a: array (a).
  forall i: int.
  forall v: a.
  [select(store(a, i, v), i)] = [v]

axiom select_neq:
  forall a: array (a).
  forall i: int.
  forall j: int.
  forall v: a.
  [i] <> [j] ==>
  [select(store(a, i, v), j)] = [select(a, j)]

(* Finite Arrays *)

selector (values, size)

class InfiniteArray (a) =
  array (a)
end

class Array (a) =
  own A;
  (InfiniteArray (a) [A] * int)
  invariant (x) = [x.size] >= [0]
end

(*
predicate models(a: Array (a) [R], b: array (a)) =
  forall i: int.
  [0] <= [i] and [i] < [!a.size] ==>
  [select(!(!a.values), i)] = [select(b, i)]
*)

predicate models(a: Array (a) [R], b: array (a)) =
  [b] = [!(!a.values)]

val create(size: int): Array (a) [R]
  consumes R^e
  produces R^c
  requires [size] >= [0]
  ensures [!result.size] = [size]
  =
    let r = new Array (a) [R] in
    r := (!r.values, size);
    r

val get(a: Array (a) [R], i: int): a
  consumes R^c
  produces R^c
  requires [0] <= [i] and [i] < [!a.size]
  ensures
    forall m: array (a).
    models ([a], [m]) ==>
    [result] = [select(m, i)]
  =
    select(!(!a.values), i)

val set(a: Array (a) [R], i: int, v: a): unit
  consumes R^c
  produces R^c
  requires [0] <= [i] and [i] < [!a.size]
  ensures
    forall m1: array (a).
    forall m2: array (a).
    old(models ([a], [m1])) ==>
    models ([a], [m2]) ==>
    [select(m2, i)] = [v]
    and forall j: int.
    [0] <= [j] and [j] < [!a.size] ==>
    [j] <> [i] ==>
    [select(m1, j)] = [select(m2, j)]
  =
    (focus !a.1) := store(!(!a.values), i, v)

class Int =
  int
end

val copy(a: Array (a) [A], i: int, j: int): Array (a) [R]
  consumes A^c, R^e
  produces A^c, R^c
  requires [0] <= [i] and [i] <= [j] and [j] < [!a.size]
  ensures
    [!result.size] = [j - i + 1] and
    forall ma: array (a).
    forall mr: array (a).
    models ([a], [ma]) ==>
    models ([result], [mr]) ==>
    forall k: int.
    [i] <= [k] and [k] <= [j] ==>
    [select(ma, k)] = [select(mr, k - i)]
  =
    let x = new Int in
    x := i;
    let r = (create(j - i + 1): Array (a) [R]) in
    while !x <= j
      invariant
        [!x] >= [i]
        and
        forall ma: array (a).
        forall mr: array (a).
        models ([a], [ma]) ==>
        models ([r], [mr]) ==>
        forall k: int.
        [i] <= [k] and [k] < [!x] ==>
        [select(ma, k)] = [select(mr, k - i)]
    do (
      set(r, !x - i, get(a, !x));
      x := !x + 1
    );
    r

(* j not included *)
predicate sorted_up_to(a: Array (int) [R], size: int) =
  forall m: array (int).
  models([a], [m]) ==>
  forall i: int.
  forall j: int.
  [0] <= [i] ==>
  [i] <= [j] ==>
  [j] < [size] ==>
  [select(m, i)] <= [select(m, j)]

predicate sorted(a: Array (int) [R]) =
  sorted_up_to([a], [!a.size])

lemma sorted_footprint:
  forall a: Array (int) [R].
  forall b: Array (int) [R].
  forall size: int.
  sorted_up_to([a], [size]) ==>
  (forall i: int.
     [0] <= [i] and [i] < [size] ==>
     [select(!(!b.values), i)] = [select(!(!a.values), i)]) ==>
  sorted_up_to([b], [size])

(*
predicate branche1(x: int) = [x] >= [0]
predicate branche2(x: int) = [x] >= [0]
predicate branche3(x: int) = [x] >= [0]
predicate branche4(x: int) = [x] >= [0]
*)

val merge(a: Array (int) [A], b: Array (int) [B], c: Array (int) [C]): unit
  consumes A^c, B^c, C^c
  produces A^c, B^c, C^c
  requires
    [!c.size] = [!a.size + !b.size]
    and [!a.size] > [0]
    and [!b.size] > [0]
    and sorted([a])
    and sorted([b])
  ensures sorted([c])
  =
    let ia = new Int in
    let ib = new Int in
    let ic = new Int in
    let toto = new Int in
    ia := 0;
    ib := 0;
    ic := 0;
    while !ic < !c.size
      invariant
        [!ia] >= [0]
        and [!ia] <= [!a.size]
        and [!ib] >= [0]
        and [!ib] <= [!b.size]
        and [!ic] >= [0]
        and [!ic] <= [!c.size]
        and ([!ic] = [!ia + !ib])
        and sorted_up_to([c], [!ic])
        and forall i: int. [0] <= [i] and [i] < [!ic] ==>
        (forall j: int. [!ia] <= [j] and [j] < [!a.size] ==>
          [select(!(!c.values), i)] <= [select(!(!a.values), j)]) and
        (forall j: int. [!ib] <= [j] and [j] < [!b.size] ==>
          [select(!(!c.values), i)] <= [select(!(!b.values), j)])
    do (
      if !ia >= !a.size then (
(*        assert (branche1([!toto]));*)
        set(c, !ic, get(b, !ib));
        ib := !ib + 1
      ) else if !ib >= !b.size then (
(*        assert (branche2([!toto]));*)
        set(c, !ic, get(a, !ia));
        ia := !ia + 1
      ) else (
        let va = get(a, !ia) in
        let vb = get(b, !ib) in
        if va <= vb then (
(*          assert (branche3([!toto]));*)
          set(c, !ic, va);
          ia := !ia + 1;
          assert (forall j: int. [!ia] <= [j] and [j] < [!a.size] ==>
                    [select(!(!c.values), !ic)] <= [select(!(!a.values), j)]);
          assert (forall j: int. [!ib] <= [j] and [j] < [!b.size] ==>
                    [select(!(!c.values), !ic)] <= [select(!(!b.values), j)])
        ) else (
(*          assert (branche4([!toto]));*)
          set(c, !ic, vb);
          ib := !ib + 1;
          assert (forall j: int. [!ia] <= [j] and [j] < [!a.size] ==>
                    [select(!(!c.values), !ic)] <= [select(!(!a.values), j)]);
          assert (forall j: int. [!ib] <= [j] and [j] < [!b.size] ==>
                    [select(!(!c.values), !ic)] <= [select(!(!b.values), j)])
        )
      );
      ic := !ic + 1
    )

val sort(a: Array (int) [A]): unit
  consumes A^c
  produces A^c
  ensures
    sorted([a])
(*  and permutation([a], [old(a)])*) 
  =
    region R1: Array (int) in
    region R2: Array (int) in
    if !a.size > 1 then (
      let a1 = (copy(a, 0, !a.size / 2 - 1): Array (int) [R1]) in
      let a2 = (copy(a, !a.size / 2, !a.size - 1): Array (int) [R2]) in
      sort(a1);
      sort(a2);
      merge(a1, a2, a)
    )

val main(): unit =
  region R: Array (int) in
  let a = (create(3): Array(int) [R]) in
(*  set(a, 0, 42);
  set(a, 1, 13);
  set(a, 2, 42);*)
  sort(a);
  assert
    (forall m: array (int). models([a], [m]) ==>
       [select(m, 1)] <= [select(m, 2)]
       and [select(m, 0)] <= [select(m, 1)]
       and [select(m, 0)] <= [select(m, 2)])