Monadic Shortcut Fusion

Runtime comparison

OSR/WOPT OSR/ROFF
f1 93.42% 93.51%
f2 94.08% 95.84%
f3 63.27% 65%
f4 69.66% 69.14%
f5 97% 97.1%
f6 8.66% 8.69%
f7 13.71% 13.72%

OSR = Optimizations with shortcut rules: ghc -fglasgow-exts -O fi.hs

WOPT = Without optimizations: ghc --make -fglasgow-exts fi.hs

ROFF = Without optimizations, all rules off: ghc -fglasgow-exts -frules-off fi.hs

Testsuite

The testsuite is formed by a module MonShortcut.hs, which contains definitions of fold and monadic build for the datatypes involved in the testsuite, and the individual test programs (f1.hs to f7.hs).

MonShortcut.hs

module MonShortcut where

mmap :: Monad m => ( a -> b) -> (m a -> m b)
mmap f m = do {a <- m; return (f a)}
{-# NOINLINE mmap #-}

{-# RULES "bind->mmap" 
      forall m f. m >>= (\a -> return (f a)) = mmap f m 
 #-}

------- mbuild for lists

mbuild :: Monad m => (forall b. (a -> b -> b) -> b ->  m b) -> m [a]
mbuild g = g (:) []
{-# INLINE [1] mbuild #-}

{-# RULES "foldr/mbuild" 
      forall c n (g :: forall b. (a -> b -> b) -> b -> m b). 
      mmap (foldr c n)(mbuild g) = g c n
 #-}

------- mbuild for arithmetic expresions

data Exp = Num Int 
         | Add Exp Exp 
         deriving Show

foldE :: (Int -> a) -> (a -> a -> a) -> Exp -> a
foldE num add = fE
  where
    fE (Num n)    = num n
    fE (Add e e') = add (fE e) (fE e')

mbuildE   :: Monad m => (forall a. (Int -> a) -> (a -> a -> a) -> m a) -> m Exp
mbuildE g  = g Num Add 
{-# INLINE [1]  mbuildE #-}

{-# RULES "foldE/mbuildE"
      forall num add (g :: forall a. (Int -> a) -> (a -> a -> a) -> m a). 
      mmap (foldE num add) (mbuildE g) = g num add 
 #-}

f1.hs

We avoid the use of GHC's length function because it has an accumulative definition and therefore requires to be defined as a higher-order fold.

module Main where

import IO
import MonShortcut
import Control.Monad

------- zipWithM in terms of mbuild

{-# RULES "zipWithM->mbuild"
      forall f xs ys. zipWithM f xs ys = mbuild (gzipWM f xs ys)
 #-}

gzipWM f (x:xs) (y:ys) c n = do z  <- f x y  
                                zs <- gzipWM f xs ys c n     
			        return (c z zs)	
gzipWM f _      _      c n = return n


------- length' in terms of foldr

{-# RULES "length'->foldr" 
    length' = foldr (\x n -> n+1) 0
 #-}

length' []     = 0
length' (x:xs) = 1 + length' xs 

------- Auxiliary definitions

xs = enumFromTo 1 1000000

put x = do {putStr (show x); return x}

------- Test

f :: IO Int 
f  = do as <- zipWithM (\x y -> put (x+y)) xs xs
        return (length' as)

main = f >> return ()

f2.hs

We avoid the use of GHC's sum function because it has an accumulative definition and therefore requires to be defined as a higher-order fold.

module Main where

import IO
import MonShortcut
import Control.Monad

------- mapM in terms of mbuild

{-# RULES "mapM->mbuild"
    forall f xs. mapM f xs = mbuild (gmapM f xs)
 #-}

gmapM f []     c n = return n
gmapM f (x:xs) c n = do y  <- f x 
                        ys <- gmapM f xs c n 
		        return (c y ys) 

------- sum' in terms of foldr

{-# RULES "sum'->foldr"
    sum' = foldr (+) 0
 #-}

sum' []     = 0
sum' (x:xs) = x + sum' xs

------- Auxiliary definitions

xs = enumFromTo 1 1000000

put x = do {putStr (show x); return x} 

------- Examples

f :: IO Int 
f  = do as <- mapM (\x -> put (x*x)) xs
        return (sum' as)

main = f >> return ()

f3.hs

module Main where

import IO
import MonShortcut
import Control.Monad

------- hGetContents in terms of mbuild

{-# RULES "hGetContents->mbuild" 
      forall h. hGetContents h = mbuild (hGetC h)
 #-}

hGetC h c n = do eof <- hIsEOF h
	         if eof then do hClose h
                                return n
	                else do x  <- hGetChar h
			        xs <- hGetC h c n 
			        return (c x xs)

------- filter in terms of foldr

{-# RULES "filter->foldr" 
      forall p. filter p = foldr (\x xs-> if p x then (x:xs) else xs) []
 #-}

------- Test

f = \h -> do cs <- hGetContents h
             return (filter (/='\n') cs)
{-# NOINLINE f #-}

main = do h <- openFile "sblp08.tex" ReadMode 
          zs <- f h
          putStr zs     

f4.hs

module Main where

import IO
import MonShortcut
import Control.Monad

------- hGetContents in terms of mbuild

{-# RULES "hGetContents->mbuild" 
      forall h. hGetContents h = mbuild (hGetC h)
 #-}

hGetC h c n = do eof <- hIsEOF h
	         if eof then do hClose h
                                return n
	                else do x  <- hGetChar h
			        xs <- hGetC h c n 
			        return (c x xs)

------- length' in terms of foldr

{-# RULES "length'->foldr" 
      length' = foldr (\x n -> n+1) 0
 #-}

length' []     = 0
length' (x:xs) = 1 + length' xs 

------- Test

f = \h -> do cs <- hGetContents h
             return (length' cs)
{-# NOINLINE f #-}

main = do h  <- openFile "sblp08.tex" ReadMode 
          f h

f5.hs

module Main where

import IO
import MonShortcut
import Control.Monad

------- mapM in terms of mbuild

{-# RULES "mapM->mbuild"
      forall f xs. mapM f xs = mbuild (gmapM f xs)
 #-}

gmapM f []     c n = return n
gmapM f (x:xs) c n = do y  <- f x 
                        ys <- gmapM f xs c n 
		        return (c y ys) 

------- filter in terms of foldr

{-# RULES "filter->foldr" 
      forall p. filter p = foldr (\x xs-> if p x then (x:xs) else xs) []
 #-}

------- Auxiliary definitions

xs = enumFromTo 1 1000000

put x = do {putStr (show x); return x} 

------- Test

f :: IO [Int] 
f = do ns <- mapM (\x -> put (x*x)) xs
       return (filter even ns)

main = f >> return ()

f6.hs

module Main where

import IO
import MonShortcut
import Control.Monad
import GHC.Base

------- Parser Monad

newtype Parser a = P (String -> [(a,String)])

instance Monad Parser where
  return a  = P (\cs -> [(a,cs)])
  p >>= f   = P (\cs -> concat [parse (f a) cs' | (a,cs') <- parse p cs])

parse       :: Parser a -> String -> [(a,String)]
parse (P p)  = p

pzero :: Parser a
pzero  = P (\cs -> [])

(<|>)           :: Parser a -> Parser a -> Parser a
(P p) <|> (P q)  = P (\cs ->  case p cs ++ q cs of
                                []      -> []
                                (x:xs)  -> [x])

item :: Parser Char
item  = P (\cs ->  case cs of
                     ""      -> []
                     (c:cs)  -> [(c,cs)])

------- parsing digits

digit :: Parser Int
digit  = do c <- item
            if isDigit c 
               then return (ord c - ord '0') 
               else pzero

isDigit c = (c >= '0') && (c <= '9')

digits :: Parser [Int]
digits  = do {d <- digit; ds <- digits; return (d:ds)}
          <|> 
          return []

------- parser for arithmetic expressions

{- parsing -}

expression :: Parser Exp
{- INLINE [0] expression -}
expression  = do n <- number
                 plusop
                 e <- expression
                 return (Add (Num n) e)
              <|>
              do n <- number
                 return (Num n)

number :: Parser Int
number  = do (n,p) <- numpow10
             return n

numpow10 = do {d <- digit; (n,p) <- numpow10; return (d*p+n,10*p)} 
           <|> 
           return (0,1)

plusop = do {c <- item; if c == '+' then return () else pzero}

{- evaluation -}

eval            :: Exp -> Int
eval (Num n)     = n
eval (Add e e')  = eval e + eval e'
{-# INLINE [1] eval #-}

{- parsing & evalauation -}

--evalexp :: Parser Int
evalexp  = do e <- expression
              return (eval e)
{-# NOINLINE evalexp #-}

{-# RULES 
"eval->foldE"          eval = foldE id (+)
"expression->mbuildE"  expression = mbuildE gexp
 #-}
     
gexp num add = do n <- number
                  plusop 
                  e <- gexp num add 
                  return (add (num n) e)
               <|> 
               do n <- number
                  return (num n)

------- Auxiliary definitions

s23 = "+2+3" ++ s23
expr = '1' : take 1000000 s23 

put x = do {putStr (show x); return x} 

------- Test

main = print $ fst $ head $ parse evalexp expr

f7.hs

module Main where

import IO
import MonShortcut
import Control.Monad
import GHC.Base

------- Parser Monad

newtype Parser a = P (String -> [(a,String)])

instance Monad Parser where
  return a  = P (\cs -> [(a,cs)])
  p >>= f   = P (\cs -> concat [parse (f a) cs' | (a,cs') <- parse p cs])

parse       :: Parser a -> String -> [(a,String)]
parse (P p)  = p

pzero :: Parser a
pzero  = P (\cs -> [])

(<|>)           :: Parser a -> Parser a -> Parser a
(P p) <|> (P q)  = P (\cs ->  case p cs ++ q cs of
                                []      -> []
                                (x:xs)  -> [x])

item :: Parser Char
item  = P (\cs ->  case cs of
                     ""      -> []
                     (c:cs)  -> [(c,cs)])

------- parsing digits

digit :: Parser Int
digit  = do c <- item
            if isDigit c 
               then return (ord c - ord '0') 
               else pzero

isDigit c = (c >= '0') && (c <= '9')

digits :: Parser [Int]
{- INLINE [0] digits -}
digits  = do {d <- digit; ds <- digits; return (d:ds)}
          <|> 
          return []

sum' []     = 0
sum' (x:xs) = x + sum' xs

{-# RULES 
"sum'->foldr"      sum'   = foldr (+) 0
"digits->mbuild"   digits = mbuild gdig
 #-}

gdig c n = do {d <- digit; ds <- gdig c n; return (c d ds)}
           <|> 
           return n

------- divisible by 3

--sumDigits :: Parser Int
sumDigits = do {ds <- digits; return (sum' ds)}
{-# NOINLINE sumDigits #-}

divby3 :: Parser Bool
divby3 = do {n <- sumDigits; return (n `mod` 3 == 0)}

------- Auxiliary definitions

s123 = "123" ++ s123
number = take 500000 s123

put x = do {putStr (show x); return x} 

------- Examples

main = print $ fst $ head $ parse divby3 number