mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Attention net example script: Clarifications on how to use with Trainer.compute_action. (#14864)
This commit is contained in:
parent
5f7ce293fe
commit
78c64ca151
2 changed files with 25 additions and 3 deletions
|
@ -1612,7 +1612,7 @@ py_test(
|
||||||
tags = ["examples", "examples_A"],
|
tags = ["examples", "examples_A"],
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["examples/attention_net.py"],
|
srcs = ["examples/attention_net.py"],
|
||||||
args = ["--as-test", "--stop-reward=80", "--torch"]
|
args = ["--as-test", "--stop-reward=80", "--framework torch"]
|
||||||
)
|
)
|
||||||
|
|
||||||
py_test(
|
py_test(
|
||||||
|
|
|
@ -17,7 +17,7 @@ parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--run", type=str, default="PPO")
|
parser.add_argument("--run", type=str, default="PPO")
|
||||||
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
|
parser.add_argument("--env", type=str, default="RepeatAfterMeEnv")
|
||||||
parser.add_argument("--num-cpus", type=int, default=0)
|
parser.add_argument("--num-cpus", type=int, default=0)
|
||||||
parser.add_argument("--torch", action="store_true")
|
parser.add_argument("--framework", choices=["tf", "torch"], default="tf")
|
||||||
parser.add_argument("--as-test", action="store_true")
|
parser.add_argument("--as-test", action="store_true")
|
||||||
parser.add_argument("--stop-iters", type=int, default=200)
|
parser.add_argument("--stop-iters", type=int, default=200)
|
||||||
parser.add_argument("--stop-timesteps", type=int, default=500000)
|
parser.add_argument("--stop-timesteps", type=int, default=500000)
|
||||||
|
@ -59,7 +59,7 @@ if __name__ == "__main__":
|
||||||
"attention_head_dim": 32,
|
"attention_head_dim": 32,
|
||||||
"attention_position_wise_mlp_dim": 32,
|
"attention_position_wise_mlp_dim": 32,
|
||||||
},
|
},
|
||||||
"framework": "torch" if args.torch else "tf",
|
"framework": args.framework,
|
||||||
}
|
}
|
||||||
|
|
||||||
stop = {
|
stop = {
|
||||||
|
@ -68,6 +68,28 @@ if __name__ == "__main__":
|
||||||
"episode_reward_mean": args.stop_reward,
|
"episode_reward_mean": args.stop_reward,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# To run the Trainer without tune.run, using the attention net and
|
||||||
|
# manual state-in handling, do the following:
|
||||||
|
|
||||||
|
# Example (use `config` from the above code):
|
||||||
|
# >> import numpy as np
|
||||||
|
# >> from ray.rllib.agents.ppo import PPOTrainer
|
||||||
|
# >>
|
||||||
|
# >> trainer = PPOTrainer(config)
|
||||||
|
# >> env = RepeatAfterMeEnv({})
|
||||||
|
# >> obs = env.reset()
|
||||||
|
# >> init_state = state = np.zeros(
|
||||||
|
# [100 (attention_memory_inference), 64 (attention_dim)], np.float32)
|
||||||
|
# >> while True:
|
||||||
|
# >> a, state_out, _ = trainer.compute_action(obs, [state])
|
||||||
|
# >> obs, reward, done, _ = env.step(a)
|
||||||
|
# >> if done:
|
||||||
|
# >> obs = env.reset()
|
||||||
|
# >> state = init_state
|
||||||
|
# >> else:
|
||||||
|
# >> state = np.concatenate([state, [state_out[0]]])[1:]
|
||||||
|
|
||||||
|
# We use tune here, which handles env and trainer creation for us.
|
||||||
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
results = tune.run(args.run, config=config, stop=stop, verbose=2)
|
||||||
|
|
||||||
if args.as_test:
|
if args.as_test:
|
||||||
|
|
Loading…
Add table
Reference in a new issue