[rllib] Fix custom model metrics in multi-device case (#7640)

* fix example

* add example test

* lin
This commit is contained in:
Eric Liang 2020-03-23 12:40:22 -07:00 committed by GitHub
parent 8adc84ccb9
commit 9a590ac6a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 9 deletions

View file

@ -54,6 +54,9 @@ class MyKerasModel(TFModelV2):
def value_function(self):
return tf.reshape(self._value_out, [-1])
def metrics(self):
return {"foo": tf.constant(42.0)}
class MyKerasQModel(DistributionalQModel):
"""Custom model for DQN."""
@ -85,6 +88,9 @@ class MyKerasQModel(DistributionalQModel):
model_out = self.base_model(input_dict["obs"])
return model_out, state
def metrics(self):
return {"foo": tf.constant(42.0)}
if __name__ == "__main__":
args = parser.parse_args()
@ -95,15 +101,33 @@ if __name__ == "__main__":
ModelCatalog.register_custom_model(
"keras_q_model", MyVisionNetwork
if args.use_vision_network else MyKerasQModel)
# Tests https://github.com/ray-project/ray/issues/7293
def check_has_custom_metric(result):
r = result["result"]["info"]["learner"]
if "default_policy" in r:
r = r["default_policy"]
assert r["model"]["foo"] == 42, result
if args.run == "DQN":
extra_config = {"learning_starts": 0}
else:
extra_config = {}
tune.run(
args.run,
stop={"episode_reward_mean": args.stop},
config={
"env": "BreakoutNoFrameskip-v4"
if args.use_vision_network else "CartPole-v0",
"num_gpus": 0,
"model": {
"custom_model": "keras_q_model"
if args.run == "DQN" else "keras_model"
},
})
config=dict(
extra_config, **{
"log_level": "INFO",
"env": "BreakoutNoFrameskip-v4"
if args.use_vision_network else "CartPole-v0",
"num_gpus": 0,
"callbacks": {
"on_train_result": check_has_custom_metric,
},
"model": {
"custom_model": "keras_q_model"
if args.run == "DQN" else "keras_model"
},
}))

View file

@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
def averaged(kv):
"""Average the value lists of a dictionary.
For non-scalar values, we simply pick the first value.
Arguments:
kv (dict): dictionary with values that are lists of floats.
@ -25,6 +27,8 @@ def averaged(kv):
for k, v in kv.items():
if v[0] is not None and not isinstance(v[0], dict):
out[k] = np.mean(v)
else:
out[k] = v[0]
return out