mrpro.utils.broadcasted_rearrange
- mrpro.utils.broadcasted_rearrange(tensor: Tensor, pattern: str, broadcasted_shape: Sequence[int] | None = None, *, reduce_views: bool = True, **axes_lengths: int) Tensor [source]
Rearrange a tensor with broadcasting.
Performs the einops rearrange or repeat operation on a tensor while preserving broadcasting.
Rearranging is a smart element reordering for multidimensional tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, repeat, and tile functions.
If a tensor has stride-0 dimensions, by default they will be preserved as stride-0 if possible and not made contiguous, thus saving memory. If
reduce_views
is True, then stride-0 dimensions will be reduced to singleton dimensions after rearranging. Optionally performs broadcasting to a specified shape before rearranging.Examples
```python >>> tensor = torch.randn(1, 16, 1, 768, 256) >>> broadcasted_rearrange(tensor, ‘… (phase k1) k0 -> phase … k1 k0’, phase=8, reduce_views=False).shape torch.Size([8, 1, 16, 1, 96, 256])
>>> tensor=torch.randn(1, 1, 1, 768, 1) >>> broadcasted_rearrange(tensor, '... (phase k1) k0 -> phase ... k1 k0', >>> broadcasted_shape=(1, 16, 1, 768, 256), phase=8, reduce_views=False).shape torch.Size([8, 1, 16, 1, 96, 256]) # Behaves as-if the tensor was of shape (1, 16, 1, 768, 256)
>>> tensor=torch.randn(1, 1, 1, 768, 1) >>> broadcasted_rearrange(tensor, '... (phase k1) k0 -> phase ... k1 k0', >>> broadcasted_shape=(1, 16, 1, 768, 256) phase=8, reduce_views=True).shape torch.Size([8, 1, 1, 1, 96, 1]) # Dimensions that are stride-0 are reduced to singleton dimensions ```
- Parameters:
tensor (
Tensor
) – The input tensor to rearrange.pattern (
str
) – The rearrange pattern. Seeeinops
documentation for more information.broadcasted_shape (
Sequence
[int
] |None
, default:None
) – The shape to broadcast the tensor to before rearranging. IfNone
, no additional broadcasting is performed.reduce_views (
bool
, default:True
) – IfTrue
, reduce stride-0 dimensions to singleton dimensions after rearranging.axes_lengths (
int
) – The lengths of the axes in the pattern. Seeeinops
documentation for more information.