Commit 76948aaf authored by Ilya Edrenkin's avatar Ilya Edrenkin
Browse files

Fix for the maxent part of RNNLM-HS. Thanks to Tomas Mikolov for the hint.


git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4373 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent d7685a12
......@@ -92,17 +92,18 @@ exit 0
# this section demonstrates RNNLM-HS rescoring (commented out by default)
# the exact results might differ insignificantly due to hogwild in RNNLM-HS training that introduces indeterminism
%WER 5.92 [ 334 / 5643, 58 ins, 32 del, 244 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg/wer_14 # baseline (no rescoring)
%WER 5.42 [ 306 / 5643, 50 ins, 34 del, 222 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs100_0.3/wer_16
%WER 5.49 [ 310 / 5643, 47 ins, 36 del, 227 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs300_0.3/wer_18
%WER 5.90 [ 333 / 5643, 54 ins, 37 del, 242 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs30_0.15/wer_15
%WER 5.49 [ 310 / 5643, 45 ins, 38 del, 227 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.15/wer_18
%WER 5.49 [ 310 / 5643, 45 ins, 38 del, 227 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.15_N1000/wer_18
%WER 5.33 [ 301 / 5643, 41 ins, 41 del, 219 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3/wer_20
%WER 5.40 [ 305 / 5643, 41 ins, 41 del, 223 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N10/wer_20
%WER 5.33 [ 301 / 5643, 41 ins, 41 del, 219 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000/wer_20
%WER 5.26 [ 297 / 5643, 44 ins, 36 del, 217 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.4/wer_18
%WER 5.25 [ 296 / 5643, 44 ins, 36 del, 216 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.4_N1000/wer_18
%WER 5.26 [ 297 / 5643, 42 ins, 39 del, 216 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.5_N1000/wer_20
%WER 5.26 [ 297 / 5643, 47 ins, 29 del, 221 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs100_0.3/wer_15
%WER 5.17 [ 292 / 5643, 46 ins, 30 del, 216 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs300_0.3/wer_16
%WER 5.64 [ 318 / 5643, 50 ins, 34 del, 234 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs30_0.15/wer_16
%WER 5.55 [ 313 / 5643, 51 ins, 32 del, 230 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.15/wer_16
%WER 5.55 [ 313 / 5643, 51 ins, 32 del, 230 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.15_N1000/wer_16
%WER 5.39 [ 304 / 5643, 50 ins, 30 del, 224 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3/wer_15
%WER 5.42 [ 306 / 5643, 50 ins, 30 del, 226 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N10/wer_15
%WER 5.39 [ 304 / 5643, 50 ins, 30 del, 224 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000/wer_15
%WER 5.37 [ 303 / 5643, 49 ins, 29 del, 225 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.4/wer_14
%WER 5.37 [ 303 / 5643, 49 ins, 29 del, 225 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.4_N1000/wer_14
%WER 5.26 [ 297 / 5643, 45 ins, 32 del, 220 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.5_N1000/wer_15
%WER 5.14 [ 290 / 5643, 43 ins, 32 del, 215 sub ] exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.75_N1000/wer_18
%WER 14.17 [ 1167 / 8234, 222 ins, 123 del, 822 sub ] exp/tri3b/decode_tgpr_dev93/wer_17
%WER 19.37 [ 1595 / 8234, 315 ins, 153 del, 1127 sub ] exp/tri3b/decode_tgpr_dev93.si/wer_15
......
......@@ -90,3 +90,11 @@ steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \
--stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \
0.5 data/lang_test_bd_fg data/local/rnnlm-hs.h400.voc40k data/test_eval92 \
exp/tri3b/decode_bd_tgpr_eval92_fg $dir
dir=exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.75_N1000
rm -rf $dir
cp -r exp/tri3b/decode_bd_tgpr_eval92_fg_rnnlm-hs400_0.3_N1000 $dir
steps/rnnlmrescore.sh --rnnlm_ver rnnlm-hs-0.1b \
--stage 7 --N 1000 --cmd "$decode_cmd" --inv-acwt 17 \
0.75 data/lang_test_bd_fg data/local/rnnlm-hs.h400.voc40k data/test_eval92 \
exp/tri3b/decode_bd_tgpr_eval92_fg $dir
......@@ -70,13 +70,13 @@ local/wsj_format_data.sh || exit 1;
(
num_threads_rnnlm=8
local/wsj_train_rnnlms.sh --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 30 --nwords 10000 --direct 0 data/local/rnnlm-hs.h30.voc10k
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 30 --nwords 10000 --direct 1000 data/local/rnnlm-hs.h30.voc10k
local/wsj_train_rnnlms.sh --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 100 --nwords 20000 --direct 0 data/local/rnnlm-hs.h100.voc20k
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 100 --nwords 20000 --direct 1500 data/local/rnnlm-hs.h100.voc20k
local/wsj_train_rnnlms.sh --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 300 --nwords 30000 --direct 0 data/local/rnnlm-hs.h300.voc30k
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 300 --nwords 30000 --direct 1500 data/local/rnnlm-hs.h300.voc30k
local/wsj_train_rnnlms.sh --rnnlm_ver rnnlm-hs-0.1b --threads $num_threads_rnnlm \
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 400 --nwords 40000 --direct 0 data/local/rnnlm-hs.h400.voc40k
--cmd "$decode_cmd -l mem_free=1G" --bptt 4 --bptt-block 10 --hidden 400 --nwords 40000 --direct 2000 data/local/rnnlm-hs.h400.voc40k
)
) &
......
CC = gcc
#The -Ofast might not work with older versions of gcc; in that case, use -O2
CFLAGS = `$(CC) -dumpversion | awk '{if(NR==1 && $$1>"4.6") print "-lm -pthread -Ofast -march=native -Wall -funroll-loops -Wno-unused-result -std=c99 -g"; else print "-lm -pthread -O2 -march=native -Wall -funroll-loops -std=c99 -g";}'`
CFLAGS = `$(CC) -dumpversion | awk '{if(NR==1 && $$1>="4.6") print "-lm -pthread -Ofast -march=native -Wall -funroll-loops -Wno-unused-result -std=c99 -g"; else print "-lm -pthread -O2 -march=native -Wall -funroll-loops -std=c99 -g";}'`
all: rnnlm
......
......@@ -21,7 +21,7 @@
#include <assert.h>
#define MAX_STRING 100
#define MAX_SENTENCE_LENGTH 1000
#define MAX_SENTENCE_LENGTH 10000
#define MAX_CODE_LENGTH 40
const int vocab_hash_size = 30000000; // Maximum 30 * 0.7 = 21M words in the vocabulary
......@@ -444,7 +444,7 @@ void *TrainModelThread(void *id) {
sen[0] = 0; // <s> token -- beginning of sentence
int good = 1;
sentence_length = 1;
while(sentence_length <= MAX_SENTENCE_LENGTH) {
while(sentence_length < MAX_SENTENCE_LENGTH) {
word = ReadWordIndex(fi);
++word_count;
sen[sentence_length] = word;
......@@ -483,9 +483,9 @@ void *TrainModelThread(void *id) {
word = sen[target];
long long feature_hashes[MAX_NGRAM_ORDER] = {0};
if(maxent_order) {
for(int order = 0; order < maxent_order && target - order >= 1; ++order) {
for(int order = 0; order < maxent_order && target - order >= 0; ++order) {
feature_hashes[order] = PRIMES[0]*PRIMES[1];
for (int b = 1; b <= order; ++b) feature_hashes[order] += PRIMES[(order*PRIMES[b]+b) % PRIMES_SIZE]*(unsigned long long)(sen[target-1-b]+1);
for (int b = 1; b <= order; ++b) feature_hashes[order] += PRIMES[(order*PRIMES[b]+b) % PRIMES_SIZE]*(unsigned long long)(sen[target-b]+1);
feature_hashes[order] = feature_hashes[order] % (maxent_hash_size - vocab_size);
}
}
......@@ -496,7 +496,7 @@ void *TrainModelThread(void *id) {
for(int c = 0; c < layer1_size; ++c) {
f += neu1[layer1_size*(target - 1) + c] * nnet.syn1[l2 + c];
}
for(int order = 0; order < maxent_order && target - order >= 1; ++order) {
for(int order = 0; order < maxent_order && target - order >= 0; ++order) {
f += nnet.synMaxent[feature_hashes[order] + vocab[word].point[d]];
}
#ifdef DEBUG
......@@ -518,7 +518,7 @@ void *TrainModelThread(void *id) {
for(int c = 0; c < layer1_size; ++c) {
nnet.syn1[l2 + c] += g_alpha * neu1[layer1_size*(target - 1) + c] - beta * nnet.syn1[l2 + c];
}
for(int order = 0; order < maxent_order && target - order >= 1; ++order) {
for(int order = 0; order < maxent_order && target - order >= 0; ++order) {
nnet.synMaxent[feature_hashes[order] + vocab[word].point[d]] += g_maxentalpha - maxent_beta * nnet.synMaxent[feature_hashes[order] + vocab[word].point[d]];
}
}
......@@ -590,7 +590,7 @@ real EvaluateModel(char* filename, int printLoglikes) {
sen[0] = 0;
int good = 1;
sentence_length = 1;
while(sentence_length <= MAX_SENTENCE_LENGTH) {
while(sentence_length < MAX_SENTENCE_LENGTH) {
word = ReadWordIndex(fi);
sen[sentence_length] = word;
if (feof(fi) || word == 0) break;
......@@ -625,9 +625,9 @@ real EvaluateModel(char* filename, int printLoglikes) {
word = sen[target];
long long feature_hashes[MAX_NGRAM_ORDER] = {0};
if(maxent_order) {
for(int order = 0; order < maxent_order && target - order >= 1; ++order) {
for(int order = 0; order < maxent_order && target - order >= 0; ++order) {
feature_hashes[order] = PRIMES[0]*PRIMES[1];
for (int b = 1; b <= order; ++b) feature_hashes[order] += PRIMES[(order*PRIMES[b]+b) % PRIMES_SIZE]*(unsigned long long)(sen[target-1-b]+1);
for (int b = 1; b <= order; ++b) feature_hashes[order] += PRIMES[(order*PRIMES[b]+b) % PRIMES_SIZE]*(unsigned long long)(sen[target-b]+1);
feature_hashes[order] = feature_hashes[order] % (maxent_hash_size - vocab_size);
}
}
......@@ -639,7 +639,7 @@ real EvaluateModel(char* filename, int printLoglikes) {
for(int c = 0; c < layer1_size; ++c) {
f += neu1[layer1_size*(target - 1) + c] * nnet.syn1[l2 + c];
}
for(int order = 0; order < maxent_order && target - order >= 1; ++order) {
for(int order = 0; order < maxent_order && target - order >= 0; ++order) {
f += nnet.synMaxent[feature_hashes[order] + vocab[word].point[d]];
}
logprob += log10(1+(vocab[word].code[d] == 1 ? exp(f) : exp(-f)));
......@@ -939,7 +939,7 @@ int main(int argc, char **argv) {
printf("\t\tStop training iff N retries with halving learning rate have failed (default 2)\n");
printf("\t-debug <int>\n");
printf("\t\tSet the debug mode (default = 2 = more info during training)\n");
printf("\t-direct-size <int>\n");
printf("\t-direct <int>\n");
printf("\t\tSet the size of hash for maxent parameters, in millions (default 0 = maxent off)\n");
printf("\t-direct-order <int>\n");
printf("\t\tSet the order of n-gram features to be used in maxent (default 3)\n");
......@@ -974,7 +974,7 @@ int main(int argc, char **argv) {
}
if ((i = ArgPos((char *)"-threads", argc, argv)) > 0) num_threads = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-min-count", argc, argv)) > 0) min_count = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-direct-size", argc, argv)) > 0) maxent_hash_size = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-direct", argc, argv)) > 0) maxent_hash_size = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-direct-order", argc, argv)) > 0) maxent_order = atoi(argv[i + 1]);
if ((i = ArgPos((char *)"-beta1", argc, argv)) > 0) beta = atof(argv[i + 1]);
if ((i = ArgPos((char *)"-beta2", argc, argv)) > 0) maxent_beta = atof(argv[i + 1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment