(* Problems illustrating the need for effects in code generation, especially the effects that cross future-stage binders. See talk-problems.pdf for explanations. This file also shows the unsatisfactory solutions for these problems in MetaOCaml. *) (* ---------------------------------------------------------------------- *) (* The warm-up example: Faulty Power *) let rec power n x = match n with | 0 -> 1 | n -> x * power (n-1) x ;; (* val power : int -> int -> int = *) let 32 = power 5 2;; let rec spower n x = match n with | 0 -> .<1>. | n -> .<.~x * .~(spower (n-1) x)>. ;; (* val spower : int -> ('a, int) code -> ('a, int) code = *) (* But we also need a specialization. Here we see the future-stage binder *) (* val spower : int -> ('a, int) code -> ('a, int) code = *) let spowern n = . .~(spower n ..)>.;; spowern 5;; (* - : ('a, int -> int) code = . (x_1 * (x_1 * (x_1 * (x_1 * (x_1 * 1)))))>. *) (* But spowern is partial ! spowern (-1);; Stack overflow during evaluation (looping recursion?). *) (* Not good error reporting! *) (* We wish to throw exceptions and being able to recover from them: *) exception BadArg let rec spowerE n x = match n with | 0 -> .<1>. | n when n > 0 -> .<.~x * .~(spowerE (n-1) x)>. | _ -> raise BadArg ;; let spowernE n = . .~(spowerE n ..)>.;; let rec gpower () = print_endline "Enter n: "; let n = read_int () in try spowernE n with BadArg -> print_endline "Bad n!"; gpower ();; (* gpower ();; *) (* The key is spowernE: the exception is thrown under the binder but gets caught outside. If we write this in Haskell and use the Error monad, we see the problem right away: the escape requires a code value but spowerE n .. is a computation. *) (* ---------------------------------------------------------------------- *) (* Guard insertion: moving open code across future-stage binders *) (* code performing a presumably complex computation *) let complex_code = .<(print_endline "complex!"; 54)>.;; (* Sample result *) . .~complex_code + 10 / y>.;; (* Elaboration; but the guarded_div should be better. We need to move assert right after the binder... *) let guarded_div x y = .<(assert (.~y <> 0); .~x / .~y)>. in . .~complex_code + .~(guarded_div .<10>. ..)>.;; (* - : ('a, int -> int) code = . (begin (print_endline "complex!"); 54 end + begin assert (y_15 <> 0); (10 / y_15) end)>. *) let levels = ((ref []) : ('c,unit) code list list ref);; (* List zipper; the first list is in the reverse order *) type 'a listzip = 'a list * 'a * 'a list;; (* List zipper; unzip to the nth element *) let unzip n lst = let rec loop acc = function | (0,h::t) -> (acc,h,t) | (n,h::t) -> loop (h::acc) (n-1,t) in loop [] (n,lst) ;; let zip (lb,x,la) = List.rev_append lb (x::la);; let guarded_div lev x y = let (lb,ctx,la) = unzip lev !levels in let ctx = . 0)>. :: ctx in levels := zip (lb,ctx,la); .<.~x / .~y>.;; let new_ctx th = let () = levels := [] :: !levels in let r = th () in match !levels with (h::t) -> begin levels := t; List.fold_right (fun e z -> .<(.~e; .~z)>.) h r end ;; let test2 = . .~(new_ctx (fun () -> .<.~complex_code + .~(guarded_div 0 .<10>. ..)>.))>.;; (* val test2 : ('_a, int -> int) code = . assert (y_4 <> 0); (begin (print_endline "complex!"); 54 end + (10 / y_4))>. *) let test3 = . .~(new_ctx (fun () -> . .~(new_ctx (fun () -> .<.~complex_code + .~(guarded_div 1 .. ..)>.))>.))>.;; (* val test3 : ('_a, int -> int -> int) code = . assert (y_5 <> 0); fun x_6 -> (begin (print_endline "complex!"); 54 end + (x_6 / y_5))>. *) let test4 = . .~(new_ctx (fun () -> .<(fun x -> .~(new_ctx (fun () -> .<.~complex_code + .~(guarded_div 1 .. ..)>.))) (.~complex_code + .~(guarded_div 0 .<5>. ..))>.))>.;; (* val test4 : ('_a, int -> int) code = . assert (y_7 <> 0); assert ((y_7 - 1) <> 0); ((fun x_8 -> (begin (print_endline "complex!"); 54 end + (x_8 / y_7))) (begin (print_endline "complex!"); 54 end + (5 / (y_7 - 1))))>. *) (* Danger! If we make a mistake, we can indeed end up with the ill-scoped code. *) let test3bad = . .~(new_ctx (fun () -> . .~(new_ctx (fun () -> .<.~complex_code + .~(guarded_div 1 .. ..)>.))>.))>.;; (* val test3bad : ('_a, int -> int -> int) code = . assert (x_35 <> 0); fun x_35 -> (begin (print_endline "complex!"); 54 end + (y_34 / x_35))>. *) (* ---------------------------------------------------------------------- *) (* Loop tiling: let-insertion across binders Our example is loop tiling in vector-matrix multiplication *) (* A sample matrix *) let a = Array.make_matrix 5 10 0;; let dimx a = Array.length a;; let dimy a = Array.length a.(0);; for i=0 to dimx a - 1 do for j=0 to dimy a - 1 do a.(i).(j) <- i + j done done;; (* A sample vector *) let v = Array.mapi (fun i _ -> i + 1) (Array.make 10 0);; (* val v : int array = [|1; 2; 3; 4; 5; 6; 7; 8; 9; 10|] *) let v' = Array.make 5 0;; (* Ordinary matrix-vector product: v' <- a * v *) (* We assume v is a long vector *) let mvmul0 n m a v v' = Array.fill v' 0 n 0; for j=0 to m-1 do for i=0 to n-1 do v'.(i) <- v'.(i) + a.(i).(j) * v.(j) done done;; mvmul0 (dimx a) (dimy a) a v v';; let [|330; 385; 440; 495; 550|] = v';; (* Tiled computation, with the tile size b *) let rec sloop lb ub step body = if lb > ub then () else begin body lb; sloop (lb + step) ub step body end ;; (* The advantage of tiling: during multiplication, the array v is traversed repeatedly. By assumption, the array is long and so won't fit in cache. A tiled program deals with the array a chunk (of size b) at a time. A jj-th chunk will be loaded into cache, used several times. When we are finished with the chunk, it won't be needed again and can safely be replaced in cache with another chunk. *) let mvmul1 b n m a v v' = Array.fill v' 0 n 0; sloop 0 (m-1) b (fun jj -> sloop 0 (n-1) b (fun ii -> for j=jj to min (jj+b-1) (m-1) do for i=ii to min (ii+b-1) (n-1) do v'.(i) <- v'.(i) + a.(i).(j) * v.(j) done done)); ;; (* Testing with various tile sizes b *) mvmul1 1 (dimx a) (dimy a) a v v';; let [|330; 385; 440; 495; 550|] = v';; mvmul1 2 (dimx a) (dimy a) a v v';; let [|330; 385; 440; 495; 550|] = v';; mvmul1 3 (dimx a) (dimy a) a v v';; let [|330; 385; 440; 495; 550|] = v';; mvmul1 4 (dimx a) (dimy a) a v v';; let [|330; 385; 440; 495; 550|] = v';; mvmul1 5 (dimx a) (dimy a) a v v';; let [|330; 385; 440; 495; 550|] = v';; (* Straightforward staging of mvmul0, assuming statically known n and m *) let gen_regular_loop lb ub body = ..) done>.;; let gmvmul0 n m = . Array.fill v' 0 n 0; .~(gen_regular_loop 0 (m-1) (fun j -> gen_regular_loop 0 (n-1) (fun i -> ..))) >. ;; gmvmul0 5 10;; (* - : ('a, int array array -> int array -> int array -> unit) code = . fun v_2 -> fun v'_3 -> (Array.fill v'_3 0 5 0); for i_4 = 0 to 9 do for i_5 = 0 to 4 do v'_3.(i_5) <- (v'_3.(i_5) + ((a_1.(i_5)).(i_4) * v_2.(i_4))) done done>. *) (* Abstracting the loop *) let gmvmul1 loop1 loop2 n m = . Array.fill v' 0 n 0; .~(loop1 0 (m-1) (fun j -> loop2 0 (n-1) (fun i -> ..))) >. ;; let gcode1 = gmvmul1 gen_regular_loop gen_regular_loop 5 10;; (* the same as above *) Array.fill v' 0 5 0;; (.!gcode1) a v v';; let [|330; 385; 440; 495; 550|] = v';; (* Split the loop in two: strip-mining *) let gen_nested_loop b lb ub body = . for i=ii to min (ii+b-1) ub do .~(body ..) done)>.;; let gcode2 = gmvmul1 (gen_nested_loop 2) (gen_nested_loop 2) 5 10;; (* val gcode2 : ('a, int array array -> int array -> int array -> unit) code = . fun v_22 -> fun v'_23 -> (Array.fill v'_23 0 5 0); (((* cross-stage persistent value (as id: sloop) *)) 0 9 2 (fun ii_24 -> for i_25 = ii_24 to (min ((ii_24 + 2) - 1) 9) do (((* cross-stage persistent value (as id: sloop) *)) 0 4 2 (fun ii_26 -> for i_27 = ii_26 to (min ((ii_26 + 2) - 1) 4) do v'_23.(i_27) <- (v'_23.(i_27) + ((a_21.(i_27)).(i_25) * v_22.(i_25))) done)) done))>. *) Array.fill v' 0 5 0;; (.!gcode2) a v v';; let [|330; 385; 440; 495; 550|] = v';; (* #directory "/home/oleg/Cache/ncaml/caml-shift/";; #load "delimcc.cma";; *) open Delimcc;; (* Tiling *) let gen_tile_loop p b lb ub = shift p (fun k -> . .~(k (fun body -> ..) done>.)))>.);; (* val gen_tile_loop : ('a, unit) code Delimcc.prompt -> int -> int -> int -> (('a, int) code -> ('a, 'b) code) -> ('a, unit) code = *) let insert_here p loop = fun lb ub body -> push_prompt p (fun () -> loop lb ub body);; (* val insert_here : 'a Delimcc.prompt -> ('b -> 'c -> 'd -> 'a) -> 'b -> 'c -> 'd -> 'a = *) let tiled_code = let p = new_prompt () in gmvmul1 (insert_here p (gen_tile_loop p 2)) (gen_tile_loop p 2) 5 10 ;; (* val tiled_code : ('a, int array array -> int array -> int array -> unit) code = . fun v_29 -> fun v'_30 -> (Array.fill v'_30 0 5 0); (((* cross-stage persistent value (as id: sloop) *)) 0 9 2 (fun ii_31 -> (((* cross-stage persistent value (as id: sloop) *)) 0 4 2 (fun ii_33 -> for i_32 = ii_31 to (min ((ii_31 + 2) - 1) 9) do for i_34 = ii_33 to (min ((ii_33 + 2) - 1) 4) do v'_30.(i_34) <- (v'_30.(i_34) + ((a_28.(i_34)).(i_32) * v_29.(i_32))) done done))))>. *) Array.fill v' 0 5 0;; (.!tiled_code) a v v';; let [|330; 385; 440; 495; 550|] = v';;