[tune] Support lambda functions in hyperparameters / tune rllib multiagent support (#2568)

* update

* func

* Update registry.py

* revert
This commit is contained in:
Eric Liang 2018-08-07 16:29:21 -07:00 committed by GitHub
parent e7f76d7914
commit 64053278aa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 23 deletions

View file

@ -34,8 +34,8 @@ dictionary.
"trial_resources": { "cpu": 1, "gpu": 0 },
"stop": { "mean_accuracy": 100 },
"config": {
"alpha": grid_search([0.2, 0.4, 0.6]),
"beta": grid_search([1, 2]),
"alpha": tune.grid_search([0.2, 0.4, 0.6]),
"beta": tune.grid_search([1, 2]),
},
"upload_dir": "s3://your_bucket/path",
"local_dir": "~/ray_results",
@ -49,7 +49,7 @@ An example of this can be found in `async_hyperband_example.py <https://github.c
Trial Variant Generation
------------------------
In the above example, we specified a grid search over two parameters using the ``grid_search`` helper function. Ray Tune also supports sampling parameters from user-specified lambda functions, which can be used in combination with grid search.
In the above example, we specified a grid search over two parameters using the ``tune.grid_search`` helper function. Ray Tune also supports sampling parameters from user-specified lambda functions, which can be used in combination with grid search.
The following shows grid search over two nested parameters combined with random sampling from two lambda functions. Note that the value of ``beta`` depends on the value of ``alpha``, which is represented by referencing ``spec.config.alpha`` in the lambda function. This lets you specify conditional parameter distributions.
@ -59,14 +59,18 @@ The following shows grid search over two nested parameters combined with random
"alpha": lambda spec: np.random.uniform(100),
"beta": lambda spec: spec.config.alpha * np.random.normal(),
"nn_layers": [
grid_search([16, 64, 256]),
grid_search([16, 64, 256]),
tune.grid_search([16, 64, 256]),
tune.grid_search([16, 64, 256]),
],
},
"repeat": 10,
By default, each random variable and grid search point is sampled once. To take multiple random samples or repeat grid search runs, add ``repeat: N`` to the experiment config. E.g. in the above, ``"repeat": 10`` repeats the 3x3 grid search 10 times, for a total of 90 trials, each with randomly sampled values of ``alpha`` and ``beta``.
.. note::
Lambda functions will be evaluated during trial variant generation. If you need to pass a literal function in your config, use ``tune.function(...)`` to escape it.
For more information on variant generation, see `basic_variant.py <https://github.com/ray-project/ray/blob/master/python/ray/tune/suggest/basic_variant.py>`__.
Resource Allocation

View file

@ -17,10 +17,10 @@ import gym
import random
import ray
from ray.rllib.agents.pg.pg import PGAgent
from ray import tune
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.test.test_multi_agent_env import MultiCartpole
from ray.tune.logger import pretty_print
from ray.tune import run_experiments
from ray.tune.registry import register_env
parser = argparse.ArgumentParser()
@ -53,16 +53,19 @@ if __name__ == "__main__":
}
policy_ids = list(policy_graphs.keys())
agent = PGAgent(
env="multi_cartpole",
config={
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": (
lambda agent_id: random.choice(policy_ids)),
run_experiments({
"test": {
"run": "PG",
"env": "multi_cartpole",
"stop": {
"training_iteration": args.num_iters
},
})
for i in range(args.num_iters):
print("== Iteration", i, "==")
print(pretty_print(agent.train()))
"config": {
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(
lambda agent_id: random.choice(policy_ids)),
},
},
}
})

View file

@ -7,9 +7,9 @@ from ray.tune.tune import run_experiments
from ray.tune.experiment import Experiment
from ray.tune.registry import register_env, register_trainable
from ray.tune.trainable import Trainable
from ray.tune.suggest import grid_search
from ray.tune.suggest import grid_search, function
__all__ = [
"Trainable", "TuneError", "grid_search", "register_env",
"register_trainable", "run_experiments", "Experiment"
"register_trainable", "run_experiments", "Experiment", "function"
]

View file

@ -2,9 +2,9 @@ from ray.tune.suggest.search import SearchAlgorithm
from ray.tune.suggest.basic_variant import BasicVariantGenerator
from ray.tune.suggest.suggestion import SuggestionAlgorithm
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.variant_generator import grid_search
from ray.tune.suggest.variant_generator import grid_search, function
__all__ = [
"SearchAlgorithm", "BasicVariantGenerator", "HyperOptSearch",
"SuggestionAlgorithm", "grid_search"
"SuggestionAlgorithm", "grid_search", "function"
]

View file

@ -51,6 +51,16 @@ def grid_search(values):
return {"grid_search": values}
class function(object):
"""Wraps `func` to make sure it is not expanded during resolution."""
def __init__(self, func):
self.func = func
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)
_STANDARD_IMPORTS = {
"random": random,
"np": numpy,