mrpro.utils.smap
- mrpro.utils.smap(function: Callable[[Tensor], Tensor], tensor: Tensor, passed_dimensions: Sequence[int] | int = (-1,)) Tensor[source]
Apply a function to a tensor serially along multiple dimensions.
The function is applied serially without batch dimensions. Compared to
torch.vmap, it works with arbitrary functions, but is slower.- Parameters:
function (
Callable[[Tensor],Tensor]) – Function to apply to the tensor. Should handlelen(fun_dims)dimensions and not change the number of dimensions.tensor (
Tensor) – Tensor to apply the function to.passed_dimensions (
Sequence[int] |int, default:(-1,)) – Dimensions NOT to be batched / dimensions that are passed to the function. Either a tuple of dimension indices (negative indices are supported) or an integer. An integernmeans the lastndimensions are passed to the function.