(* Generators in OCaml
We translate the in-order tree traversal example from an old article
Generators in Icon, Python, and Scheme, 2004.
http://okmij.org/ftp/Scheme/enumerators-callcc.html#Generators
using OCaml and delimited continuations rather than call/cc + mutation.
The code is shorter, and it even types.
*)
(* A sample program Python programmers seem to be proud of, an in-order
traversal of a tree:
>>>> # A recursive generator that generates Tree leaves in in-order.
>>> def inorder(t):
... if t:
... for x in inorder(t.left):
... yield x
... yield t.label
... for x in inorder(t.right):
... yield x
Given below is the complete implementation in OCaml.
*)
open Delimcc;;
(* A few preliminaries: define the tree and build a sample tree *)
type label = int
type tree = Leaf | Node of label * tree * tree;;
let make_full_tree depth =
let rec loop label = function
| 0 -> Leaf
| n -> Node (label, loop (2*label) (pred n), loop (2*label+1) (pred n))
in loop 1 depth
;;
(* val make_full_tree : int -> tree = *)
let tree1 = make_full_tree 3;;
(*
val tree1 : tree =
Node (1, Node (2, Node (4, Leaf, Leaf), Node (5, Leaf, Leaf)),
Node (3, Node (6, Leaf, Leaf), Node (7, Leaf, Leaf)))
*)
(* In Python, `yield' is a keyword. In OCaml, it is a regular OCaml function.
Furthermore, it is a user-defined function, in one line of code.
To get generators there is no need to extend a language.
*)
let yield p v = shift0 p (fun k -> fun body -> body v; k () body);;
(* val yield : (('a -> 'b) -> 'c) Delimcc.prompt -> 'a -> unit = *)
(* The in_order function itself: compare with the Python version *)
let rec in_order p = function
| Leaf -> ()
| Node (n, left, right) ->
in_order p left;
yield p n;
in_order p right
;;
(* val in_order : ((label -> unit) -> unit) prompt -> tree -> unit = *)
(* The enumerator: the for-loop essentially *)
let enumerate coll body =
let p = new_prompt () in
(push_prompt p (fun () -> in_order p coll; fun _ -> ())) body
;;
(* val enumerate : tree -> (label -> unit) -> unit = *)
(* The result of the traversing the sample tree *)
enumerate tree1 (fun x -> Printf.printf "%d " x);;
(* 4 2 5 1 6 3 7 *)