[tune] PBT + Memnn example (#5723)

* Add example file

* Move into train function

* Somewhat working example of MemNN, still has some failed trials

* Reorganize into a class

* Small fixes

* Iteration decrease and fix hyperparam_mutations

* Add example file

* Move into train function

* Somewhat working example of MemNN, still has some failed trials

* Reorganize into a class

* Small fixes

* Iteration decrease and fix hyperparam_mutations

* Some style edits

* Address PR changes without modifying learning rate

* Add configs and hyperparameter mutations

* Add tune test

* Modify import locations

* Some parameter changes for testing

* Update memnn example

* Add tensorboard support and address PR comment

* Final changes

* lint

* generator
This commit is contained in:
Anthony Yu 2019-10-05 09:22:37 -07:00 committed by Richard Liaw
parent fb33160df8
commit b99cdf4e39
4 changed files with 291 additions and 0 deletions

View file

@ -125,6 +125,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
python /ray/python/ray/tune/examples/skopt_example.py \
--smoke-test
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
python /ray/python/ray/tune/examples/memnn_example.py \
--smoke-test
# uncomment once statsmodels is updated.
# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
# python /ray/python/ray/tune/examples/bohb_example.py \

View file

@ -16,6 +16,7 @@ General Examples
- `pbt_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py>`__: Example of using a Trainable class with PopulationBasedTraining scheduler.
- `pbt_ppo_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler.
- `logging_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/logging_example.py>`__: Example of custom loggers and custom trial directory naming.
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
Search Algorithm Examples
-------------------------

View file

@ -16,6 +16,7 @@ General Examples
- `pbt_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py>`__: Example of using a Trainable class with PopulationBasedTraining scheduler.
- `pbt_ppo_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler.
- `logging_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/logging_example.py>`__: Example of custom loggers and custom trial directory naming.
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
Search Algorithm Examples
-------------------------

View file

@ -0,0 +1,285 @@
"""Example training a memory neural net on the bAbI dataset.
References Keras and is based off of https://keras.io/examples/babi_memnn/.
"""
from __future__ import print_function
from tensorflow.python.keras.models import Sequential, Model, load_model
from tensorflow.python.keras.layers.embeddings import Embedding
from tensorflow.python.keras.layers import (Input, Activation, Dense, Permute,
Dropout)
from tensorflow.python.keras.layers import add, dot, concatenate
from tensorflow.python.keras.layers import LSTM
from tensorflow.python.keras.optimizers import RMSprop
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
from ray.tune import Trainable
import argparse
import tarfile
import numpy as np
import re
def tokenize(sent):
"""Return the tokens of a sentence including punctuation.
>>> tokenize("Bob dropped the apple. Where is the apple?")
["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"]
"""
return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()]
def parse_stories(lines, only_supporting=False):
"""Parse stories provided in the bAbi tasks format
If only_supporting is true, only the sentences
that support the answer are kept.
"""
data = []
story = []
for line in lines:
line = line.decode("utf-8").strip()
nid, line = line.split(" ", 1)
nid = int(nid)
if nid == 1:
story = []
if "\t" in line:
q, a, supporting = line.split("\t")
q = tokenize(q)
if only_supporting:
# Only select the related substory
supporting = map(int, supporting.split())
substory = [story[i - 1] for i in supporting]
else:
# Provide all the substories
substory = [x for x in story if x]
data.append((substory, q, a))
story.append("")
else:
sent = tokenize(line)
story.append(sent)
return data
def get_stories(f, only_supporting=False, max_length=None):
"""Given a file name, read the file,
retrieve the stories,
and then convert the sentences into a single story.
If max_length is supplied,
any stories longer than max_length tokens will be discarded.
"""
def flatten(data):
return sum(data, [])
data = parse_stories(f.readlines(), only_supporting=only_supporting)
data = [(flatten(story), q, answer) for story, q, answer in data
if not max_length or len(flatten(story)) < max_length]
return data
def vectorize_stories(word_idx, story_maxlen, query_maxlen, data):
inputs, queries, answers = [], [], []
for story, query, answer in data:
inputs.append([word_idx[w] for w in story])
queries.append([word_idx[w] for w in query])
answers.append(word_idx[answer])
return (pad_sequences(inputs, maxlen=story_maxlen),
pad_sequences(queries, maxlen=query_maxlen), np.array(answers))
def read_data():
# Get the file
try:
path = get_file(
"babi-tasks-v1-2.tar.gz",
origin="https://s3.amazonaws.com/text-datasets/"
"babi_tasks_1-20_v1-2.tar.gz")
except Exception:
print(
"Error downloading dataset, please download it manually:\n"
"$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2" # noqa: E501
".tar.gz\n"
"$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz" # noqa: E501
)
raise
# Choose challenge
challenges = {
# QA1 with 10,000 samples
"single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_"
"single-supporting-fact_{}.txt",
# QA2 with 10,000 samples
"two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_"
"two-supporting-facts_{}.txt",
}
challenge_type = "single_supporting_fact_10k"
challenge = challenges[challenge_type]
with tarfile.open(path) as tar:
train_stories = get_stories(tar.extractfile(challenge.format("train")))
test_stories = get_stories(tar.extractfile(challenge.format("test")))
return train_stories, test_stories
class MemNNModel(Trainable):
def build_model(self):
"""Helper method for creating the model"""
vocab = set()
for story, q, answer in self.train_stories + self.test_stories:
vocab |= set(story + q + [answer])
vocab = sorted(vocab)
# Reserve 0 for masking via pad_sequences
vocab_size = len(vocab) + 1
story_maxlen = max(
len(x) for x, _, _ in self.train_stories + self.test_stories)
query_maxlen = max(
len(x) for _, x, _ in self.train_stories + self.test_stories)
word_idx = {c: i + 1 for i, c in enumerate(vocab)}
self.inputs_train, self.queries_train, self.answers_train = (
vectorize_stories(word_idx, story_maxlen, query_maxlen,
self.train_stories))
self.inputs_test, self.queries_test, self.answers_test = (
vectorize_stories(word_idx, story_maxlen, query_maxlen,
self.test_stories))
# placeholders
input_sequence = Input((story_maxlen, ))
question = Input((query_maxlen, ))
# encoders
# embed the input sequence into a sequence of vectors
input_encoder_m = Sequential()
input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
input_encoder_m.add(Dropout(self.config.get("dropout", 0.3)))
# output: (samples, story_maxlen, embedding_dim)
# embed the input into a sequence of vectors of size query_maxlen
input_encoder_c = Sequential()
input_encoder_c.add(
Embedding(input_dim=vocab_size, output_dim=query_maxlen))
input_encoder_c.add(Dropout(self.config.get("dropout", 0.3)))
# output: (samples, story_maxlen, query_maxlen)
# embed the question into a sequence of vectors
question_encoder = Sequential()
question_encoder.add(
Embedding(
input_dim=vocab_size, output_dim=64,
input_length=query_maxlen))
question_encoder.add(Dropout(self.config.get("dropout", 0.3)))
# output: (samples, query_maxlen, embedding_dim)
# encode input sequence and questions (which are indices)
# to sequences of dense vectors
input_encoded_m = input_encoder_m(input_sequence)
input_encoded_c = input_encoder_c(input_sequence)
question_encoded = question_encoder(question)
# compute a "match" between the first input vector sequence
# and the question vector sequence
# shape: `(samples, story_maxlen, query_maxlen)`
match = dot([input_encoded_m, question_encoded], axes=(2, 2))
match = Activation("softmax")(match)
# add the match matrix with the second input vector sequence
response = add(
[match, input_encoded_c]) # (samples, story_maxlen, query_maxlen)
response = Permute(
(2, 1))(response) # (samples, query_maxlen, story_maxlen)
# concatenate the match matrix with the question vector sequence
answer = concatenate([response, question_encoded])
# the original paper uses a matrix multiplication.
# we choose to use a RNN instead.
answer = LSTM(32)(answer) # (samples, 32)
# one regularization layer -- more would probably be needed.
answer = Dropout(self.config.get("dropout", 0.3))(answer)
answer = Dense(vocab_size)(answer) # (samples, vocab_size)
# we output a probability distribution over the vocabulary
answer = Activation("softmax")(answer)
# build the final model
model = Model([input_sequence, question], answer)
return model
def _setup(self, config):
self.train_stories, self.test_stories = read_data()
model = self.build_model()
rmsprop = RMSprop(
lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9))
model.compile(
optimizer=rmsprop,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
self.model = model
def _train(self):
# train
self.model.fit(
[self.inputs_train, self.queries_train],
self.answers_train,
batch_size=self.config.get("batch_size", 32),
epochs=self.config.get("epochs", 1),
validation_data=([self.inputs_test, self.queries_test],
self.answers_test),
verbose=0)
_, accuracy = self.model.evaluate(
[self.inputs_train, self.queries_train],
self.answers_train,
verbose=0)
return {"mean_accuracy": accuracy}
def _save(self, checkpoint_dir):
file_path = checkpoint_dir + "/model"
self.model.save(file_path)
return file_path
def _restore(self, path):
# See https://stackoverflow.com/a/42763323
del self.model
self.model = load_model(path)
if __name__ == "__main__":
import ray
from ray.tune import Trainable, run
from ray.tune.schedulers import PopulationBasedTraining
parser = argparse.ArgumentParser()
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
args, _ = parser.parse_known_args()
ray.init()
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=5,
hyperparam_mutations={
"dropout": lambda: np.random.uniform(0, 1),
"lr": lambda: 10**np.random.randint(-10, 0),
"rho": lambda: np.random.uniform(0, 1)
})
results = run(
MemNNModel,
name="pbt_babi_memnn",
scheduler=pbt,
stop={"training_iteration": 20 if args.smoke_test else 100},
num_samples=4,
config={
"batch_size": 32,
"epochs": 1,
"dropout": 0.3,
"lr": 0.01,
"rho": 0.9
})