mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] Fix custom model metrics in multi-device case (#7640)
* fix example * add example test * lin
This commit is contained in:
parent
8adc84ccb9
commit
9a590ac6a5
2 changed files with 37 additions and 9 deletions
|
@ -54,6 +54,9 @@ class MyKerasModel(TFModelV2):
|
|||
def value_function(self):
|
||||
return tf.reshape(self._value_out, [-1])
|
||||
|
||||
def metrics(self):
|
||||
return {"foo": tf.constant(42.0)}
|
||||
|
||||
|
||||
class MyKerasQModel(DistributionalQModel):
|
||||
"""Custom model for DQN."""
|
||||
|
@ -85,6 +88,9 @@ class MyKerasQModel(DistributionalQModel):
|
|||
model_out = self.base_model(input_dict["obs"])
|
||||
return model_out, state
|
||||
|
||||
def metrics(self):
|
||||
return {"foo": tf.constant(42.0)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
@ -95,15 +101,33 @@ if __name__ == "__main__":
|
|||
ModelCatalog.register_custom_model(
|
||||
"keras_q_model", MyVisionNetwork
|
||||
if args.use_vision_network else MyKerasQModel)
|
||||
|
||||
# Tests https://github.com/ray-project/ray/issues/7293
|
||||
def check_has_custom_metric(result):
|
||||
r = result["result"]["info"]["learner"]
|
||||
if "default_policy" in r:
|
||||
r = r["default_policy"]
|
||||
assert r["model"]["foo"] == 42, result
|
||||
|
||||
if args.run == "DQN":
|
||||
extra_config = {"learning_starts": 0}
|
||||
else:
|
||||
extra_config = {}
|
||||
|
||||
tune.run(
|
||||
args.run,
|
||||
stop={"episode_reward_mean": args.stop},
|
||||
config={
|
||||
"env": "BreakoutNoFrameskip-v4"
|
||||
if args.use_vision_network else "CartPole-v0",
|
||||
"num_gpus": 0,
|
||||
"model": {
|
||||
"custom_model": "keras_q_model"
|
||||
if args.run == "DQN" else "keras_model"
|
||||
},
|
||||
})
|
||||
config=dict(
|
||||
extra_config, **{
|
||||
"log_level": "INFO",
|
||||
"env": "BreakoutNoFrameskip-v4"
|
||||
if args.use_vision_network else "CartPole-v0",
|
||||
"num_gpus": 0,
|
||||
"callbacks": {
|
||||
"on_train_result": check_has_custom_metric,
|
||||
},
|
||||
"model": {
|
||||
"custom_model": "keras_q_model"
|
||||
if args.run == "DQN" else "keras_model"
|
||||
},
|
||||
}))
|
||||
|
|
|
@ -15,6 +15,8 @@ logger = logging.getLogger(__name__)
|
|||
def averaged(kv):
|
||||
"""Average the value lists of a dictionary.
|
||||
|
||||
For non-scalar values, we simply pick the first value.
|
||||
|
||||
Arguments:
|
||||
kv (dict): dictionary with values that are lists of floats.
|
||||
|
||||
|
@ -25,6 +27,8 @@ def averaged(kv):
|
|||
for k, v in kv.items():
|
||||
if v[0] is not None and not isinstance(v[0], dict):
|
||||
out[k] = np.mean(v)
|
||||
else:
|
||||
out[k] = v[0]
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue