[RLlib] Fix MARWIL tf policy. (#25384)

This commit is contained in:
Jun Gong 2022-06-03 01:50:36 -07:00 committed by GitHub
parent 99429b7a92
commit 1d24d6af98
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 15 deletions

View file

@ -215,7 +215,7 @@ def get_marwil_tf_policy(base: type) -> type:
action_dist = dist_class(model_out, model)
value_estimates = model.value_function()
self.loss = MARWILLoss(
self._marwil_loss = MARWILLoss(
self,
value_estimates,
action_dist,
@ -224,18 +224,18 @@ def get_marwil_tf_policy(base: type) -> type:
self.config["beta"],
)
return self.loss.total_loss
return self._marwil_loss.total_loss
@override(base)
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
stats = {
"policy_loss": self.loss.p_loss,
"total_loss": self.loss.total_loss,
"policy_loss": self._marwil_loss.p_loss,
"total_loss": self._marwil_loss.total_loss,
}
if self.config["beta"] != 0.0:
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
stats["vf_explained_var"] = self.loss.explained_variance
stats["vf_loss"] = self.loss.v_loss
stats["vf_explained_var"] = self._marwil_loss.explained_variance
stats["vf_loss"] = self._marwil_loss.v_loss
return stats

View file

@ -198,7 +198,13 @@ class TestMARWIL(unittest.TestCase):
)
else:
loss_out, v_loss, p_loss = policy.get_session().run(
[policy._loss, policy.loss.v_loss, policy.loss.p_loss],
# policy._loss is create by TFPolicy, and is basically the
# loss tensor of the static graph.
[
policy._loss,
policy._marwil_loss.v_loss,
policy._marwil_loss.p_loss,
],
feed_dict=policy._get_loss_inputs_dict(
postprocessed_batch, shuffle=False
),
@ -212,8 +218,8 @@ class TestMARWIL(unittest.TestCase):
check(v_loss, expected_vf_loss, decimals=4)
check(p_loss, expected_pol_loss, decimals=4)
else:
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
check(policy._marwil_loss.v_loss, expected_vf_loss, decimals=4)
check(policy._marwil_loss.p_loss, expected_pol_loss, decimals=4)
check(loss_out, expected_loss, decimals=3)

View file

@ -47,11 +47,17 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
Providing nested lists w/o this preprocessing step would
confuse a SampleBatch constructor.
"""
for k, v in policy.view_requirements.items():
if k not in json_data:
continue
for k, v in json_data.items():
data_col = (
policy.view_requirements[k].data_col
if k in policy.view_requirements
else ""
)
if policy.config.get("_disable_action_flattening") and (
k == SampleBatch.ACTIONS or v.data_col == SampleBatch.ACTIONS
k == SampleBatch.ACTIONS
or data_col == SampleBatch.ACTIONS
or k == SampleBatch.PREV_ACTIONS
or data_col == SampleBatch.PREV_ACTIONS
):
json_data[k] = tree.map_structure_up_to(
policy.action_space_struct,
@ -60,7 +66,10 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
check_types=False,
)
elif policy.config.get("_disable_preprocessor_api") and (
k == SampleBatch.OBS or v.data_col == SampleBatch.OBS
k == SampleBatch.OBS
or data_col == SampleBatch.OBS
or k == SampleBatch.NEXT_OBS
or data_col == SampleBatch.NEXT_OBS
):
json_data[k] = tree.map_structure_up_to(
policy.observation_space_struct,

View file

@ -79,7 +79,7 @@ class NestedActionSpacesTest(unittest.TestCase):
config["env_config"] = {
"action_space": action_space,
}
for flatten in [False, True]:
for flatten in [True, False]:
print(f"A={action_space} flatten={flatten}")
shutil.rmtree(config["output"])
config["_disable_action_flattening"] = not flatten