mrpro.utils.broadcasted_rearrange

mrpro.utils.broadcasted_rearrange(tensor: Tensor, pattern: str, broadcasted_shape: Sequence[int] | None = None, *, reduce_views: bool = True, **axes_lengths: int) Tensor[source]

Rearrange a tensor with broadcasting.

Performs the einops rearrange or repeat operation on a tensor while preserving broadcasting.

Rearranging is a smart element reordering for multidimensional tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, repeat, and tile functions.

If a tensor has stride-0 dimensions, by default they will be preserved as stride-0 if possible and not made contiguous, thus saving memory. If reduce_views is True, then stride-0 dimensions will be reduced to singleton dimensions after rearranging. Optionally performs broadcasting to a specified shape before rearranging.

Examples

```python >>> tensor = torch.randn(1, 16, 1, 768, 256) >>> broadcasted_rearrange(tensor, ‘… (phase k1) k0 -> phase … k1 k0’, phase=8, reduce_views=False).shape torch.Size([8, 1, 16, 1, 96, 256])

>>> tensor=torch.randn(1, 1, 1, 768, 1)
>>> broadcasted_rearrange(tensor, '... (phase k1) k0 -> phase ... k1 k0',
>>>    broadcasted_shape=(1, 16, 1, 768, 256), phase=8, reduce_views=False).shape
torch.Size([8, 1, 16, 1, 96, 256]) # Behaves as-if the tensor was of shape (1, 16, 1, 768, 256)
>>> tensor=torch.randn(1, 1, 1, 768, 1)
>>> broadcasted_rearrange(tensor, '... (phase k1) k0 -> phase ... k1 k0',
>>>    broadcasted_shape=(1, 16, 1, 768, 256) phase=8, reduce_views=True).shape
torch.Size([8, 1, 1, 1, 96, 1]) # Dimensions that are stride-0 are reduced to singleton dimensions
```
Parameters:
  • tensor (Tensor) – The input tensor to rearrange.

  • pattern (str) – The rearrange pattern. See einops documentation for more information.

  • broadcasted_shape (Sequence[int] | None, default: None) – The shape to broadcast the tensor to before rearranging. If None, no additional broadcasting is performed.

  • reduce_views (bool, default: True) – If True, reduce stride-0 dimensions to singleton dimensions after rearranging.

  • axes_lengths (int) – The lengths of the axes in the pattern. See einops documentation for more information.