mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
Cartpole MAML + Discrete (#11028)
This commit is contained in:
parent
180c259702
commit
47b499d899
2 changed files with 43 additions and 10 deletions
|
@ -55,25 +55,29 @@ DEFAULT_CONFIG = with_common_config({
|
|||
"maml_optimizer_steps": 5,
|
||||
# Inner Adaptation Step size
|
||||
"inner_lr": 0.1,
|
||||
# Use Meta Env Template
|
||||
"use_meta_env": True,
|
||||
})
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
|
||||
# @mluo: TODO
|
||||
def set_worker_tasks(workers):
|
||||
def set_worker_tasks(workers, use_meta_env):
|
||||
if use_meta_env:
|
||||
n_tasks = len(workers.remote_workers())
|
||||
tasks = workers.local_worker().foreach_env(lambda x: x)[0].sample_tasks(
|
||||
n_tasks)
|
||||
tasks = workers.local_worker().foreach_env(lambda x: x)[
|
||||
0].sample_tasks(n_tasks)
|
||||
for i, worker in enumerate(workers.remote_workers()):
|
||||
worker.foreach_env.remote(lambda env: env.set_task(tasks[i]))
|
||||
|
||||
|
||||
class MetaUpdate:
|
||||
def __init__(self, workers, maml_steps, metric_gen):
|
||||
def __init__(self, workers, maml_steps, metric_gen, use_meta_env):
|
||||
self.workers = workers
|
||||
self.maml_optimizer_steps = maml_steps
|
||||
self.metric_gen = metric_gen
|
||||
self.use_meta_env = use_meta_env
|
||||
|
||||
def __call__(self, data_tuple):
|
||||
# Metaupdate Step
|
||||
|
@ -91,7 +95,7 @@ class MetaUpdate:
|
|||
self.workers.sync_weights()
|
||||
|
||||
# Set worker tasks
|
||||
set_worker_tasks(self.workers)
|
||||
set_worker_tasks(self.workers, self.use_meta_env)
|
||||
|
||||
# Update KLS
|
||||
def update(pi, pi_id):
|
||||
|
@ -141,7 +145,8 @@ def execution_plan(workers, config):
|
|||
workers.sync_weights()
|
||||
|
||||
# Samples and sets worker tasks
|
||||
set_worker_tasks(workers)
|
||||
use_meta_env = config["use_meta_env"]
|
||||
set_worker_tasks(workers, use_meta_env)
|
||||
|
||||
# Metric Collector
|
||||
metric_collect = CollectMetrics(
|
||||
|
@ -191,7 +196,8 @@ def execution_plan(workers, config):
|
|||
|
||||
# Metaupdate Step
|
||||
train_op = rollouts.for_each(
|
||||
MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect))
|
||||
MetaUpdate(workers, config["maml_optimizer_steps"], metric_collect,
|
||||
use_meta_env))
|
||||
return train_op
|
||||
|
||||
|
||||
|
|
27
rllib/tuned_examples/maml/cartpole-maml.yaml
Normal file
27
rllib/tuned_examples/maml/cartpole-maml.yaml
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Same configs as Pendulum
|
||||
cartpole-maml:
|
||||
env: CartPole-v0
|
||||
run: MAML
|
||||
stop:
|
||||
training_iteration: 100
|
||||
config:
|
||||
horizon: 200
|
||||
rollout_fragment_length: 200
|
||||
num_envs_per_worker: 10
|
||||
inner_adaptation_steps: 1
|
||||
maml_optimizer_steps: 5
|
||||
gamma: 0.99
|
||||
lambda: 1.0
|
||||
lr: 0.001
|
||||
vf_loss_coeff: 0.5
|
||||
clip_param: 0.3
|
||||
kl_target: 0.01
|
||||
kl_coeff: 0.001
|
||||
num_workers: 20
|
||||
num_gpus: 1
|
||||
inner_lr: 0.03
|
||||
clip_actions: False
|
||||
use_meta_env: False
|
||||
model:
|
||||
fcnet_hiddens: [64, 64]
|
||||
free_log_std: True
|
Loading…
Add table
Reference in a new issue