Cartpole MAML + Discrete (#11028)

This commit is contained in:
Michael Luo 2020-10-02 03:56:34 -07:00 committed by GitHub
parent 180c259702
commit 47b499d899
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 43 additions and 10 deletions

View file

@ -55,25 +55,29 @@ DEFAULT_CONFIG = with_common_config({
"maml_optimizer_steps": 5,
# Inner Adaptation Step size
"inner_lr": 0.1,
# Use Meta Env Template
"use_meta_env": True,
})
# __sphinx_doc_end__
# yapf: enable
# @mluo: TODO
def set_worker_tasks(workers):
n_tasks = len(workers.remote_workers())
tasks = workers.local_worker().foreach_env(lambda x: x)[0].sample_tasks(
n_tasks)
for i, worker in enumerate(workers.remote_workers()):
worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))
def set_worker_tasks(workers, use_meta_env):
if use_meta_env:
n_tasks = len(workers.remote_workers())
tasks = workers.local_worker().foreach_env(lambda x: x)[
0].sample_tasks(n_tasks)
for i, worker in enumerate(workers.remote_workers()):
worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))
class MetaUpdate:
def __init__(self, workers, maml_steps, metric_gen):
def __init__(self, workers, maml_steps, metric_gen, use_meta_env):
self.workers = workers
self.maml_optimizer_steps = maml_steps
self.metric_gen = metric_gen
self.use_meta_env = use_meta_env
def __call__(self, data_tuple):
# Metaupdate Step
@ -91,7 +95,7 @@ class MetaUpdate:
self.workers.sync_weights()
# Set worker tasks
set_worker_tasks(self.workers)
set_worker_tasks(self.workers, self.use_meta_env)
# Update KLS
def update(pi, pi_id):
@ -141,7 +145,8 @@ def execution_plan(workers, config):
workers.sync_weights()
# Samples and sets worker tasks
set_worker_tasks(workers)
use_meta_env = config["use_meta_env"]
set_worker_tasks(workers, use_meta_env)
# Metric Collector
metric_collect = CollectMetrics(
@ -191,7 +196,8 @@ def execution_plan(workers, config):
# Metaupdate Step
train_op = rollouts.for_each(
MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect))
MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect,
use_meta_env))
return train_op

View file

@ -0,0 +1,27 @@
# Same configs as Pendulum
cartpole-maml:
env: CartPole-v0
run: MAML
stop:
training_iteration: 100
config:
horizon: 200
rollout_fragment_length: 200
num_envs_per_worker: 10
inner_adaptation_steps: 1
maml_optimizer_steps: 5
gamma: 0.99
lambda: 1.0
lr: 0.001
vf_loss_coeff: 0.5
clip_param: 0.3
kl_target: 0.01
kl_coeff: 0.001
num_workers: 20
num_gpus: 1
inner_lr: 0.03
clip_actions: False
use_meta_env: False
model:
fcnet_hiddens: [64, 64]
free_log_std: True