summaryrefslogtreecommitdiff
path: root/src/runtime
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime')
-rw-r--r--src/runtime/c/pgf/data.c7
-rw-r--r--src/runtime/c/pgf/data.h5
-rw-r--r--src/runtime/c/pgf/parser.c263
-rw-r--r--src/runtime/c/pgf/pgf.h3
-rw-r--r--src/runtime/c/pgf/reader.c51
-rw-r--r--src/runtime/c/utils/pgf-translate.c6
6 files changed, 256 insertions, 79 deletions
diff --git a/src/runtime/c/pgf/data.c b/src/runtime/c/pgf/data.c
index 36729b23f..74dba9cb8 100644
--- a/src/runtime/c/pgf/data.c
+++ b/src/runtime/c/pgf/data.c
@@ -3,6 +3,7 @@
#include <gu/type.h>
#include <gu/variant.h>
#include <gu/assert.h>
+#include <math.h>
bool
pgf_tokens_equal(PgfTokens t1, PgfTokens t2)
@@ -184,6 +185,12 @@ GU_DEFINE_TYPE(
GU_MEMBER(PgfCatFun, prob, double),
GU_MEMBER(PgfCatFun, fun, PgfCId));
+static float inf_float = INFINITY;
+
+GU_DEFINE_TYPE(PgfMetaChildMap, GuMap,
+ gu_type(PgfCat), NULL,
+ gu_type(float), &inf_float);
+
GU_DEFINE_TYPE(
PgfCat, struct,
GU_MEMBER(PgfCat, context, PgfHypos),
diff --git a/src/runtime/c/pgf/data.h b/src/runtime/c/pgf/data.h
index 63c26d318..7fe2fc7d3 100644
--- a/src/runtime/c/pgf/data.h
+++ b/src/runtime/c/pgf/data.h
@@ -145,11 +145,16 @@ struct PgfCatFun {
PgfCId fun;
};
+typedef GuMap PgfMetaChildMap;
+extern GU_DECLARE_TYPE(PgfMetaChildMap, GuMap);
+
struct PgfCat {
// TODO: Add cid here
PgfHypos context;
float meta_prob;
+ float meta_token_prob;
+ PgfMetaChildMap* meta_child_probs;
GuLength n_functions;
PgfCatFun functions[]; // XXX: resolve to PgfFunDecl*?
diff --git a/src/runtime/c/pgf/parser.c b/src/runtime/c/pgf/parser.c
index 3d97b5a39..a05600884 100644
--- a/src/runtime/c/pgf/parser.c
+++ b/src/runtime/c/pgf/parser.c
@@ -492,7 +492,9 @@ pgf_item_set_curr_symbol(PgfItem* item, GuPool* pool)
static PgfItem*
pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx,
- PgfProduction prod, PgfItemBuf* conts, GuPool* pool)
+ PgfProduction prod, PgfItemBuf* conts,
+ float delta_prob,
+ GuPool* pool)
{
PgfItemBase* base = gu_new(PgfItemBase, pool);
base->ccat = ccat;
@@ -557,6 +559,7 @@ pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx,
best_cont->inside_prob-ccat->viterbi_prob+
best_cont->outside_prob;
}
+ item->outside_prob += delta_prob;
pgf_item_set_curr_symbol(item, pool);
return item;
@@ -650,7 +653,12 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after,
nargs * sizeof(PgfPArg));
gu_seq_set(item->args, PgfPArg, nargs,
((PgfPArg) { .hypos = NULL, .ccat = cat }));
- item->inside_prob += cat->viterbi_prob;
+
+ PgfCIdMap* meta_child_probs =
+ item->base->ccat->cnccat->abscat->meta_child_probs;
+ item->inside_prob +=
+ cat->viterbi_prob+
+ gu_map_get(meta_child_probs, cat->cnccat->abscat, float);
PgfSymbol prev = item->curr_sym;
PgfSymbolCat* scat = (PgfSymbolCat*)
@@ -673,10 +681,11 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after,
static void
pgf_parsing_production(PgfParseState* state,
PgfCCat* ccat, size_t lin_idx,
- PgfProduction prod, PgfItemBuf* conts)
+ PgfProduction prod, PgfItemBuf* conts,
+ float delta_prob)
{
PgfItem* item =
- pgf_new_item(state->offset, ccat, lin_idx, prod, conts, state->pool);
+ pgf_new_item(state->offset, ccat, lin_idx, prod, conts, delta_prob, state->pool);
gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item);
}
@@ -798,7 +807,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
* i.e. process it. */
if (conts2) {
pgf_parsing_production(before, cat, i,
- prod, conts2);
+ prod, conts2, 0);
}
}
@@ -818,7 +827,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
* i.e. process it. */
if (conts2) {
pgf_parsing_production(state, cat, i,
- prod, conts2);
+ prod, conts2, 0);
}
}
@@ -835,14 +844,15 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
static void
pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
- PgfItem* item, PgfCCat* ccat, size_t lin_idx)
+ PgfItem* item, PgfCCat* ccat, size_t lin_idx,
+ float delta_prob)
{
gu_enter("-> cat: %d", ccat->fid);
if (gu_seq_is_null(ccat->prods)) {
// Empty category
return;
}
-
+
PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map, ccat, lin_idx,
before->pool, before->pool);
@@ -856,17 +866,17 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
PgfProductionSeq prods = ccat->prods;
for (size_t i = 0; i < ccat->n_synprods; i++) {
PgfProduction prod =
- gu_seq_get(prods, PgfProduction, i);
- pgf_parsing_production(before, ccat, lin_idx, prod, conts);
+ gu_seq_get(prods, PgfProduction, i);
+ pgf_parsing_production(before, ccat, lin_idx, prod, conts, delta_prob);
}
-
+
if (ccat->cnccat->abscat->meta_prob != INFINITY &&
ccat->fid < before->ps->concr->total_cats) {
// Top-down prediction for meta rules
PgfItem *item =
- pgf_new_item(before->offset, ccat, lin_idx, before->ps->meta_prod, conts, before->pool);
+ pgf_new_item(before->offset, ccat, lin_idx, before->ps->meta_prod, conts, 0, before->pool);
item->inside_prob =
- 1000000 + ccat->cnccat->abscat->meta_prob * 1000;
+ ccat->cnccat->abscat->meta_prob;
gu_buf_heap_push(before->agenda, &pgf_item_prob_order, &item);
}
@@ -880,7 +890,7 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
new_item->base->lin_idx == lin_idx &&
gu_seq_length(new_item->args) == 0) {
pgf_parsing_production(before, ccat, lin_idx,
- new_item->base->prod, conts);
+ new_item->base->prod, conts, 0);
}
}
}
@@ -901,7 +911,7 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
PgfProductionApply* papp = i.data;
if (gu_seq_length(papp->args) == 0) {
pgf_parsing_production(before, ccat, lin_idx,
- prod, conts);
+ prod, conts, 0);
}
break;
}
@@ -931,67 +941,107 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
gu_exit("<-");
}
-static void
+static float
pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after,
- PgfItem* item, PgfItem* meta_item, PgfItemBuf* agenda,
- bool print)
+ PgfItemBuf* index, PgfItem* meta_item,
+ PgfItemBuf* agenda)
{
- PgfItemBuf* conts =
- pgf_parsing_get_conts(before->conts_map,
- item->base->ccat, item->base->lin_idx,
- before->pool, before->pool);
- gu_buf_push(conts, PgfItem*, meta_item);
- if (gu_buf_length(conts) == 1) {
- PgfItem* copy = pgf_item_copy(item, after->pool);
- copy->base = pgf_item_base_copy(item->base, after->pool);
- copy->base->conts = conts;
- copy->outside_prob =
- meta_item->inside_prob+meta_item->outside_prob;
-
+ float prob = INFINITY;
+
+ PgfMetaChildMap* meta_child_probs =
+ meta_item->base->ccat->cnccat->abscat->meta_child_probs;
+ if (meta_child_probs == NULL)
+ return prob;
+
+ if (!gu_map_has(before->generated_cats, index)) {
+ gu_map_put(before->generated_cats, index, PgfCCat*, NULL);
+
+ size_t n_items = gu_buf_length(index);
+ for (size_t i = 0; i < n_items; i++) {
+ PgfItem *item = gu_buf_get(index, PgfItem*, i);
+
+ float meta_prob =
+ meta_item->inside_prob+
+ meta_item->outside_prob+
+ gu_map_get(meta_child_probs, item->base->ccat->cnccat->abscat, float);
+
+ PgfItemBuf* conts =
+ pgf_parsing_get_conts(before->conts_map,
+ item->base->ccat, item->base->lin_idx,
+ before->pool, before->pool);
+ if (gu_buf_length(conts) == 0) {
+ float outside_prob =
+ pgf_parsing_bu_predict(before, after,
+ item->base->conts, meta_item,
+ conts);
+
+ if (outside_prob > meta_prob)
+ outside_prob = meta_prob;
+
+ for (size_t j = i; j < n_items; j++) {
+ PgfItem *item_ = gu_buf_get(index, PgfItem*, j);
+
+ if (item->base->conts == item_->base->conts) {
+ PgfItem* copy = pgf_item_copy(item_, after->pool);
+ copy->base = pgf_item_base_copy(item_->base, after->pool);
+ copy->base->conts = conts;
+ copy->outside_prob = outside_prob;
#ifdef PGF_PARSER_DEBUG
- copy->start = before->offset;
- copy->end = before->offset;
-
- if (print) {
- GuPool* tmp_pool = gu_new_pool();
- GuOut* out = gu_file_out(stderr, tmp_pool);
- GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool);
- GuExn* err = gu_exn(NULL, type, tmp_pool);
- pgf_print_item(copy, wtr, err, tmp_pool);
- gu_pool_free(tmp_pool);
- } else {
- copy->end = after->offset;
- }
+ copy->start = before->offset;
+ copy->end = (agenda == NULL)
+ ? after->offset
+ : before->offset;
#endif
- gu_buf_push(agenda, PgfItem*, copy);
+ if (agenda == NULL)
+ pgf_parsing_add_transition(before, after, after->ts->tok, copy);
+ else
+ gu_buf_push(agenda, PgfItem*, copy);
- size_t n_items = gu_buf_length(item->base->conts);
- for (size_t i = 0; i < n_items; i++) {
- PgfItem *item_ = gu_buf_get(item->base->conts, PgfItem*, i);
- pgf_parsing_bu_predict(before, after, item_, meta_item, conts, true);
- }
- } else {
- /* If it has already been completed, combine. */
+ float item_prob =
+ copy->inside_prob+copy->outside_prob;
+ if (prob > item_prob)
+ prob = item_prob;
+ }
+ }
+ } else {
+ size_t n_items = gu_buf_length(conts);
+ for (size_t i = 0; i < n_items; i++) {
+ PgfItem *item = gu_buf_get(conts, PgfItem*, i);
+
+ float item_prob =
+ item->inside_prob+item->outside_prob;
+ if (prob > item_prob)
+ prob = item_prob;
+ }
+ prob += item->inside_prob;
- /*PgfCCat* completed =
- pgf_parsing_get_completed(before, conts);
- if (completed) {
- pgf_parsing_combine(before, after, meta_item, completed, item->base->lin_idx);
- }*/
+ /* If it has already been completed, combine. */
- PgfParseState* state = after;
- while (state != NULL) {
- PgfCCat* completed =
- pgf_parsing_get_completed(state, conts);
- if (completed) {
- pgf_parsing_combine(state, state->next, meta_item, completed, item->base->lin_idx);
- }
+ /*PgfCCat* completed =
+ pgf_parsing_get_completed(before, conts);
+ if (completed) {
+ pgf_parsing_combine(before, after, meta_item, completed, item->base->lin_idx);
+ }*/
- state = state->next;
+ PgfParseState* state = after;
+ while (state != NULL) {
+ PgfCCat* completed =
+ pgf_parsing_get_completed(state, conts);
+ if (completed) {
+ pgf_parsing_combine(state, state->next, meta_item, completed, item->base->lin_idx);
+ }
+
+ state = state->next;
+ }
+ }
+
+ if (meta_prob != INFINITY)
+ gu_buf_push(conts, PgfItem*, meta_item);
}
}
+ return prob;
}
static void
@@ -1002,7 +1052,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
PgfSymbolCat* scat = gu_variant_data(sym);
PgfPArg* parg = gu_seq_index(item->args, PgfPArg, scat->d);
gu_assert(!parg->hypos || !parg->hypos->len);
- pgf_parsing_td_predict(before, after, item, parg->ccat, scat->r);
+ pgf_parsing_td_predict(before, after, item, parg->ccat, scat->r, 0);
break;
}
case PGF_SYMBOL_KS: {
@@ -1105,7 +1155,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
if (parg->ccat->fid > 0 &&
parg->ccat->fid >= before->ps->concr->total_cats)
- pgf_parsing_td_predict(before, after, item, parg->ccat, slit->r);
+ pgf_parsing_td_predict(before, after, item, parg->ccat, slit->r, 0);
else {
PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map,
@@ -1133,7 +1183,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
pext->callback = callback;
pgf_parsing_production(before, parg->ccat, slit->r,
- prod, conts);
+ prod, conts, 0);
}
} else {
/* If it has already been completed, combine. */
@@ -1168,6 +1218,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
}
PgfParseState *meta_after = NULL;
+static PgfLiteralCallback pgf_meta_callback;
static void
pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
@@ -1202,7 +1253,7 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
case 0:
pgf_parsing_td_predict(before, after, item,
pcoerce->coerce,
- item->base->lin_idx);
+ item->base->lin_idx, 0);
break;
case 1:
pgf_parsing_complete(before, after, item, NULL);
@@ -1241,6 +1292,14 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
pgf_parsing_complete(before, after, item, before->meta_ep);
if (accepted && after != NULL) {
+ if (pext->callback == &pgf_meta_callback) {
+ float meta_token_prob =
+ item->base->ccat->cnccat->abscat->meta_token_prob;
+ if (meta_token_prob == INFINITY)
+ break;
+ item->inside_prob += meta_token_prob;
+ }
+
PgfSymbol prev = item->curr_sym;
PgfSymbolKS* sks = (PgfSymbolKS*)
gu_alloc_variant(PGF_SYMBOL_KS,
@@ -1265,6 +1324,48 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
}
}
+typedef struct {
+ GuMapItor fn;
+ PgfParseState* before;
+ PgfParseState* after;
+ PgfItem* meta_item;
+} PgfMetaPredictFn;
+
+static void
+pgf_parsing_meta_predict(GuMapItor* fn, const void* key, void* value, GuExn* err)
+{
+ (void) (err);
+
+ PgfCId abscat = *((PgfCId*) key);
+ float meta_prob = *((float*) value);
+ PgfMetaPredictFn* clo = (PgfMetaPredictFn*) fn;
+ PgfParseState* before = clo->before;
+ PgfParseState* after = clo->after;
+ PgfItem* meta_item = clo->meta_item;
+{
+ GuPool* tmp_pool = gu_new_pool();
+ GuOut* out = gu_file_out(stdout, tmp_pool);
+ GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool);
+ GuExn* err = gu_exn(NULL, type, tmp_pool);
+ gu_string_write(abscat, wtr, err);
+ gu_pool_free(tmp_pool);
+}
+ PgfCncCat* cnccat =
+ gu_map_get(before->ps->concr->cnccats, &abscat, PgfCncCat*);
+ if (cnccat == NULL)
+ return;
+
+ size_t n_cats = gu_list_length(cnccat->cats);
+ for (size_t i = 0; i < n_cats; i++) {
+ PgfCCat* ccat = gu_list_index(cnccat->cats, i);
+
+ for (size_t lin_idx = 0; lin_idx < cnccat->n_lins; lin_idx++) {
+ pgf_parsing_td_predict(before, after,
+ meta_item, ccat, lin_idx, meta_prob);
+ }
+ }
+}
+
static bool
pgf_match_meta(PgfConcr* concr, PgfItem *item, PgfToken tok,
PgfExprProb** out_ep, GuPool *pool)
@@ -1298,14 +1399,18 @@ pgf_match_meta(PgfConcr* concr, PgfItem *item, PgfToken tok,
PgfParseState* before =
gu_container(out_ep, PgfParseState, meta_ep);
- size_t n_items = gu_buf_length(after->ts->lexicon_idx);
- for (size_t i = 0; i < n_items; i++) {
- PgfItem* item_ =
- gu_buf_get(after->ts->lexicon_idx, PgfItem*, i);
- pgf_parsing_bu_predict(before, after,
- item_, item, after->agenda, false);
- after->ps->target = item_;
+ PgfCIdMap* meta_child_probs =
+ item->base->ccat->cnccat->abscat->meta_child_probs;
+ if (meta_child_probs != NULL) {
+ PgfMetaPredictFn clo = { { pgf_parsing_meta_predict }, before, after, item };
+ gu_map_iter(meta_child_probs, &clo.fn, NULL);
}
+/*
+ fprintf(stderr, "------------------------------------\n");
+ pgf_parsing_bu_predict(before, after,
+ after->ts->lexicon_idx, item,
+ NULL);
+ fprintf(stderr, "------------------------------------\n");*/
return false;
}
}
@@ -1651,14 +1756,14 @@ pgf_parser_init_state(PgfConcr* concr, PgfCId cat, size_t lin_idx, GuPool* pool)
PgfProduction prod =
gu_seq_get(prods, PgfProduction, i);
PgfItem* item =
- pgf_new_item(0, ccat, lin_idx, prod, conts, pool);
+ pgf_new_item(0, ccat, lin_idx, prod, conts, 0, pool);
gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item);
}
PgfItem *item =
- pgf_new_item(0, ccat, lin_idx, ps->meta_prod, conts, pool);
+ pgf_new_item(0, ccat, lin_idx, ps->meta_prod, conts, 0, pool);
item->inside_prob =
- 1000000 + ccat->cnccat->abscat->meta_prob * 1000;
+ ccat->cnccat->abscat->meta_prob;
gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item);
}
}
@@ -1896,7 +2001,7 @@ pgf_parser_bu_index(PgfConcr* concr, PgfCCat* ccat, PgfProduction prod,
pgf_parsing_get_conts(conts_map, ccat, lin_idx,
pool, tmp_pool);
PgfItem* item =
- pgf_new_item(0, ccat, lin_idx, prod, conts, pool);
+ pgf_new_item(0, ccat, lin_idx, prod, conts, 0, pool);
pgf_parser_bu_item(concr, item, conts_map, pool, tmp_pool);
}
diff --git a/src/runtime/c/pgf/pgf.h b/src/runtime/c/pgf/pgf.h
index 91659d95e..e14b4c8c8 100644
--- a/src/runtime/c/pgf/pgf.h
+++ b/src/runtime/c/pgf/pgf.h
@@ -69,6 +69,9 @@ pgf_read(GuIn* in, GuPool* pool, GuExn* err);
*/
+bool
+pgf_load_meta_child_probs(PgfPGF*, const char* fpath, GuPool* pool);
+
#include <gu/type.h>
extern GU_DECLARE_TYPE(PgfPGF, struct);
diff --git a/src/runtime/c/pgf/reader.c b/src/runtime/c/pgf/reader.c
index 1fee45f83..08cc16096 100644
--- a/src/runtime/c/pgf/reader.c
+++ b/src/runtime/c/pgf/reader.c
@@ -30,6 +30,7 @@
#include <gu/exn.h>
#include <gu/utf8.h>
#include <math.h>
+#include <stdio.h>
#define GU_LOG_ENABLE
#include <gu/log.h>
@@ -656,6 +657,8 @@ pgf_compute_meta_probs(GuMapItor* fn, const void* key, void* value, GuExn* err)
mass += cat->functions[i].prob;
}
cat->meta_prob = - log(fabs(1 - mass));
+ cat->meta_token_prob = INFINITY;
+ cat->meta_child_probs = NULL;
}
static void
@@ -936,3 +939,51 @@ pgf_read(GuIn* in, GuPool* pool, GuExn* err)
gu_return_on_exn(err, NULL);
return pgf;
}
+
+bool
+pgf_load_meta_child_probs(PgfPGF* pgf, const char* fpath, GuPool* pool)
+{
+ FILE *fp = fopen(fpath, "r");
+ if (!fp)
+ return false;
+
+ GuPool* tmp_pool = gu_new_pool();
+
+ for (;;) {
+ char cat1_s[21];
+ char cat2_s[21];
+ float prob;
+
+ if (fscanf(fp, "%20s\t%20s\t%f", cat1_s, cat2_s, &prob) < 3)
+ break;
+
+ prob = - log(prob);
+
+ GuString cat1 = gu_str_string(cat1_s, tmp_pool);
+ PgfCat* abscat1 =
+ gu_map_get(pgf->abstract.cats, &cat1, PgfCat*);
+ if (abscat1 == NULL)
+ return false;
+
+ if (strcmp(cat2_s, "_") == 0) {
+ abscat1->meta_token_prob = prob;
+ } else {
+ GuString cat2 = gu_str_string(cat2_s, tmp_pool);
+ PgfCat* abscat2 = gu_map_get(pgf->abstract.cats, &cat2, PgfCat*);
+ if (abscat2 == NULL)
+ return false;
+
+ if (abscat1->meta_child_probs == NULL) {
+ abscat1->meta_child_probs =
+ gu_map_type_new(PgfMetaChildMap, pool);
+ }
+
+ gu_map_put(abscat1->meta_child_probs, abscat2, float, prob);
+ }
+ }
+
+ gu_pool_free(tmp_pool);
+
+ fclose(fp);
+ return true;
+}
diff --git a/src/runtime/c/utils/pgf-translate.c b/src/runtime/c/utils/pgf-translate.c
index aae09e70d..a740d3204 100644
--- a/src/runtime/c/utils/pgf-translate.c
+++ b/src/runtime/c/utils/pgf-translate.c
@@ -87,6 +87,12 @@ int main(int argc, char* argv[]) {
goto fail_read;
}
+ if (!pgf_load_meta_child_probs(pgf, "../../../examples/PennTreebank/test2.probs", pool)) {
+ fprintf(stderr, "Loading meta child probs failed\n");
+ status = EXIT_FAILURE;
+ goto fail_read;
+ }
+
// Look up the source and destination concrete categories
PgfConcr* from_concr =
gu_map_get(pgf->concretes, &from_lang, PgfConcr*);