Commit c769e352 authored by Abdelwahab HEBA's avatar Abdelwahab HEBA
Browse files

Add WSJ dataloader & End-to-End Asr system with ctc based loss function

parent a5125bad
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch import optim\n",
"import torch.nn.functional as F\n",
"import datetime\n",
"#import gc\n",
"#from warpctc_pytorch import CTCLoss\n",
"\n",
"USE_CUDA = torch.cuda.is_available()\n",
"device = torch.device(\"cuda:0\" if USE_CUDA else \"cpu\")\n",
"print(torch.cuda.get_device_name(0))\n",
"torch.multiprocessing.set_sharing_strategy('file_system')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run WSJ_dataset.ipynb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
" def __init__(self, input_size, hidden_size, output_size, n_layers=1):#, dropout=0.1):\n",
" super(EncoderRNN, self).__init__()\n",
" # Keep for reference\n",
" self.input_size = input_size\n",
" self.hidden_size = hidden_size\n",
" self.output_size = output_size\n",
" self.n_layers = n_layers\n",
" #self.dropout = dropout\n",
" # Define layers\n",
" self.gru = nn.GRU(input_size, hidden_size, n_layers,\n",
" dropout=(0 if n_layers==1 else dropout),\n",
" bidirectional=True)\n",
" self.out = nn.Linear(hidden_size*2, output_size)\n",
" \n",
" def forward(self, input_seq, hidden=None):\n",
" #packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths)\n",
" # Forward pass through GRU\n",
" outputs, _ = self.gru(input_seq, hidden)\n",
" outputs = self.out(outputs)\n",
" outputs = F.log_softmax(outputs,dim=2)\n",
" return outputs\n",
" #return outputs, F.log_softmax(outputs, dim=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(input_variable, lengths, target_variable, target_lengths, \n",
" encoder, encoder_optimizer, clip, ctc_loss):\n",
" # Zero gradients\n",
" encoder_optimizer.zero_grad()\n",
" \n",
" # Set device options\n",
" input_variable = input_variable.to(device)\n",
" lengths = lengths.to(device)\n",
" target_variable = target_variable.to(device)\n",
" target_lengths = target_lengths.to(device)\n",
" \n",
" ## CTC Torch\n",
" encoder_outputs = encoder(input_variable)\n",
" loss = ctc_loss(encoder_outputs, target_variable.t(), lengths, target_lengths)\n",
" \n",
" loss.backward()\n",
"\n",
" # Clip gradients: gradients are modified in place\n",
" torch.nn.utils.clip_grad_norm_(encoder.parameters(),clip)\n",
" \n",
" # Adjust model weights\n",
" encoder_optimizer.step()\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def trainIters(model_name, dataloader_training, encoder, encoder_optimizer, \n",
" encoder_n_layers, save_dir, n_iteration, print_every, save_every,\n",
" clip, corpus_name, loadFilename):\n",
" # Load batches for each iteration\n",
" training_batches = [batch_sample for batch_sample in dataloader_training]\n",
" # Initializations\n",
" print('Initializing ...')\n",
" start_iteration = 1\n",
" print_loss = 0\n",
" if loadFilename:\n",
" start_iteration = checkpoint['iteration']+1\n",
" \n",
" # Training loop\n",
" print('Training ...')\n",
" ctc_loss = nn.CTCLoss(reduction='mean')\n",
" for iteration in range(start_iteration, n_iteration + 1):\n",
" # Extract fields from batch\n",
" for training_batch in training_batches:\n",
" input_variable,lengths, _, _, _,target_variable,target_lengths, _ = training_batch\n",
" \n",
" # Run a training iteration with batch\n",
" loss = train(input_variable,lengths, target_variable, target_lengths, encoder,\n",
" encoder_optimizer, clip, ctc_loss)\n",
" print_loss +=loss.item()\n",
" \n",
" # Print progress\n",
" if iteration % print_every == 0:\n",
" print_loss_avg = print_loss/print_every\n",
" print(\"Iteration:{}, Percent complete: {:.1f}%; Average loss: {:.4f}\".format(iteration,iteration/n_iteration * 100, print_loss_avg))\n",
" print_loss=0\n",
" \n",
" # Save checkpoint\n",
" if (iteration % save_every ==0):\n",
" directory = os.path.join(save_dir, model_name, corpus_name, '{}_{}'.format(encoder_n_layers,hidden_size))\n",
" if not os.path.exists(directory):\n",
" os.makedirs(directory)\n",
" torch.save({\n",
" 'iteration': iteration,\n",
" 'en' : encoder.state_dict(),\n",
" 'en_opt' : encoder_optimizer.state_dict(),\n",
" 'loss' : loss,\n",
" }, os.path.join(directory, '{}_{}.tar'.format(iteration,'checkpoint_ctc')))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configure models\n",
"model_name = 'ctc_model'\n",
"input_size = 120\n",
"hidden_size = 300\n",
"output_size = set_train.voc.num_chars\n",
"print(output_size)\n",
"encoder_n_layers = 3\n",
"dropout = 0\n",
"\n",
"# Set checkpoint to load from; set to None if starting from scratch\n",
"loadFilename = None\n",
"checkpoint_iter = 25\n",
"save_dir='save'\n",
"corpus_name= 'WSJ'\n",
"#loadFilename = os.path.join(save_dir, model_name, corpus_name,\n",
"# '{}_{}'.format(encoder_n_layers, hidden_size),\n",
"# '{}_checkpoint_ctc.tar'.format(checkpoint_iter))\n",
"print(loadFilename)\n",
"# Load model if a loadFilename is provided\n",
"if loadFilename:\n",
" # If loading on same machine the model was trained on\n",
" checkpoint = torch.load(loadFilename)\n",
" # If loading a model trained on GPU to CPU\n",
" # checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))\n",
" encoder_sd = checkpoint['en']\n",
" encoder_optimizer_sd = checkpoint['en_opt']\n",
"print('Building encoder and decoder ...')\n",
"\n",
"# Initialize encoder & decoder models\n",
"encoder = EncoderRNN(input_size, hidden_size, output_size, encoder_n_layers)\n",
"#encoder.register_backward_hook(encoder.backward_hook)\n",
"if loadFilename:\n",
" encoder.load_state_dict(encoder_sd)\n",
"# Use appropriate device\n",
"encoder = encoder.to(device)\n",
"print('Models built and ready to go!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clip = 50.0\n",
"learning_rate = 0.0001\n",
"n_iteration = 50\n",
"print_every = 1\n",
"save_every = 5\n",
"save_dir = 'save'\n",
"corpus_name= 'WSJ'\n",
"# Ensure dropout layers are in train mode\n",
"encoder.train()\n",
"\n",
"# Initialize optimizers\n",
"print('Building optimizers ...')\n",
"encoder_optimizer = optim.Adam(encoder.parameters(), lr = learning_rate)\n",
"\n",
"if loadFilename:\n",
" encoder_optimizer.load_state_dict(encoder_optimizer_sd)\n",
"\n",
"# Run training iterations\n",
"time_begin=datetime.datetime.now()\n",
"trainIters(model_name, dataloader_set_train, encoder, encoder_optimizer, \n",
" encoder_n_layers,\n",
" save_dir, n_iteration, print_every, save_every, clip, corpus_name, loadFilename)\n",
"time_end=datetime.datetime.now()\n",
"print(time_end-time_begin)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Time CTC Torch: 13:20:07.989254 pour 25 iteration\n",
"# Loss it1: 3 => loss it25: 0.415"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class GreedySearchDecoder(nn.Module):\n",
" def __init__(self, encoder):\n",
" super(GreedySearchDecoder, self).__init__()\n",
" self.encoder = encoder\n",
" \n",
" def forward(self, input_seq, input_length, max_length):\n",
" # Forward input through encoder model\n",
" encoder_output = self.encoder(input_seq)\n",
" # Initialize tensors to append decoded words to\n",
" all_tokens = torch.zeros([0], device= device, dtype = torch.long)\n",
" all_scores = torch.zeros([0], device= device)\n",
" # Iteratively decode one word token at a time\n",
" for i in range(encoder_output.size(0)):\n",
" # Obtain most likely word token and its softmax score\n",
" encoder_scores, encoder_input = torch.max(encoder_output[i,:],dim=1)\n",
" # Record token and score\n",
" all_tokens = torch.cat((all_tokens,encoder_input), dim=0)\n",
" all_scores = torch.cat((all_scores,encoder_scores), dim=0)\n",
" # Return collections of word tokens and scores\n",
" return all_tokens, all_scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def evaluate(encoder, searcher, voc, sentence, max_length):\n",
" # Format input sentence as a batch\n",
" # Words -> indexes\n",
" input_batch = input_sentence\n",
" # Create lengths tensor\n",
" # Transpose dimensions of batch to match models' expectations\n",
" #input_batch = torch.LongTensor(indexes_batch).transpose(0,1)\n",
" \n",
" # Use appropriate device\n",
" input_batch = input_batch.to(device)\n",
" lengths =[]\n",
" # Decode sentence with searcher\n",
" tokens, scores = searcher(input_batch,lengths,max_length)\n",
" print(scores)\n",
" # indexes -> words\n",
" decoded_words = [voc.index2char[token.item()] for token in tokens if token!=0]\n",
" return decoded_words\n",
"\n",
"def evaluateInput(input_sentence, target, encoder, searcher, voc):\n",
" #input_sentence = input_sentence.squeeze(0)\n",
" #print(input_sentence.shape)\n",
" output_words = evaluate(encoder, searcher, voc, input_sentence,target.size(0))\n",
" # Format and print response sentence\n",
" output_words[:] = [x for x in output_words]\n",
" print('Transcription:', ' '.join(output_words))\n",
" Truth_words = [voc.index2char[x.item()] for x in target]\n",
" print('Truth:', ' '.join(Truth_words))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Set dropout layers to eval mode\n",
"encoder.eval()\n",
"# Initialize search module\n",
"searcher = GreedySearchDecoder(encoder)\n",
"\n",
"# Begin chatting (uncomment and run the following line to begin)\n",
"i=500\n",
"for j,batch in enumerate(dataloader_set_train):\n",
" if j == i:\n",
" feats, len_feat, targets, len_t , mask_target= batch \n",
" break\n",
" \n",
"#feats, len_feat, targets, len_t , mask_targets = iter(dataloader_set_train).__next__()\n",
"\n",
"input_sentence=feats[:,30,:].unsqueeze(1)\n",
"print(input_sentence.shape)\n",
"target=targets[:,30]\n",
"print(target.shape)\n",
"evaluateInput(input_sentence, target, encoder, searcher, dataloader_set_train.dataset.voc)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import print_function\n",
"import torch.utils.data as data\n",
"import os\n",
"import os.path\n",
"import shutil\n",
"import errno\n",
"import torch\n",
"import kaldi_io\n",
"import sys\n",
"import csv\n",
"import string\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Default word token\n",
"PAD_token = 0\n",
"SOS_token = 1\n",
"EOS_token = 2\n",
"# Default char token\n",
"BLANK_token = 0\n",
"SPACE_token = 1\n",
"\n",
"class WSJ(data.Dataset):\n",
" \"\"\" 'English data WSJ'\n",
" Args:\n",
" subset_wsj (string): Root Directory of dataset were preprocessed WSJ data with kaldi was made. train_si84 and test_93\n",
" \"\"\"\n",
" def __init__(self,path_text,path_feat,voc,transform=None):#, target_transform=None, dev_mode=False):\n",
" self.path_text = os.path.abspath(path_text)\n",
" self.path_feat = os.path.abspath(path_feat)\n",
" self.transform = transform\n",
" self.REMOVE_WORD_LIST=[ \"!EXCLAMATION-POINT\",\"\\\"CLOSE-QUOTE\",\"\\\"DOUBLE-QUOTE\",\"\\\"END-OF-QUOTE\",\"\\\"END-QUOTE\",\"\\\"IN-QUOTES\",\"\\\"QUOTE\",\"\\\"UNQUOTE\",\"#SHARP-SIGN\",\n",
" \"%PERCENT\",\"&AMPERSAND\",\"'BOUT\",\"'CAUSE\", \"'COURSE\", \"'CUSE\", \"'EM\", \"'END-INNER-QUOTE\", \"'END-QUOTE\", \"'FRISCO\", \"'GAIN\", \"'INNER-QUOTE\",\n",
" \"'KAY\", \"'M\", \"'N\",\"'QUOTE\", \"'ROUND\", \"'S\", \"'SINGLE-QUOTE\", \"'TIL\", \"'TIS\", \"'TWAS\", \"(BEGIN-PARENS\", \"(IN-PARENTHESES\", \"(LEFT-PAREN\",\n",
" \"(OPEN-PARENTHESES\", \"(PAREN\", \"(PARENS\", \"(PARENTHESES\", \")CLOSE-PAREN\", \")CLOSE_PAREN\", \")CLOSE-PARENTHESES\", \")END-PAREN\", \")END-PARENS\", \")END-PARENTHESES\",\n",
" \")END-THE-PAREN\", \")PAREN\", \")PARENS\", \")RIGHT-PAREN\", \")UN-PARENTHESES\", \",COMMA\", \"-DASH\", \"-HYPHEN\" , \"...ELLIPSIS\", \".DECIMAL\", \".DOT\",\n",
" \".FULL-STOP\", \".PERIOD\", \".POINT\", \"/SLASH\", \":COLON\", \";SEMI-COLON\", \"<NOISE>\", \"<SPOKEN_NOISE>\", \"<UNK>\", \"?QUESTION-MARK\",\"{BRACE\",\n",
" \"{LEFT-BRACE\", \"{OPEN-BRACE\", \"}CLOSE-BRACE\", \"}RIGHT-BRACE\" ]\n",
" #self.target_transform = target_transform\n",
" self.target_dict = {}\n",
" self.target_dict_unsorted = {}\n",
" self.feat_dict = {}\n",
" self.voc = voc\n",
" if not self._check_exists():\n",
" raise RuntimeError('Dataset not found.' + ' You can check the directory given.')\n",
" \n",
" # Read Feats and labels with kaldi-io-python\n",
" # Read Targets\n",
" with open(self.path_text) as file:\n",
" textreader = file.readlines()\n",
" for row in textreader:\n",
" row = row.split('\\n')\n",
" row = row[0].split(' ')\n",
" list_words=self._get_words(row[1:])\n",
" if len(list_words)!= 0 :\n",
" self.target_dict_unsorted[row[0]]=list_words\n",
" #self.feat_dict = {key:torch.from_numpy(mat) for key,mat in kaldi_io.read_mat_scp(self.path_feat)}\n",
" for key, mat in kaldi_io.read_mat_scp(self.path_feat):\n",
" if str(key) in list(self.target_dict_unsorted.keys()):\n",
" self.feat_dict[key] = torch.from_numpy(mat)\n",
" self.target_dict[key] = self.target_dict_unsorted[key]\n",
" \n",
" # Check dict and on data => \n",
" # Shell command cat data/train_si284/text | awk '{$1=\"\"; print $0}' | grep -o . | sort | uniq -c\n",
" # [\"`\",\"~\",\"<\", \">\",\"_\", \"-\", \",\", \";\", \":\", \"!\", \"?\", \"/\", \".\", \"'\", '\"', \"(\", \")\", \"{\", \"}\",\"*\", \"&\"]\n",
" def _get_words(self,row):\n",
" list_words=[]\n",
" for word in row:\n",
" if word in self.REMOVE_WORD_LIST:\n",
" continue\n",
" word_corrected = re.sub(r\"\\(|\\)|\\*|\\.|~+|:|!|\\?|;|\\\"|&|-\",'',word)\n",
" word_corrected = re.sub(r\"`\",\"'\",word_corrected)\n",
" if len(word_corrected)!=0:\n",
" list_words.append(word_corrected)\n",
" return list_words\n",
" \n",
" def _check_exists(self):\n",
" return os.path.exists(self.path_feat) and os.path.exists(self.path_text)\n",
" \n",
" def _convertTarget2Id(self,sentence):\n",
" return torch.tensor([self.voc.word2index[word] for word in sentence],dtype=torch.long)\n",
" \n",
" def _convertTarget2Phn(self,sentence):\n",
" target=[]\n",
" for word in sentence:\n",
" for phn in self.voc.dict_phn[word]:\n",
" target.append(phn)\n",
" #target.append(SPACE_token)\n",
" return torch.tensor(target,dtype=torch.uint8)\n",
" def _convertTarget2Char(self,sentence):\n",
" target=[]\n",
" for word in sentence:\n",
" # Rajouter une fonction qui le fait automatiquement par la suite\n",
" if word == 'MR':\n",
" word = 'MISTER'\n",
" if word == 'MRS':\n",
" word = 'MISIS'\n",
" if word == 'MS':\n",
" word = 'MISS'\n",
" for c in word:\n",
" target.append(self.voc.char2index[c])\n",
" target.append(SPACE_token)\n",
" return torch.tensor(target[:-1],dtype=torch.uint8)\n",
" \n",
" def __getitem__(self,index):\n",
" \"\"\"\n",
" Args:\n",
" index (int): Index\n",
" char (bool): use character based\n",
" Returns:\n",
" tuple: (feat- 2 dimensions, target, targetchar) where the target is sequence of word index\n",
" and targetchar is sequence of char index\n",
" \"\"\"\n",
" # Explore further! \n",
" if self.transform:\n",
" self.target_dict= self.transform(self.target_dict)\n",
"\n",
" ###### Get key for the index\n",
" key = list(self.feat_dict.keys())[index]\n",
" return {'key':key,\n",
" 'feat':self.feat_dict[key], \n",
" 'target':self._convertTarget2Id(self.target_dict[key]),\n",
" 'targetchar':self._convertTarget2Char(self.target_dict[key])}#,\n",
" #'targetphn':self._convertTarget2Phn(self.target_dict[key])}\n",
" \n",
" def __len__(self):\n",
" return len(self.target_dict.items())\n",
" \n",
" def collate_fn(self,data):\n",
" \"\"\"Creates mini-batch tensors from the list of tuples (features, target).\n",
" \n",
" We should build custom collate_fn rather than using default collate_fn, \n",
" because merging target (including padding) is not supported in default.\n",
" Args:\n",
" data: list of tuple (features, target). \n",
" - features: torch tensor of shape (?, N) N dimensional features; variable length.\n",
" - target: torch tensor of shape (?); variable length.\n",
" Returns:\n",
" features: torch tensor of shape (batch_size, padded_length, N).\n",
" targets: torch tensor of shape (batch_size, padded_length).\n",
" Masks: list; valid length for each padded caption.\n",
" \"\"\"\n",
" # data is already sorted by feat-to-len from kaldi\n",
" features = []\n",
" targets = []\n",
" targets_char = []\n",
" #targets_phn = []\n",
" for d in data:\n",
" features.append(d['feat'])\n",
" targets.append(d['target'])\n",
" targets_char.append(d['targetchar'])\n",
" #targets_phn.append(d['targetphn'])\n",
" #features, targets, targets_char = [d['feat'] for d in data], [d['target'] for d in data], [d['targetchar'] for d in data]\n",
" \n",
" # Merge features (from matrix of 2D tensor to 3D tensor).\n",
" #features = torch.stack(features, 0)\n",
" # Lengths features and targets\n",
" lengths_features = torch.tensor([len(feature_tensor) for feature_tensor in features],dtype=torch.int32)\n",
" lengths_target = torch.tensor([len(target) for target in targets],dtype=torch.int32)\n",
" lengths_target_char = torch.tensor([len(target) for target in targets_char], dtype = torch.int32)\n",
" #lengths_target_phn = torch.tensor([len(target) for target in targets_phn], dtype = torch.int32)\n",
" \n",
" # Generate features\n",
" final_features = torch.zeros(max(lengths_features),len(features),features[0].size(1))\n",
" #final_features = torch.zeros(max_len,len(features),features[0].size(1))\n",
" for i, feature in enumerate(features):\n",
" final_features[:feature.size(0),i,:] = feature\n",
" \n",
" # Generate Labels\n",
" # torch.int32 for CTCloss function\n",
" # torch.long for Seq2seq function\n",
" # FOR WORD\n",
" final_targets = torch.zeros(max(lengths_target), len(targets),dtype=torch.long)\n",
" mask_targets= torch.ones(max(lengths_target), len(targets),dtype=torch.uint8)\n",
" for i, target in enumerate(targets):\n",
" end = lengths_target[i]\n",
" final_targets[:end,i] = target[:]\n",
" mask_targets[end:,i] = 0\n",
" \n",
" # FOR CHAR\n",
" final_targets_char = torch.zeros(max(lengths_target_char), len(targets_char),dtype=torch.int32)\n",
" mask_targets_char = torch.ones(max(lengths_target_char), len(targets_char),dtype=torch.uint8)\n",
" for i, target in enumerate(targets_char):\n",
" end = lengths_target_char[i]\n",
" final_targets_char[:end,i] = target[:]\n",
" mask_targets_char[end:,i] = 0\n",
" \n",
" \n",
" # FOR PHN\n",
" #final_targets_phn = torch.zeros(max(lengths_target_phn), len(targets_phn),dtype=torch.int32)\n",
" #mask_targets_phn = torch.ones(max(lengths_target_phn), len(targets_phn),dtype=torch.uint8)\n",
" #for i, target in enumerate(targets_phn):\n",
" # end = lengths_target_phn[i]\n",
" # final_targets_phn[:end,i] = target[:]\n",
" # mask_targets_phn[end:,i] = 0\n",
"\n",
" return final_features, lengths_features, final_targets, lengths_target , mask_targets, final_targets_char,lengths_target_char, mask_targets_char\n",
" #, final_targets_phn, lengths_target_phn, mask_targets_phn\n",
"\n",
"\n",
" def get_loader(self,batch_size, shuffle, num_workers):\n",
" \"\"\"Returns torch.utils.data.DataLoader for custom coco dataset.\"\"\"\n",
" \n",
" # Data loader for WSJ dataset\n",
" # This will return (features,lengths_features, targets, lengths_target) for each iteration.\n",
" # features: a tensor of shape (batch_size, padded_length_feature, N dim).\n",
" # lengths_features: a list indicating valid length for each feature matrix (T, N dim). length is (batch_size).\n",
" # targets: a tensor of shape (batch_size, padded_length_targets).\n",
" # lengths: a list indicating valid length for each target. length is (batch_size).\n",
" data_loader = torch.utils.data.DataLoader(dataset=self, \n",
" batch_size=batch_size,\n",
" shuffle=shuffle,\n",
" num_workers=num_workers,\n",
" collate_fn=self.collate_fn)\n",
" return data_loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Voc:\n",
" def __init__(self,name,dict_file,phn2ID_file):\n",
" self.name = name\n",
" self.REMOVE_WORD_LIST=[ \"!EXCLAMATION-POINT\",\"\\\"CLOSE-QUOTE\",\"\\\"DOUBLE-QUOTE\",\"\\\"END-OF-QUOTE\",\"\\\"END-QUOTE\",\"\\\"IN-QUOTES\",\"\\\"QUOTE\",\"\\\"UNQUOTE\",\"#SHARP-SIGN\",\n",
" \"%PERCENT\",\"&AMPERSAND\",\"'BOUT\",\"'CAUSE\", \"'COURSE\", \"'CUSE\", \"'EM\", \"'END-INNER-QUOTE\", \"'END-QUOTE\", \"'FRISCO\", \"'GAIN\", \"'INNER-QUOTE\",\n",
" \"'KAY\", \"'M\", \"'N\",\"'QUOTE\", \"'ROUND\", \"'S\", \"'SINGLE-QUOTE\", \"'TIL\", \"'TIS\", \"'TWAS\", \"(BEGIN-PARENS\", \"(IN-PARENTHESES\", \"(LEFT-PAREN\",\n",
" \"(OPEN-PARENTHESES\", \"(PAREN\", \"(PARENS\", \"(PARENTHESES\", \")CLOSE-PAREN\", \")CLOSE_PAREN\", \")CLOSE-PARENTHESES\", \")END-PAREN\", \")END-PARENS\", \")END-PARENTHESES\",\n",
" \")END-THE-PAREN\", \")PAREN\", \")PARENS\", \")RIGHT-PAREN\", \")UN-PARENTHESES\", \",COMMA\", \"-DASH\", \"-HYPHEN\" , \"...ELLIPSIS\", \".DECIMAL\", \".DOT\",\n",
" \".FULL-STOP\", \".PERIOD\", \".POINT\", \"/SLASH\", \":COLON\", \";SEMI-COLON\", \"<NOISE>\", \"<SPOKEN_NOISE>\", \"<UNK>\", \"?QUESTION-MARK\",\"{BRACE\",\n",
" \"{LEFT-BRACE\", \"{OPEN-BRACE\", \"}CLOSE-BRACE\", \"}RIGHT-BRACE\" ]\n",
" self.dict_phn, self.index2phn, self.phn2index = self.load_dict(os.path.abspath(dict_file),os.path.abspath(phn2ID_file))\n",
" self.trimmed = False\n",
" # Word list\n",
" self.word2index = {\"PAD\":PAD_token, \"SOS\":SOS_token, \"EOS\":EOS_token}\n",
" self.word2count = {\"PAD\":1, \"SOS\":1, \"EOS\":1}\n",
" self.index2word = {PAD_token:\"PAD\", SOS_token:\"SOS\", EOS_token:\"EOS\"}\n",
" self.num_words = 3 # Count PAD, SOS, EOS\n",
" # Char list\n",
" self.char2index = {\"_\" : BLANK_token, \" \":SPACE_token}\n",
" self.char2count = {}\n",
" self.index2char = {BLANK_token:\"_\", SPACE_token:\" \"}\n",
" self.num_chars = 2 # Count Blank & Space\n",
" \n",
" def load_dict(self,dict_file,phn2ID_file):\n",
" # ID 0 => Blank\n",
" # Without space\n",
" idx = 2\n",
" id2phn = {0:'_'}\n",
" phn2id = {'_':0}\n",
" dict_phn = {}\n",
" ####### MAP word to ID phonemes\n",
" dict_phn={}\n",
" with open(dict_file) as file:\n",
" textreader = file.readlines()\n",
" for row in textreader:\n",
" row = row.split('\\n')\n",
" row = row[0].split(' ')\n",
" if row[0] in self.REMOVE_WORD_LIST:\n",
" continue\n",
" dict_phn[row[0]]=[int(i)-idx for i in row[1:]]\n",
" ######## MAP ID to phonemes\n",
" with open(phn2ID_file) as file:\n",
" textreader = file.readlines()\n",
" for row in textreader[2:]:\n",
" row = row.split('\\n')[0].split(' ')\n",
" id2phn[int(row[1])-idx] = row[0]\n",
" phn2id[row[0]] = int(row[1])-idx\n",
" return dict_phn, id2phn, phn2id\n",
" \n",
" def addSentence(self,WSJ_subset):\n",
" # Build word2index matrix\n",
" for k,sentence in WSJ_subset.target_dict.items():\n",
" for word in sentence:\n",
" self._addWord(word)\n",
" \n",
" def _addWord(self,word):\n",
" if word not in self.word2index.keys():\n",
" self.word2index[word] = self.num_words\n",
" self.word2count[word] = 1\n",
" self.index2word[self.num_words] = word\n",
" self.num_words += 1\n",
" else:\n",
" self.word2count[word] += 1\n",
" for c in word:\n",
" if c not in self.char2index.keys():\n",
" self.char2index[c] = self.num_chars\n",
" self.char2count[c] = 1\n",
" self.index2char[self.num_chars] = c\n",
" self.num_chars += 1\n",
" else:\n",
" self.char2count[c] += 1\n",
"\n",
" # Remove words below a certain count threshold\n",
" def trim(self, min_count):\n",