2019-06-07 16:45:36 -07:00
|
|
|
from ray.rllib.agents.trainer import with_common_config
|
|
|
|
from ray.rllib.agents.trainer_template import build_trainer
|
2020-01-31 20:57:52 +00:00
|
|
|
from ray.rllib.agents.marwil.marwil_policy import MARWILTFPolicy
|
2019-01-17 11:00:43 +08:00
|
|
|
from ray.rllib.optimizers import SyncBatchReplayOptimizer
|
|
|
|
|
|
|
|
# yapf: disable
|
|
|
|
# __sphinx_doc_begin__
|
|
|
|
DEFAULT_CONFIG = with_common_config({
|
2019-02-14 19:35:21 -08:00
|
|
|
# You should override this to point to an offline dataset (see agent.py).
|
|
|
|
"input": "sampler",
|
|
|
|
# Use importance sampling estimators for reward
|
|
|
|
"input_evaluation": ["is", "wis"],
|
|
|
|
|
2019-01-17 11:00:43 +08:00
|
|
|
# Scaling of advantages in exponential terms
|
|
|
|
# When beta is 0, MARWIL is reduced to imitation learning
|
|
|
|
"beta": 1.0,
|
|
|
|
# Balancing value estimation loss and policy optimization loss
|
|
|
|
"vf_coeff": 1.0,
|
|
|
|
# Whether to calculate cumulative rewards
|
|
|
|
"postprocess_inputs": True,
|
|
|
|
# Whether to rollout "complete_episodes" or "truncate_episodes"
|
|
|
|
"batch_mode": "complete_episodes",
|
|
|
|
# Learning rate for adam optimizer
|
|
|
|
"lr": 1e-4,
|
|
|
|
# Number of timesteps collected for each SGD round
|
|
|
|
"train_batch_size": 2000,
|
|
|
|
# Number of steps max to keep in the batch replay buffer
|
|
|
|
"replay_buffer_size": 100000,
|
|
|
|
# Number of steps to read before learning starts
|
|
|
|
"learning_starts": 0,
|
|
|
|
# === Parallelism ===
|
|
|
|
"num_workers": 0,
|
|
|
|
})
|
|
|
|
# __sphinx_doc_end__
|
|
|
|
# yapf: enable
|
|
|
|
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
def make_optimizer(workers, config):
|
|
|
|
return SyncBatchReplayOptimizer(
|
|
|
|
workers,
|
|
|
|
learning_starts=config["learning_starts"],
|
|
|
|
buffer_size=config["replay_buffer_size"],
|
|
|
|
train_batch_size=config["train_batch_size"],
|
|
|
|
)
|
2019-01-17 11:00:43 +08:00
|
|
|
|
|
|
|
|
2020-01-18 03:48:44 +01:00
|
|
|
def validate_config(config):
|
|
|
|
# PyTorch check.
|
|
|
|
if config["use_pytorch"]:
|
|
|
|
raise ValueError("DDPG does not support PyTorch yet! Use tf instead.")
|
|
|
|
|
|
|
|
|
2019-06-07 16:45:36 -07:00
|
|
|
MARWILTrainer = build_trainer(
|
|
|
|
name="MARWIL",
|
|
|
|
default_config=DEFAULT_CONFIG,
|
2020-01-31 20:57:52 +00:00
|
|
|
default_policy=MARWILTFPolicy,
|
2020-01-18 03:48:44 +01:00
|
|
|
validate_config=validate_config,
|
2020-01-31 20:57:52 +00:00
|
|
|
make_policy_optimizer=make_optimizer)
|