-- implementation of probM (to the extent described in the DSL paper)
-- via the ContT monad
module ProbM where
import qualified Data.Map as Map
import Prelude hiding (flip)
import Control.Monad.Cont
import Debug.Trace
import System.Random
type Prob = Float
-- The PV tree representing the distribution
data VC a = V a | C (PV a)
type PV a = [(Prob, VC a)]
instance Show a => Show (VC a) where
show (V a) = show a
show (C _) = "C"
pv_unit :: a -> PV a
pv_unit x = [(1.0, V x)]
-- The normalization (flattening) of the tree
explore :: Ord a => Maybe Int -> PV a -> PV a
explore maxdepth choices = loop [] (make_jobque Map.empty 1.0 0 choices)
where
add_answer pcontrib v anss = Map.insertWith (+) v pcontrib anss
make_jobque anss p depth choices =
foldr (\e (anss,jq) ->
case e of
(pt,V v) -> (add_answer (pt * p) v anss,jq)
(pt,C t) -> (anss,(pt * p, depth, t):jq))
(anss,[]) choices
loop susp (anss,[]) = Map.foldWithKey (\v p a -> (p,V v):a) susp anss
loop susp (anss,(p, d, t):rest) | maybe True (d <) maxdepth =
let (anss',newjq) = make_jobque anss p (succ d) t in
loop susp (anss',newjq ++ rest)
loop susp (anss,(p, d, t):rest) = loop ((p,C t):susp) (anss,rest)
-- Main probabilistic operations
type ProbM r a = Cont (PV r) a
dist :: [(Prob, a)] -> ProbM r a
dist lst = Cont $ \k -> map (\(p,v) -> (p,C (k v))) lst
fail_ :: ProbM r a
fail_ = dist []
reify0 :: ProbM a a -> PV a
reify0 m = runCont m pv_unit
exact_reify :: Ord a => ProbM a a -> PV a
exact_reify = explore Nothing . reify0
reflect :: PV a -> ProbM r a
reflect lst | trace (show (length lst) ++ " worlds to reflect") False =
undefined
reflect lst = Cont $ \k -> map (f k) lst
where
f k (p, V v) = (p, C (k v))
f k (p, C t) = (p, C (map (f k) t))
-- Derived (convenience) operations
flip p = dist [(p, True), (1-p, False)]
-- Uniform choice from [0..(n-1)]
uniform :: Int -> ProbM r Int
uniform 1 = return 0
uniform n | n > 1 = loop 0.0 [] (pred n)
where
p = 1.0 / fromIntegral n
loop pacc acc 0 = dist ((1 - pacc, 0):acc)
loop pacc acc i = loop (pacc + p) ((p,i):acc) (pred i)
-- Naive, rejection sampling: for comparison
type RandSeed = Int
rejection_sample_dist :: Ord a => RandSeed -> Int -> PV a -> PV a
rejection_sample_dist randomseed nsamples ch =
let (answers,_) =
ntimes nsamples (loop 1.0 ch) (Map.empty, mkStdGen randomseed)
in Map.foldWithKey (\v p a -> (p / fromIntegral nsamples,V v):a)
[] answers
where
ntimes 0 f seed = seed
ntimes n f seed = ntimes (pred n) f (f seed)
add_answer pcontrib v answers = Map.insertWith (+) v pcontrib answers
loop pcontrib ch (answers,g) =
case ch of
[(p,V v)] -> (add_answer (p * pcontrib) v answers, g)
[] -> (answers,g)
[(p,C th)] -> loop (p * pcontrib) th (answers,g)
ch -> -- choosing one thread randomly
let ptotal = foldr (\(p,_) pa -> pa + p) 0 ch
(r',g') = random g -- 0 <= r' < 1
r = r' * ptotal -- 0 <= r < ptotal
selection pcum' ((p,th):rest) =
let pcum = pcum' + p in
if r < pcum
then loop (pcontrib * ptotal) [(1.0,th)] (answers,g')
else selection pcum rest
in selection 0.0 ch
-- Sample a distribution with a look-ahead exploration
{-
A single sample can give us more than one data point: if one of
the choices is a definite value, we note it right away, with
its weight. The rest of the choices will be re-scaled automatically.
The argument explorer is the look-ahead explorer of the branch,
giving an expanded and flattened branch.
-}
sample_dist :: Ord a => RandSeed -> Int -> (PV a -> PV a) -> PV a -> PV a
sample_dist randomseed nsamples explorer ch =
let (answers',choices) = foldr (\e (answers,jq) ->
case e of
(p,V v) -> (add_answer (fromIntegral nsamples * p) v answers,jq)
e -> (answers,e:jq)) (Map.empty,[]) ch in
let (answers,_) =
ntimes nsamples (loop 1.0 choices) (answers', mkStdGen randomseed)
in Map.foldWithKey (\v p a -> (p / fromIntegral nsamples,V v):a)
[] answers
where
ntimes 0 f seed = seed
ntimes n f seed = ntimes (pred n) f (f seed)
add_answer pcontrib v answers = Map.insertWith (+) v pcontrib answers
get_ptotal ch = foldr (\(p,_) pa -> pa + p) 0 ch -- total weight of choices
loop pcontrib ch (answers,g) =
case ch of
[(p,V v)] -> (add_answer (p * pcontrib) v answers, g)
[] -> (answers,g)
[(p,C th)] -> loop (p * pcontrib) (explorer th) (answers,g)
ch -> -- choosing one thread randomly
make_selection pcontrib g (foldl (look_ahead pcontrib) (answers,[]) ch)
-- explore the branch a bit
look_ahead pcontrib (answers,acc) (p,V v)
= (add_answer (p * pcontrib) v answers, acc)
look_ahead pcontrib (answers,acc) (p,C t) =
case explorer t of
[] -> (answers,acc)
[(p1,V v)] -> (add_answer (p * p1 * pcontrib) v answers, acc)
ch ->
let ptotal = get_ptotal ch in
(answers, (p * ptotal, map (\(p,x) -> (p / ptotal,x)) ch):acc)
make_selection pcontrib g (answers,[]) = (answers,g)
make_selection pcontrib g (answers,ch) = -- choose a thread and commit to it
let ptotal = get_ptotal ch
(r',g') = random g -- 0 <= r' < 1
r = r' * ptotal -- 0 <= r < ptotal
selection pcum' ((p,th):rest) =
let pcum = pcum' + p in
if r < pcum
then loop (pcontrib * ptotal) th (answers,g')
else selection pcum rest
in selection 0.0 ch
sample_reify randomseed nsamples m =
sample_dist randomseed nsamples id (reify0 m)
-- Estimate the approximation error by computing the mean and the
-- standard deviation over multiple runs of sampler
statistics :: Ord a =>
(RandSeed, RandSeed) -> (RandSeed -> PV a) ->
[(a, Float, Float)]
statistics (randomseed1, randomseed2) sampler = loop Map.empty randomseed1
where
loop answers seed | seed > randomseed2 =
let n = fromIntegral (randomseed2 - randomseed1 + 1) in
Map.foldWithKey
(\v (p,p2) a -> (v, p / n, sqrt ((p2 - p * p / n) / (n-1))) : a)
[] answers
loop answers seed = loop (foldl add_answer answers (sampler seed))
(succ seed)
add_answer answers (p, V v) =
Map.insertWith (\(p,p2) (p',p2') -> (p+p', p2+p2')) v (p, p * p) answers
-- Sprinkler example
grass_model = do
rain <- flip 0.3
sprinkler <- flip 0.5
grass_is_wet <-
foldl1 (liftM2 (||)) [liftM (rain &&) (flip 0.9),
liftM (sprinkler &&) (flip 0.8),
flip 0.1]
if grass_is_wet then return rain else fail_
scp1 = reify0 grass_model
-- [(0.3,C),(0.7,C)]
scpe0 = explore (Just 0) scp1
-- [(0.7,C),(0.3,C)]
scpe1 = explore (Just 1) scp1
-- [(0.35,C),(0.35,C),(0.15,C),(0.15,C)]
scpe2 = explore (Just 2) scp1
-- [(3.5000008e-2,C),(0.315,C),(3.5000008e-2,C),(0.315,C),
-- (1.5000004e-2,C),(0.135,C),(1.5000004e-2,C),(0.135,C)]
scpe3 = explore (Just 3) scp1
-- [(7.000001e-3,C),(2.8000006e-2,C),(6.299999e-2,C),(0.252,C),
-- (7.000001e-3,C),(2.8000006e-2,C),(6.299999e-2,C),(0.252,C),
-- (3.0000007e-3,C),(1.2000004e-2,C),(2.6999999e-2,C),(0.108,C),
-- (3.0000007e-3,C),(1.2000004e-2,C),(2.6999999e-2,C),(0.108,C)]
scpe4 = explore (Just 4) scp1
{-
[(6.300001e-3,C),(7.0000015e-4,C),(2.5200006e-2,C),(2.8000006e-3,C),
(5.669999e-2,C),(6.2999995e-3,C),(0.2268,C),(2.52e-2,C),(6.300001e-3,C),
(7.0000015e-4,C),(2.5200006e-2,C),(2.8000006e-3,C),(5.669999e-2,C),
(6.2999995e-3,C),(0.2268,C),(2.52e-2,C),(2.7000005e-3,C),(3.0000007e-4,C),
(1.0800003e-2,C),(1.2000004e-3,C),(2.4299998e-2,C),(2.6999998e-3,C),
(9.72e-2,C),(1.08e-2,C),(2.7000005e-3,C),(3.0000007e-4,C),
(1.0800003e-2,C),(1.2000004e-3,C),(2.4299998e-2,C),(2.6999998e-3,C),
(9.72e-2,C),(1.08e-2,C)]
-}
scpe5 = explore (Just 5) scp1
-- [(0.322,False),(0.2838,True)]
-- Ken's test
talarm = do
earthquake <- flip 0.01
burglary <- flip 0.1
if earthquake
then if burglary then flip 0.99 else flip 0.2
else if burglary then flip 0.98 else flip 0.01
talarmr = exact_reify talarm
-- [(0.89128,V False),(0.10872,V True)]
{- test2.ibl
let x =
obs { z : 'a } in
dist [ 0.01 : { z = 'a, w = 'b },
0.02 : { z = 'a, w = 'c },
0.97 : { z = 'd, w = 'e } ]
in
let y =
if x.w == 'b
then dist [ 0.9 : true, 0.1 : false ]
else if x.w == 'c
then dist [ 0.6 : true, 0.4 : false ]
else dist [ 0.2 : true, 0.8 : false ]
in
y
-}
-- Could use better `records'
data TIBL2 = TIBL2{lz :: Char, lw :: Char}
testibl2 = do
x <- dist [ (0.01, TIBL2{lz = 'a', lw = 'b'}),
(0.02, TIBL2{lz = 'a', lw = 'c'}),
(0.97, TIBL2{lz = 'd', lw = 'e'}) ]
>>= (\x -> if lz x == 'a' then return x else fail_)
y <- if lw x == 'b' then flip 0.9
else if lw x == 'c' then flip 0.6
else flip 0.2
return y
testibl2_r = exact_reify testibl2
-- [(9.0e-3,False),(2.1e-2,True)]
-- OCaml: [(0.021, V true); (0.00900000000000000105, V false)]
flips_xor' n = loop n
where
loop 1 = flip 0.5
loop n = do
r <- reflect (exact_reify (loop (n-1)))
liftM (r /=) $ flip 0.5
rtflips_10xor' = exact_reify (flips_xor' 10)
{-
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
2 worlds to reflect
[(0.5,False),(0.5,True)]
-}
drunk_coin = do
x <- flip 0.5
lost <- flip 0.9
if lost then fail_ else return x
-- Compute AND of n tosses of the drunk coin
dcoin_and 1 = drunk_coin
-- The following is not lazy enough
-- If we got tail, we don't care about the other flips
-- dcoin_and n = liftM2 (&&) drunk_coin (dcoin_and (n-1))
dcoin_and n = drunk_coin >>= (\x -> if x then dcoin_and (n-1) else return False)
dcoin_and_exact = exact_reify (dcoin_and 10)
-- [(5.263159e-2,False),(9.765647e-14,True)]
{- Exact probability
let () = assert (
reify None (fun () -> dcoin_and 10)
= [(9.76562499999997764e-14, V true); (0.0526315789473632833, V false)]);;
(* reify: done; 11 accepted 10 rejected 0 left *)
(* Thus we managed to do with only 21 threads *)
-}
tflip2_shared = do -- sharing of flips
v <- flip 0.5
return $ v && v
tflip2_shared_rej = rejection_sample_dist 1 100 (reify0 tflip2_shared)
-- [(0.48,False),(0.52,True)]
-- OCaml: [(0.48, V true); (0.52, V false)])
tflip2_shared_rej'= rejection_sample_dist 17 100 (reify0 tflip2_shared)
-- [(0.47,False),(0.53,True)]
-- Pure rejection sample of drunk coin
dcoin_and_rej100 =
rejection_sample_dist 17 100 (reify0 $ dcoin_and 10)
-- [(5.0e-2,False)]
-- OCaml: [(0.01, V false)]);;
dcoin_and_rej10000 =
rejection_sample_dist 17 10000 (reify0 $ dcoin_and 10)
-- [(5.25e-2,False)]
-- OCaml: [(0.052, V false)]);;
-- Importance sampling
dcoin_and_imp100 =
sample_reify 17 100 (dcoin_and 10)
-- [(5.380622e-2,False)]
dcoin_and_imp5000 =
sample_reify 17 5000 (dcoin_and 10)
-- [(5.2586384e-2,False),(6.0000134e-14,True)]
-- OCaml:
-- [(1.19999999999999741e-13, V true); (0.0527161389576979306, V false)]);;
-- The result below is pretty good: the averages are nice, and
-- the variance is relatively low. Note that the estimate for True
-- is significant (to about 2.5 sigma)
dcoin_and_simp5000 =
statistics (1,50) (\seed -> sample_reify seed 5000 (dcoin_and 10))
-- [(False,5.2693073e-2,5.962375e-4),(True,9.400026e-14,3.686706e-14)]
-- ------------------------------------------------------------------------
-- Investigation of memoization and sharing
-- First we note that our exploration of the tree uses memoization
-- That is, as reified tree is explored by sampling, the result
-- of exploration is memoized.
full_2tree 1 = flip 0.5 >>= return . (:[])
full_2tree n = do
x <- flip 0.5
xs <- full_2tree (pred n)
return (x:xs)
full_2tree_obs n = do
t <- full_2tree n
if trace "full_2tree built" t == replicate n True
then return () else fail_
{-
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_2tree_obs 1)
full_2tree built -- printed 2 times
[(0.8,())]
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_2tree_obs 2)
full_2tree built -- printed 4 times
[(0.45,())]
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_2tree_obs 3)
full_2tree built -- printed 7 times (not 8! times)
[(0.2,())]
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_2tree_obs 4)
full_2tree built -- printed 12 times (not 16 or 20 times)
[(0.15,())]
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_2tree_obs 5)
full_2tree built -- printed 16 times
[(0.1,())]
Note that in the last example, `full_2tree built' is printed 16 times:
not 20 times (the number of samples) and not 32 times (the number of leaves).
That means that some of the samples explored already explored branches:
the value for that branch has already been computed and memoized.
*ProbM> sample_reify 17 20 (full_2tree_obs 1)
full_2tree built -- printed 2 times
[(0.5,())]
*ProbM> sample_reify 17 20 (full_2tree_obs 2)
full_2tree built -- printed 4 times
[(0.1,())]
*ProbM> sample_reify 17 20 (full_2tree_obs 3)
full_2tree built -- printed 8 times
[(0.1,())]
*ProbM> sample_reify 17 20 (full_2tree_obs 4)
full_2tree built -- printed 14 times
[(7.5e-2,())]
*ProbM> sample_reify 17 20 (full_2tree_obs 5)
full_2tree built -- printed 24 times
[(5.0e-2,())]
The importance sampling with the look-ahead beam explores more leaves.
The result for `sample_reify 17 20 (full_2tree_obs 4)' again
demonstrates that some samples follow the already explored paths, and so
not recomputed.
-}
full_10tree 1 = uniform 10 >>= return . (:[])
full_10tree n = do
x <- uniform 10
xs <- full_10tree (pred n)
return (x:xs)
full_10tree_obs n = do
t <- full_10tree n
if trace "full_10tree built" t == replicate n 1
then return () else fail_
{-
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_10tree_obs 1)
[(0.15,())]
(0.01 secs, 524892 bytes)
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_10tree_obs 2)
[]
(0.01 secs, 524892 bytes)
*ProbM> rejection_sample_dist 17 20 (reify0 $ full_10tree_obs 7)
[]
(0.02 secs, 1454516 bytes)
and the same amount of memory for levels 7, 8 and 9.
This is consisntent with rejection sampling: the sample explores only
a small part of the tree. It traverses only 20 paths...
*ProbM> sample_reify 17 20 (full_10tree_obs 1)
[(0.10000001,())]
(0.01 secs, 521604 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 2)
[(5.0e-3,())]
*ProbM> sample_reify 17 20 (full_10tree_obs 3)
[]
(0.03 secs, 1360188 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 4)
[]
(0.05 secs, 2090732 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 5)
[]
(0.08 secs, 2805808 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 6)
[]
(0.10 secs, 4271072 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 7)
[]
(0.13 secs, 4972572 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 8)
[]
(0.15 secs, 5687908 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 9)
[]
(0.17 secs, 6393220 bytes)
*ProbM> sample_reify 17 20 (full_10tree_obs 10)
[]
(0.20 secs, 7838288 bytes)
As more nodes are explored, more is memoized.
sample_reify 17 20000 (full_10tree_obs 10)
Man, that was the bad idea. GHCi took 1.1GB, swapped most of my other processes
and almost froze my system. I had to kill GHCi.
So, memory leak due to memoization is real.
-}
ones 0 [] = True
ones n l = head l == 1 && ones (pred n) (tail l)
full_10tree' 1 = uniform 10 >>= return . (:[])
full_10tree' n = do
trace ("full_10tree' level " ++ show n) (return ())
x <- uniform 10
xs <- trace ("full_10tree' for " ++ show (n-1)) $
full_10tree' (pred n)
return (x:xs)
full_10tree_obs' n = do
t <- full_10tree' n
if ones n (trace "full_10tree built" t)
then return () else fail_
{-
The trace of the execution of
rejection_sample_dist 17 20 (reify0 $ full_10tree_obs' 4)
confirms that memoization is occurring (there are fewer
`full_10tree' level 3' messages than `full_10tree' level 1' message -- meaning
nodes up in the path are memozied) and that we have to build the
whole list before we can check against the evidence (the evidence
being the list of all ones). Indeed, we always see this pattern of trace
messages
full_10tree' for 1
full_10tree built
and never see
full_10tree' level 3
full_10tree built
the latter would have corresponded to the evidence check at near the root
(root being level 4), without the need to construct the leaves (level 1).
No wonder for tree of depth 3 even the look-ahead sample of 20 fails,
and there is no hope for trees of depth 8 or higher. Note this example
corresponds quite closely to Avi's Music example.
-}