Skip to content

supports non-leaf inputs for autograd.backward() function #60521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions aten/src/ATen/templates/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -721,8 +721,11 @@ class TORCH_API Tensor {
/// \param inputs Inputs w.r.t. which the gradient will be accumulated into
/// ``at::Tensor::grad``. All other Tensors will be ignored. If not
/// provided, the gradient is accumulated into all the leaf Tensors
/// that were used to compute the current tensor. All the provided inputs
/// must be leaf Tensors.
/// that were used to compute the current tensor.
/// When inputs are provided and a given input is not a leaf,
/// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
/// It is an implementation detail on which the user should not rely.
/// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const {
// NB: Adding this wrapper to _backward here because we'd like our
// 'backwards' api to accept the 'inputs' argument optionally. Since code gen
Expand Down
12 changes: 10 additions & 2 deletions test/cpp/api/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,16 @@ TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y = torch::randn({5,5}, torch::requires_grad());
Variable z = x * x;
Variable w = z + x * y + y * y;
ASSERT_THROWS_WITH(w.backward(torch::ones({5, 5}), false, false, {z}), "is not a leaf Tensor");
Variable w = y * z + x * y + y * y;

Variable x_grad_expected = 2 * x * y + y;
Variable z_grad_expected = y;

w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});

ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
ASSERT_FALSE(y.grad().defined());
}

TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
Expand Down
8 changes: 3 additions & 5 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,16 +1010,14 @@ def test_backward_with_nonleaf_inputs(self):

out = x_nonleaf ** 2 + y * x_nonleaf + y ** 2

out.backward(torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[x, y])
out.backward(torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[x, y, x_nonleaf])
x_grad_expected = 2 * x + y
y_grad_expected = x + 2 * y
x_non_leaf_expected = 2 * x_nonleaf + y

self.assertEqual(y.grad, y_grad_expected)
self.assertEqual(x.grad, x_grad_expected)

self.assertRaisesRegex(RuntimeError, 'not a leaf Tensor',
lambda: out.backward(torch.ones(2, 2, dtype=torch.double),
create_graph=True, inputs=[x, y, x_nonleaf]))
self.assertEqual(x_nonleaf.grad, x_non_leaf_expected)

# backward doesn't have an allow_unused flag, so the behavior of backward
# when variable is not part of the graph is as if allow_used were true
Expand Down
10 changes: 8 additions & 2 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,13 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=
in a user-specified CUDA stream context, see
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.

.. note::

When ``inputs`` are provided and a given input is not a leaf,
the current implementation will call its grad_fn (though it is not strictly needed to get this gradients).
It is an implementation detail on which the user should not rely.
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.

Args:
gradient (Tensor or None): Gradient w.r.t. the
tensor. If it is a tensor, it will be automatically converted
Expand All @@ -241,8 +248,7 @@ def backward(self, gradient=None, retain_graph=None, create_graph=False, inputs=
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
accumulated into ``.grad``. All other Tensors will be ignored. If not
provided, the gradient is accumulated into all the leaf Tensors that were
used to compute the attr::tensors. All the provided inputs must be leaf
Tensors.
used to compute the attr::tensors.
"""
if has_torch_function_unary(self):
return handle_torch_function(
Expand Down
10 changes: 8 additions & 2 deletions torch/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ def backward(
in a user-specified CUDA stream context, see
:ref:`Stream semantics of backward passes<bwd-cuda-stream-semantics>`.

.. note::

When ``inputs`` are provided and a given input is not a leaf,
the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
It is an implementation detail on which the user should not rely.
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.

Args:
tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
computed.
Expand All @@ -121,8 +128,7 @@ def backward(
inputs (Sequence[Tensor] or Tensor, optional): Inputs w.r.t. which the gradient
be will accumulated into ``.grad``. All other Tensors will be ignored. If
not provided, the gradient is accumulated into all the leaf Tensors that
were used to compute the attr::tensors. All the provided inputs must be leaf
Tensors.
were used to compute the attr::tensors.
"""
if grad_variables is not None:
warnings.warn("'grad_variables' is deprecated. Use 'grad_tensors' instead.")
Expand Down
5 changes: 1 addition & 4 deletions torch/csrc/autograd/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ variable_list run_backward(
grad_fn = impl::try_get_grad_accumulator(input);
}
if (accumulate_grad) {
TORCH_CHECK(
input.is_leaf(),
"One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor"
)
input.retain_grad();
}
TORCH_CHECK(
input.requires_grad(),
Expand Down
5 changes: 4 additions & 1 deletion torch/csrc/autograd/autograd.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ namespace autograd {
/// \param inputs Inputs w.r.t. which the gradient will be accumulated into
/// `at::Tensor::grad`. All other Tensors will be ignored. If not provided, the gradient
/// is accumulated into all the leaf Tensors that were used to compute param `tensors`.
/// All the provided inputs must be leaf Tensors.
// When inputs are provided and a given input is not a leaf,
// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
// It is an implementation detail on which the user should not rely.
// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
TORCH_API void backward(
const variable_list& tensors,
const variable_list& grad_tensors = {},
Expand Down
3 changes: 1 addition & 2 deletions torch/csrc/autograd/python_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ PyObject *THPEngine_run_backward(PyObject *self, PyObject *args, PyObject *kwarg
grad_fn = torch::autograd::impl::try_get_grad_accumulator(tensor);
}
if (accumulate_grad) {
THPUtils_assert(tensor.is_leaf(),
"One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor");
tensor.retain_grad();
}
THPUtils_assert(tensor.requires_grad(),
"One of the differentiated Tensors does not require grad");
Expand Down