[RLlib] Attention net example script: Clarifications on how to use with Trainer.compute_action. (#14864)

This commit is contained in:
Sven Mika 2021-03-23 19:33:01 +01:00 committed by GitHub
parent 5f7ce293fe
commit 78c64ca151
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 3 deletions

View file

@ -1612,7 +1612,7 @@ py_test(
tags = ["examples", "examples_A"], tags = ["examples", "examples_A"],
size = "medium", size = "medium",
srcs = ["examples/attention_net.py"], srcs = ["examples/attention_net.py"],
args = ["--as-test", "--stop-reward=80", "--torch"] args = ["--as-test", "--stop-reward=80", "--framework torch"]
) )
py_test( py_test(

View file

@ -17,7 +17,7 @@ parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default="PPO") parser.add_argument("--run", type=str, default="PPO")
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv") parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
parser.add_argument("--num-cpus", type=int, default=0) parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--torch", action="store_true") parser.add_argument("--framework", choices=["tf", "torch"], default="tf")
parser.add_argument("--as-test", action="store_true") parser.add_argument("--as-test", action="store_true")
parser.add_argument("--stop-iters", type=int, default=200) parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-timesteps", type=int, default=500000) parser.add_argument("--stop-timesteps", type=int, default=500000)
@ -59,7 +59,7 @@ if __name__ == "__main__":
"attention_head_dim": 32, "attention_head_dim": 32,
"attention_position_wise_mlp_dim": 32, "attention_position_wise_mlp_dim": 32,
}, },
"framework": "torch" if args.torch else "tf", "framework": args.framework,
} }
stop = { stop = {
@ -68,6 +68,28 @@ if __name__ == "__main__":
"episode_reward_mean": args.stop_reward, "episode_reward_mean": args.stop_reward,
} }
# To run the Trainer without tune.run, using the attention net and
# manual state-in handling, do the following:
# Example (use `config` from the above code):
# >> import numpy as np
# >> from ray.rllib.agents.ppo import PPOTrainer
# >>
# >> trainer = PPOTrainer(config)
# >> env = RepeatAfterMeEnv({})
# >> obs = env.reset()
# >> init_state = state = np.zeros(
# [100 (attention_memory_inference), 64 (attention_dim)], np.float32)
# >> while True:
# >> a, state_out, _ = trainer.compute_action(obs, [state])
# >> obs, reward, done, _ = env.step(a)
# >> if done:
# >> obs = env.reset()
# >> state = init_state
# >> else:
# >> state = np.concatenate([state, [state_out[0]]])[1:]
# We use tune here, which handles env and trainer creation for us.
results = tune.run(args.run, config=config, stop=stop, verbose=2) results = tune.run(args.run, config=config, stop=stop, verbose=2)
if args.as_test: if args.as_test: