From b99cdf4e391ccb858ffa8c772e144dc32d2569c6 Mon Sep 17 00:00:00 2001 From: Anthony Yu <13611707+anthonyhsyu@users.noreply.github.com> Date: Sat, 5 Oct 2019 09:22:37 -0700 Subject: [PATCH] [tune] PBT + Memnn example (#5723) * Add example file * Move into train function * Somewhat working example of MemNN, still has some failed trials * Reorganize into a class * Small fixes * Iteration decrease and fix hyperparam_mutations * Add example file * Move into train function * Somewhat working example of MemNN, still has some failed trials * Reorganize into a class * Small fixes * Iteration decrease and fix hyperparam_mutations * Some style edits * Address PR changes without modifying learning rate * Add configs and hyperparameter mutations * Add tune test * Modify import locations * Some parameter changes for testing * Update memnn example * Add tensorboard support and address PR comment * Final changes * lint * generator --- ci/jenkins_tests/run_tune_tests.sh | 4 + doc/source/tune-examples.rst | 1 + python/ray/tune/examples/README.rst | 1 + python/ray/tune/examples/pbt_memnn_example.py | 285 ++++++++++++++++++ 4 files changed, 291 insertions(+) create mode 100644 python/ray/tune/examples/pbt_memnn_example.py diff --git a/ci/jenkins_tests/run_tune_tests.sh b/ci/jenkins_tests/run_tune_tests.sh index a7899898d..e03ef1d49 100755 --- a/ci/jenkins_tests/run_tune_tests.sh +++ b/ci/jenkins_tests/run_tune_tests.sh @@ -125,6 +125,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} python /ray/python/ray/tune/examples/skopt_example.py \ --smoke-test +$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ + python /ray/python/ray/tune/examples/memnn_example.py \ + --smoke-test + # uncomment once statsmodels is updated. # $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \ # python /ray/python/ray/tune/examples/bohb_example.py \ diff --git a/doc/source/tune-examples.rst b/doc/source/tune-examples.rst index cb0a4c9d0..5fcd7035a 100644 --- a/doc/source/tune-examples.rst +++ b/doc/source/tune-examples.rst @@ -16,6 +16,7 @@ General Examples - `pbt_example `__: Example of using a Trainable class with PopulationBasedTraining scheduler. - `pbt_ppo_example `__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler. - `logging_example `__: Example of custom loggers and custom trial directory naming. +- `pbt_memnn_example `__: Example of training a Memory NN on bAbI with Keras using PBT. Search Algorithm Examples ------------------------- diff --git a/python/ray/tune/examples/README.rst b/python/ray/tune/examples/README.rst index d2f48cba0..68b214cdb 100644 --- a/python/ray/tune/examples/README.rst +++ b/python/ray/tune/examples/README.rst @@ -16,6 +16,7 @@ General Examples - `pbt_example `__: Example of using a Trainable class with PopulationBasedTraining scheduler. - `pbt_ppo_example `__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler. - `logging_example `__: Example of custom loggers and custom trial directory naming. +- `pbt_memnn_example `__: Example of training a Memory NN on bAbI with Keras using PBT. Search Algorithm Examples ------------------------- diff --git a/python/ray/tune/examples/pbt_memnn_example.py b/python/ray/tune/examples/pbt_memnn_example.py new file mode 100644 index 000000000..c9ab4cc1c --- /dev/null +++ b/python/ray/tune/examples/pbt_memnn_example.py @@ -0,0 +1,285 @@ +"""Example training a memory neural net on the bAbI dataset. + +References Keras and is based off of https://keras.io/examples/babi_memnn/. +""" + +from __future__ import print_function + +from tensorflow.python.keras.models import Sequential, Model, load_model +from tensorflow.python.keras.layers.embeddings import Embedding +from tensorflow.python.keras.layers import (Input, Activation, Dense, Permute, + Dropout) +from tensorflow.python.keras.layers import add, dot, concatenate +from tensorflow.python.keras.layers import LSTM +from tensorflow.python.keras.optimizers import RMSprop +from tensorflow.python.keras.utils.data_utils import get_file +from tensorflow.python.keras.preprocessing.sequence import pad_sequences +from ray.tune import Trainable +import argparse +import tarfile +import numpy as np +import re + + +def tokenize(sent): + """Return the tokens of a sentence including punctuation. + + >>> tokenize("Bob dropped the apple. Where is the apple?") + ["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"] + """ + return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()] + + +def parse_stories(lines, only_supporting=False): + """Parse stories provided in the bAbi tasks format + + If only_supporting is true, only the sentences + that support the answer are kept. + """ + data = [] + story = [] + for line in lines: + line = line.decode("utf-8").strip() + nid, line = line.split(" ", 1) + nid = int(nid) + if nid == 1: + story = [] + if "\t" in line: + q, a, supporting = line.split("\t") + q = tokenize(q) + if only_supporting: + # Only select the related substory + supporting = map(int, supporting.split()) + substory = [story[i - 1] for i in supporting] + else: + # Provide all the substories + substory = [x for x in story if x] + data.append((substory, q, a)) + story.append("") + else: + sent = tokenize(line) + story.append(sent) + return data + + +def get_stories(f, only_supporting=False, max_length=None): + """Given a file name, read the file, + retrieve the stories, + and then convert the sentences into a single story. + + If max_length is supplied, + any stories longer than max_length tokens will be discarded. + """ + + def flatten(data): + return sum(data, []) + + data = parse_stories(f.readlines(), only_supporting=only_supporting) + data = [(flatten(story), q, answer) for story, q, answer in data + if not max_length or len(flatten(story)) < max_length] + return data + + +def vectorize_stories(word_idx, story_maxlen, query_maxlen, data): + inputs, queries, answers = [], [], [] + for story, query, answer in data: + inputs.append([word_idx[w] for w in story]) + queries.append([word_idx[w] for w in query]) + answers.append(word_idx[answer]) + return (pad_sequences(inputs, maxlen=story_maxlen), + pad_sequences(queries, maxlen=query_maxlen), np.array(answers)) + + +def read_data(): + # Get the file + try: + path = get_file( + "babi-tasks-v1-2.tar.gz", + origin="https://s3.amazonaws.com/text-datasets/" + "babi_tasks_1-20_v1-2.tar.gz") + except Exception: + print( + "Error downloading dataset, please download it manually:\n" + "$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2" # noqa: E501 + ".tar.gz\n" + "$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz" # noqa: E501 + ) + raise + + # Choose challenge + challenges = { + # QA1 with 10,000 samples + "single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_" + "single-supporting-fact_{}.txt", + # QA2 with 10,000 samples + "two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_" + "two-supporting-facts_{}.txt", + } + challenge_type = "single_supporting_fact_10k" + challenge = challenges[challenge_type] + + with tarfile.open(path) as tar: + train_stories = get_stories(tar.extractfile(challenge.format("train"))) + test_stories = get_stories(tar.extractfile(challenge.format("test"))) + + return train_stories, test_stories + + +class MemNNModel(Trainable): + def build_model(self): + """Helper method for creating the model""" + vocab = set() + for story, q, answer in self.train_stories + self.test_stories: + vocab |= set(story + q + [answer]) + vocab = sorted(vocab) + + # Reserve 0 for masking via pad_sequences + vocab_size = len(vocab) + 1 + story_maxlen = max( + len(x) for x, _, _ in self.train_stories + self.test_stories) + query_maxlen = max( + len(x) for _, x, _ in self.train_stories + self.test_stories) + + word_idx = {c: i + 1 for i, c in enumerate(vocab)} + self.inputs_train, self.queries_train, self.answers_train = ( + vectorize_stories(word_idx, story_maxlen, query_maxlen, + self.train_stories)) + self.inputs_test, self.queries_test, self.answers_test = ( + vectorize_stories(word_idx, story_maxlen, query_maxlen, + self.test_stories)) + + # placeholders + input_sequence = Input((story_maxlen, )) + question = Input((query_maxlen, )) + + # encoders + # embed the input sequence into a sequence of vectors + input_encoder_m = Sequential() + input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64)) + input_encoder_m.add(Dropout(self.config.get("dropout", 0.3))) + # output: (samples, story_maxlen, embedding_dim) + + # embed the input into a sequence of vectors of size query_maxlen + input_encoder_c = Sequential() + input_encoder_c.add( + Embedding(input_dim=vocab_size, output_dim=query_maxlen)) + input_encoder_c.add(Dropout(self.config.get("dropout", 0.3))) + # output: (samples, story_maxlen, query_maxlen) + + # embed the question into a sequence of vectors + question_encoder = Sequential() + question_encoder.add( + Embedding( + input_dim=vocab_size, output_dim=64, + input_length=query_maxlen)) + question_encoder.add(Dropout(self.config.get("dropout", 0.3))) + # output: (samples, query_maxlen, embedding_dim) + + # encode input sequence and questions (which are indices) + # to sequences of dense vectors + input_encoded_m = input_encoder_m(input_sequence) + input_encoded_c = input_encoder_c(input_sequence) + question_encoded = question_encoder(question) + + # compute a "match" between the first input vector sequence + # and the question vector sequence + # shape: `(samples, story_maxlen, query_maxlen)` + match = dot([input_encoded_m, question_encoded], axes=(2, 2)) + match = Activation("softmax")(match) + + # add the match matrix with the second input vector sequence + response = add( + [match, input_encoded_c]) # (samples, story_maxlen, query_maxlen) + response = Permute( + (2, 1))(response) # (samples, query_maxlen, story_maxlen) + + # concatenate the match matrix with the question vector sequence + answer = concatenate([response, question_encoded]) + + # the original paper uses a matrix multiplication. + # we choose to use a RNN instead. + answer = LSTM(32)(answer) # (samples, 32) + + # one regularization layer -- more would probably be needed. + answer = Dropout(self.config.get("dropout", 0.3))(answer) + answer = Dense(vocab_size)(answer) # (samples, vocab_size) + # we output a probability distribution over the vocabulary + answer = Activation("softmax")(answer) + + # build the final model + model = Model([input_sequence, question], answer) + return model + + def _setup(self, config): + self.train_stories, self.test_stories = read_data() + model = self.build_model() + rmsprop = RMSprop( + lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9)) + model.compile( + optimizer=rmsprop, + loss="sparse_categorical_crossentropy", + metrics=["accuracy"]) + self.model = model + + def _train(self): + # train + self.model.fit( + [self.inputs_train, self.queries_train], + self.answers_train, + batch_size=self.config.get("batch_size", 32), + epochs=self.config.get("epochs", 1), + validation_data=([self.inputs_test, self.queries_test], + self.answers_test), + verbose=0) + _, accuracy = self.model.evaluate( + [self.inputs_train, self.queries_train], + self.answers_train, + verbose=0) + return {"mean_accuracy": accuracy} + + def _save(self, checkpoint_dir): + file_path = checkpoint_dir + "/model" + self.model.save(file_path) + return file_path + + def _restore(self, path): + # See https://stackoverflow.com/a/42763323 + del self.model + self.model = load_model(path) + + +if __name__ == "__main__": + import ray + from ray.tune import Trainable, run + from ray.tune.schedulers import PopulationBasedTraining + + parser = argparse.ArgumentParser() + parser.add_argument( + "--smoke-test", action="store_true", help="Finish quickly for testing") + args, _ = parser.parse_known_args() + ray.init() + + pbt = PopulationBasedTraining( + time_attr="training_iteration", + metric="mean_accuracy", + mode="max", + perturbation_interval=5, + hyperparam_mutations={ + "dropout": lambda: np.random.uniform(0, 1), + "lr": lambda: 10**np.random.randint(-10, 0), + "rho": lambda: np.random.uniform(0, 1) + }) + + results = run( + MemNNModel, + name="pbt_babi_memnn", + scheduler=pbt, + stop={"training_iteration": 20 if args.smoke_test else 100}, + num_samples=4, + config={ + "batch_size": 32, + "epochs": 1, + "dropout": 0.3, + "lr": 0.01, + "rho": 0.9 + })