diff --git a/python/ray/util/sgd/torch/torch_trainer.py b/python/ray/util/sgd/torch/torch_trainer.py index 805ae37bd..9293eb548 100644 --- a/python/ray/util/sgd/torch/torch_trainer.py +++ b/python/ray/util/sgd/torch/torch_trainer.py @@ -670,6 +670,12 @@ class BaseTorchTrainable(Trainable): You may want to override this if using a custom LR scheduler. """ + if self._is_overriden("_train"): + raise DeprecationWarning( + "Trainable._train is deprecated and will be " + "removed in " + "a future version of Ray. Override Trainable.step instead.") + train_stats = self.trainer.train(max_retries=10, profile=True) validation_stats = self.trainer.validate(profile=True) stats = merge_dicts(train_stats, validation_stats)