mrpro.operators.functionals.MSE

class mrpro.operators.functionals.MSE[source]

Bases: L2NormSquared

Functional class for the mean squared error.

__init__(target: Tensor | None | complex = None, weight: Tensor | complex = 1.0, dim: int | Sequence[int] | None = None, divide_by_n: bool = True, keepdim: bool = False) None[source]

Initialize MSE Functional.

The MSE functional is given by \(f: C^N \rightarrow [0, \infty), x \rightarrow 1/N \| W (x-b)\|_2^2\), where \(W\) is either a scalar or tensor that corresponds to a (block-) diagonal operator that is applied to the input. The division by N can be disabled by setting divide_by_n to False. For more details also see mrpro.operators.functionals.L2NormSquared.

Parameters:
  • target (Tensor | None | complex, default: None) – target element - often data tensor (see above)

  • weight (Tensor | complex, default: 1.0) – weight parameter (see above)

  • dim (int | Sequence[int] | None, default: None) – dimension(s) over which functional is reduced. All other dimensions of weight ( x - target) will be treated as batch dimensions.

  • divide_by_n (bool, default: True) – If True, the result is scaled by the number of elements of the dimensions in the tensor weight ( x - target) indexed by dim. The functional is thus calculated as the mean, else the sum.

  • keepdim (bool, default: False) – If True, the dimension(s) of the input indexed by dim are maintained and collapsed to singeltons, else they are removed from the result.

__call__(x: Tensor) tuple[Tensor][source]

Compute the Mean Squared Error (MSE).

Calculates \(1/N \| W * (x - b) \|_2^2\), where \(W\) is weight, \(b\) is target, and N is the number of elements over which the mean is computed (if divide_by_n is True at initialization of L2NormSquared). The squared norm is computed along dimensions specified by dim.

Parameters:

x (Tensor) – Input tensor.

Returns:

The MSE. If keepdim is True, the dimensions dim are retained with size 1; otherwise, they are reduced.

forward(x: Tensor) tuple[Tensor][source]

Apply forward of MSE.

Note

Prefer calling the instance of the MSE as operator(x) over directly calling this method. See this PyTorch discussion.

prox(x: Tensor, sigma: Tensor | float = 1.0) tuple[Tensor][source]

Proximal Mapping of the squared L2 Norm.

Apply the proximal mapping of the squared L2 norm.

Parameters:
  • x (Tensor) – input tensor

  • sigma (Tensor | float, default: 1.0) – scaling factor

Returns:

Proximal mapping applied to the input tensor

prox_convex_conj(x: Tensor, sigma: Tensor | float = 1.0) tuple[Tensor][source]

Convex conjugate of squared L2 Norm.

Apply the proximal mapping of the convex conjugate of the squared L2 norm.

Parameters:
Returns:

Proximal of convex conjugate applied to the input tensor

__add__(other: Operator[Unpack[Tin], Tout]) Operator[Unpack[Tin], Tout][source]
__add__(other: Tensor | complex) Operator[Unpack[Tin], tuple[Unpack[Tin]]]

Operator addition.

Returns lambda x: self(x) + other(x) if other is a operator, lambda x: self(x) + other*x if other is a tensor

__matmul__(other: Operator[Unpack[Tin2], tuple[Unpack[Tin]]] | Operator[Unpack[Tin2], tuple[Tensor, ...]]) Operator[Unpack[Tin2], Tout][source]

Operator composition.

Returns lambda x: self(other(x))

__mul__(other: Tensor | complex) Operator[Unpack[Tin], Tout][source]

Operator multiplication with tensor.

Returns lambda x: self(x*other)

__or__(other: ProximableFunctional) ProximableFunctionalSeparableSum[Tensor, Tensor][source]

Create a ProximableFunctionalSeparableSum object from two proximable functionals.

Parameters:

other (ProximableFunctional) – second functional to be summed

Returns:

ProximableFunctionalSeparableSum object

__radd__(other: Tensor | complex) Operator[Unpack[Tin], tuple[Unpack[Tin]]][source]

Operator right addition.

Returns lambda x: other*x + self(x)

__rmul__(scalar: Tensor | complex) ProximableFunctional[source]

Multiply functional with scalar.