from ray.rllib.agents.trainer import with_common_config from ray.rllib.agents.trainer_template import build_trainer from ray.rllib.agents.marwil.marwil_policy import MARWILTFPolicy from ray.rllib.optimizers import SyncBatchReplayOptimizer # yapf: disable # __sphinx_doc_begin__ DEFAULT_CONFIG = with_common_config({ # You should override this to point to an offline dataset (see agent.py). "input": "sampler", # Use importance sampling estimators for reward "input_evaluation": ["is", "wis"], # Scaling of advantages in exponential terms # When beta is 0, MARWIL is reduced to imitation learning "beta": 1.0, # Balancing value estimation loss and policy optimization loss "vf_coeff": 1.0, # Whether to calculate cumulative rewards "postprocess_inputs": True, # Whether to rollout "complete_episodes" or "truncate_episodes" "batch_mode": "complete_episodes", # Learning rate for adam optimizer "lr": 1e-4, # Number of timesteps collected for each SGD round "train_batch_size": 2000, # Number of steps max to keep in the batch replay buffer "replay_buffer_size": 100000, # Number of steps to read before learning starts "learning_starts": 0, # === Parallelism === "num_workers": 0, }) # __sphinx_doc_end__ # yapf: enable def make_optimizer(workers, config): return SyncBatchReplayOptimizer( workers, learning_starts=config["learning_starts"], buffer_size=config["replay_buffer_size"], train_batch_size=config["train_batch_size"], ) def validate_config(config): # PyTorch check. if config["use_pytorch"]: raise ValueError("DDPG does not support PyTorch yet! Use tf instead.") MARWILTrainer = build_trainer( name="MARWIL", default_config=DEFAULT_CONFIG, default_policy=MARWILTFPolicy, validate_config=validate_config, make_policy_optimizer=make_optimizer)