mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] Support lambda functions in hyperparameters / tune rllib multiagent support (#2568)
* update * func * Update registry.py * revert
This commit is contained in:
parent
e7f76d7914
commit
64053278aa
5 changed files with 40 additions and 23 deletions
|
@ -34,8 +34,8 @@ dictionary.
|
|||
"trial_resources": { "cpu": 1, "gpu": 0 },
|
||||
"stop": { "mean_accuracy": 100 },
|
||||
"config": {
|
||||
"alpha": grid_search([0.2, 0.4, 0.6]),
|
||||
"beta": grid_search([1, 2]),
|
||||
"alpha": tune.grid_search([0.2, 0.4, 0.6]),
|
||||
"beta": tune.grid_search([1, 2]),
|
||||
},
|
||||
"upload_dir": "s3://your_bucket/path",
|
||||
"local_dir": "~/ray_results",
|
||||
|
@ -49,7 +49,7 @@ An example of this can be found in `async_hyperband_example.py <https://github.c
|
|||
Trial Variant Generation
|
||||
------------------------
|
||||
|
||||
In the above example, we specified a grid search over two parameters using the ``grid_search`` helper function. Ray Tune also supports sampling parameters from user-specified lambda functions, which can be used in combination with grid search.
|
||||
In the above example, we specified a grid search over two parameters using the ``tune.grid_search`` helper function. Ray Tune also supports sampling parameters from user-specified lambda functions, which can be used in combination with grid search.
|
||||
|
||||
The following shows grid search over two nested parameters combined with random sampling from two lambda functions. Note that the value of ``beta`` depends on the value of ``alpha``, which is represented by referencing ``spec.config.alpha`` in the lambda function. This lets you specify conditional parameter distributions.
|
||||
|
||||
|
@ -59,14 +59,18 @@ The following shows grid search over two nested parameters combined with random
|
|||
"alpha": lambda spec: np.random.uniform(100),
|
||||
"beta": lambda spec: spec.config.alpha * np.random.normal(),
|
||||
"nn_layers": [
|
||||
grid_search([16, 64, 256]),
|
||||
grid_search([16, 64, 256]),
|
||||
tune.grid_search([16, 64, 256]),
|
||||
tune.grid_search([16, 64, 256]),
|
||||
],
|
||||
},
|
||||
"repeat": 10,
|
||||
|
||||
By default, each random variable and grid search point is sampled once. To take multiple random samples or repeat grid search runs, add ``repeat: N`` to the experiment config. E.g. in the above, ``"repeat": 10`` repeats the 3x3 grid search 10 times, for a total of 90 trials, each with randomly sampled values of ``alpha`` and ``beta``.
|
||||
|
||||
.. note::
|
||||
|
||||
Lambda functions will be evaluated during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it.
|
||||
|
||||
For more information on variant generation, see `basic_variant.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/suggest/basic_variant.py>`__.
|
||||
|
||||
Resource Allocation
|
||||
|
|
|
@ -17,10 +17,10 @@ import gym
|
|||
import random
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.pg.pg import PGAgent
|
||||
from ray import tune
|
||||
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
|
||||
from ray.rllib.test.test_multi_agent_env import MultiCartpole
|
||||
from ray.tune.logger import pretty_print
|
||||
from ray.tune import run_experiments
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -53,16 +53,19 @@ if __name__ == "__main__":
|
|||
}
|
||||
policy_ids = list(policy_graphs.keys())
|
||||
|
||||
agent = PGAgent(
|
||||
env="multi_cartpole",
|
||||
config={
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": (
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
run_experiments({
|
||||
"test": {
|
||||
"run": "PG",
|
||||
"env": "multi_cartpole",
|
||||
"stop": {
|
||||
"training_iteration": args.num_iters
|
||||
},
|
||||
})
|
||||
|
||||
for i in range(args.num_iters):
|
||||
print("== Iteration", i, "==")
|
||||
print(pretty_print(agent.train()))
|
||||
"config": {
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": tune.function(
|
||||
lambda agent_id: random.choice(policy_ids)),
|
||||
},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
|
|
@ -7,9 +7,9 @@ from ray.tune.tune import run_experiments
|
|||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.registry import register_env, register_trainable
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.suggest import grid_search
|
||||
from ray.tune.suggest import grid_search, function
|
||||
|
||||
__all__ = [
|
||||
"Trainable", "TuneError", "grid_search", "register_env",
|
||||
"register_trainable", "run_experiments", "Experiment"
|
||||
"register_trainable", "run_experiments", "Experiment", "function"
|
||||
]
|
||||
|
|
|
@ -2,9 +2,9 @@ from ray.tune.suggest.search import SearchAlgorithm
|
|||
from ray.tune.suggest.basic_variant import BasicVariantGenerator
|
||||
from ray.tune.suggest.suggestion import SuggestionAlgorithm
|
||||
from ray.tune.suggest.hyperopt import HyperOptSearch
|
||||
from ray.tune.suggest.variant_generator import grid_search
|
||||
from ray.tune.suggest.variant_generator import grid_search, function
|
||||
|
||||
__all__ = [
|
||||
"SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch",
|
||||
"SuggestionAlgorithm", "grid_search"
|
||||
"SuggestionAlgorithm", "grid_search", "function"
|
||||
]
|
||||
|
|
|
@ -51,6 +51,16 @@ def grid_search(values):
|
|||
return {"grid_search": values}
|
||||
|
||||
|
||||
class function(object):
|
||||
"""Wraps `func` to make sure it is not expanded during resolution."""
|
||||
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
|
||||
_STANDARD_IMPORTS = {
|
||||
"random": random,
|
||||
"np": numpy,
|
||||
|
|
Loading…
Add table
Reference in a new issue