SOS_token = 0 EOS_token = 1 classLang: def__init__(self, name): self.name = name self.word2index = {} self.word2count = {} self.index2word = {0: "SOS", 1: "EOS"} self.n_words = 2# Count SOS and EOS
defaddSentence(self, sentence): for word in sentence.split(' '): self.addWord(word)
defaddWord(self, word): if word notin self.word2index: self.word2index[word] = self.n_words self.word2count[word] = 1 self.index2word[self.n_words] = word self.n_words += 1 else: self.word2count[word] += 1
1 2 3 4 5 6 7 8
defnormalizeString(s): t = s s = s.lower().strip() s = re.sub(r"([.!?])", r" \1", s) s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) if len(s.replace(' ','')): # ascii文字以外 return s return t
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defreadLangs(lang1, lang2): print("Reading lines...") with open(lang1) as f: lines1 = f.readlines() with open(lang2) as f: lines2 = f.readlines() pairs = [] for l1,l2 in zip(lines1,lines2): l1 = normalizeString(l1.rstrip('\n')) l2 = normalizeString(l2.rstrip('\n')) pairs.append([l1,l2]) input_lang = Lang(lang1) output_lang = Lang(lang2) return input_lang, output_lang, pairs
1 2 3 4 5
MAX_LENGTH = 10 deffilterPair(p): return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH deffilterPairs(pairs): return [pair for pair in pairs if filterPair(pair)]
if use_teacher_forcing: # Teacher forcing: Feed the target as the next input for di in range(target_length): decoder_output, decoder_hidden, decoder_attention = decoder( decoder_input, decoder_hidden, encoder_outputs) loss += criterion(decoder_output, target_tensor[di]) decoder_input = target_tensor[di] # Teacher forcing
else: # Without teacher forcing: use its own predictions as the next input for di in range(target_length): decoder_output, decoder_hidden, decoder_attention = decoder( decoder_input, decoder_hidden, encoder_outputs) topv, topi = decoder_output.topk(1) decoder_input = topi.squeeze().detach() # detach from history as input
loss += criterion(decoder_output, target_tensor[di]) if decoder_input.item() == EOS_token: break
loss.backward()
encoder_optimizer.step() decoder_optimizer.step()
return loss.item() / target_length
1 2 3 4 5 6 7 8 9 10 11 12 13 14
import time import math
defasMinutes(s): m = math.floor(s / 60) s -= m * 60 return'%dm %ds' % (m, s)
deftimeSince(since, percent): now = time.time() s = now - since es = s / (percent) rs = es - s return'%s (- %s)' % (asMinutes(s), asMinutes(rs))
encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate) decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate) training_pairs = [tensorsFromPair(random.choice(pairs)) for i in range(n_iters)] criterion = nn.NLLLoss()
for iter in range(1, n_iters + 1): training_pair = training_pairs[iter - 1] input_tensor = training_pair[0] target_tensor = training_pair[1]
loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion) print_loss_total += loss plot_loss_total += loss
if iter % print_every == 0: print_loss_avg = print_loss_total / print_every print_loss_total = 0 print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters), iter, iter / n_iters * 100, print_loss_avg))
if iter % plot_every == 0: plot_loss_avg = plot_loss_total / plot_every plot_losses.append(plot_loss_avg) plot_loss_total = 0
showPlot(plot_losses)
1 2 3 4 5 6 7 8 9 10 11 12 13
import matplotlib.pyplot as plt plt.switch_backend('agg') import matplotlib.ticker as ticker import numpy as np
defshowPlot(points): plt.figure() fig, ax = plt.subplots() # this locator puts ticks at regular intervals loc = ticker.MultipleLocator(base=0.2) ax.yaxis.set_major_locator(loc) plt.plot(points)