Library SortingNetworks

Require Export Permutations.

Comparator networks


Definition comparator : Set := (prod nat nat).

Lemma comp_eq_dec : forall c c':comparator, {c=c'}+{c<>c'}.

Definition make (m n:nat) : comparator := (n,m).

Global Notation "x [<] y" := (x,y) (at level 0).

Definition comparator_network : Set := list comparator.

Lemma CN_eq_dec : forall C C' : comparator_network, {C = C'} + {C <> C'}.

Definition comp_channels (n:nat) (c:comparator) :=
  let (i,j) := c in (i<n) /\ (j<n) /\ (i<>j).

Definition channels (n:nat) (S:comparator_network) :=
           forall c:comparator, (In c S) -> (comp_channels n c).

Lemma channels_app : forall n C C', channels n C -> channels n C' ->
                                    channels n (C ++ C').

Lemma channels_cons : forall n S c, channels n (c :: S) -> channels n S.

Definition comp_standard (n:nat) (c:comparator) :=
  let (i,j) := c in (i<n) /\ (j<n) /\ (i<j).

Definition standard (n:nat) (S:comparator_network) :=
           forall c:comparator, (In c S) -> (comp_standard n c).

Lemma standard_channels : forall n C, standard n C -> channels n C.

Lemma standard_app : forall n C C', standard n C -> standard n C' ->
                                    standard n (C ++ C').

All possible comparators


Fixpoint till_n (n:nat) : (list nat) :=
  match n with
  | O => (O :: nil)
  | S k => till_n k ++ (S k :: nil)
  end.

Lemma till_n_lemma : forall n i:nat, i<=n -> In i (till_n n).

Lemma till_n_lt : forall n i:nat, In i (till_n n) -> i <= n.

Lemma till_n_length: forall n, length (till_n n) = S n.

Lemma till_n_char : forall n m, m <= n -> forall x, nth m (till_n n) x = m.

Lemma till_n_NoDup : forall n, NoDup (till_n n).

The following properties of NoDup are proved using till_n

Lemma NoDup_all_lt_count : forall l n, NoDup l -> all_lt n l -> length l <= n.

Lemma NoDup_all_lt_prop : forall n l, NoDup l -> all_lt n l -> length l < n ->
                                      exists k, k < n /\ ~In k l.

Lemma NoDup_all_lt_lt : forall n l, NoDup l -> all_lt n l ->
                        forall k, k < n -> ~In k l -> length l < n.

Fixpoint all_st_comps (n:nat) : list comparator :=
  match n with
  | O => nil
  | S O => nil
  | S k => all_st_comps k ++ (map (make k) (till_n (pred k)))
  end.

Lemma all_st_comps_lemma : forall n c, comp_standard n c -> In c (all_st_comps n).

Lemma all_st_comps_standard : forall n c,
                              In c (all_st_comps n) -> comp_standard n c.

Lemma all_st_comps_NoDup : forall n, NoDup (all_st_comps n).

Execution


Fixpoint apply (c:comparator) n (s:bin_seq n) : (bin_seq n) :=
  let (i,j):=c in let x:=(get s i) in let y:=(get s j) in
    match (le_lt_dec x y) with
    | left _ => s
    | right _ => set (set s j x) i y
    end.


Lemma apply_transp_le : forall n, forall s:bin_seq n, forall x y,
  (get s x) <= (get s y) -> apply x[<]y s = s.

Lemma apply_transp_gt : forall n, forall s:bin_seq n, forall x y,
  (get s x) > (get s y) -> apply x[<]y s = apply_perm (transposition x y) s.

Lemma apply_id : forall n s i, s = apply (n:=n) i[<]i s.

Lemma apply_zeros : forall n s i j, i < n -> j < n -> i<>j ->
                                    zeros (n:=n) (apply i[<]j s) = zeros s.

Lemma apply_zeros_folded : forall n s c, comp_channels n c ->
                                         zeros (apply (n:=n) c s) = zeros s.

Lemma apply_inv : forall x y n, forall s:bin_seq n, x < n -> y < n ->
      apply (x)[<](y) s = apply_perm (transposition x y) (apply (y)[<](x) s).

Fixpoint full_apply (S:comparator_network) n (s:bin_seq n) : (bin_seq n) :=
  match S with
  | nil => s
  | cons c S' => full_apply S' _ (apply c s)
  end.


Lemma full_apply_zeros : forall C n, forall s:bin_seq n, channels n C ->
                                     zeros (full_apply C s) = zeros s.

Lemma full_apply_app : forall C C' n, forall s:bin_seq n,
                       full_apply (C++C') s = full_apply C' (full_apply C s).

Lemma standard_network_idemp : forall n S, standard n S -> forall s:bin_seq n,
                               sorted s -> full_apply S s = s.

Sorting networks


Definition sorting_network (n:nat) (S:comparator_network) :=
  (channels n S) /\ forall s:bin_seq n, sorted (full_apply S s).

Definition outputs (C:comparator_network) (n:nat) : (list (bin_seq n)) :=
  (map (full_apply C (n:=n)) (all_bin_seqs n)).

Lemma outputs_app : forall n C C' C'', outputs C' n = outputs C'' n ->
                                       outputs (C' ++ C) n = outputs (C'' ++ C) n.

Theorem SN_char : forall C n, channels n C ->
                 (forall s, In s (outputs C n) -> sorted s) ->
                  sorting_network n C.

The following are sanity-check examples.

Definition SN2 := (0[<]1 :: nil).
Definition SN3 := (0[<]1 :: 0[<]2 :: 1[<]2 :: nil).
Definition SN4 := (0[<]1 :: 2[<]3 :: 0[<]2 :: 1[<]3 :: 1[<]2 :: nil).
Definition SN5 := (0[<]1 :: 2[<]3 :: 0[<]2 :: 1[<]3 :: 1[<]2 :: 3[<]4 :: 1[<]3 :: 0[<]1 :: 2[<]3 :: nil).
Definition SN9 := (0[<]1 :: 2[<]3 :: 4[<]5 :: 6[<]7 :: 1[<]3 :: 5[<]7 :: 0[<]2 :: 4[<]6 :: 1[<]5 :: 3[<]7 :: 0[<]4 :: 2[<]6 :: 1[<]8 :: 2[<]4 :: 1[<]2 :: 3[<]5 :: 4[<]8 :: 2[<]4 :: 6[<]8 :: 5[<]8 :: 3[<]6 :: 0[<]1 :: 3[<]4 :: 5[<]6 :: 7[<]8 :: nil).

Ltac show_SN :=
  split; [
    intros c Hc; repeat (elim Hc; clear Hc; intro Hc);
    rewrite <- Hc; simpl; auto with arith
  |
    apply SN_char; simpl; intros s Hs;
    repeat (elim Hs; clear Hs; intro Hs); rewrite <- Hs; simpl; auto with arith
  ].

Theorem SN2_SN : sorting_network 2 SN2.

Theorem SN3_SN: sorting_network 3 SN3.

Theorem SN4_SN: sorting_network 4 SN4.


Lemma SN_compute : forall n S, sorting_network n S ->
                   forall s:bin_seq n, full_apply S s = sort s.

Lemma SN_idemp : forall n S, sorting_network n S ->
                 forall s:bin_seq n, sorted s -> full_apply S s = s.

Lemma SN_sort : forall C n, channels n C -> sorting_network n C ->
                forall s, In s (outputs C n) -> sorted s.

Lemma SN_extend : forall n N S, sorting_network n N -> standard n S ->
                                sorting_network n (N++S).

Lemma SN_dec : forall n C, channels n C ->
                           {sorting_network n C} + {~sorting_network n C}.

Lemma exists_SN_dec : forall m l, (forall C, In C l -> channels m C) ->
                                  {exists C, In C l /\ sorting_network m C} +
                                  {forall C, In C l -> ~sorting_network m C}.