(* The library to support a probabilistic embedded domain specific language *)
(* Probabilistic inference procedures *)
(* These procedures implement different exploration strategies over a
lazy search tree, transforming a search tree into a flatter form *)
open Ptypes
open Printf
(* We also use polymorphic maps, pMap.* in the current directory *)
;;
(* ------------------------------------------------------------------------ *)
(* Exact inference strategies: deterministic search procedures *)
(* Explore and flatten the tree: perform exact inference to the given depth *)
(* If maxdepth is None, explore as far as possible *)
let explore (maxdepth : int option) (choices : 'a pV) : 'a pV =
let rec loop p depth down choices ((ans,susp) as answers) =
match choices with
| [] -> answers
| (pt,V v)::rest -> loop p depth down rest
(PMap.insert_with (+.) v (pt *. p) ans,susp)
| (pt,C t)::rest when down ->
let down' = match maxdepth with Some x -> depth < x | None -> true in
loop p depth down rest
(loop (pt *. p) (succ depth) down' (t ()) answers)
| (pt, c)::rest -> loop p depth down rest (ans,(pt *. p,c)::susp) in
let (ans,susp) = loop 1.0 0 true choices (PMap.empty,[])
in PMap.foldi (fun v p a -> (p,V v)::a) ans susp;;
let nearly_one = 1.0 -. 1e-7;; (* For robust comparison with 1.0 *)
(* Explore but do not flatten the tree:
perform exact inference to the given depth
We still pick out all the produced answers and note the failures. *)
let shallow_explore maxdepth (choices : 'a pV) : 'a pV =
let add_answer pcontrib v mp = PMap.insert_with (+.) v pcontrib mp in
let rec loop pc depth ans acc = function
| [] -> (ans,acc)
| (p,V v)::rest -> loop pc depth (add_answer (p *. pc) v ans) acc rest
| c::rest when depth >= maxdepth -> loop pc depth ans (c::acc) rest
| (p,C t)::rest ->
let (ans,ch) = loop (pc *. p) (succ depth) ans [] (t ()) in
let ptotal = List.fold_left (fun pa (p,_) -> pa +. p) 0.0 ch in
let acc =
if ptotal = 0.0 then acc
else if ptotal < nearly_one then
(p *. ptotal, let ch = List.map (fun (p,x) -> (p /. ptotal,x)) ch
in C (fun () -> ch))::acc
else (p, C (fun () -> ch))::acc in
loop pc depth ans acc rest
in
let (ans,susp) = loop 1.0 0 PMap.empty [] choices
in PMap.foldi (fun v p a -> (p,V v)::a) ans susp;;
(* Explore the tree till we find the first success -- the first leaf
(V v) -- and return the resulting tree. If the tree turns out to
have no leaves, return the empty tree.
*)
let rec first_success: 'a pV -> 'a pV = function
| [] -> []
| ((_,Ptypes.V _) :: _) as l -> l
| (pt,Ptypes.C t) :: rest -> (* Unclear: expand and do BFS *)
first_success (rest @ List.map (fun (p,v) -> (pt *. p,v)) (t ()))
(* ------------------------------------------------------------------------ *)
(* Semi-Exact inference strategies: deterministic search procedures *)
(* over a subtree *)
(* A bounds estimator: obtain the bounds on the probabilty
of evidence.
The object probabilistic program must return (), or fail.
Currently I don't know how to assign bounds when several values
may be returned.
This restriction seems consistent with Problog, which too determines
bounds on the probability of a query.
We traverse the tree breadth-first. If the number of unexplored branches
raises above the threshold, we discard the branch with the lowest
probability mass. A discarded branch with the probability mass p contributes
0 to the current lower bound and p to the current upper bound.
A successful branch with mass p contributes p to both bounds.
A failed branch contributes 0 to both bounds.
*)
let bounded_explore maxsize (choices : unit pV) : prob * prob =
let rec loop explore pc low high jqueue jqsize = function
| [] -> next low high jqueue jqsize
| (p,V _)::rest ->
let pe = pc *. p in
loop explore pc (low +. pe) (high +. pe) jqueue jqsize rest
| (p,C t)::rest ->
if explore then
loop explore pc low high
(PMap.insert_with (@) (pc *. p) [t] jqueue) (jqsize + 1)
rest
else loop explore pc low (high +. pc *. p) jqueue jqsize rest
and next low high jqueue = function
| 0 -> (low,high)
| jqsize when jqsize < maxsize ->
let ((p,t::ts),jqueue) = PMap.delete_find_max jqueue in
let jqueue = if ts = [] then jqueue else PMap.add p ts jqueue in
loop true p low high jqueue (jqsize - 1) (t ())
| jqsize ->
let ((p,t::ts),jqueue) = PMap.delete_find_min jqueue in
let jqueue = if ts = [] then jqueue else PMap.add p ts jqueue in
loop false p low high jqueue (jqsize - 1) (t ())
in loop true 1.0 0. 0. PMap.empty 0 choices
;;
(* The convergence, however, is not as good as sampling on sorted.ml *)
(* ------------------------------------------------------------------------ *)
(* Approximate inference strategies: *)
(* Trace a few paths from the root to a leaf of the search tree *)
(* The following procedures are non-deterministic; they use a given selector*)
(* procedure, of the type 'selector', to chose among the alternatives. *)
(* For top-level inference, the selector uses system random generator. *)
(* Naive, rejection sampling: the baseline *)
let rejection_sample_dist (selector: 'a vc selector) nsamples ch : 'a pV =
let rec loop pcontrib ans = function
| [(p,V v)] -> PMap.insert_with (+.) v (p *. pcontrib) ans
| [] -> ans
| [(p,C th)] -> loop (p *. pcontrib) ans (th ())
| ch -> (* choosing one thread randomly *)
let (ptotal,th) = selector ch in
loop (pcontrib *. ptotal) ans [(1.0,th)] in
let rec driver ch ans = function
| 0 -> let ns = float_of_int nsamples in
printf "rejection_sample: done %d worlds\n" nsamples;
PMap.foldi (fun v p a -> (p /. ns,V v)::a) ans []
| n -> driver ch (loop 1.0 ans ch) (pred n) in
driver ch PMap.empty nsamples
;;
(* Sample a distribution with a look-ahead exploration *)
(* A single sample can give us more than one data point: if one of
the choices is a definite value, we note it right away, with
its weight. The rest of the choices will be re-scaled automatically.
*)
(* Given a sampler, a function 'seed->'seed, run it a certain number
of times and return the resulting seed and the number of runs
*)
type sample_runner =
{sample_runner : 'seed. 'seed -> ('seed -> 'seed) -> 'seed * int};;
let sample_dist (selector : 'a pV selector) (sample_runner : sample_runner)
ch : 'a pV =
let look_ahead pcontrib (ans,acc) = function (* explore the branch a bit *)
| (p,V v) -> (PMap.insert_with (+.) v (p *. pcontrib) ans, acc)
| (p,C t) -> begin
match t () with
| [] -> (ans,acc)
| [(p1,V v)] ->
(PMap.insert_with (+.) v (p *. p1 *. pcontrib) ans, acc)
| ch ->
let ptotal = List.fold_left (fun pa (p,_) -> pa +. p) 0.0 ch in
(ans,
if ptotal < nearly_one then
(p *. ptotal, List.map (fun (p,x) -> (p /. ptotal,x)) ch)::acc
else (p, ch)::acc)
end in
let rec loop pcontrib ans = function
| [(p,V v)] -> PMap.insert_with (+.) v (p *. pcontrib) ans
| [] -> ans
| [(p,C th)] -> loop (p *. pcontrib) ans (th ())
| ch -> (* choosing one thread randomly *)
begin
match List.fold_left (look_ahead pcontrib) (ans,[]) ch with
| (ans,[]) -> ans
| (ans,cch) ->
let (ptotal,th) = selector cch in
loop (pcontrib *. ptotal) ans th end in
let toploop pcontrib ans cch = (* cch are already pre-explored *)
let (ptotal,th) = selector cch in
loop (pcontrib *. ptotal) ans th in
let driver pcontrib vals cch =
let (ans,nsamples) =
sample_runner.sample_runner PMap.empty
(fun ans -> toploop pcontrib ans cch) in
let ns = float_of_int nsamples in
let ans = PMap.foldi
(fun v p ans ->
PMap.insert_with (+.) v (ns *. p) ans) vals ans in
printf "sample_importance: done %d worlds\n" nsamples;
PMap.foldi (fun v p a -> (p /. ns,V v)::a) ans [] in
let rec make_threads pcontrib ans ch = (* pre-explore initial threads *)
match List.fold_left (look_ahead pcontrib) (ans,[]) ch with
| (ans,[]) -> (* pre-exploration solved the problem *)
PMap.foldi (fun v p a -> (p,V v)::a) ans []
| (ans,[(p,ch)]) -> (* only one choice, make more *)
make_threads (pcontrib *. p) ans ch
(* List.rev is for literal compatibility with an earlier version *)
| (ans,cch) -> driver pcontrib ans (List.rev cch)
in
make_threads 1.0 PMap.empty ch
;;
(* Another idea for a better sampler: given the list of threads, the pV
structure, split in in two and sample the two parts separately; then combine the
results. Preferably sample all high-weight threads separately from low-weight
threads, so that low-probability threads are not starved by high-probability
ones.
*)
(*
The sample_dist above works great; yet there are cases where one-step
look-ahead is insufficient. Consider the following program
(* select a random point with 0..9 square *)
let random_pos () = (uniform 10, uniform 10)
let model () =
let _ = geometric_bounded 3 0.98 in
let np = 0 in
let np = if not (flip 0.98) then np else
if random_pos () <> (3,5) then fail () else succ np in
let np = if not (flip 0.98) then np else
if random_pos () <> (3,5) then fail () else succ np in
let np = if not (flip 0.98) then np else
if random_pos () <> (3,5) then fail () else succ np in
if np <> 3 then fail ()
;;
let [(9.41192000000000293e-07, V ())] =
exact_reify model;;
Alas, doing importance sampling on the above model fails.
sample_importance (random_selector 17) 1500 model;;
gives no samples. The reason of course is the final test: we have
to go through a long series of choices before we do the test, at which
point it fails. We should bring the evidence checking closer to the
point of choice... Perhaps lightweiht constraint solving may help?
Here is a different idea for an improved importance sampling:
when the sampling process encounters a failure, as in these lines
let rec loop pcontrib ans = function
| [(p,V v)] -> PMap.insert_with (+.) v (p *. pcontrib) ans
| [] -> ans (* failure detected *)
accumulate pcontrib (probably weighed by 1/nsamples) as the failure
probability of the main thread, selected by the driver.
Use the failure probability to scale the probabilities of the main
thread to affect the selection by the main driver (we should
keep the scaling factor around to correct for the probabilities of the
found answers). So that the more
failures are reported by a thread, the less likely the driver will
select it.
*)
(* The following reification procedures didn't show any compelling advantage *)
(*
let observe_la test maxdepth th =
reflect (reify (Some maxdepth) (fun () -> (observe test th)));;
(* Now we not only force the delayed branch, but also flatten it,
up to the given depth. The simple sample_reify is the case of
maxdepth = 0
*)
let sample_explore_reify
randomseed nsamples maxdepth (thunk : unit -> 'a) : 'a pV =
sample_dist randomseed nsamples (explore (Some maxdepth))
(reify (Some maxdepth) thunk);;
(* For large models, the memoizer may run out of memory... *)
let sample_explore_reify_memoize
randomseed nsamples maxdepth (thunk : unit -> 'a) : 'a pV =
let mem_explore ch =
List.map (function (p,C th) -> (p, let v = lazy (explore (Some maxdepth)
(th ()))
in C (fun () -> Lazy.force v))
| (p,V v) -> (p, V v)) ch in
sample_dist randomseed nsamples mem_explore
(reify (Some maxdepth) thunk);;
*)
(* ------------------------------------------------------------------------ *)
(* Utilities *)
(* Estimate the approximation error by computing the mean and the
* standard deviation over multiple runs of the sampler *)
let statistics (randomseed1, randomseed2) (sampler : (int -> 'a pV)) =
let answers = Hashtbl.create 17 in
for randomseed = randomseed1 to randomseed2 do
List.iter
(fun (p, V v) ->
try let (pold, p2old) = Hashtbl.find answers v in
Hashtbl.replace answers v (pold +. p, p2old +. p *. p)
with Not_found -> Hashtbl.add answers v (p, p *. p))
(sampler randomseed)
done;
let n = float_of_int (randomseed2 - randomseed1 + 1) in
Hashtbl.fold
(fun v (p,p2) a -> (v, p /. n, sqrt ((p2 -. p *. p /. n) /. n)) :: a)
answers [];;
(* Normalize the distribution. We also return the total probability mass,
the normalization constant.
*)
let normalize l =
let total = List.fold_left (fun acc (p,_) -> p +. acc) 0.0 l in
(total, List.map (fun (p,v) -> (p /. total,v)) l);;
(* Time the execution *)
let timeit thunk =
let time_start = Sys.time () in
let r = thunk () in
Printf.printf "\nTime spent: %g sec\n" (Sys.time () -. time_start);
r;;