ray/rllib/agents/a3c/a3c_torch_policy.py
Sven Mika e2edca45d4
[RLlib] PPO torch memory leak and unnecessary torch.Tensor creation and gc'ing. (#7238)
* Take out stats to analyze memory leak in torch PPO.

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* WIP.

* LINT.

* Fix determine_tests_to_run.py.

* minor change to re-test after determine_tests_to_run.py.

* LINT.

* update comments.

* WIP

* WIP

* WIP

* FIX.

* Fix sequence_mask being dependent on torch being installed.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.

* Fix strange ray-core tf-error in test_memory_scheduling test case.
2020-02-22 11:02:31 -08:00

87 lines
2.9 KiB
Python

import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
def actor_critic_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
values = model.value_function()
dist = dist_class(logits, model)
log_probs = dist.logp(train_batch[SampleBatch.ACTIONS])
policy.entropy = dist.entropy().mean()
policy.pi_err = -train_batch[Postprocessing.ADVANTAGES].dot(
log_probs.reshape(-1))
policy.value_err = nn.functional.mse_loss(
values.reshape(-1), train_batch[Postprocessing.VALUE_TARGETS])
overall_err = sum([
policy.pi_err,
policy.config["vf_loss_coeff"] * policy.value_err,
-policy.config["entropy_coeff"] * policy.entropy,
])
return overall_err
def loss_and_entropy_stats(policy, train_batch):
return {
"policy_entropy": policy.entropy.item(),
"policy_loss": policy.pi_err.item(),
"vf_loss": policy.value_err.item(),
}
def add_advantages(policy,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch[SampleBatch.DONES][-1]
if completed:
last_r = 0.0
else:
last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1])
return compute_advantages(sample_batch, last_r, policy.config["gamma"],
policy.config["lambda"],
policy.config["use_gae"],
policy.config["use_critic"])
def model_value_predictions(policy, input_dict, state_batches, model,
action_dist):
return {SampleBatch.VF_PREDS: model.value_function()}
def apply_grad_clipping(policy):
info = {}
if policy.config["grad_clip"]:
total_norm = nn.utils.clip_grad_norm_(policy.model.parameters(),
policy.config["grad_clip"])
info["grad_gnorm"] = total_norm
return info
def torch_optimizer(policy, config):
return torch.optim.Adam(policy.model.parameters(), lr=config["lr"])
class ValueNetworkMixin:
def _value(self, obs):
_ = self.model({"obs": torch.Tensor([obs]).to(self.device)}, [], [1])
return self.model.value_function()[0]
A3CTorchPolicy = build_torch_policy(
name="A3CTorchPolicy",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=loss_and_entropy_stats,
postprocess_fn=add_advantages,
extra_action_out_fn=model_value_predictions,
extra_grad_process_fn=apply_grad_clipping,
optimizer_fn=torch_optimizer,
mixins=[ValueNetworkMixin])