(* Meta-programming with delimited continuations *)
(* Writing an efficient specialized version of Gibonacci,
without using any fix-point combinators, etc.
*)
(* The following code relies on the delimcc library:
http://okmij.org/ftp/continuations/
Please make sure dlldelimcc.so is in your LD_LIBRARY_PATH
or in ld.conf-described paths.
*)
open Delimcc
open Printf;;
(* The original Gibonacci *)
let rec gib x y n =
if n = 0 then x else
if n = 1 then y else
gib x y (n-1) + gib x y (n-2)
;;
(* gib 1 1 5;; gives 8 *)
(* Naively staged Gibonacci, to the statically known value of n *)
let rec gibgen x y n =
if n = 0 then x else
if n = 1 then y else
.<.~(gibgen x y (n-1)) + .~(gibgen x y (n-2))>.
;;
(*
val gibgen : ('a, int) code -> ('a, int) code -> int -> ('a, int) code =
*)
let test_gibgen n = . .~(gibgen .. .. n)>.;;
(* val test_gibgen : int -> ('a, int -> int -> int) code = *)
let test_gibgen5 = test_gibgen 5;;
(*
val test_gibgen5 : ('a, int -> int -> int) code =
.
fun y_2 -> ((((y_2 + x_1) + y_2) + (y_2 + x_1)) + ((y_2 + x_1) + y_2))>.
*)
(* Clearly, the naive Gibonacci is inefficient.
The specialized code test_gibgen5 shows why:
the computation (y_2 + x_1) is repeated thrice within such a short fragment
*)
(* To improve Gibonacci, we have to add memoization *)
(* First we define the abstract data types of memoization table
with integer keys *)
(* For the sake of the closest correspondence with circle-shift.elf,
we use pairs to emulate 'a option data type. In the rest of the
code, 'a maybe is an abstract data type.
*)
module Maybe :
sig
type 'a maybe
val nothing : 'a maybe
val just : 'a -> 'a maybe
val ifnothing : 'a maybe -> bool
val fromjust : 'a maybe -> 'a
end = struct
type 'a maybe = bool * (unit -> 'a)
let nothing = (true, fun () -> failwith "nothing")
let just x = (false, fun () -> x)
let ifnothing = fst
let fromjust x = snd x ()
end;;
open Maybe;;
module Memo :
sig
type 'a memo
val empty : 'a memo
val lookup : int -> 'a memo -> 'a maybe
val ext : 'a memo -> int -> 'a -> 'a memo
end = struct
(* The following implementation uses functions, for compatibility
with circle-shift.elf. The rest of the code does not depend
on the implementation and can't even know it.
*)
type 'a memo = int -> 'a maybe
let empty = fun key -> nothing
let lookup = fun n table -> table n
let ext = fun table n v ->
fun key -> if key = n then just v else table key
end;;
open Memo;;
(* we can write the standard, textbook memoizer *)
(* It memoizes the result of the application of function f to the integer n.
*)
let new_memo () =
let table = ref empty in
fun f n ->
let r = lookup n !table in
if ifnothing r
then (* memo table miss *)
let v = f n in (* compute the value *)
table := ext !table n v; v
else fromjust r (* else return the memoized value *)
;;
(* Now we can memoize Gibonacci and obtain an improved version *)
let gibo x y =
let memo = new_memo () in
let rec loop n =
if n = 0 then x else
if n = 1 then y else
memo loop (n-1) + memo loop (n-2)
in loop
;;
let test_gibo5 = gibo 1 1 5;; (* 8 *)
let test_gibo30 = gibo 1 1 30;;
(* 1346269, without memoization it would've taken a while...*)
(* We may try to stage it, naively *)
let sgibo_naive x y =
let memo = new_memo () in
let rec loop n =
if n = 0 then x else
if n = 1 then y else
.<.~(memo loop (n-1)) + .~(memo loop (n-2))>.
in loop
;;
let test_sgibo_naive5 =
. .~(sgibo_naive .. .. 5)>.;;
(*
val test_sgibo_naive5 : ('a, int -> int -> int) code =
.
fun y_2 -> ((((y_2 + x_1) + y_2) + (y_2 + x_1)) + ((y_2 + x_1) + y_2))>.
*)
(* Alas, the result shows the duplication of computations. The result of
loop, in sgibo_naive, is a present-stage value but future-stage
computation. We saved effort at the present stage but we saved no
computation at the future stage. We need let insertion to save
future-stage computations.
*)
(* But the let-insertion isn't that easy! The naive version *)
let sgibo1_naive x y =
let memo = new_memo () in
let rec loop n =
if n = 0 then x else
if n = 1 then y else
..
in loop
;;
let test_sgibo1_naive5 =
. .~(sgibo1_naive .. .. 5)>.;;
(*
val test_sgibo1_naive5 : ('a, int -> int -> int) code =
.
fun y_2 ->
let t1_3 =
let t1_5 =
let t1_7 = let t1_9 = y_2 and t2_10 = x_1 in (t1_9 + t2_10)
and t2_8 = y_2 in
(t1_7 + t2_8)
and t2_6 = let t1_9 = y_2 and t2_10 = x_1 in (t1_9 + t2_10) in
(t1_5 + t2_6)
and t2_4 =
let t1_7 = let t1_9 = y_2 and t2_10 = x_1 in (t1_9 + t2_10)
and t2_8 = y_2 in
(t1_7 + t2_8) in
(t1_3 + t2_4)>.
*)
(* the naive version obviously doesn't do any good: It creates even bigger
duplicated computations *)
(* We have to change the memo table implementation. Our memo table should
contain only those future-stage computations that are future-stage
values. So, we need to do let-insertion after we detected a miss.
But for that, we have to re-write everything in CPS. We have to write
the memo-table implementation in CPS:
*)
let new_memo_let_CPS () =
let table = ref empty in
fun f n k ->
let r = lookup n !table in
if ifnothing r
then (* memo table miss *)
f n (* compute the value *)
(fun v -> ..; k ..)>.)
else k (fromjust r) (* else return the memoized value *)
;;
(* but we also must re-write sgibo in CPS! *)
let sgibo_CPS x y =
let memo = new_memo_let_CPS () in
let rec loop n k =
if n = 0 then k x else
if n = 1 then k y else
memo loop (n-1) (fun r1 ->
memo loop (n-2) (fun r2 ->
k .<.~r1 + .~r2>.))
in loop
;;
let test_sgibo_CPS5 =
. .~(sgibo_CPS .. .. 5 (fun x ->x))>.;;
(*
val test_sgibo_CPS5 : ('a, int -> int -> int) code =
.
fun y_2 ->
let t_3 = y_2 in
let t_4 = x_1 in
let t_5 = (t_3 + t_4) in
let t_6 = (t_5 + t_3) in let t_7 = (t_6 + t_5) in (t_7 + t_6)>.
*)
(* Now we get the desired result: no duplicate computations.
At the cost of changing all of our code, even sgibo, in CPS.
Memoization is no longer easy -- it becomes very intrusive.
*)
(* Not only this approach inconvenient, it is also unsafe.
The mutation in maintaining the table in new_memo_let_CPS
results in unsafety. We store in the `global' memo table code
values like .. -- with variables bound in the scope
that is more narrow than the dynamic scope of the table.
*)
(* Let's make a simple `pessimization' of sgibo1_CPS. Let's suppose the
programmer didn't want to rewrite gib in CPS, and continued to use
memoization in `direct style'.
*)
let sgibo1_bad x y =
let memo = new_memo_let_CPS () in
let rec loop n =
if n = 0 then x else
if n = 1 then y else
.<.~(memo (fun n k -> k (loop n)) (n-1) (fun x ->x)) +
.~(memo (fun n k -> k (loop n)) (n-2) (fun x ->x))>.
in loop
;;
let test_sgibo1_bad =
. .~(sgibo1_bad .. .. 5)>.;;
(*
val test_sgibo1_bad : ('a, int -> int -> int) code =
.
fun y_2 ->
(let t_7 = (t_6 + t_5) in t_7 +
let t_6 =
(let t_5 = (t_3 + let t_4 = x_1 in t_4) in t_5 + let t_3 = y_2 in t_3) in
t_6)>.
*)
(* Although the result appears efficient -- only four additions --
it is incorrect! Please notice how variable t_6 is referenced before
it is bound. Attempting to run this code gives
.! test_sgibo1_bad;;
Unbound value t_6
Exception: Trx.TypeCheckingError.
*)
(* To rely on MetaOCaml's type soundness, we must not use any side effects
in our code generator. We could write our memoizing gib without state,
by including state-passing in our continuation-passing, as follows.
*)
let new_memo_let_CPS_only f n k table =
let r = lookup n table in
if ifnothing r
then
f n
(fun v table -> .. (ext table n ..))>.)
table
else
k (fromjust r) table
;;
let sgibo_CPS_only x y =
let memo = new_memo_let_CPS_only in
let rec loop n k =
if n = 0 then k x else
if n = 1 then k y else
memo loop (n-1) (fun r1 ->
memo loop (n-2) (fun r2 ->
k .<.~r1 + .~r2>.))
in loop
;;
let test_sgibo_CPS_only5 =
. .~(sgibo_CPS_only .. .. 5 (fun r table -> r) empty)>.;;
(* Our solution, with delimited continuations *)
(* We write the let-insertion memo table thusly: *)
let new_memo_let p =
fun f n -> (* do the lookup first *)
let r = shift p (fun k table -> k (lookup n table) table) in
if ifnothing r
then let v = f n in
shift p (fun k table ->
.. in
k .. table')>.)
else fromjust r (* value found *)
;;
let top_fn thunk =
let p = new_prompt () in
(push_prompt p (fun () -> let v = thunk p in fun table -> v)) empty
;;
(* We now write the staged optimal Gibonacci in the direct style, just like
sgibo_naive. The only difference from sgibo_naive is the use of
new_memo_let and the corresponding p argument (which is the artefact
of the OCaml's delimcc library. The extra argument is not needed in
our calculus, or its Twelf implementation.)
*)
let sgibo p x y =
let memo = new_memo_let p in
let rec loop n =
if n = 0 then x else
if n = 1 then y else
.<.~(memo loop (n-1)) + .~(memo loop (n-2))>.
in loop
;;
let test_sgibon n =
. .~(top_fn(fun p -> sgibo p .. .. n))>.;;
let test_sgibo5 = test_sgibon 5;;
(*
val test_sgibo5 : ('a, int -> int -> int) code =
.
fun y_2 ->
let t_3 = y_2 in
let t_4 = x_1 in
let t_5 = (t_3 + t_4) in
let t_6 = (t_5 + t_3) in let t_7 = (t_6 + t_5) in (t_7 + t_6)>.
*)
let test_sgibo5r = (.! test_sgibo5) 1 1;; (* 8 *)
let test_sgibo8 = test_sgibon 8;;
(*
val test_sgibo8 : ('a, int -> int -> int) code =
.
fun y_2 ->
let t_3 = x_1 in
let t_4 = y_2 in
let t_5 = (t_4 + t_3) in
let t_6 = (t_5 + t_4) in
let t_7 = (t_6 + t_5) in
let t_8 = (t_7 + t_6) in
let t_9 = (t_8 + t_7) in let t_10 = (t_9 + t_8) in (t_10 + t_9)>.
*)
let test_sgibo8r = (.! test_sgibo8) 1 1;; (* 34 *)
(* Our calculus is safe, its Twelf implementation is safe, but MetaOCaml
does not enforce at present our restriction. We have to verify it
manually. See the end of the file fib.ml for an example.
*)