2019-06-07 16:45:36 -07:00
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
2020-04-07 01:38:50 +02:00
|
|
|
from ray.rllib.agents.marwil.marwil_tf_policy import MARWILTFPolicy
|
2020-05-21 10:16:18 -07:00
|
|
|
from ray.rllib.execution.replay_ops import SimpleReplayBuffer, Replay, \
|
|
|
|
StoreToReplayBuffer
|
|
|
|
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches
|
|
|
|
from ray.rllib.execution.concurrency_ops import Concurrently
|
|
|
|
from ray.rllib.execution.train_ops import TrainOneStep
|
|
|
|
from ray.rllib.execution.metric_ops import StandardMetricsReporting
|
2019-01-17 11:00:43 +08:00
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
DEFAULT_CONFIG = with_common_config({
|
2019-02-14 19:35:21 -08:00
|
|
|
# You should override this to point to an offline dataset (see agent.py).
|
|
|
|
"input": "sampler",
|
|
|
|
# Use importance sampling estimators for reward
|
|
|
|
"input_evaluation": ["is", "wis"],
|
|
|
|
|
2021-03-30 21:43:11 +02:00
|
|
|
# If true, use the Generalized Advantage Estimator (GAE)
|
|
|
|
# with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
|
|
|
|
"use_gae": True,
|
|
|
|
|
2020-09-09 17:33:21 +02:00
|
|
|
# Scaling of advantages in exponential terms.
|
|
|
|
# When beta is 0.0, MARWIL is reduced to imitation learning.
|
2019-01-17 11:00:43 +08:00
|
|
|
"beta": 1.0,
|
2020-09-09 17:33:21 +02:00
|
|
|
# Balancing value estimation loss and policy optimization loss.
|
2019-01-17 11:00:43 +08:00
|
|
|
"vf_coeff": 1.0,
|
2021-01-22 19:36:02 +01:00
|
|
|
# If specified, clip the global norm of gradients by this amount.
|
|
|
|
"grad_clip": None,
|
2020-09-09 17:33:21 +02:00
|
|
|
# Whether to calculate cumulative rewards.
|
2019-01-17 11:00:43 +08:00
|
|
|
"postprocess_inputs": True,
|
2020-09-09 17:33:21 +02:00
|
|
|
# Whether to rollout "complete_episodes" or "truncate_episodes".
|
2019-01-17 11:00:43 +08:00
|
|
|
"batch_mode": "complete_episodes",
|
2020-09-09 17:33:21 +02:00
|
|
|
# Learning rate for adam optimizer.
|
2019-01-17 11:00:43 +08:00
|
|
|
"lr": 1e-4,
|
2020-09-09 17:33:21 +02:00
|
|
|
# Number of timesteps collected for each SGD round.
|
2019-01-17 11:00:43 +08:00
|
|
|
"train_batch_size": 2000,
|
2020-09-23 15:46:06 -07:00
|
|
|
# Size of the replay buffer in batches (not timesteps!).
|
|
|
|
"replay_buffer_size": 1000,
|
2020-09-09 17:33:21 +02:00
|
|
|
# Number of steps to read before learning starts.
|
2019-01-17 11:00:43 +08:00
|
|
|
"learning_starts": 0,
|
|
|
|
# === Parallelism ===
|
|
|
|
"num_workers": 0,
|
|
|
|
})
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
# yapf: enable
|
|
|
|
|
|
|
|
|
2020-04-07 01:38:50 +02:00
|
|
|
def get_policy_class(config):
|
2020-05-27 16:19:13 +02:00
|
|
|
if config["framework"] == "torch":
|
2020-04-07 01:38:50 +02:00
|
|
|
from ray.rllib.agents.marwil.marwil_torch_policy import \
|
|
|
|
MARWILTorchPolicy
|
|
|
|
return MARWILTorchPolicy
|
2020-01-18 03:48:44 +01:00
|
|
|
|
|
|
|
|
2020-05-21 10:16:18 -07:00
|
|
|
def execution_plan(workers, config):
|
|
|
|
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
|
|
|
replay_buffer = SimpleReplayBuffer(config["replay_buffer_size"])
|
|
|
|
|
|
|
|
store_op = rollouts \
|
|
|
|
.for_each(StoreToReplayBuffer(local_buffer=replay_buffer))
|
|
|
|
|
|
|
|
replay_op = Replay(local_buffer=replay_buffer) \
|
|
|
|
.combine(
|
2020-12-09 01:41:45 +01:00
|
|
|
ConcatBatches(
|
|
|
|
min_batch_size=config["train_batch_size"],
|
|
|
|
count_steps_by=config["multiagent"]["count_steps_by"],
|
|
|
|
)) \
|
2020-05-21 10:16:18 -07:00
|
|
|
.for_each(TrainOneStep(workers))
|
|
|
|
|
|
|
|
train_op = Concurrently(
|
|
|
|
[store_op, replay_op], mode="round_robin", output_indexes=[1])
|
|
|
|
|
|
|
|
return StandardMetricsReporting(train_op, workers, config)
|
|
|
|
|
|
|
|
|
2021-03-08 15:41:27 +01:00
|
|
|
def validate_config(config):
|
|
|
|
if config["num_gpus"] > 1:
|
|
|
|
raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")
|
|
|
|
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
MARWILTrainer = build_trainer(
|
|
|
|
name="MARWIL",
|
|
|
|
default_config=DEFAULT_CONFIG,
|
2020-01-31 20:57:52 +00:00
|
|
|
default_policy=MARWILTFPolicy,
|
2020-04-07 01:38:50 +02:00
|
|
|
get_policy_class=get_policy_class,
|
2021-03-08 15:41:27 +01:00
|
|
|
validate_config=validate_config,
|
2020-05-21 10:16:18 -07:00
|
|
|
execution_plan=execution_plan)
|