mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Fix MARWIL tf policy. (#25384)
This commit is contained in:
parent
99429b7a92
commit
1d24d6af98
4 changed files with 30 additions and 15 deletions
|
@ -215,7 +215,7 @@ def get_marwil_tf_policy(base: type) -> type:
|
|||
action_dist = dist_class(model_out, model)
|
||||
value_estimates = model.value_function()
|
||||
|
||||
self.loss = MARWILLoss(
|
||||
self._marwil_loss = MARWILLoss(
|
||||
self,
|
||||
value_estimates,
|
||||
action_dist,
|
||||
|
@ -224,18 +224,18 @@ def get_marwil_tf_policy(base: type) -> type:
|
|||
self.config["beta"],
|
||||
)
|
||||
|
||||
return self.loss.total_loss
|
||||
return self._marwil_loss.total_loss
|
||||
|
||||
@override(base)
|
||||
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
|
||||
stats = {
|
||||
"policy_loss": self.loss.p_loss,
|
||||
"total_loss": self.loss.total_loss,
|
||||
"policy_loss": self._marwil_loss.p_loss,
|
||||
"total_loss": self._marwil_loss.total_loss,
|
||||
}
|
||||
if self.config["beta"] != 0.0:
|
||||
stats["moving_average_sqd_adv_norm"] = self._moving_average_sqd_adv_norm
|
||||
stats["vf_explained_var"] = self.loss.explained_variance
|
||||
stats["vf_loss"] = self.loss.v_loss
|
||||
stats["vf_explained_var"] = self._marwil_loss.explained_variance
|
||||
stats["vf_loss"] = self._marwil_loss.v_loss
|
||||
|
||||
return stats
|
||||
|
||||
|
|
|
@ -198,7 +198,13 @@ class TestMARWIL(unittest.TestCase):
|
|||
)
|
||||
else:
|
||||
loss_out, v_loss, p_loss = policy.get_session().run(
|
||||
[policy._loss, policy.loss.v_loss, policy.loss.p_loss],
|
||||
# policy._loss is create by TFPolicy, and is basically the
|
||||
# loss tensor of the static graph.
|
||||
[
|
||||
policy._loss,
|
||||
policy._marwil_loss.v_loss,
|
||||
policy._marwil_loss.p_loss,
|
||||
],
|
||||
feed_dict=policy._get_loss_inputs_dict(
|
||||
postprocessed_batch, shuffle=False
|
||||
),
|
||||
|
@ -212,8 +218,8 @@ class TestMARWIL(unittest.TestCase):
|
|||
check(v_loss, expected_vf_loss, decimals=4)
|
||||
check(p_loss, expected_pol_loss, decimals=4)
|
||||
else:
|
||||
check(policy.loss.v_loss, expected_vf_loss, decimals=4)
|
||||
check(policy.loss.p_loss, expected_pol_loss, decimals=4)
|
||||
check(policy._marwil_loss.v_loss, expected_vf_loss, decimals=4)
|
||||
check(policy._marwil_loss.p_loss, expected_pol_loss, decimals=4)
|
||||
check(loss_out, expected_loss, decimals=3)
|
||||
|
||||
|
||||
|
|
|
@ -47,11 +47,17 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
|
|||
Providing nested lists w/o this preprocessing step would
|
||||
confuse a SampleBatch constructor.
|
||||
"""
|
||||
for k, v in policy.view_requirements.items():
|
||||
if k not in json_data:
|
||||
continue
|
||||
for k, v in json_data.items():
|
||||
data_col = (
|
||||
policy.view_requirements[k].data_col
|
||||
if k in policy.view_requirements
|
||||
else ""
|
||||
)
|
||||
if policy.config.get("_disable_action_flattening") and (
|
||||
k == SampleBatch.ACTIONS or v.data_col == SampleBatch.ACTIONS
|
||||
k == SampleBatch.ACTIONS
|
||||
or data_col == SampleBatch.ACTIONS
|
||||
or k == SampleBatch.PREV_ACTIONS
|
||||
or data_col == SampleBatch.PREV_ACTIONS
|
||||
):
|
||||
json_data[k] = tree.map_structure_up_to(
|
||||
policy.action_space_struct,
|
||||
|
@ -60,7 +66,10 @@ def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict:
|
|||
check_types=False,
|
||||
)
|
||||
elif policy.config.get("_disable_preprocessor_api") and (
|
||||
k == SampleBatch.OBS or v.data_col == SampleBatch.OBS
|
||||
k == SampleBatch.OBS
|
||||
or data_col == SampleBatch.OBS
|
||||
or k == SampleBatch.NEXT_OBS
|
||||
or data_col == SampleBatch.NEXT_OBS
|
||||
):
|
||||
json_data[k] = tree.map_structure_up_to(
|
||||
policy.observation_space_struct,
|
||||
|
|
|
@ -79,7 +79,7 @@ class NestedActionSpacesTest(unittest.TestCase):
|
|||
config["env_config"] = {
|
||||
"action_space": action_space,
|
||||
}
|
||||
for flatten in [False, True]:
|
||||
for flatten in [True, False]:
|
||||
print(f"A={action_space} flatten={flatten}")
|
||||
shutil.rmtree(config["output"])
|
||||
config["_disable_action_flattening"] = not flatten
|
||||
|
|
Loading…
Add table
Reference in a new issue