{-# OPTIONS -fglasgow-exts -fth #-} -- Reify the (compiled) code to its typed TH representation -- (or, the dictionary _view_, to be precise) and reflect/compile that code. -- We must spread the code through several modules, due to the -- particular requirement of the Template Haskell. -- See DiffTest.hs for reflection of the differentiated TH code back -- into (machine) code. module Diff where import TypedCode -- Lift Nums, Fractionals, and Floating to code expressions instance Num a => Num (Code a) where x + y = op'add `appC` x `appC` y x - y = op'sub `appC` x `appC` y x * y = op'mul `appC` x `appC` y negate x = op'negate `appC` x fromInteger = integerC instance Fractional a => Fractional (Code a) where x / y = op'div `appC` x `appC` y recip x = op'recip `appC` x fromRational = rationalC instance Floating a => Floating (Code a) where pi = op'pi sin x = op'sin `appC` x cos x = op'cos `appC` x testf1 :: Num a => a testf1 = 1 + 2 testf1' = return (testf1 :: Code Int) testf1'' = showQC testf1' -- (GHC.Num.+) 1 2 -- We can define a function test1f x = let y = x * x in y + 1 test1 = test1f (2.0::Float) -- we can even compile it. At any point, we can reify it, into -- a `dictionary view' -- The result is the TH code, which we can print, and compile back -- to the code. We can also differentiate the TH code, simplify it, -- partially evaluate it, etc. test1c = new'diffVar >>= \ (v::Var Float) -> return \$ (test1f (var'exp v),v) test1r = test1c >>= \ (c,v) -> reflectDF v c test1cp = showQC test1r -- and reflect it back, see DiffTest.hs {- We must stress that there is no `reify' function. One may say it is built into Haskell already. *Diff> test1 5.0 *DiffTest> test1' 5.0 *Diff> test1cp \dx_0 -> GHC.Num.+ (GHC.Num.* dx_0 dx_0) 1 -} -- Symbolic Differentiation of the reified, typed TH code expressions -- The derivative over the code is a type preserving operation diffC :: (Floating a, Floating b) => Var b -> Code a -> Code a diffC v c | Just _ <- on'litC c = 0 diffC v c | Just ev <- on'varC v c = either (const 1) (const 0) ev diffC v c | Just (x,y) <- on'2opC op'add c = (diffC v x) + (diffC v y) diffC v c | Just (x,y) <- on'2opC op'sub c = (diffC v x) - (diffC v y) diffC v c | Just (x,y) <- on'2opC op'mul c = ((diffC v x) * y) + (x * (diffC v y)) diffC v c | Just (x,y) <- on'2opC op'div c = ((diffC v x) * y - x * (diffC v y)) / (y*y) diffC v c | Just x <- on'1opC op'negate c = negate (diffC v x) diffC v c | Just x <- on'1opC op'recip c = negate (diffC v x) / (x*x) diffC v c | Just x <- on'1opC op'sin c = (diffC v x) * cos x diffC v c | Just x <- on'1opC op'cos c = negate ((diffC v x) * sin x) diffC v c = error \$ "Cannot handle code: " ++ show c test1d = test1c >>= \ (c,v) -> reflectDF v \$ diffC v c test1dp = showQC test1d {- *Diff> test1dp \dx_0 -> (GHC.Num.+) ((GHC.Num.+) ((GHC.Num.*) 1 dx_0) ((GHC.Num.*) dx_0 1)) 0 -} -- Simplification rules -- simplification is type-preserving -- obviously, simplification is an `open-ended' problem: -- we could even recognize common sub-expressions and simplify them -- by introducing let binding. -- In the following however, we do trivial simplification only. -- One can always add more simplification rules later. simpleC :: Floating a => Var b -> Code a -> Code a -- repeat until no simplifications are made simpleC v c | Just c' <- simpleCL v c = simpleC v c' simpleC v c = c simpleCL :: Floating a => Var b -> Code a -> Maybe (Code a) simpleCL v c | Just _ <- on'litC c = Nothing simpleCL v c | Just _ <- on'varC v c = Nothing simpleCL v c | Just (x,y) <- on'2opC op'add c = simple'recur op'add sadd v x y where sadd x y | Just 0 <- on'litRationalC x = Just y sadd x y | Just 0 <- on'litRationalC y = Just x -- constant folding sadd x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y) = Just (fromRational \$ x + y) sadd x y = Nothing simpleCL v c | Just (x,y) <- on'2opC op'sub c = simple'recur op'sub ssub v x y where ssub x y | Just 0 <- on'litRationalC y = Just x -- constant folding ssub x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y) = Just (fromRational \$ x - y) ssub x y = Nothing simpleCL v c | Just (x,y) <- on'2opC op'mul c = simple'recur op'mul smul v x y where smul x y | Just 0 <- on'litRationalC x = Just (fromRational 0) smul x y | Just 0 <- on'litRationalC y = Just (fromRational 0) smul x y | Just 1 <- on'litRationalC x = Just y smul x y | Just 1 <- on'litRationalC y = Just x smul x y | (Just x, Just y) <- (on'litRationalC x, on'litRationalC y) = Just (fromRational \$ x * y) smul x y = Nothing -- error \$ unwords ["here",show x,show y] -- Nothing simpleCL v c | Just (x,y) <- on'2opC op'div c = simple'recur op'div sdiv v x y where sdiv x y | Just 0 <- on'litRationalC x = Just (fromRational 0) sdiv x y = Nothing -- error \$ unwords ["here",show x,show y] -- Nothing simpleCL v c | Just x <- on'1opC op'negate c = simple'recur1 op'negate sneg v x where sneg x | Just 0 <- on'litRationalC x = Just (fromRational 0) sneg x = Nothing simpleCL v c = Nothing simple'recur op fn v x y = case (simpleCL v x, simpleCL v y) of (Nothing,Nothing) -> fn x y (Just x,Nothing) -> Just (op `appC` x `appC` y) (Nothing,Just y) -> Just (op `appC` x `appC` y) (Just x,Just y) -> Just (op `appC` x `appC` y) simple'recur1 op fn v x = case simpleCL v x of Nothing -> fn x Just x -> Just (op `appC` x) test1ds = test1c >>= \ (c,v) -> reflectDF v \$ simpleC v \$ diffC v c test1dsp = showQC test1ds {- *Diff> test1dsp \dx_0 -> GHC.Num.+ dx_0 dx_0 -} -- And that's about it. Putting it all together gives us: diff_fn :: Floating b => (forall a. Floating a => a -> a) -> QCode (b -> b) diff_fn f = do v <- new'diffVar let body = f (var'exp v) -- reified body of the function reflectDF v . simpleC v . diffC v \$ body -- differentiate and simplify -- This is a useful helper to show us the code of the function in question show_fn :: (forall a. Floating a => a -> a) -> IO () show_fn f = showQC ( do v <- new'diffVar reflectDF v (f (var'exp v))) -- We can either print the result of diff_fn, or compile it -- (that is, splice it: see DiffTest.hs) -- More examples test2f x = foldl (\z c -> x*z + c) 0 [1,2,3] test2n = test2f (4::Float) -- 27.0 test2s = show_fn test2f {- *Diff> test2s \dx_0 -> GHC.Num.+ (GHC.Num.* dx_0 (GHC.Num.+ (GHC.Num.* dx_0 (GHC.Num.+ (GHC.Num.* dx_0 0) 1)) 2)) 3 -} test2ds = showQC (diff_fn test2f) {- Not too bad: 2*x + 2 *Diff> test2ds \dx_0 -> GHC.Num.+ (GHC.Num.+ dx_0 2) dx_0 -} {- The differentiated code can be `compiled back', see DiffTest.hs test2dn = \$(reflectQC (diff_fn test2f)) (4::Float) -- 10.0 -} -- Check the constant folding test11f x = 2*x + 3*x test11ds = showQC (diff_fn test11f) -- \dx_0 -> 5%1 -- Here's a slightly more complex example: test5f x = sin (5*x + pi/2) + cos(1 / x) test5n = test5f (pi::Float) -- cos(1/pi)-1 == -5.023426e-2 test5ds = showQC (diff_fn test5f) {- which isn't too bad: quite optimal, actually *Diff> test5ds \dx_0 -> GHC.Num.+ (GHC.Num.* 5 (GHC.Float.cos (GHC.Num.+ (GHC.Num.* 5 dx_0) (GHC.Real./ GHC.Float.pi 2)))) (GHC.Num.negate (GHC.Num.* (GHC.Real./ ((-1)%1) (GHC.Num.* dx_0 dx_0)) (GHC.Float.sin (GHC.Real./ 1 dx_0)))) -} -- One may evaluate the function test5f numerically, differentiate it -- symbolically, check the result of differentiation -- and evaluate it -- numerically right away. See test5dn in DiffTest.hs for the latter. -- We can even do partial derivatives: test3f x y = (x*y + (5*x*x)) / y test4x y = diff_fn (\x -> test3f x (fromIntegral y)) test4y x = diff_fn (test3f (fromInteger x)) test4xds = showQC (test4x 1) -- 1 + 10*x test4yds = showQC (test4y 5) {- *DiffTest> test4yds \dx_0 -> GHC.Real./ (GHC.Num.- (GHC.Num.* 5 dx_0) (GHC.Num.+ (GHC.Num.* 5 dx_0) (125%1))) (GHC.Num.* dx_0 dx_0) -} {- In DiffTest.hs -- partial derivative with respect to x test4xdn = \$(reflectQC (test4x 1)) (2::Float) -- 21.0 -- partial derivative with respect to y test4ydn = \$(reflectQC (test4y 5)) (5::Float) -- -5.0 -}