mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[rllib] [experimental] custom RL training pipelines (PG_pl, A2C_pl) (#7213)
This commit is contained in:
parent
7bef7031c2
commit
46af992efd
11 changed files with 447 additions and 9 deletions
|
@ -66,6 +66,14 @@ py_test(
|
|||
srcs = ["agents/dqn/tests/test_dqn.py"]
|
||||
)
|
||||
|
||||
# A2CTrainer
|
||||
py_test(
|
||||
name = "test_a2c",
|
||||
tags = ["agents_dir"],
|
||||
size = "small",
|
||||
srcs = ["agents/a3c/tests/test_a2c.py"]
|
||||
)
|
||||
|
||||
# PGTrainer
|
||||
py_test(
|
||||
name = "test_pg",
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from ray.rllib.agents.a3c.a3c import A3CTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.a3c.a2c import A2CTrainer
|
||||
from ray.rllib.agents.a3c.a2c_pipeline import A2CPipeline
|
||||
from ray.rllib.utils import renamed_agent
|
||||
|
||||
A2CAgent = renamed_agent(A2CTrainer)
|
||||
A3CAgent = renamed_agent(A3CTrainer)
|
||||
|
||||
__all__ = [
|
||||
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG"
|
||||
"A2CAgent", "A3CAgent", "A2CTrainer", "A3CTrainer", "DEFAULT_CONFIG",
|
||||
"A2CPipeline"
|
||||
]
|
||||
|
|
38
rllib/agents/a3c/a2c_pipeline.py
Normal file
38
rllib/agents/a3c/a2c_pipeline.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
"""Experimental pipeline-based impl; run this with --run='A2C_pl'"""
|
||||
|
||||
import math
|
||||
|
||||
from ray.rllib.agents.a3c.a2c import A2CTrainer
|
||||
from ray.rllib.utils.experimental_dsl import (
|
||||
ParallelRollouts, ConcatBatches, ComputeGradients, AverageGradients,
|
||||
ApplyGradients, TrainOneStep, StandardMetricsReporting)
|
||||
|
||||
|
||||
def training_pipeline(workers, config):
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
if config["microbatch_size"]:
|
||||
num_microbatches = math.ceil(
|
||||
config["train_batch_size"] / config["microbatch_size"])
|
||||
# In microbatch mode, we want to compute gradients on experience
|
||||
# microbatches, average a number of these microbatches, and then apply
|
||||
# the averaged gradient in one SGD step. This conserves GPU memory,
|
||||
# allowing for extremely large experience batches to be used.
|
||||
train_op = (
|
||||
rollouts.combine(
|
||||
ConcatBatches(min_batch_size=config["microbatch_size"]))
|
||||
.for_each(ComputeGradients(workers)) # (grads, info)
|
||||
.batch(num_microbatches) # List[(grads, info)]
|
||||
.for_each(AverageGradients()) # (avg_grads, info)
|
||||
.for_each(ApplyGradients(workers)))
|
||||
else:
|
||||
# In normal mode, we execute one SGD step per each train batch.
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
A2CPipeline = A2CTrainer.with_updates(training_pipeline=training_pipeline)
|
34
rllib/agents/a3c/tests/test_a2c.py
Normal file
34
rllib/agents/a3c/tests/test_a2c.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.a3c import a2c_pipeline
|
||||
|
||||
|
||||
class TestA2C(unittest.TestCase):
|
||||
"""Sanity tests for A2CPipeline."""
|
||||
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_a2c_pipeline(ray_start_regular):
|
||||
trainer = a2c_pipeline.A2CPipeline(
|
||||
env="CartPole-v0", config={"min_iter_time_s": 0})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
|
||||
def test_a2c_pipeline_microbatch(ray_start_regular):
|
||||
trainer = a2c_pipeline.A2CPipeline(
|
||||
env="CartPole-v0",
|
||||
config={
|
||||
"min_iter_time_s": 0,
|
||||
"microbatch_size": 10
|
||||
})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
|
@ -1,7 +1,10 @@
|
|||
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG
|
||||
from ray.rllib.agents.pg.pg_pipeline import PGPipeline
|
||||
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \
|
||||
post_process_advantages
|
||||
from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss
|
||||
|
||||
__all__ = ["PGTrainer", "pg_tf_loss", "pg_torch_loss",
|
||||
"post_process_advantages", "DEFAULT_CONFIG"]
|
||||
__all__ = [
|
||||
"PGTrainer", "pg_tf_loss", "pg_torch_loss", "post_process_advantages",
|
||||
"DEFAULT_CONFIG", "PGPipeline"
|
||||
]
|
||||
|
|
24
rllib/agents/pg/pg_pipeline.py
Normal file
24
rllib/agents/pg/pg_pipeline.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
"""Experimental pipeline-based impl; run this with --run='PG_pl'"""
|
||||
|
||||
from ray.rllib.agents.pg.pg import PGTrainer
|
||||
from ray.rllib.utils.experimental_dsl import (
|
||||
ParallelRollouts, ConcatBatches, TrainOneStep, StandardMetricsReporting)
|
||||
|
||||
|
||||
def training_pipeline(workers, config):
|
||||
# Collects experiences in parallel from multiple RolloutWorker actors.
|
||||
rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
|
||||
# Combine experiences batches until we hit `train_batch_size` in size.
|
||||
# Then, train the policy on those experiences and update the workers.
|
||||
train_op = rollouts \
|
||||
.combine(ConcatBatches(
|
||||
min_batch_size=config["train_batch_size"])) \
|
||||
.for_each(TrainOneStep(workers))
|
||||
|
||||
# Add on the standard episode reward, etc. metrics reporting. This returns
|
||||
# a LocalIterator[metrics_dict] representing metrics for each train step.
|
||||
return StandardMetricsReporting(train_op, workers, config)
|
||||
|
||||
|
||||
PGPipeline = PGTrainer.with_updates(training_pipeline=training_pipeline)
|
|
@ -3,6 +3,7 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.agents.pg as pg
|
||||
from ray.rllib.agents.pg import PGPipeline
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.tf.tf_action_dist import Categorical
|
||||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
|
||||
|
@ -11,8 +12,15 @@ from ray.rllib.utils import check, fc
|
|||
|
||||
|
||||
class TestPG(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init()
|
||||
|
||||
ray.init()
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def test_pg_pipeline(ray_start_regular):
|
||||
trainer = PGPipeline(env="CartPole-v0", config={"min_iter_time_s": 0})
|
||||
assert isinstance(trainer.train(), dict)
|
||||
|
||||
def test_pg_compilation(self):
|
||||
"""Test whether a PGTrainer can be built with both frameworks."""
|
||||
|
@ -101,5 +109,6 @@ class TestPG(unittest.TestCase):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main(verbosity=1)
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
|
|
@ -100,6 +100,16 @@ def _import_marwil():
|
|||
return marwil.MARWILTrainer
|
||||
|
||||
|
||||
def _import_a2c_pipeline():
|
||||
from ray.rllib.agents import a3c
|
||||
return a3c.A2CPipeline
|
||||
|
||||
|
||||
def _import_pg_pipeline():
|
||||
from ray.rllib.agents import pg
|
||||
return pg.PGPipeline
|
||||
|
||||
|
||||
ALGORITHMS = {
|
||||
"SAC": _import_sac,
|
||||
"DDPG": _import_ddpg,
|
||||
|
@ -120,6 +130,10 @@ ALGORITHMS = {
|
|||
"APPO": _import_appo,
|
||||
"DDPPO": _import_ddppo,
|
||||
"MARWIL": _import_marwil,
|
||||
|
||||
# Experimental pipeline-based impls.
|
||||
"A2C_pl": _import_a2c_pipeline,
|
||||
"PG_pl": _import_pg_pipeline,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -22,7 +22,8 @@ def build_trainer(name,
|
|||
after_train_result=None,
|
||||
collect_metrics_fn=None,
|
||||
before_evaluate_fn=None,
|
||||
mixins=None):
|
||||
mixins=None,
|
||||
training_pipeline=None):
|
||||
"""Helper function for defining a custom trainer.
|
||||
|
||||
Functions will be run in this order to initialize the trainer:
|
||||
|
@ -66,6 +67,8 @@ def build_trainer(name,
|
|||
mixins (list): list of any class mixins for the returned trainer class.
|
||||
These mixins will be applied in order and will have higher
|
||||
precedence than the Trainer class
|
||||
training_pipeline (func): Experimental support for custom
|
||||
training pipelines. This overrides `make_policy_optimizer`.
|
||||
|
||||
Returns:
|
||||
a Trainer instance that uses the specified args.
|
||||
|
@ -100,7 +103,12 @@ def build_trainer(name,
|
|||
else:
|
||||
self.workers = self._make_workers(env_creator, policy, config,
|
||||
self.config["num_workers"])
|
||||
if make_policy_optimizer:
|
||||
self.train_pipeline = None
|
||||
self.optimizer = None
|
||||
|
||||
if training_pipeline:
|
||||
self.train_pipeline = training_pipeline(self.workers, config)
|
||||
elif make_policy_optimizer:
|
||||
self.optimizer = make_policy_optimizer(self.workers, config)
|
||||
else:
|
||||
optimizer_config = dict(
|
||||
|
@ -113,6 +121,9 @@ def build_trainer(name,
|
|||
|
||||
@override(Trainer)
|
||||
def _train(self):
|
||||
if self.train_pipeline:
|
||||
return self._train_pipeline()
|
||||
|
||||
if before_train_step:
|
||||
before_train_step(self)
|
||||
prev_steps = self.optimizer.num_steps_sampled
|
||||
|
@ -140,6 +151,14 @@ def build_trainer(name,
|
|||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
def _train_pipeline(self):
|
||||
if before_train_step:
|
||||
before_train_step(self)
|
||||
res = next(self.train_pipeline)
|
||||
if after_train_result:
|
||||
after_train_result(self, res)
|
||||
return res
|
||||
|
||||
@override(Trainer)
|
||||
def _before_evaluate(self):
|
||||
if before_evaluate_fn:
|
||||
|
|
|
@ -5,6 +5,7 @@ import logging
|
|||
import pickle
|
||||
|
||||
import ray
|
||||
from ray.util.iter import ParallelIteratorWorker
|
||||
from ray.rllib.env.atari_wrappers import wrap_deepmind, is_atari
|
||||
from ray.rllib.env.base_env import BaseEnv
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
|
@ -52,7 +53,7 @@ def get_global_worker():
|
|||
|
||||
|
||||
@DeveloperAPI
|
||||
class RolloutWorker(EvaluatorInterface):
|
||||
class RolloutWorker(EvaluatorInterface, ParallelIteratorWorker):
|
||||
"""Common experience collection class.
|
||||
|
||||
This class wraps a policy instance and an environment class to
|
||||
|
@ -241,6 +242,12 @@ class RolloutWorker(EvaluatorInterface):
|
|||
global _global_worker
|
||||
_global_worker = self
|
||||
|
||||
def gen_rollouts():
|
||||
while True:
|
||||
yield self.sample()
|
||||
|
||||
ParallelIteratorWorker.__init__(self, gen_rollouts, False)
|
||||
|
||||
policy_config = policy_config or {}
|
||||
if (tf and policy_config.get("eager")
|
||||
and not policy_config.get("no_eager_on_workers")):
|
||||
|
|
280
rllib/utils/experimental_dsl.py
Normal file
280
rllib/utils/experimental_dsl.py
Normal file
|
@ -0,0 +1,280 @@
|
|||
"""Experimental operators for defining distributed training pipelines.
|
||||
|
||||
TODO(ekl): describe the concepts."""
|
||||
|
||||
from typing import List, Any
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.util.iter import from_actors, LocalIterator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.evaluation.worker_set import WorkerSet
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
|
||||
|
||||
def ParallelRollouts(workers: WorkerSet,
|
||||
mode="bulk_sync") -> LocalIterator[SampleBatch]:
|
||||
"""Operator to collect experiences in parallel from rollout workers.
|
||||
|
||||
If there are no remote workers, experiences will be collected serially from
|
||||
the local worker instance instead.
|
||||
|
||||
Arguments:
|
||||
workers (WorkerSet): set of rollout workers to use.
|
||||
mode (str): One of {'async', 'bulk_sync'}.
|
||||
- In 'async' mode, batches are returned as soon as they are
|
||||
computed by rollout workers with no order guarantees.
|
||||
- In 'bulk_sync' mode, we collect one batch from each worker
|
||||
and concatenate them together into a large batch to return.
|
||||
|
||||
Returns:
|
||||
A local iterator over experiences collected in parallel.
|
||||
|
||||
Examples:
|
||||
>>> rollouts = ParallelRollouts(workers, mode="async")
|
||||
>>> batch = next(rollouts)
|
||||
>>> print(batch.count)
|
||||
50 # config.sample_batch_size
|
||||
|
||||
>>> rollouts = ParallelRollouts(workers, mode="bulk_sync")
|
||||
>>> batch = next(rollouts)
|
||||
>>> print(batch.count)
|
||||
200 # config.sample_batch_size * config.num_workers
|
||||
"""
|
||||
|
||||
if not workers.remote_workers():
|
||||
# Handle the serial sampling case.
|
||||
def sampler(_):
|
||||
while True:
|
||||
yield workers.local_worker().sample()
|
||||
|
||||
return LocalIterator(sampler)
|
||||
|
||||
# Create a parallel iterator over generated experiences.
|
||||
rollouts = from_actors(workers.remote_workers())
|
||||
|
||||
if mode == "bulk_sync":
|
||||
return rollouts \
|
||||
.batch_across_shards() \
|
||||
.for_each(lambda batches: SampleBatch.concat_samples(batches))
|
||||
elif mode == "async":
|
||||
return rollouts.gather_async()
|
||||
else:
|
||||
raise ValueError(
|
||||
"mode must be one of 'bulk_sync', 'async', got '{}'".format(mode))
|
||||
|
||||
|
||||
def StandardMetricsReporting(train_op: LocalIterator[Any], workers: WorkerSet,
|
||||
config: dict):
|
||||
"""Operator to periodically collect and report metrics.
|
||||
|
||||
Arguments:
|
||||
train_op (LocalIterator): Operator for executing training steps.
|
||||
We ignore the output values.
|
||||
workers (WorkerSet): Rollout workers to collect metrics from.
|
||||
config (dict): Trainer configuration, used to determine the frequency
|
||||
of stats reporting.
|
||||
|
||||
Returns:
|
||||
A local iterator over training results.
|
||||
|
||||
Examples:
|
||||
>>> train_op = ParallelRollouts(...).for_each(TrainOneStep(...))
|
||||
>>> metrics_op = StandardMetricsReporting(train_op, workers, config)
|
||||
>>> next(metrics_op)
|
||||
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
|
||||
"""
|
||||
|
||||
output_op = train_op \
|
||||
.filter(OncePerTimeInterval(config["min_iter_time_s"])) \
|
||||
.for_each(CollectMetrics(
|
||||
workers, min_history=config["metrics_smoothing_episodes"],
|
||||
timeout_seconds=config["collect_metrics_timeout"]))
|
||||
return output_op
|
||||
|
||||
|
||||
class ConcatBatches:
|
||||
"""Callable used to merge batches into larger batches for training.
|
||||
|
||||
This should be used with the .combine() operator.
|
||||
|
||||
Examples:
|
||||
>>> rollouts = ParallelRollouts(...)
|
||||
>>> rollouts = rollouts.combine(ConcatBatches(min_batch_size=10000))
|
||||
>>> print(next(rollouts).count)
|
||||
10000
|
||||
"""
|
||||
|
||||
def __init__(self, min_batch_size: int):
|
||||
self.min_batch_size = min_batch_size
|
||||
self.buffer = []
|
||||
self.count = 0
|
||||
|
||||
def __call__(self, batch: SampleBatch) -> List[SampleBatch]:
|
||||
if not isinstance(batch, SampleBatch):
|
||||
raise ValueError("Expected type SampleBatch, got {}: {}".format(
|
||||
type(batch), batch))
|
||||
self.buffer.append(batch)
|
||||
self.count += batch.count
|
||||
if self.count >= self.min_batch_size:
|
||||
out = SampleBatch.concat_samples(self.buffer)
|
||||
self.buffer = []
|
||||
self.count = 0
|
||||
return [out]
|
||||
return []
|
||||
|
||||
|
||||
class TrainOneStep:
|
||||
"""Callable that improves the policy and updates workers.
|
||||
|
||||
This should be used with the .for_each() operator.
|
||||
|
||||
Examples:
|
||||
>>> rollouts = ParallelRollouts(...)
|
||||
>>> train_op = rollouts.for_each(TrainOneStep(workers))
|
||||
>>> print(next(train_op)) # This trains the policy on one batch.
|
||||
{"learner_stats": {"policy_loss": ...}}
|
||||
"""
|
||||
|
||||
def __init__(self, workers: WorkerSet):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, batch: SampleBatch) -> List[dict]:
|
||||
info = self.workers.local_worker().learn_on_batch(batch)
|
||||
if self.workers.remote_workers():
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
return info
|
||||
|
||||
|
||||
class CollectMetrics:
|
||||
"""Callable that collects metrics from workers.
|
||||
|
||||
The metrics are smoothed over a given history window.
|
||||
|
||||
This should be used with the .for_each() operator. For a higher level
|
||||
API, consider using StandardMetricsReporting instead.
|
||||
|
||||
Examples:
|
||||
>>> output_op = train_op.for_each(CollectMetrics(workers))
|
||||
>>> print(next(output_op))
|
||||
{"episode_reward_max": ..., "episode_reward_mean": ..., ...}
|
||||
"""
|
||||
|
||||
def __init__(self, workers, min_history=100, timeout_seconds=180):
|
||||
self.workers = workers
|
||||
self.episode_history = []
|
||||
self.to_be_collected = []
|
||||
self.min_history = min_history
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def __call__(self, info):
|
||||
episodes, self.to_be_collected = collect_episodes(
|
||||
self.workers.local_worker(),
|
||||
self.workers.remote_workers(),
|
||||
self.to_be_collected,
|
||||
timeout_seconds=self.timeout_seconds)
|
||||
orig_episodes = list(episodes)
|
||||
missing = self.min_history - len(episodes)
|
||||
if missing > 0:
|
||||
episodes.extend(self.episode_history[-missing:])
|
||||
assert len(episodes) <= self.min_history
|
||||
self.episode_history.extend(orig_episodes)
|
||||
self.episode_history = self.episode_history[-self.min_history:]
|
||||
res = summarize_episodes(episodes, orig_episodes)
|
||||
res.update(info=info)
|
||||
return res
|
||||
|
||||
|
||||
class OncePerTimeInterval:
|
||||
"""Callable that returns True once per given interval.
|
||||
|
||||
This should be used with the .filter() operator to throttle / rate-limit
|
||||
metrics reporting. For a higher-level API, consider using
|
||||
StandardMetricsReporting instead.
|
||||
|
||||
Examples:
|
||||
>>> throttled_op = train_op.filter(OncePerTimeInterval(5))
|
||||
>>> start = time.time()
|
||||
>>> next(throttled_op)
|
||||
>>> print(time.time() - start)
|
||||
5.00001 # will be greater than 5 seconds
|
||||
"""
|
||||
|
||||
def __init__(self, delay):
|
||||
self.delay = delay
|
||||
self.last_called = 0
|
||||
|
||||
def __call__(self, item):
|
||||
now = time.time()
|
||||
if now - self.last_called > self.delay:
|
||||
self.last_called = now
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ComputeGradients:
|
||||
"""Callable that computes gradients with respect to the policy loss.
|
||||
|
||||
This should be used with the .for_each() operator.
|
||||
|
||||
Examples:
|
||||
>>> grads_op = rollouts.for_each(ComputeGradients(workers))
|
||||
>>> print(next(grads_op))
|
||||
{"var_0": ..., ...}, {"learner_stats": ...} # grads, learner info
|
||||
"""
|
||||
|
||||
def __init__(self, workers):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, samples):
|
||||
grad, info = self.workers.local_worker().compute_gradients(samples)
|
||||
return grad, info
|
||||
|
||||
|
||||
class ApplyGradients:
|
||||
"""Callable that applies gradients and updates workers.
|
||||
|
||||
This should be used with the .for_each() operator.
|
||||
|
||||
Examples:
|
||||
>>> apply_op = grads_op.for_each(ApplyGradients(workers))
|
||||
>>> print(next(apply_op))
|
||||
{"learner_stats": ...} # learner info
|
||||
"""
|
||||
|
||||
def __init__(self, workers):
|
||||
self.workers = workers
|
||||
|
||||
def __call__(self, item):
|
||||
gradients, info = item
|
||||
self.workers.local_worker().apply_gradients(gradients)
|
||||
if self.workers.remote_workers():
|
||||
weights = ray.put(self.workers.local_worker().get_weights())
|
||||
for e in self.workers.remote_workers():
|
||||
e.set_weights.remote(weights)
|
||||
return info
|
||||
|
||||
|
||||
class AverageGradients:
|
||||
"""Callable that averages the gradients in a batch.
|
||||
|
||||
This should be used with the .for_each() operator after a set of gradients
|
||||
have been batched with .batch().
|
||||
|
||||
Examples:
|
||||
>>> batched_grads = grads_op.batch(32)
|
||||
>>> avg_grads = batched_grads.for_each(AverageGradients())
|
||||
>>> print(next(avg_grads))
|
||||
{"var_0": ..., ...}, {"learner_stats": ...} # avg grads, last info
|
||||
"""
|
||||
|
||||
def __call__(self, gradients):
|
||||
acc = None
|
||||
for grad, info in gradients:
|
||||
if acc is None:
|
||||
acc = grad
|
||||
else:
|
||||
acc = [a + b for a, b in zip(acc, grad)]
|
||||
return acc, info
|
Loading…
Add table
Reference in a new issue