summaryrefslogtreecommitdiff
path: root/src/Transfer/SyntaxToCore.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Transfer/SyntaxToCore.hs')
-rw-r--r--src/Transfer/SyntaxToCore.hs173
1 files changed, 121 insertions, 52 deletions
diff --git a/src/Transfer/SyntaxToCore.hs b/src/Transfer/SyntaxToCore.hs
index b13579293..0d5907890 100644
--- a/src/Transfer/SyntaxToCore.hs
+++ b/src/Transfer/SyntaxToCore.hs
@@ -28,11 +28,11 @@ declsToCore :: [Decl] -> [Decl]
declsToCore m = evalState (declsToCore_ m) newState
declsToCore_ :: [Decl] -> C [Decl]
-declsToCore_ = desugar
+declsToCore_ = deriveDecls
+ >>> desugar
+ >>> compilePattDecls
>>> numberMetas
- >>> deriveDecls
>>> replaceCons
- >>> compilePattDecls
>>> expandOrPatts
>>> optimize
@@ -61,13 +61,14 @@ numberMetas = mapM f
return $ EVar $ Ident $ "?" ++ show (nextMeta st) -- FIXME: hack
_ -> composOpM f t
+
--
-- * Pattern equations
--
compilePattDecls :: [Decl] -> C [Decl]
compilePattDecls [] = return []
-compilePattDecls (d@(ValueDecl x _ _):ds) =
+compilePattDecls (d@(ValueDecl x _ _ _):ds) =
do
let (xs,rest) = span (isValueDecl x) ds
d <- mergeDecls (d:xs)
@@ -75,20 +76,26 @@ compilePattDecls (d@(ValueDecl x _ _):ds) =
return (d:rs)
compilePattDecls (d:ds) = liftM (d:) (compilePattDecls ds)
--- | Take a non-empty list of pattern equations for the same
--- function, and produce a single declaration.
+-- | Checks if a declaration is a value declaration
+-- of the given identifier.
+isValueDecl :: Ident -> Decl -> Bool
+isValueDecl x (ValueDecl y _ _ _) = x == y
+isValueDecl _ _ = False
+
+-- | Take a non-empty list of pattern equations with guards
+-- for the same function, and produce a single declaration.
mergeDecls :: [Decl] -> C Decl
-mergeDecls ds@(ValueDecl x p _:_)
- = do let cs = [ (ps,rhs) | ValueDecl _ ps rhs <- ds ]
- (pss,rhss) = unzip cs
+mergeDecls ds@(ValueDecl x p _ _:_)
+ = do let cs = [ (ps,g,rhs) | ValueDecl _ ps g rhs <- ds ]
+ (pss,_,_) = unzip3 cs
n = length p
when (not (all ((== n) . length) pss))
$ fail $ "Pattern count mismatch for " ++ printTree x
vs <- freshIdents n
- let cases = map (\ (ps,rhs) -> Case (mkPRec ps) rhs) cs
+ let cases = map (\ (ps,g,rhs) -> Case (mkPRec ps) g rhs) cs
c = ECase (mkERec (map EVar vs)) cases
f = foldr (EAbs . VVar) c vs
- return $ ValueDecl x [] f
+ return $ ValueDecl x [] GuardNo f
where mkRec r f = r . zipWith (\i e -> f (Ident ("p"++show i)) e) [0..]
mkPRec = mkRec PRec FieldPattern
mkERec = mkRec ERec FieldValue
@@ -118,6 +125,10 @@ derivators = [
("Ord", deriveOrd)
]
+--
+-- * Deriving instances of Compos
+--
+
deriveCompos :: Derivator
deriveCompos t@(Ident ts) k cs =
do
@@ -128,7 +139,7 @@ deriveCompos t@(Ident ts) k cs =
dt = apply (EVar (Ident "Compos")) [c, EVar t]
r = ERec [FieldValue (Ident "composOp") co,
FieldValue (Ident "composFold") cf]
- return [TypeDecl d dt, ValueDecl d [] r]
+ return [TypeDecl d dt, ValueDecl d [] GuardNo r]
deriveComposOp :: Ident -> Exp -> [(Ident,Exp)] -> C Exp
deriveComposOp t k cs =
@@ -149,9 +160,9 @@ deriveComposOp t k cs =
EApp (EVar t') c | t' == t -> apply (e f) [c, e v]
_ -> e v
calls = zipWith rec vars (argumentTypes ct)
- return $ Case (PCons ci (map PVar vars)) (apply (e ci) calls)
+ return $ Case (PCons ci (map PVar vars)) gtrue (apply (e ci) calls)
cases <- mapM (uncurry mkCase) cs
- let cases' = cases ++ [Case PWild (e x)]
+ let cases' = cases ++ [Case PWild gtrue (e x)]
fb <- abstract (arity k) $ const $ pv f \-> pv x \-> ECase (e x) cases'
return fb
@@ -180,17 +191,61 @@ deriveComposFold t k cs =
p = EProj (e r) (Ident "mplus")
joinCalls [] = z
joinCalls cs = foldr1 (\x y -> apply p [x,y]) cs
- return $ Case (PCons ci (map PVar vars)) (joinCalls calls)
+ return $ Case (PCons ci (map PVar vars)) gtrue (joinCalls calls)
cases <- mapM (uncurry mkCase) cs
- let cases' = cases ++ [Case PWild (e x)]
+ let cases' = cases ++ [Case PWild gtrue (e x)]
fb <- abstract (arity k) $ const $ pv f \-> pv x \-> ECase (e x) cases'
return $ VWild \-> pv r \-> fb
+--
+-- * Deriving instances of Show
+--
+
deriveShow :: Derivator
deriveShow t k cs = fail $ "derive Show not implemented"
+--
+-- * Deriving instances of Eq
+--
+
+-- FIXME: how do we require Eq instances for all
+-- constructor arguments?
+
deriveEq :: Derivator
-deriveEq t k cs = fail $ "derive Eq not implemented"
+deriveEq t@(Ident tn) k cs =
+ do
+ let ats = argumentTypes k
+ d = Ident ("eq_"++tn)
+ dt <- abstractType ats (EApp (EVar (Ident "Eq")) . apply (EVar t))
+ eq <- mkEq
+ r <- abstract (arity k) (\_ -> ERec [FieldValue (Ident "eq") eq])
+ return [TypeDecl d dt, ValueDecl d [] GuardNo r]
+ where
+ mkEq = do
+ x <- freshIdent
+ cases <- mapM (uncurry mkEqCase) cs
+ return $ EAbs (VVar x) (ECase (EVar x) cases)
+ mkEqCase c ct =
+ do
+ let n = arity ct
+ vs1 <- freshIdents n
+ vs2 <- freshIdents n
+ y <- freshIdent
+ let p1 = PCons c (map PVar vs1)
+ p2 = PCons c (map PVar vs2)
+ es1 = map EVar vs1
+ es2 = map EVar vs2
+ tc | n == 0 = true
+ -- FIXME: using EEq doesn't work right now
+ | otherwise = foldr1 EAnd (zipWith EEq es1 es2)
+ c1 = Case p2 gtrue tc
+ c2 = Case PWild gtrue false
+ return $ Case p1 gtrue (EAbs (VVar y) (ECase (EVar y) [c1,c2]))
+
+
+--
+-- * Deriving instances of Ord
+--
deriveOrd :: Derivator
deriveOrd t k cs = fail $ "derive Ord not implemented"
@@ -268,10 +323,10 @@ removeUselessMatch = return . map f
f x = case x of
EAbs (VVar x) b ->
case f b of
- -- replace \x -> case x of { y -> e } with \y -> e,
+ -- replace \x -> case x of { y | True -> e } with \y -> e,
-- if x is not free in e
- ECase (EVar x') [Case (PVar y) e]
- | x' == x && not (x `isFreeIn` e)
+ ECase (EVar x') [Case (PVar y) g e]
+ | x' == x && isTrueGuard g && not (x `isFreeIn` e)
-> f (EAbs (VVar y) e)
-- replace unused variable in lambda with wild card
e | not (x `isFreeIn` e) -> f (EAbs VWild e)
@@ -282,31 +337,33 @@ removeUselessMatch = return . map f
v = if not (x `isFreeIn` e') then VWild else VVar x
in EPi v (f t) e'
-- replace unused variables in case patterns with wild cards
- Case p e ->
- let e' = f e
- p' = f (removeUnusedVarPatts (freeVars e') p)
- in Case p' e'
+ Case p (GuardExp g) e ->
+ let g' = f g
+ e' = f e
+ used = freeVars g' `Set.union` freeVars e'
+ p' = f (removeUnusedVarPatts used p)
+ in Case p' (GuardExp g') e'
-- for value declarations without patterns, compilePattDecls
-- generates pattern matching on the empty record, remove these
- ECase (ERec []) [Case (PRec []) e] -> f e
+ ECase (ERec []) [Case (PRec []) g e] | isTrueGuard g -> f e
-- if the pattern matching is on a single field of a record expression
-- with only one field, there is no need to wrap it in a record
ECase (ERec [FieldValue x e]) cs | all (isSingleFieldPattern x) (casePatterns cs)
- -> f (ECase e [ Case p r | Case (PRec [FieldPattern _ p]) r <- cs ])
- -- for all fields in record matching where all patterns just
+ -> f (ECase e [ Case p g r | Case (PRec [FieldPattern _ p]) g r <- cs ])
+ -- for all fields in record matching where all patterns for the field just
-- bind variables, substitute in the field value (if it is a variable)
- -- in the right hand sides.
+ -- in the guards and right hand sides.
ECase (ERec fs) cs | all isPRec (casePatterns cs) ->
- let g (FieldValue f v@(EVar _):fs) xs
+ let h (FieldValue f v@(EVar _):fs) xs
| all (onlyBindsFieldToVariable f) (casePatterns xs)
- = g fs (map (inlineField f v) xs)
- g (f:fs) xs = let (fs',xs') = g fs xs in (f:fs',xs')
- g [] xs = ([],xs)
- inlineField f v (Case (PRec fps) e) =
+ = h fs (map (inlineField f v) xs)
+ h (f:fs) xs = let (fs',xs') = h fs xs in (f:fs',xs')
+ h [] xs = ([],xs)
+ inlineField f v (Case (PRec fps) (GuardExp g) e) =
let p' = PRec [fp | fp@(FieldPattern f' _) <- fps, f' /= f]
ss = zip (fieldPatternVars f fps) (repeat v)
- in Case p' (substs ss e)
- (fs',cs') = g fs cs
+ in Case p' (GuardExp (substs ss g)) (substs ss e)
+ (fs',cs') = h fs cs
x' = ECase (ERec fs') cs'
in if length fs' < length fs then f x' else composOp f x'
-- Remove wild card patterns in record patterns
@@ -314,6 +371,11 @@ removeUselessMatch = return . map f
where wildcards = [fp | fp@(FieldPattern _ PWild) <- fps]
_ -> composOp f x
+isTrueGuard :: Guard -> Bool
+isTrueGuard (GuardExp (EVar (Ident "True"))) = True
+isTrueGuard GuardNo = True
+isTrueGuard _ = False
+
removeUnusedVarPatts :: Set Ident -> Tree a -> Tree a
removeUnusedVarPatts keep x = case x of
PVar id | not (id `Set.member` keep) -> PWild
@@ -325,7 +387,7 @@ isSingleFieldPattern x p = case p of
_ -> False
casePatterns :: [Case] -> [Pattern]
-casePatterns cs = [p | Case p _ <- cs]
+casePatterns cs = [p | Case p _ _ <- cs]
isPRec :: Pattern -> Bool
isPRec (PRec _) = True
@@ -357,7 +419,7 @@ expandOrPatts = return . map f
_ -> composOp f x
expandCase :: Case -> [Case]
-expandCase (Case p e) = [ Case p' e | p' <- expandPatt p ]
+expandCase (Case p g e) = [ Case p' g e | p' <- expandPatt p ]
expandPatt :: Pattern -> [Pattern]
expandPatt p = case p of
@@ -383,14 +445,15 @@ desugar = return . map f
f x = case x of
PListCons p1 p2 -> pListCons <| p1 <| p2
PList xs -> pList (map f [p | PListElem p <- xs])
+ GuardNo -> gtrue
EIf exp0 exp1 exp2 -> ifBool <| exp0 <| exp1 <| exp2
EDo bs e -> mkDo (map f bs) (f e)
BindNoVar exp0 -> BindVar VWild <| exp0
EPiNoVar exp0 exp1 -> EPi VWild <| exp0 <| exp1
EBind exp0 exp1 -> appBind <| exp0 <| exp1
EBindC exp0 exp1 -> appBindC <| exp0 <| exp1
- EOr exp0 exp1 -> andBool <| exp0 <| exp1
- EAnd exp0 exp1 -> orBool <| exp0 <| exp1
+ EOr exp0 exp1 -> orBool <| exp0 <| exp1
+ EAnd exp0 exp1 -> andBool <| exp0 <| exp1
EEq exp0 exp1 -> overlBin "eq" <| exp0 <| exp1
ENe exp0 exp1 -> overlBin "ne" <| exp0 <| exp1
ELt exp0 exp1 -> overlBin "lt" <| exp0 <| exp1
@@ -457,14 +520,14 @@ appCons e1 e2 = apply (EVar (Ident "Cons")) [EMeta,e1,e2]
--
andBool :: Exp -> Exp -> Exp
-andBool e1 e2 = ifBool e1 e2 (var "False")
+andBool e1 e2 = ifBool e1 e2 false
orBool :: Exp -> Exp -> Exp
-orBool e1 e2 = ifBool e1 (var "True") e2
+orBool e1 e2 = ifBool e1 true e2
ifBool :: Exp -> Exp -> Exp -> Exp
-ifBool c t e = ECase c [Case (PCons (Ident "True") []) t,
- Case (PCons (Ident "False") []) e]
+ifBool c t e = ECase c [Case (PCons (Ident "True") []) gtrue t,
+ Case (PCons (Ident "False") []) gtrue e]
--
-- * Substitution
@@ -483,7 +546,7 @@ substs ss = f (Map.fromList ss)
ELet ds e3 ->
ELet [LetDef id (f ss e1) (f ss' e2) | LetDef id e1 e2 <- ds] (f ss' e3)
where ss' = ss `mapMinusSet` letDefBinds ds
- Case p e -> Case p (f ss' e) where ss' = ss `mapMinusSet` binds p
+ Case p g e -> Case p (f ss' g) (f ss' e) where ss' = ss `mapMinusSet` binds p
EAbs (VVar id) e -> EAbs (VVar id) (f ss' e) where ss' = Map.delete id ss
EPi (VVar id) e1 e2 ->
EPi (VVar id) (f ss e1) (f ss' e2) where ss' = Map.delete id ss
@@ -497,6 +560,15 @@ substs ss = f (Map.fromList ss)
var :: String -> Exp
var s = EVar (Ident s)
+true :: Exp
+true = var "True"
+
+false :: Exp
+false = var "False"
+
+gtrue :: Guard
+gtrue = GuardExp true
+
-- | Apply an expression to a list of arguments.
apply :: Exp -> [Exp] -> Exp
apply = foldl EApp
@@ -511,7 +583,8 @@ abstract n f =
-- | Abstract a type over some arguments.
abstractType :: [Exp] -- ^ argument types
- -> ([Exp] -> Exp)
+ -> ([Exp] -> Exp) -- ^ function from variable expressions
+ -- to the expression to return
-> C Exp
abstractType ts f =
do
@@ -551,7 +624,8 @@ freeVars = f
(Set.unions (f exp3:map f (letDefRhss defs)) Set.\\ letDefBinds defs)
:map f (letDefTypes defs)
ECase exp cases -> f exp `Set.union`
- Set.unions [ f e Set.\\ binds p | Case p e <- cases]
+ Set.unions [(f g `Set.union` f e) Set.\\ binds p
+ | Case p g e <- cases]
EAbs (VVar id) exp -> Set.delete id (f exp)
EPi (VVar id) exp1 exp2 -> f exp1 `Set.union` Set.delete id (f exp2)
EVar i -> Set.singleton i
@@ -568,7 +642,7 @@ countFreeOccur x = f
f t = case t of
ELet defs _ | x `Set.member` letDefBinds defs ->
sum (map f (letDefTypes defs))
- Case p e | x `Set.member` binds p -> 0
+ Case p _ _ | x `Set.member` binds p -> 0
EAbs (VVar id) _ | id == x -> 0
EPi (VVar id) exp1 _ | id == x -> f exp1
EVar id | id == x -> 1
@@ -584,11 +658,6 @@ binds = f
PVar id -> Set.singleton id
_ -> composOpMonoid f p
--- | Checks if a declaration is a value declaration
--- of the given identifier.
-isValueDecl :: Ident -> Decl -> Bool
-isValueDecl x (ValueDecl y _ _) = x == y
-isValueDecl _ _ = False
fromPRec :: [FieldPattern] -> [(Ident,Pattern)]
fromPRec fps = [ (l,p) | FieldPattern l p <- fps ]