{-# LANGUAGE GADTs, DataKinds, KindSignatures #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
-- User-friendlier syntax for writing the model
-- (constructing the node trace)
module Syntax (
Distribution,
bern, categorical, uniformly, uniform, normal, gamma, beta,
Model, MCMCM, SExp, pair, if_,
dist, diracN, condition,
mcmC
) where
import Control.Monad.State
import Control.Applicative
import qualified Data.IntSet as Set
import qualified Memory as Mem
import Metropolis
import Distribution (logDensity)
-- reexporting and renaming
type Distribution a = DistK KResampleable a
bern = bernS
categorical :: Eq a => [(a,Double)] -> Distribution a
categorical = categoricalS
uniformly :: Eq a => [a] -> Distribution a
uniformly = uniformlyS
uniform = uniformS
normal = normalS
gamma = gammaS
beta = betaS
{-
Main design criteria for the user-visible Syntax:
-- we want to statically prevent conditioning where the observed values
are the ones computed during the course of the program.
The observed value must be the one known before the program starts.
-- proper conditioning is done only on sampleable nodes, on random
variables -- not on arbitrary boolean formulas
So, we have to statically distinguish nodes (or, random variables),
constructed from distributions, and expressions involving nodes.
To avoid the Borel paradox.
General conditioning could be programmed, e.g., with simulated
annealing.
-- The whole program should be a node (but we can convert an expression
-- to a node if needed to be)
One may wish to write niceex1 as follows, in a nice functional syntax
niceex1 =
let c = uniformN (Val 0) (Val 1)
d = bernN (Val 0.5)
e = bernN c
in diracN (liftN2 (&&) e d)
but this will not work in the slightly more complex example
let c = uniformN (Val 0) (Val 1)
d = bernN (Val 0.5)
e = betaN c c
...
which uses 'c' twice. Since c is a computation to build the node, the
naive implementation will duplicate 'c'. We need to detect
sharing. This is the same main problem that vexes us in representing
circuits in Haskell. So, the model construction has to be in a
monad -- at least so that we can generate unique names and detect
sharing.
-}
-- A model produces a stochastic expression, which is converted
-- to a node if needed.
-- That is, in the end, the whole program is represented by
-- the final node, which depends on the other nodes.
type Model a = MCMCM (SExp a)
-- Stochastic expression
data SExp a where
Val :: a -> SExp a -- immediate (a Haskell) value, e.g., literal
NRef :: NodeRef a -> SExp a -- reference to a constructed node
-- General computation over possibly several nodes
NT :: NTree a -> (a -> b) -> SExp b
data NTree a where
NTLeaf :: NodeRef a -> NTree a
NTPair :: NTree a -> NTree b -> NTree (a,b)
-- Perform a computation on a sigle SExp, doing a bit of
-- partial evaluation
-- That is why we represent the node computations as (NT t g),
-- with the separate tuple of nodes and the combining function.
instance Functor SExp where
fmap f (Val x) = Val (f x)
fmap f (NRef x) = NT (NTLeaf x) f
fmap f (NT t g) = NT t (f . g)
instance Applicative SExp where
pure = Val
Val f <*> Val y = Val (f y)
-- NRef f <*> Val y = Cannot happen: there are no distributions over
-- functions
Val f <*> NRef y = NT (NTLeaf y) f
-- (NRef x) (NRef y) = cannot happen
NT t g <*> Val y = NT t (\ab -> g ab y)
Val f <*> NT t g = NT t (\ab -> f (g ab))
NT t g <*> NRef y = NT (NTPair t (NTLeaf y)) (uncurry g)
-- NRef x <*> NT t g = cannot happen
NT t1 g1 <*> NT t2 g2 = NT (NTPair t1 t2) (\ (ab,cd) -> g1 ab (g2 cd))
-- But SExp is NOT a Monad. Can you guess why?
instance Num a => Num (SExp a) where
fromInteger = pure . fromInteger
x + y = liftA2 (+) x y
x - y = liftA2 (-) x y
x * y = liftA2 (*) x y
negate = fmap negate
abs = fmap abs
signum = fmap signum
instance Fractional a => Fractional (SExp a) where
(/) = liftA2 (/)
recip = fmap recip
fromRational = pure . fromRational
pair :: SExp a -> SExp b -> SExp (a,b)
pair = liftA2 (,)
type family Args a :: *
type instance Args (DistK k a) = Model a
type instance Args (b -> a) = SExp b -> Args a
class StochasticPrim a where
dist :: a -> Args a
instance Show a => StochasticPrim (DistK k a) where
dist d = fmap NRef $ with_node_none d
instance Show b => StochasticPrim (a -> DistK k b) where
dist d x = fmap NRef $ with_node_ctor x d
instance Show c => StochasticPrim (a -> b -> DistK k c) where
dist d x y = fmap NRef $ with_node_ctor (pair x y) (uncurry d)
{-
-- Node constructors for the end user. These ones are exported.
uniformN :: NodeCTor k1 Double -> NodeCTor k2 Double -> Model Double
uniformN x y = fmap NRef $
with_node_ctor (pair x y) (uncurry uniform)
-}
-- Since Dirac is special and frequently used, we define the synonym for
-- dist dirac x
-- It turns a general node computation or a literal into a single node.
-- NB: if x is just a NoderRef, (diracN x) returns this reference (shares)
-- rather than creating a copy of it!
diracN :: Show a => SExp a -> Model a
diracN (Val x) = fmap NRef $ create_node Nothing (diracS x)
diracN x@NRef{} = return x
diracN x = fmap NRef $ with_node_ctor x diracS
-- It is now statically ensured that we can condition only on
-- external data. We cannot condition on the results produced
-- during the computation.
-- condition is the transformation on DistK, which is lifted
-- functorially
class Condition a where
type CondR a :: *
type CondV a :: *
condition :: CondV a -> a -> CondR a
instance Condition (DistK KResampleable a) where
type CondR (DistK KResampleable a) = DistK KObserved a
type CondV (DistK KResampleable a) = a
condition = conditioned
instance Condition b => Condition (a -> b) where
type CondR (a -> b) = a -> CondR b
type CondV (a -> b) = CondV b
condition v f = \x -> condition v (f x)
mcmC :: Show a => Integer -> Model a -> [a]
mcmC limit m = mcmc limit (m >>= to_node)
-- These are internal functions and should not be exported and used
-- by the end user
eval_nctor :: SExp a -> a
eval_nctor (Val x) = x
eval_nctor (NRef x) = nref_val x
eval_nctor (NT t g) = g $ eval_ntree t
eval_ntree :: NTree a -> a
eval_ntree (NTLeaf x) = nref_val x
eval_ntree (NTPair t1 t2) = (eval_ntree t1, eval_ntree t2)
-- If it is not a node, make a Dirac node
to_node :: Show a => SExp a -> MCMCM (NodeRef a)
to_node (Val x) = create_node Nothing (diracS x)
to_node (NRef x) = return x
to_node x = with_node_ctor x diracS
-- Generalization of with_node for the arbitrary tuple of dependent
-- nodes
with_node_ctor :: Show b => SExp a -> (a -> DistK k b) -> TraceM b
with_node_ctor (Val x) df = create_node Nothing (df x)
with_node_ctor (NRef x) df = with_node x df
with_node_ctor (NT (NTLeaf x) g) df = with_node x (df . g)
with_node_ctor (NT (NTPair (NTLeaf x) (NTLeaf y)) g) df =
with_node2 (x,y) (df . g)
with_node_ctor ndr@(NT t g) df = do
let dist = df (eval_nctor ndr)
nnew <- create_node (Just $ doupdate t (df . g)) dist
register_dependencies t (Mem.addr nnew)
return nnew
where
-- The update never resamples. The val may change because it depends
-- on the node value, and LL may also change even if val stays the same.
-- The self NodeRef is always the last argument.
doupdate :: NTree b -> (b -> DistK k a) -> NodeRef a ->
MCMCM (NodeRef a, NodeSet)
doupdate ndrold df nsold = do
ns <- refresh_nref nsold -- node being updated
(tv,ts) <- reeval_tree ndrold -- node on which we depend
let self = Mem.dref ns
tnow <- fmap now get
-- the parent node must have been updated
when (ts <= tstamp self) $ fail (unwords ["dependency violation!",
"node", show (Mem.addr ns), "depends on tree",
"whose tstamp", show ts])
let self' = case df tv of
-- Dirac is treated specially, see the note at the end of the file.
Dirac v -> self{val = v,tstamp=tnow}
-- don't update the tstamp if only LL changes
-- !!! Don't forget to update the distribution
Resampleable d ->
self{ll = logDensity d (val self), dst = d}
Observed d _ ->
self{ll = logDensity d (val self), dst = d}
return (Mem.modify_ref self' ns,Set.empty)
-- The node na depends on all nodes in NTree. Register these
-- dependencies
register_dependencies :: NTree a -> NodeAddress -> MCMCM ()
register_dependencies (NTLeaf ndr) na = new_dependency ndr na
register_dependencies (NTPair t1 t2) na = do
register_dependencies t1 na
register_dependencies t2 na
-- re-evaluate the tree and return the value and the max timestamp
reeval_tree :: NTree a -> MCMCM (a,TStamp)
reeval_tree (NTLeaf ndrold) = do
ndr <- refresh_nref ndrold
let nd = Mem.dref ndr
ts = tstamp nd
return $ (val $ nd, ts)
reeval_tree (NTPair t1 t2) = do
(t1v,ts1) <- reeval_tree t1
(t2v,ts2) <- reeval_tree t2
return ((t1v,t2v),max ts1 ts2)
if_ :: Show a => SExp Bool -> Model a -> Model a -> Model a
if_ test th el = do
-- Always make a new node, for the sake of ifnode
entry <- with_node_ctor test diracS
fmap NRef $ ifnode entry (th >>= to_node) (el >>= to_node)
-- Examples in the user-friendly syntax
niceex1 = do
c <- dist uniform 0 1
d <- dist bern 0.5
e <- dist bern c
diracN ((&&) <$> e <*> d)
niceex1r = mcmC 7 niceex1
-- All True
nice_prog_mult_conditions c1 c2 = do
b <- dist beta 1 1
dist (c1 `condition` bern) b
dist (c2 `condition` bern) b
diracN b
niceexmcr11 = sum $ mcmC 1000 (nice_prog_mult_conditions True False)
-- 499.799988743311
niceexmcr12 = sum $ mcmC 1000 (nice_prog_mult_conditions False False)
-- 257.2928205920028
prog_mixture1 cv = do
c <- dist bern 0.5
if_ c (dist (cv `condition` normal) 1 1)
(dist (cv `condition` uniform) 0 3)
return c
prog_mixture1_run = mcmC 20 (prog_mixture1 (-2))
-- all true
exif1 = do
c <- dist bern 0.5
d <- if_ c (dist normal ((fromIntegral . fromEnum) <$> c) 1)
(diracN 5)
e <- if_ (not <$> c) (dist normal 20 1) (diracN d)
diracN (d + e)
exif1r = mcmC 7 exif1
exif2 = do
c <- dist bern 0.5
if_ c (cl1 10) (cl1 20)
where
cl1 x = do
d <- dist bern 0.5
if_ d cl3 (diracN x)
cl3 = do
a <- dist uniform 0 1
b <- dist uniform 0 1
diracN (a + b)
exif2r = mcmC 17 exif2
-- Programs from my messages about the problem with conditionals
tac1 = do
c <- dist bern 0.5
if_ c
(dist normal 0 1)
(dist uniform 10 20)
return c
tac1_run = length . filter id $ mcmC 1000 tac1
-- 620 True
tac2 = do
c <- dist bern 0.5
if_ c
(diracN (Val True) >> dist normal 0 1)
(dist uniform 10 20)
return c
tac2_run = length . filter id $ mcmC 1000 tac2
-- 523 True
tac3 = do
c <- dist bern 0.5
if_ c
(diracN (Val True) >> dist normal 0 1)
(dist uniform 10 20 >>= diracN)
return c
tac3_run = length . filter id $ mcmC 1000 tac3
-- 532 True
-- tac2 and tac3 produce identical results, as expected:
-- since diracN is unit, e >>= diracN is indeed the same as diracN
tac10 = do
c <- dist bern 0.5
if_ c
-- then
(dist (0 `condition` normal) 0 1)
-- else
(dist (0 `condition` uniform) 10 20)
return c
tac10_run = length . filter id $ mcmC 100 tac10
-- 100 -- all True
tac11 = do
c <- dist bern 0.5
if_ c
-- then
(dist (0 `condition` normal) 0 1)
-- else
(do
d <- diracN 10
dist (0 `condition` uniform) d 20)
return c
tac11_run = length . filter id $ mcmC 100 tac11
-- 100 -- all True