mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] PBT + Memnn example (#5723)
* Add example file * Move into train function * Somewhat working example of MemNN, still has some failed trials * Reorganize into a class * Small fixes * Iteration decrease and fix hyperparam_mutations * Add example file * Move into train function * Somewhat working example of MemNN, still has some failed trials * Reorganize into a class * Small fixes * Iteration decrease and fix hyperparam_mutations * Some style edits * Address PR changes without modifying learning rate * Add configs and hyperparameter mutations * Add tune test * Modify import locations * Some parameter changes for testing * Update memnn example * Add tensorboard support and address PR comment * Final changes * lint * generator
This commit is contained in:
parent
fb33160df8
commit
b99cdf4e39
4 changed files with 291 additions and 0 deletions
|
@ -125,6 +125,10 @@ $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE}
|
|||
python /ray/python/ray/tune/examples/skopt_example.py \
|
||||
--smoke-test
|
||||
|
||||
$SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
python /ray/python/ray/tune/examples/memnn_example.py \
|
||||
--smoke-test
|
||||
|
||||
# uncomment once statsmodels is updated.
|
||||
# $SUPPRESS_OUTPUT docker run --rm --shm-size=${SHM_SIZE} --memory=${MEMORY_SIZE} $DOCKER_SHA \
|
||||
# python /ray/python/ray/tune/examples/bohb_example.py \
|
||||
|
|
|
@ -16,6 +16,7 @@ General Examples
|
|||
- `pbt_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py>`__: Example of using a Trainable class with PopulationBasedTraining scheduler.
|
||||
- `pbt_ppo_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler.
|
||||
- `logging_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/logging_example.py>`__: Example of custom loggers and custom trial directory naming.
|
||||
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
||||
|
||||
Search Algorithm Examples
|
||||
-------------------------
|
||||
|
|
|
@ -16,6 +16,7 @@ General Examples
|
|||
- `pbt_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py>`__: Example of using a Trainable class with PopulationBasedTraining scheduler.
|
||||
- `pbt_ppo_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_ppo_example.py>`__: Example of optimizing a distributed RLlib algorithm (PPO) with the PopulationBasedTraining scheduler.
|
||||
- `logging_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/logging_example.py>`__: Example of custom loggers and custom trial directory naming.
|
||||
- `pbt_memnn_example <https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_memnn_example.py>`__: Example of training a Memory NN on bAbI with Keras using PBT.
|
||||
|
||||
Search Algorithm Examples
|
||||
-------------------------
|
||||
|
|
285
python/ray/tune/examples/pbt_memnn_example.py
Normal file
285
python/ray/tune/examples/pbt_memnn_example.py
Normal file
|
@ -0,0 +1,285 @@
|
|||
"""Example training a memory neural net on the bAbI dataset.
|
||||
|
||||
References Keras and is based off of https://keras.io/examples/babi_memnn/.
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.models import Sequential, Model, load_model
|
||||
from tensorflow.python.keras.layers.embeddings import Embedding
|
||||
from tensorflow.python.keras.layers import (Input, Activation, Dense, Permute,
|
||||
Dropout)
|
||||
from tensorflow.python.keras.layers import add, dot, concatenate
|
||||
from tensorflow.python.keras.layers import LSTM
|
||||
from tensorflow.python.keras.optimizers import RMSprop
|
||||
from tensorflow.python.keras.utils.data_utils import get_file
|
||||
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
|
||||
from ray.tune import Trainable
|
||||
import argparse
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
|
||||
def tokenize(sent):
|
||||
"""Return the tokens of a sentence including punctuation.
|
||||
|
||||
>>> tokenize("Bob dropped the apple. Where is the apple?")
|
||||
["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"]
|
||||
"""
|
||||
return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()]
|
||||
|
||||
|
||||
def parse_stories(lines, only_supporting=False):
|
||||
"""Parse stories provided in the bAbi tasks format
|
||||
|
||||
If only_supporting is true, only the sentences
|
||||
that support the answer are kept.
|
||||
"""
|
||||
data = []
|
||||
story = []
|
||||
for line in lines:
|
||||
line = line.decode("utf-8").strip()
|
||||
nid, line = line.split(" ", 1)
|
||||
nid = int(nid)
|
||||
if nid == 1:
|
||||
story = []
|
||||
if "\t" in line:
|
||||
q, a, supporting = line.split("\t")
|
||||
q = tokenize(q)
|
||||
if only_supporting:
|
||||
# Only select the related substory
|
||||
supporting = map(int, supporting.split())
|
||||
substory = [story[i - 1] for i in supporting]
|
||||
else:
|
||||
# Provide all the substories
|
||||
substory = [x for x in story if x]
|
||||
data.append((substory, q, a))
|
||||
story.append("")
|
||||
else:
|
||||
sent = tokenize(line)
|
||||
story.append(sent)
|
||||
return data
|
||||
|
||||
|
||||
def get_stories(f, only_supporting=False, max_length=None):
|
||||
"""Given a file name, read the file,
|
||||
retrieve the stories,
|
||||
and then convert the sentences into a single story.
|
||||
|
||||
If max_length is supplied,
|
||||
any stories longer than max_length tokens will be discarded.
|
||||
"""
|
||||
|
||||
def flatten(data):
|
||||
return sum(data, [])
|
||||
|
||||
data = parse_stories(f.readlines(), only_supporting=only_supporting)
|
||||
data = [(flatten(story), q, answer) for story, q, answer in data
|
||||
if not max_length or len(flatten(story)) < max_length]
|
||||
return data
|
||||
|
||||
|
||||
def vectorize_stories(word_idx, story_maxlen, query_maxlen, data):
|
||||
inputs, queries, answers = [], [], []
|
||||
for story, query, answer in data:
|
||||
inputs.append([word_idx[w] for w in story])
|
||||
queries.append([word_idx[w] for w in query])
|
||||
answers.append(word_idx[answer])
|
||||
return (pad_sequences(inputs, maxlen=story_maxlen),
|
||||
pad_sequences(queries, maxlen=query_maxlen), np.array(answers))
|
||||
|
||||
|
||||
def read_data():
|
||||
# Get the file
|
||||
try:
|
||||
path = get_file(
|
||||
"babi-tasks-v1-2.tar.gz",
|
||||
origin="https://s3.amazonaws.com/text-datasets/"
|
||||
"babi_tasks_1-20_v1-2.tar.gz")
|
||||
except Exception:
|
||||
print(
|
||||
"Error downloading dataset, please download it manually:\n"
|
||||
"$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2" # noqa: E501
|
||||
".tar.gz\n"
|
||||
"$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz" # noqa: E501
|
||||
)
|
||||
raise
|
||||
|
||||
# Choose challenge
|
||||
challenges = {
|
||||
# QA1 with 10,000 samples
|
||||
"single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_"
|
||||
"single-supporting-fact_{}.txt",
|
||||
# QA2 with 10,000 samples
|
||||
"two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_"
|
||||
"two-supporting-facts_{}.txt",
|
||||
}
|
||||
challenge_type = "single_supporting_fact_10k"
|
||||
challenge = challenges[challenge_type]
|
||||
|
||||
with tarfile.open(path) as tar:
|
||||
train_stories = get_stories(tar.extractfile(challenge.format("train")))
|
||||
test_stories = get_stories(tar.extractfile(challenge.format("test")))
|
||||
|
||||
return train_stories, test_stories
|
||||
|
||||
|
||||
class MemNNModel(Trainable):
|
||||
def build_model(self):
|
||||
"""Helper method for creating the model"""
|
||||
vocab = set()
|
||||
for story, q, answer in self.train_stories + self.test_stories:
|
||||
vocab |= set(story + q + [answer])
|
||||
vocab = sorted(vocab)
|
||||
|
||||
# Reserve 0 for masking via pad_sequences
|
||||
vocab_size = len(vocab) + 1
|
||||
story_maxlen = max(
|
||||
len(x) for x, _, _ in self.train_stories + self.test_stories)
|
||||
query_maxlen = max(
|
||||
len(x) for _, x, _ in self.train_stories + self.test_stories)
|
||||
|
||||
word_idx = {c: i + 1 for i, c in enumerate(vocab)}
|
||||
self.inputs_train, self.queries_train, self.answers_train = (
|
||||
vectorize_stories(word_idx, story_maxlen, query_maxlen,
|
||||
self.train_stories))
|
||||
self.inputs_test, self.queries_test, self.answers_test = (
|
||||
vectorize_stories(word_idx, story_maxlen, query_maxlen,
|
||||
self.test_stories))
|
||||
|
||||
# placeholders
|
||||
input_sequence = Input((story_maxlen, ))
|
||||
question = Input((query_maxlen, ))
|
||||
|
||||
# encoders
|
||||
# embed the input sequence into a sequence of vectors
|
||||
input_encoder_m = Sequential()
|
||||
input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
|
||||
input_encoder_m.add(Dropout(self.config.get("dropout", 0.3)))
|
||||
# output: (samples, story_maxlen, embedding_dim)
|
||||
|
||||
# embed the input into a sequence of vectors of size query_maxlen
|
||||
input_encoder_c = Sequential()
|
||||
input_encoder_c.add(
|
||||
Embedding(input_dim=vocab_size, output_dim=query_maxlen))
|
||||
input_encoder_c.add(Dropout(self.config.get("dropout", 0.3)))
|
||||
# output: (samples, story_maxlen, query_maxlen)
|
||||
|
||||
# embed the question into a sequence of vectors
|
||||
question_encoder = Sequential()
|
||||
question_encoder.add(
|
||||
Embedding(
|
||||
input_dim=vocab_size, output_dim=64,
|
||||
input_length=query_maxlen))
|
||||
question_encoder.add(Dropout(self.config.get("dropout", 0.3)))
|
||||
# output: (samples, query_maxlen, embedding_dim)
|
||||
|
||||
# encode input sequence and questions (which are indices)
|
||||
# to sequences of dense vectors
|
||||
input_encoded_m = input_encoder_m(input_sequence)
|
||||
input_encoded_c = input_encoder_c(input_sequence)
|
||||
question_encoded = question_encoder(question)
|
||||
|
||||
# compute a "match" between the first input vector sequence
|
||||
# and the question vector sequence
|
||||
# shape: `(samples, story_maxlen, query_maxlen)`
|
||||
match = dot([input_encoded_m, question_encoded], axes=(2, 2))
|
||||
match = Activation("softmax")(match)
|
||||
|
||||
# add the match matrix with the second input vector sequence
|
||||
response = add(
|
||||
[match, input_encoded_c]) # (samples, story_maxlen, query_maxlen)
|
||||
response = Permute(
|
||||
(2, 1))(response) # (samples, query_maxlen, story_maxlen)
|
||||
|
||||
# concatenate the match matrix with the question vector sequence
|
||||
answer = concatenate([response, question_encoded])
|
||||
|
||||
# the original paper uses a matrix multiplication.
|
||||
# we choose to use a RNN instead.
|
||||
answer = LSTM(32)(answer) # (samples, 32)
|
||||
|
||||
# one regularization layer -- more would probably be needed.
|
||||
answer = Dropout(self.config.get("dropout", 0.3))(answer)
|
||||
answer = Dense(vocab_size)(answer) # (samples, vocab_size)
|
||||
# we output a probability distribution over the vocabulary
|
||||
answer = Activation("softmax")(answer)
|
||||
|
||||
# build the final model
|
||||
model = Model([input_sequence, question], answer)
|
||||
return model
|
||||
|
||||
def _setup(self, config):
|
||||
self.train_stories, self.test_stories = read_data()
|
||||
model = self.build_model()
|
||||
rmsprop = RMSprop(
|
||||
lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9))
|
||||
model.compile(
|
||||
optimizer=rmsprop,
|
||||
loss="sparse_categorical_crossentropy",
|
||||
metrics=["accuracy"])
|
||||
self.model = model
|
||||
|
||||
def _train(self):
|
||||
# train
|
||||
self.model.fit(
|
||||
[self.inputs_train, self.queries_train],
|
||||
self.answers_train,
|
||||
batch_size=self.config.get("batch_size", 32),
|
||||
epochs=self.config.get("epochs", 1),
|
||||
validation_data=([self.inputs_test, self.queries_test],
|
||||
self.answers_test),
|
||||
verbose=0)
|
||||
_, accuracy = self.model.evaluate(
|
||||
[self.inputs_train, self.queries_train],
|
||||
self.answers_train,
|
||||
verbose=0)
|
||||
return {"mean_accuracy": accuracy}
|
||||
|
||||
def _save(self, checkpoint_dir):
|
||||
file_path = checkpoint_dir + "/model"
|
||||
self.model.save(file_path)
|
||||
return file_path
|
||||
|
||||
def _restore(self, path):
|
||||
# See https://stackoverflow.com/a/42763323
|
||||
del self.model
|
||||
self.model = load_model(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import ray
|
||||
from ray.tune import Trainable, run
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--smoke-test", action="store_true", help="Finish quickly for testing")
|
||||
args, _ = parser.parse_known_args()
|
||||
ray.init()
|
||||
|
||||
pbt = PopulationBasedTraining(
|
||||
time_attr="training_iteration",
|
||||
metric="mean_accuracy",
|
||||
mode="max",
|
||||
perturbation_interval=5,
|
||||
hyperparam_mutations={
|
||||
"dropout": lambda: np.random.uniform(0, 1),
|
||||
"lr": lambda: 10**np.random.randint(-10, 0),
|
||||
"rho": lambda: np.random.uniform(0, 1)
|
||||
})
|
||||
|
||||
results = run(
|
||||
MemNNModel,
|
||||
name="pbt_babi_memnn",
|
||||
scheduler=pbt,
|
||||
stop={"training_iteration": 20 if args.smoke_test else 100},
|
||||
num_samples=4,
|
||||
config={
|
||||
"batch_size": 32,
|
||||
"epochs": 1,
|
||||
"dropout": 0.3,
|
||||
"lr": 0.01,
|
||||
"rho": 0.9
|
||||
})
|
Loading…
Add table
Reference in a new issue