mrpro.operators.ConjugateGradientOp
- class mrpro.operators.ConjugateGradientOp[source]
Bases:
Module
Solves a linear positive semidefinite system with the conjugate gradient method.
Solves :math:
A x = b
where \(A\) is a linear operator or a matrix of linear operators , \(b\) is a tensor or a tuple of tensors.The operator is autograd differentiable using implicit differentiation. This is useful for including CG within a network [MODL], [PINQI]. If this is not needed for your application, consider using
mrpro.algorithms.optimizers.cg
directly.References
[MODL]Aggarwal, H. K., et al. MoDL: Model-based deep learning architecture for inverse problems. (2018) IEEE TMI 2018, 38(2), 394-405. https://arxiv.org/abs/1712.02862
[PINQI]Zimmermann, F. F., Kolbitsch, C., Schuenke, P., & Kofler, A. PINQI: an end-to-end physics-informed approach to learned quantitative MRI reconstruction. IEEE TCI 2024, https://arxiv.org/abs/2306.11023
- __init__(operator_factory: Callable[[...], LinearOperatorMatrix | LinearOperator], rhs_factory: Callable[[...], tuple[Tensor, ...]], implicit_backward: bool = True, tolerance: float = 1e-6, max_iterations: int = 100)[source]
Initialize a conjugate gradient operator.
Both the operator and the right-hand side are given as factory functions. The arguments given to the operator when calling it are passed to the factory functions.
Example: Regularized Least Squares
Consider the regularized least squares problem: \(\min_x \|A x - y\|_2^2 + \alpha \|x - x_0\|_2^2\).
The normal equations are \((A^H A + \alpha I) x = A^H y + \alpha x_0\). This can be solved using the ConjugateGradientOp as follows: .. code-block:: python
operator_factory = lambda alpha, x0, b: A.gram + alpha rhs_factory = lambda alpha, x0, b: A.H(b)[0] + alpha * x0 op = ConjugateGradientOp(operator_factory, rhs_factory) solution = op(alpha, x0, b)
- Parameters:
operator_factory (
Callable
[...
,LinearOperatorMatrix
|LinearOperator
]) – A factory function that returns the operator \(A\). Should return either aLinearOperatorMatrix
or aLinearOperator
.rhs_factory (
Callable
[...
,tuple
[Tensor
,...
]]) – A factory function that returns the right-hand side \(b\) Should return a tuple of tensors.implicit_backward (
bool
, default:True
) – IfTrue
, the backward pass is done using implicit differentiation. IfFalse
, the backward pass is done using unrolling the CG loop.tolerance (
float
, default:1e-6
) – The tolerance for the conjugate gradient method. The tolerance is relative to the norm of the right-hand side. The same relative tolerance is used in the backward pass if using implicit differentiation.max_iterations (
int
, default:100
) – The maximum number of iterations for the conjugate gradient method. The same maximum number of iterations is used in the backward pass if using implicit differentiation.warning:: (..) – If implicit_backward is
True
,tolerance
andmax_iterations
should be chosen such that the cg algorithm converges, otherwise the backward will be wrong.