summaryrefslogtreecommitdiff
path: root/src/runtime/haskell/PGF/Probabilistic.hs
blob: 37db7f7ff8b60f4ac482843c73025479a390e0da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
module PGF.Probabilistic
         ( Probabilities(..)
         , mkProbabilities                 -- :: PGF -> M.Map CId Double -> Probabilities
         , defaultProbabilities            -- :: PGF -> Probabilities
         , getProbabilities
         , setProbabilities
         , showProbabilities               -- :: Probabilities -> String
         , readProbabilitiesFromFile       -- :: FilePath -> PGF -> IO Probabilities

         , probTree
         , rankTreesByProbs
         , mkProbDefs
         ) where

import PGF.CId
import PGF.Data
import PGF.Macros

import qualified Data.Map as Map
import Data.List (sortBy,partition,nub,mapAccumL)
import Data.Maybe (fromMaybe) --, fromJust

-- | An abstract data structure which represents
-- the probabilities for the different functions in a grammar.
data Probabilities = Probs {
  funProbs :: Map.Map CId Double,
  catProbs :: Map.Map CId (Double, [(Double, CId)])
  }

-- | Renders the probability structure as string
showProbabilities :: Probabilities -> String
showProbabilities = unlines . concatMap prProb . Map.toList . catProbs where
  prProb (c,(p,fns)) = pr (p,c) : map pr fns
  pr (p,f) = showCId f ++ "\t" ++ show p

-- | Reads the probabilities from a file.
-- This should be a text file where on every line
-- there is a function name followed by a real number.
-- The number represents the probability mass allocated for that function.
-- The function name and the probability should be separated by a whitespace.
readProbabilitiesFromFile :: FilePath -> PGF -> IO Probabilities
readProbabilitiesFromFile file pgf = do
  s <- readFile file
  let ps0 = Map.fromList [(mkCId f,read p) | f:p:_ <- map words (lines s)]
  return $ mkProbabilities pgf ps0

-- | Builds probability tables. The second argument is a map
-- which contains the know probabilities. If some function is
-- not in the map then it gets assigned some probability based
-- on the even distribution of the unallocated probability mass
-- for the result category.
mkProbabilities :: PGF -> Map.Map CId Double -> Probabilities
mkProbabilities pgf probs =
  let funs1 = Map.fromList [(f,p) | (_,(_,fns)) <- Map.toList cats1, (p,f) <- fns]
      cats1 = Map.mapWithKey (\c (_,fns,_) -> 
                                 let p'   = fromMaybe 0 (Map.lookup c probs)
                                     fns' = sortBy cmpProb (fill fns)
                                 in (p', fns'))
                             (cats (abstract pgf))
  in Probs funs1 cats1
  where
    cmpProb (p1,_) (p2,_) = compare p2 p1

    fill fs = pad [(Map.lookup f probs,f) | (_,f) <- fs]
      where
        pad :: [(Maybe Double,a)] -> [(Double,a)]
        pad pfs = [(fromMaybe deflt mb_p,f) | (mb_p,f) <- pfs]
          where
            deflt = case length [f | (Nothing,f) <- pfs] of
                      0 -> 0
                      n -> max 0 ((1 - sum [d | (Just d,f) <- pfs]) / fromIntegral n)

-- | Returns the default even distibution.
defaultProbabilities :: PGF -> Probabilities
defaultProbabilities pgf = mkProbabilities pgf Map.empty

getProbabilities :: PGF -> Probabilities
getProbabilities pgf = Probs {
  funProbs = Map.map (\(_,_,_,p) -> p      ) (funs (abstract pgf)),
  catProbs = Map.map (\(_,fns,p) -> (p,fns)) (cats (abstract pgf))
  }

setProbabilities :: Probabilities -> PGF -> PGF
setProbabilities probs pgf = pgf {
  abstract = (abstract pgf) {
    funs = mapUnionWith (\(ty,a,df,_) p       -> (ty,a,df,  p)) (funs (abstract pgf)) (funProbs probs),
    cats = mapUnionWith (\(hypos,_,_) (p,fns) -> (hypos,fns,p)) (cats (abstract pgf)) (catProbs probs)
  }}
  where
    mapUnionWith f map1 map2 = 
      Map.mapWithKey (\k v -> maybe v (f v) (Map.lookup k map2)) map1

-- | compute the probability of a given tree
probTree :: PGF -> Expr -> Double
probTree pgf t = case t of
  EApp f e -> probTree pgf f * probTree pgf e
  EFun f   -> case Map.lookup f (funs (abstract pgf)) of
                Just (_,_,_,p) -> p
                Nothing        -> 1
  _ -> 1

-- | rank from highest to lowest probability
rankTreesByProbs :: PGF -> [Expr] -> [(Expr,Double)]
rankTreesByProbs pgf ts = sortBy (\ (_,p) (_,q) -> compare q p) 
  [(t, probTree pgf t) | t <- ts]


mkProbDefs :: PGF -> ([[CId]],[(CId,Type,[Equation])])
mkProbDefs pgf =
  let cs = [(c,hyps,fns) | (c,(hyps0,fs,_)) <- Map.toList (cats (abstract pgf)),
                           not (elem c [cidString,cidInt,cidFloat]),
                           let hyps = zipWith (\(bt,_,ty) n -> (bt,mkCId ('v':show n),ty))
                                              hyps0
                                              [1..]
                               fns  = [(f,ty) | (_,f) <- fs, 
                                               let Just (ty,_,_,_) = Map.lookup f (funs (abstract pgf))]
           ]
      ((_,css),eqss) = mapAccumL (\(ngen,css) (c,hyps,fns) -> 
              let st0      = (1,Map.empty)
                  ((_,eqs_map),cs) = computeConstrs pgf st0 [(fn,[],es) | (fn,(DTyp _ _ es)) <- fns]
                  (ngen', eqs) = mapAccumL (mkEquation eqs_map hyps) ngen fns
                  ceqs     = [(id,DTyp [] cidFloat [],reverse eqs) | (id,eqs) <- Map.toList eqs_map, not (null eqs)]
              in ((ngen',cs:css),(p_f c, mkType c hyps, eqs):ceqs)) (1,[]) cs
  in (reverse (concat css),concat eqss)
  where
    mkEImplArg bt e
      | bt == Explicit = e
      | otherwise      = EImplArg e
      
    mkPImplArg bt p
      | bt == Explicit = p
      | otherwise      = PImplArg p

    mkType c hyps =
      DTyp (hyps++[mkHypo (DTyp [] c es)]) cidFloat []
      where
        is = reverse [0..length hyps-1]
        es = [mkEImplArg bt (EVar i) | (i,(bt,_,_)) <- zip is hyps]

    sig = (funs (abstract pgf), \_ -> Nothing)
    
    mkEquation ceqs hyps ngen (fn,ty@(DTyp args _ es)) =
      let fs1         = case Map.lookup (p_f fn) ceqs of
                          Nothing              -> [mkApp (k_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
                          Just eqs | null eqs  -> []
                                   | otherwise -> [mkApp (p_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
          (ngen',fs2) = mapAccumL mkFactor2 ngen vs2
          fs3         = map mkFactor3 vs3
          eq = Equ (map mkTildeP xes++[PApp fn (zipWith mkArgP [1..] args)])
                   (mkMult (fs1++fs2++fs3))
      in (ngen',eq)
      where
        xes = map (normalForm sig k env) es

        mkTildeP e =
          case e of
            EImplArg e -> PImplArg (PTilde e)
            e          ->           PTilde e

        mkArgP n (bt,_,_) = mkPImplArg bt (PVar (mkCId ('v':show n)))

        mkMult []  = ELit (LFlt 1)
        mkMult [e] = e
        mkMult es  = mkApp (mkCId "mult") es

        mkFactor2 ngen (src,dst) =
          let vs = [EVar (k-i-1) | (i,ty) <- src]
          in (ngen+1,mkApp (p_i ngen) vs)

        mkFactor3 (i,DTyp _ c es) =
          let v = EVar (k-i-1)
          in mkApp (p_f c) (map (normalForm sig k env) es++[v])

        (k,env,vs1,vs2,vs3) = mkDeps ty

        mkDeps (DTyp args _ es) =
          let (k,env,dep1) = updateArgs 0 [] [] args
              dep2         = foldl (update k env) dep1 es
              (vs2,vs3)    = closure k dep2 [] []
              vs1          = concat [src | (src,dst) <- dep2, elem k dst]
          in (k,map (\k -> VGen k []) env,vs1,reverse vs2,vs3)
          where
            updateArgs k env dep []                              = (k,env,dep)
            updateArgs k env dep ((_,x,ty@(DTyp _ _ es)) : args) =
              let dep1 = foldl (update k env) dep es ++ [([(k,ty)],[])]
                  env1 | x == wildCId =     env
                       | otherwise    = k : env
              in updateArgs (k+1) env1 dep1 args

            update k env dep e =
              case e of
                EApp e1 e2 -> update k env (update k env dep e1) e2
                EFun _     -> dep
                EVar i     -> let (dep1,(src,dst):dep2) = splitAt (env !! i) dep
                              in dep1++(src,k:dst):dep2

            closure k []               vs2 vs3 = (vs2,vs3)
            closure k ((src,dst):deps) vs2 vs3
              | null dst   = closure k deps vs2 (vs3++src)
              | otherwise  =
                  let (deps1,deps2) = partition (\(src',dst') -> not (null [v1 | v1 <- dst, v2 <- dst', v1 == v2])) deps
                      deps3 = (src,dst):deps1
                      src2  = concatMap fst deps3
                      dst2  = [v | v <- concatMap snd deps3
                                 , lookup v src2 == Nothing]
                      dep2  = (src2,dst2)
                      dst'  = nub dst
                  in if null deps1
                       then if dst' == [k]
                              then closure k deps2 vs2 vs3
                              else closure k deps2 ((src,dst') : vs2) vs3
                       else closure k (dep2 : deps2) vs2 vs3
{-
        mkNewSig src =
          DTyp (mkArgs 0 0 [] src) cidFloat []
          where
            mkArgs k l env []                      = []
            mkArgs k l env ((i,DTyp _ c es) : src)
               | i == k    = let ty = DTyp [] c (map (normalForm sig k env) es)
                             in (Explicit,wildCId,ty) : mkArgs (k+1) (l+1) (VGen l [] : env) src
               | otherwise = mkArgs (k+1) l (VMeta 0 env [] : env) src
-}
type CState = (Int,Map.Map CId [Equation])

computeConstrs :: PGF -> CState -> [(CId,[Patt],[Expr])] -> (CState,[[CId]])
computeConstrs pgf (ngen,eqs_map) fns@((id,pts,[]):rest)
  | null rest =
     let eqs_map' = 
           Map.insertWith (++) (p_f id)
                               (if null pts
                                  then []
                                  else [Equ pts (ELit (LFlt 1.0))])
                               eqs_map
     in ((ngen,eqs_map'),[])
  | otherwise =
     let (st,ks) = mapAccumL mk_k (ngen,eqs_map) fns

         mk_k (ngen,eqs_map) (id,pts,[])
           | null pts  = ((ngen,eqs_map),k_f id)
           | otherwise = let eqs_map' = 
                               Map.insertWith (++) 
                                              (p_f id) 
                                              [Equ pts (EFun (k_i ngen))]
                                              eqs_map
                         in ((ngen+1,eqs_map'),k_i ngen)

     in (st,[ks])
computeConstrs pgf st fns =
  let (st',res) = mapAccumL (\st (p,fns) -> computeConstrs pgf st fns)
                            st
                            (computeConstr fns)
  in (st',concat res)
  where
    computeConstr fns = merge (split fns (Map.empty,[]))

    merge (cns,vrs) =
      [(p,fns++[(id,ps++[p],es) | (id,ps,es) <- vrs])
                                | (p,fns) <- concatMap addArgs (Map.toList cns)]
      ++
      if null vrs 
        then []
        else [(PWild,[(id,ps++[PWild],es) | (id,ps,es) <- vrs])]
      where
        addArgs (cn,fns) = addArg (length args) cn [] fns
          where
            Just (DTyp args _ _es,_,_,_) = Map.lookup cn (funs (abstract pgf))

        addArg 0 cn ps fns = [(PApp cn (reverse ps),fns)]
        addArg n cn ps fns = concat [addArg (n-1) cn (arg:ps) fns' | (arg,fns') <- computeConstr fns]

    split []                   (cns,vrs) = (cns,vrs)
    split ((id, ps, e:es):fns) (cns,vrs) = split fns (extract e [])
      where
        extract (EFun cn)     args = (Map.insertWith (++) cn [(id,ps,args++es)] cns, vrs)
        extract (EVar i)      args = (cns, (id,ps,es):vrs)
        extract (EApp e1 e2)  args = extract e1 (e2:args)
        extract (ETyped e ty) args = extract e args
        extract (EImplArg e)  args = extract e args

p_f c = mkCId ("p_"++showCId c)
p_i i = mkCId ("p_"++show i)
k_f f = mkCId ("k_"++showCId f)
k_i i = mkCId ("k_"++show i)