Skip to content

Implement static shape inference for AdvancedSubtensor #1566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 31, 2025

Closes #1532
Closes #1565

Logic could be slightly simplified after #1564


📚 Documentation preview 📚: https://pytensor--1566.org.readthedocs.build/en/1566/

@ricardoV94
Copy link
Member Author

Failing jax test is unrelated: #1567

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Nice trick of using add(*advanced_indices).type.shape to handle static shape inference!

assert y[bool_idx1].type.shape == (None, 5, 6)
assert y[bool_idx1, :, None:-4:-1].type.shape == (None, 5, 3)
assert y[bool_idx1, idx2].type.shape == (3, None, 6)
assert y[bool_idx1, idx1, :].type.shape == (4, 6)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this raise a runtime error if the number of true entries in bool_idx1 != 4

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there's no way indexing can happen otherwise. If idx1 was being broadcast we may at some point optimize away the broacast (with "shape_unsafe" tag, see #1561), but we don't do that yet.

]
else:
# This could have been a basic subtensor!
indexed_shape = basic_group_shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What?! How?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has no vector indices, Subtensor handles all those cases

@ricardoV94 ricardoV94 merged commit 0fd160b into pymc-devs:main Aug 1, 2025
69 of 71 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AdvancedSubtensor doesn't manage to infer the static shape Preserve static shape after AdvancedSubtensor that only shuffles elements
2 participants