diff options
Diffstat (limited to 'src/Transfer/SyntaxToCore.hs')
| -rw-r--r-- | src/Transfer/SyntaxToCore.hs | 173 |
1 files changed, 121 insertions, 52 deletions
diff --git a/src/Transfer/SyntaxToCore.hs b/src/Transfer/SyntaxToCore.hs index b13579293..0d5907890 100644 --- a/src/Transfer/SyntaxToCore.hs +++ b/src/Transfer/SyntaxToCore.hs @@ -28,11 +28,11 @@ declsToCore :: [Decl] -> [Decl] declsToCore m = evalState (declsToCore_ m) newState declsToCore_ :: [Decl] -> C [Decl] -declsToCore_ = desugar +declsToCore_ = deriveDecls + >>> desugar + >>> compilePattDecls >>> numberMetas - >>> deriveDecls >>> replaceCons - >>> compilePattDecls >>> expandOrPatts >>> optimize @@ -61,13 +61,14 @@ numberMetas = mapM f return $ EVar $ Ident $ "?" ++ show (nextMeta st) -- FIXME: hack _ -> composOpM f t + -- -- * Pattern equations -- compilePattDecls :: [Decl] -> C [Decl] compilePattDecls [] = return [] -compilePattDecls (d@(ValueDecl x _ _):ds) = +compilePattDecls (d@(ValueDecl x _ _ _):ds) = do let (xs,rest) = span (isValueDecl x) ds d <- mergeDecls (d:xs) @@ -75,20 +76,26 @@ compilePattDecls (d@(ValueDecl x _ _):ds) = return (d:rs) compilePattDecls (d:ds) = liftM (d:) (compilePattDecls ds) --- | Take a non-empty list of pattern equations for the same --- function, and produce a single declaration. +-- | Checks if a declaration is a value declaration +-- of the given identifier. +isValueDecl :: Ident -> Decl -> Bool +isValueDecl x (ValueDecl y _ _ _) = x == y +isValueDecl _ _ = False + +-- | Take a non-empty list of pattern equations with guards +-- for the same function, and produce a single declaration. mergeDecls :: [Decl] -> C Decl -mergeDecls ds@(ValueDecl x p _:_) - = do let cs = [ (ps,rhs) | ValueDecl _ ps rhs <- ds ] - (pss,rhss) = unzip cs +mergeDecls ds@(ValueDecl x p _ _:_) + = do let cs = [ (ps,g,rhs) | ValueDecl _ ps g rhs <- ds ] + (pss,_,_) = unzip3 cs n = length p when (not (all ((== n) . length) pss)) $ fail $ "Pattern count mismatch for " ++ printTree x vs <- freshIdents n - let cases = map (\ (ps,rhs) -> Case (mkPRec ps) rhs) cs + let cases = map (\ (ps,g,rhs) -> Case (mkPRec ps) g rhs) cs c = ECase (mkERec (map EVar vs)) cases f = foldr (EAbs . VVar) c vs - return $ ValueDecl x [] f + return $ ValueDecl x [] GuardNo f where mkRec r f = r . zipWith (\i e -> f (Ident ("p"++show i)) e) [0..] mkPRec = mkRec PRec FieldPattern mkERec = mkRec ERec FieldValue @@ -118,6 +125,10 @@ derivators = [ ("Ord", deriveOrd) ] +-- +-- * Deriving instances of Compos +-- + deriveCompos :: Derivator deriveCompos t@(Ident ts) k cs = do @@ -128,7 +139,7 @@ deriveCompos t@(Ident ts) k cs = dt = apply (EVar (Ident "Compos")) [c, EVar t] r = ERec [FieldValue (Ident "composOp") co, FieldValue (Ident "composFold") cf] - return [TypeDecl d dt, ValueDecl d [] r] + return [TypeDecl d dt, ValueDecl d [] GuardNo r] deriveComposOp :: Ident -> Exp -> [(Ident,Exp)] -> C Exp deriveComposOp t k cs = @@ -149,9 +160,9 @@ deriveComposOp t k cs = EApp (EVar t') c | t' == t -> apply (e f) [c, e v] _ -> e v calls = zipWith rec vars (argumentTypes ct) - return $ Case (PCons ci (map PVar vars)) (apply (e ci) calls) + return $ Case (PCons ci (map PVar vars)) gtrue (apply (e ci) calls) cases <- mapM (uncurry mkCase) cs - let cases' = cases ++ [Case PWild (e x)] + let cases' = cases ++ [Case PWild gtrue (e x)] fb <- abstract (arity k) $ const $ pv f \-> pv x \-> ECase (e x) cases' return fb @@ -180,17 +191,61 @@ deriveComposFold t k cs = p = EProj (e r) (Ident "mplus") joinCalls [] = z joinCalls cs = foldr1 (\x y -> apply p [x,y]) cs - return $ Case (PCons ci (map PVar vars)) (joinCalls calls) + return $ Case (PCons ci (map PVar vars)) gtrue (joinCalls calls) cases <- mapM (uncurry mkCase) cs - let cases' = cases ++ [Case PWild (e x)] + let cases' = cases ++ [Case PWild gtrue (e x)] fb <- abstract (arity k) $ const $ pv f \-> pv x \-> ECase (e x) cases' return $ VWild \-> pv r \-> fb +-- +-- * Deriving instances of Show +-- + deriveShow :: Derivator deriveShow t k cs = fail $ "derive Show not implemented" +-- +-- * Deriving instances of Eq +-- + +-- FIXME: how do we require Eq instances for all +-- constructor arguments? + deriveEq :: Derivator -deriveEq t k cs = fail $ "derive Eq not implemented" +deriveEq t@(Ident tn) k cs = + do + let ats = argumentTypes k + d = Ident ("eq_"++tn) + dt <- abstractType ats (EApp (EVar (Ident "Eq")) . apply (EVar t)) + eq <- mkEq + r <- abstract (arity k) (\_ -> ERec [FieldValue (Ident "eq") eq]) + return [TypeDecl d dt, ValueDecl d [] GuardNo r] + where + mkEq = do + x <- freshIdent + cases <- mapM (uncurry mkEqCase) cs + return $ EAbs (VVar x) (ECase (EVar x) cases) + mkEqCase c ct = + do + let n = arity ct + vs1 <- freshIdents n + vs2 <- freshIdents n + y <- freshIdent + let p1 = PCons c (map PVar vs1) + p2 = PCons c (map PVar vs2) + es1 = map EVar vs1 + es2 = map EVar vs2 + tc | n == 0 = true + -- FIXME: using EEq doesn't work right now + | otherwise = foldr1 EAnd (zipWith EEq es1 es2) + c1 = Case p2 gtrue tc + c2 = Case PWild gtrue false + return $ Case p1 gtrue (EAbs (VVar y) (ECase (EVar y) [c1,c2])) + + +-- +-- * Deriving instances of Ord +-- deriveOrd :: Derivator deriveOrd t k cs = fail $ "derive Ord not implemented" @@ -268,10 +323,10 @@ removeUselessMatch = return . map f f x = case x of EAbs (VVar x) b -> case f b of - -- replace \x -> case x of { y -> e } with \y -> e, + -- replace \x -> case x of { y | True -> e } with \y -> e, -- if x is not free in e - ECase (EVar x') [Case (PVar y) e] - | x' == x && not (x `isFreeIn` e) + ECase (EVar x') [Case (PVar y) g e] + | x' == x && isTrueGuard g && not (x `isFreeIn` e) -> f (EAbs (VVar y) e) -- replace unused variable in lambda with wild card e | not (x `isFreeIn` e) -> f (EAbs VWild e) @@ -282,31 +337,33 @@ removeUselessMatch = return . map f v = if not (x `isFreeIn` e') then VWild else VVar x in EPi v (f t) e' -- replace unused variables in case patterns with wild cards - Case p e -> - let e' = f e - p' = f (removeUnusedVarPatts (freeVars e') p) - in Case p' e' + Case p (GuardExp g) e -> + let g' = f g + e' = f e + used = freeVars g' `Set.union` freeVars e' + p' = f (removeUnusedVarPatts used p) + in Case p' (GuardExp g') e' -- for value declarations without patterns, compilePattDecls -- generates pattern matching on the empty record, remove these - ECase (ERec []) [Case (PRec []) e] -> f e + ECase (ERec []) [Case (PRec []) g e] | isTrueGuard g -> f e -- if the pattern matching is on a single field of a record expression -- with only one field, there is no need to wrap it in a record ECase (ERec [FieldValue x e]) cs | all (isSingleFieldPattern x) (casePatterns cs) - -> f (ECase e [ Case p r | Case (PRec [FieldPattern _ p]) r <- cs ]) - -- for all fields in record matching where all patterns just + -> f (ECase e [ Case p g r | Case (PRec [FieldPattern _ p]) g r <- cs ]) + -- for all fields in record matching where all patterns for the field just -- bind variables, substitute in the field value (if it is a variable) - -- in the right hand sides. + -- in the guards and right hand sides. ECase (ERec fs) cs | all isPRec (casePatterns cs) -> - let g (FieldValue f v@(EVar _):fs) xs + let h (FieldValue f v@(EVar _):fs) xs | all (onlyBindsFieldToVariable f) (casePatterns xs) - = g fs (map (inlineField f v) xs) - g (f:fs) xs = let (fs',xs') = g fs xs in (f:fs',xs') - g [] xs = ([],xs) - inlineField f v (Case (PRec fps) e) = + = h fs (map (inlineField f v) xs) + h (f:fs) xs = let (fs',xs') = h fs xs in (f:fs',xs') + h [] xs = ([],xs) + inlineField f v (Case (PRec fps) (GuardExp g) e) = let p' = PRec [fp | fp@(FieldPattern f' _) <- fps, f' /= f] ss = zip (fieldPatternVars f fps) (repeat v) - in Case p' (substs ss e) - (fs',cs') = g fs cs + in Case p' (GuardExp (substs ss g)) (substs ss e) + (fs',cs') = h fs cs x' = ECase (ERec fs') cs' in if length fs' < length fs then f x' else composOp f x' -- Remove wild card patterns in record patterns @@ -314,6 +371,11 @@ removeUselessMatch = return . map f where wildcards = [fp | fp@(FieldPattern _ PWild) <- fps] _ -> composOp f x +isTrueGuard :: Guard -> Bool +isTrueGuard (GuardExp (EVar (Ident "True"))) = True +isTrueGuard GuardNo = True +isTrueGuard _ = False + removeUnusedVarPatts :: Set Ident -> Tree a -> Tree a removeUnusedVarPatts keep x = case x of PVar id | not (id `Set.member` keep) -> PWild @@ -325,7 +387,7 @@ isSingleFieldPattern x p = case p of _ -> False casePatterns :: [Case] -> [Pattern] -casePatterns cs = [p | Case p _ <- cs] +casePatterns cs = [p | Case p _ _ <- cs] isPRec :: Pattern -> Bool isPRec (PRec _) = True @@ -357,7 +419,7 @@ expandOrPatts = return . map f _ -> composOp f x expandCase :: Case -> [Case] -expandCase (Case p e) = [ Case p' e | p' <- expandPatt p ] +expandCase (Case p g e) = [ Case p' g e | p' <- expandPatt p ] expandPatt :: Pattern -> [Pattern] expandPatt p = case p of @@ -383,14 +445,15 @@ desugar = return . map f f x = case x of PListCons p1 p2 -> pListCons <| p1 <| p2 PList xs -> pList (map f [p | PListElem p <- xs]) + GuardNo -> gtrue EIf exp0 exp1 exp2 -> ifBool <| exp0 <| exp1 <| exp2 EDo bs e -> mkDo (map f bs) (f e) BindNoVar exp0 -> BindVar VWild <| exp0 EPiNoVar exp0 exp1 -> EPi VWild <| exp0 <| exp1 EBind exp0 exp1 -> appBind <| exp0 <| exp1 EBindC exp0 exp1 -> appBindC <| exp0 <| exp1 - EOr exp0 exp1 -> andBool <| exp0 <| exp1 - EAnd exp0 exp1 -> orBool <| exp0 <| exp1 + EOr exp0 exp1 -> orBool <| exp0 <| exp1 + EAnd exp0 exp1 -> andBool <| exp0 <| exp1 EEq exp0 exp1 -> overlBin "eq" <| exp0 <| exp1 ENe exp0 exp1 -> overlBin "ne" <| exp0 <| exp1 ELt exp0 exp1 -> overlBin "lt" <| exp0 <| exp1 @@ -457,14 +520,14 @@ appCons e1 e2 = apply (EVar (Ident "Cons")) [EMeta,e1,e2] -- andBool :: Exp -> Exp -> Exp -andBool e1 e2 = ifBool e1 e2 (var "False") +andBool e1 e2 = ifBool e1 e2 false orBool :: Exp -> Exp -> Exp -orBool e1 e2 = ifBool e1 (var "True") e2 +orBool e1 e2 = ifBool e1 true e2 ifBool :: Exp -> Exp -> Exp -> Exp -ifBool c t e = ECase c [Case (PCons (Ident "True") []) t, - Case (PCons (Ident "False") []) e] +ifBool c t e = ECase c [Case (PCons (Ident "True") []) gtrue t, + Case (PCons (Ident "False") []) gtrue e] -- -- * Substitution @@ -483,7 +546,7 @@ substs ss = f (Map.fromList ss) ELet ds e3 -> ELet [LetDef id (f ss e1) (f ss' e2) | LetDef id e1 e2 <- ds] (f ss' e3) where ss' = ss `mapMinusSet` letDefBinds ds - Case p e -> Case p (f ss' e) where ss' = ss `mapMinusSet` binds p + Case p g e -> Case p (f ss' g) (f ss' e) where ss' = ss `mapMinusSet` binds p EAbs (VVar id) e -> EAbs (VVar id) (f ss' e) where ss' = Map.delete id ss EPi (VVar id) e1 e2 -> EPi (VVar id) (f ss e1) (f ss' e2) where ss' = Map.delete id ss @@ -497,6 +560,15 @@ substs ss = f (Map.fromList ss) var :: String -> Exp var s = EVar (Ident s) +true :: Exp +true = var "True" + +false :: Exp +false = var "False" + +gtrue :: Guard +gtrue = GuardExp true + -- | Apply an expression to a list of arguments. apply :: Exp -> [Exp] -> Exp apply = foldl EApp @@ -511,7 +583,8 @@ abstract n f = -- | Abstract a type over some arguments. abstractType :: [Exp] -- ^ argument types - -> ([Exp] -> Exp) + -> ([Exp] -> Exp) -- ^ function from variable expressions + -- to the expression to return -> C Exp abstractType ts f = do @@ -551,7 +624,8 @@ freeVars = f (Set.unions (f exp3:map f (letDefRhss defs)) Set.\\ letDefBinds defs) :map f (letDefTypes defs) ECase exp cases -> f exp `Set.union` - Set.unions [ f e Set.\\ binds p | Case p e <- cases] + Set.unions [(f g `Set.union` f e) Set.\\ binds p + | Case p g e <- cases] EAbs (VVar id) exp -> Set.delete id (f exp) EPi (VVar id) exp1 exp2 -> f exp1 `Set.union` Set.delete id (f exp2) EVar i -> Set.singleton i @@ -568,7 +642,7 @@ countFreeOccur x = f f t = case t of ELet defs _ | x `Set.member` letDefBinds defs -> sum (map f (letDefTypes defs)) - Case p e | x `Set.member` binds p -> 0 + Case p _ _ | x `Set.member` binds p -> 0 EAbs (VVar id) _ | id == x -> 0 EPi (VVar id) exp1 _ | id == x -> f exp1 EVar id | id == x -> 1 @@ -584,11 +658,6 @@ binds = f PVar id -> Set.singleton id _ -> composOpMonoid f p --- | Checks if a declaration is a value declaration --- of the given identifier. -isValueDecl :: Ident -> Decl -> Bool -isValueDecl x (ValueDecl y _ _) = x == y -isValueDecl _ _ = False fromPRec :: [FieldPattern] -> [(Ident,Pattern)] fromPRec fps = [ (l,p) | FieldPattern l p <- fps ] |
