Skip to content

Implement static shape inference for AdvancedSubtensor #1566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 125 additions & 65 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby
from itertools import chain, groupby, zip_longest
from typing import cast, overload

import numpy as np
Expand Down Expand Up @@ -39,7 +39,7 @@
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import clip
from pytensor.tensor.math import add, clip
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.type import (
TensorType,
Expand All @@ -63,6 +63,7 @@
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneSliceConst,
NoneTypeT,
SliceConstant,
SliceType,
Expand Down Expand Up @@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
return ps.as_scalar(a)


def slice_static_length(slc, dim_length):
if dim_length is None:
# TODO: Some cases must be zero by definition, we could handle those
return None

entries = [None, None, None]
for i, entry in enumerate((slc.start, slc.stop, slc.step)):
if entry is None:
continue

try:
entries[i] = get_scalar_constant_value(entry)
except NotScalarConstantError:
return None

return len(range(*slice(*entries).indices(dim_length)))


class Subtensor(COp):
"""Basic NumPy indexing operator."""

Expand Down Expand Up @@ -886,50 +905,15 @@ def make_node(self, x, *inputs):
)

padded = [
*get_idx_list((None, *inputs), self.idx_list),
*indices_from_subtensor(inputs, self.idx_list),
*[slice(None, None, None)] * (x.type.ndim - len(idx_list)),
]

out_shape = []

def extract_const(value):
if value is None:
return value, True
try:
value = get_scalar_constant_value(value)
return value, True
except NotScalarConstantError:
return value, False

for the_slice, length in zip(padded, x.type.shape, strict=True):
if not isinstance(the_slice, slice):
continue

if length is None:
out_shape.append(None)
continue

start = the_slice.start
stop = the_slice.stop
step = the_slice.step

is_slice_const = True

start, is_const = extract_const(start)
is_slice_const = is_slice_const and is_const

stop, is_const = extract_const(stop)
is_slice_const = is_slice_const and is_const

step, is_const = extract_const(step)
is_slice_const = is_slice_const and is_const

if not is_slice_const:
out_shape.append(None)
continue

slice_length = len(range(*slice(start, stop, step).indices(length)))
out_shape.append(slice_length)
out_shape = [
slice_static_length(slc, length)
for slc, length in zip(padded, x.type.shape, strict=True)
if isinstance(slc, slice)
]

return Apply(
self,
Expand Down Expand Up @@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op):

__props__ = ()

def make_node(self, x, *index):
def make_node(self, x, *indices):
x = as_tensor_variable(x)
index = tuple(map(as_index_variable, index))
indices = tuple(map(as_index_variable, indices))

explicit_indices = []
new_axes = []
for idx in indices:
if isinstance(idx.type, TensorType) and idx.dtype == "bool":
if idx.type.ndim == 0:
raise NotImplementedError(
"Indexing with scalar booleans not supported"
)

# We create a fake symbolic shape tuple and identify the broadcast
# dimensions from the shape result of this entire subtensor operation.
with config.change_flags(compute_test_value="off"):
fake_shape = tuple(
tensor(dtype="int64", shape=()) if s != 1 else 1 for s in x.type.shape
)
# Check static shape aligned
axis = len(explicit_indices) - len(new_axes)
indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
for j, (indexed_length, indexer_length) in enumerate(
zip(indexed_shape, idx.type.shape)
):
if (
indexed_length is not None
and indexer_length is not None
and indexed_length != indexer_length
):
raise IndexError(
f"boolean index did not match indexed tensor along axis {axis + j};"
f"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
)
# Convert boolean indices to integer with nonzero, to reason about static shape next
if isinstance(idx, Constant):
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
else:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices = idx.nonzero()
explicit_indices.extend(nonzero_indices)
else:
if isinstance(idx.type, NoneTypeT):
new_axes.append(len(explicit_indices))
explicit_indices.append(idx)

fake_index = tuple(
chain.from_iterable(
pytensor.tensor.basic.nonzero(idx)
if getattr(idx, "ndim", 0) > 0
and getattr(idx, "dtype", None) == "bool"
else (idx,)
for idx in index
)
if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
raise IndexError(
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
)

out_shape = tuple(
i.value if isinstance(i, Constant) else None
for i in indexed_result_shape(fake_shape, fake_index)
)
# Perform basic and advanced indexing shape inference separately
basic_group_shape = []
advanced_indices = []
adv_group_axis = None
last_adv_group_axis = None
expanded_x_shape = tuple(
np.insert(np.array(x.type.shape, dtype=object), 1, new_axes)
)
for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
):
if isinstance(idx.type, NoneTypeT):
basic_group_shape.append(1) # New-axis
elif isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
basic_group_shape.append(slice_static_length(idx.data, dim_length))
elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice):
basic_group_shape.append(
slice_static_length(slice(*idx.owner.inputs), dim_length)
)
else:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape.append(None)
else: # TensorType
# Keep track of advanced group axis
if adv_group_axis is None:
# First time we see an advanced index
adv_group_axis, last_adv_group_axis = i, i
elif last_adv_group_axis == (i - 1):
# Another advanced indexing aligned with the first group
last_adv_group_axis = i
else:
# Non-consecutive advanced index, all advanced index views get moved to the front
adv_group_axis = 0
advanced_indices.append(idx)

if advanced_indices:
try:
# Use variadic add to infer static shape of advanced integer indices
advanced_group_static_shape = add(*advanced_indices).type.shape
except ValueError:
# It fails when static shapes are inconsistent
static_shapes = [idx.type.shape for idx in advanced_indices]
raise IndexError(
f"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}"
)
# Combine advanced and basic views
indexed_shape = [
*basic_group_shape[:adv_group_axis],
*advanced_group_static_shape,
*basic_group_shape[adv_group_axis:],
]
else:
# This could have been a basic subtensor!
indexed_shape = basic_group_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What?! How?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has no vector indices, Subtensor handles all those cases


return Apply(
self,
(x, *index),
[tensor(dtype=x.type.dtype, shape=out_shape)],
[x, *indices],
[tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
)

def R_op(self, inputs, eval_points):
Expand Down
5 changes: 4 additions & 1 deletion pytensor/tensor/type_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs):
return SliceConstant(slicetype, x)


NoneSliceConst = Constant(slicetype, slice(None), name="slice(None)")


class NoneTypeT(Generic):
"""
Inherit from Generic to have c code working.
Expand All @@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
return NoneConst


__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"]
__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst", "NoneSliceConst"]
4 changes: 3 additions & 1 deletion pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,9 @@ def includes_bool(args_el):

# Check if the number of dimensions isn't too large.
if self.ndim < index_dim_count:
raise IndexError("too many indices for array")
raise IndexError(
f"too many indices for tensor: tensor is {self.ndim}-dimensional, but {index_dim_count} were indexed"
)

# Convert an Ellipsis if provided into an appropriate number of
# slice(None).
Expand Down
90 changes: 90 additions & 0 deletions tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
import sys
from io import StringIO

Expand Down Expand Up @@ -1847,6 +1848,95 @@ def setup_method(self):
self.ix2 = lmatrix()
self.ixr = lrow()

def test_static_shape(self):
x = tensor("x", shape=(None, None))
y = tensor("y", shape=(4, 5, 6))
idx1 = tensor("idx1", shape=(10,), dtype=int)
idx2 = tensor("idx2", shape=(3, None), dtype=int)

assert x[idx1].type.shape == (10, None)
assert x[:, idx1].type.shape == (None, 10)
assert x[idx2, :5].type.shape == (3, None, None)
assert specify_shape(x, (None, 7))[idx2, :5].type.shape == (3, None, 5)
assert specify_shape(x, (None, 3))[idx2, :5].type.shape == (3, None, 3)
assert x[idx1, idx2].type.shape == (3, 10)
assert x[idx2, idx1].type.shape == (3, 10)
assert x[None, idx1, idx2].type.shape == (1, 3, 10)
assert x[idx1, None, idx2].type.shape == (3, 10, 1)
assert x[idx1, idx2, None].type.shape == (3, 10, 1)

assert y[idx1, idx2, ::-1].type.shape == (3, 10, 6)
assert y[idx1, ::-1, idx2].type.shape == (3, 10, 5)
assert y[::-1, idx1, idx2].type.shape == (4, 3, 10)
assert y[::-1, idx1, None, idx2].type.shape == (3, 10, 4, 1)

msg = re.escape(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(10,), (9,)]"
)
with pytest.raises(IndexError, match=msg):
x[idx1, idx1[1:]]

def test_static_shape_boolean(self):
y = tensor("y", shape=(4, 5, 6))
idx1 = tensor("idx1", shape=(4,), dtype=int)
idx2 = tensor("idx2", shape=(3, None), dtype=int)
bool_idx1 = tensor("bool_idx1", shape=(4,), dtype=bool)
bool_idx2 = tensor(
"bool_idx2",
shape=(
None,
5,
),
dtype=bool,
)

assert y[bool_idx1].type.shape == (None, 5, 6)
assert y[bool_idx1, :, None:-4:-1].type.shape == (None, 5, 3)
assert y[bool_idx1, idx2].type.shape == (3, None, 6)
assert y[bool_idx1, idx1, :].type.shape == (4, 6)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this raise a runtime error if the number of true entries in bool_idx1 != 4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there's no way indexing can happen otherwise. If idx1 was being broadcast we may at some point optimize away the broacast (with "shape_unsafe" tag, see #1561), but we don't do that yet.

assert y[bool_idx1, :, idx1].type.shape == (4, 5)
assert y[bool_idx1, idx1, idx2].type.shape == (3, 4)
assert y[None, bool_idx1, None, idx2, None, idx1].type.shape == (3, 4, 1, 1, 1)

assert y[bool_idx2, :].type.shape == (None, 6)
assert y[bool_idx2, idx1].type.shape == (4,)
assert y[bool_idx2, idx2].type.shape == (3, None)

msg = re.escape(
"too many indices for tensor: tensor is 3-dimensional, but 4 were indexed"
)
with pytest.raises(IndexError, match=msg):
y[bool_idx2, bool_idx2]

# Case that could conceivably be detected as index error at definition time
bad_idx = ptb.concatenate([idx1, idx1])
assert y[bool_idx1, bad_idx].type.shape == (8, 6)

def test_static_shape_constant_boolean(self):
y = tensor("y", shape=(None, None, None))
idx1 = tensor("idx1", shape=(3,), dtype=int)
idx2 = tensor("idx2", shape=(4, None), dtype=int)

bool_idx1 = constant(np.array([True, False, True, True]), name="bool_idx1")
bool_idx2 = constant(
np.array([[True, False, True, True], [True, False, False, True]]),
name="bool_idx2",
)

assert y[bool_idx1].type.shape == (3, None, None)
assert y[bool_idx1, :, idx1].type.shape == (3, None)
assert y[bool_idx1, :, idx2].type.shape == (4, 3, None)

assert y[bool_idx2].type.shape == (5, None)
assert y[bool_idx1, idx2].type.shape == (4, 3, None)

bad_idx = ptb.concatenate([idx1, idx1])
msg = re.escape(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(3,), (6,)]"
)
with pytest.raises(IndexError, match=msg):
y[bool_idx1, bad_idx]

@pytest.mark.parametrize(
"inplace",
[
Expand Down
Loading