1414"""Interface for linear operators."""
1515
1616import functools
17+ from dataclasses import dataclass
18+ from typing import Tuple
19+
1720import jax
1821import jax .numpy as jnp
19- import numpy as onp
2022
21- from jaxopt .tree_util import tree_map , tree_sum , tree_mul
23+ from jaxopt .tree_util import tree_map
2224
2325
2426class DenseLinearOperator :
25-
2627 def __init__ (self , pytree ):
2728 self .pytree = pytree
2829
@@ -33,7 +34,7 @@ def matvec(self, x):
3334 return tree_map (jnp .dot , self .pytree , x )
3435
3536 def rmatvec (self , _ , y ):
36- return tree_map (lambda w ,yi : jnp .dot (w .T , yi ), self .pytree , y )
37+ return tree_map (lambda w , yi : jnp .dot (w .T , yi ), self .pytree , y )
3738
3839 def matvec_and_rmatvec (self , x , y ):
3940 return self .matvec (x ), self .rmatvec (x , y )
@@ -52,11 +53,11 @@ def col_norm(w):
5253 if not squared :
5354 col_norms = jnp .sqrt (col_norms )
5455 return col_norms
56+
5557 return tree_map (col_norm , self .pytree )
5658
5759
5860class FunctionalLinearOperator :
59-
6061 def __init__ (self , fun , params ):
6162 self .fun = functools .partial (fun , params )
6263
@@ -71,7 +72,7 @@ def rmatvec(self, x, y):
7172
7273 def matvec_and_rmatvec (self , x , y ):
7374 matvec_x , vjp = jax .vjp (self .matvec , x )
74- rmatvec_y , = vjp (y )
75+ ( rmatvec_y ,) = vjp (y )
7576 return matvec_x , rmatvec_y
7677
7778 def normal_matvec (self , x ):
@@ -85,3 +86,72 @@ def _make_linear_operator(matvec):
8586 return DenseLinearOperator
8687 else :
8788 return functools .partial (FunctionalLinearOperator , matvec )
89+
90+
91+ def block_row_matvec (block , x ):
92+ """Performs a matvec for a row of block matrices.
93+
94+ The following matvec is performed:
95+ [U1, ..., UN] * [x1, ..., xN]
96+ where U1, ..., UN are matrices and x1, ..., xN are vectors
97+ of compatible shapes.
98+ """
99+ if len (block ) != len (x ):
100+ raise ValueError (
101+ "We need as many blocks in the matrix as in the vector."
102+ )
103+ return sum (jax .tree_util .tree_map (jnp .dot , block , x ))
104+
105+
106+ # TODO(gnegiar): Extend to arbitrary block shapes.
107+ @jax .tree_util .register_pytree_node_class
108+ @dataclass
109+ class BlockLinearOperator :
110+ """Represents a linear operator defined by blocks over a block pytree.
111+
112+ Attributes:
113+ blocks: a 2x2 block matrix of the form
114+ [[A, B]
115+ [C, D]]
116+ """
117+
118+ blocks : Tuple [Tuple [jnp .array ]]
119+
120+ def __call__ (self , x ):
121+ return self .matvec (x )
122+
123+ def matvec (self , x ):
124+ """Performs the block matvec with u defined by blocks.
125+
126+ The matvec is of form:
127+ [u1, u2]
128+ [[A, B] *
129+ [C, D]]
130+
131+ """
132+ return jax .tree_util .tree_map (
133+ lambda row_of_blocks : block_row_matvec (row_of_blocks , x ),
134+ self .blocks ,
135+ is_leaf = lambda x : x is self .blocks [0 ] or x is self .blocks [1 ],
136+ )
137+
138+ def rmatvec (self , x , y ):
139+ return self .matvec_and_rmatvec (x , y )[1 ]
140+
141+ def matvec_and_rmatvec (self , x , y ):
142+ matvec_x , vjp = jax .vjp (self .matvec , x )
143+ (rmatvec_y ,) = vjp (y )
144+ return matvec_x , rmatvec_y
145+
146+ def normal_matvec (self , x ):
147+ """Computes A^T A x from matvec(x) = A x."""
148+ matvec_x , vjp = jax .vjp (self .matvec , x )
149+ return vjp (matvec_x )[0 ]
150+
151+ def tree_flatten (self ):
152+ return self .blocks , None
153+
154+ @classmethod
155+ def tree_unflatten (cls , aux_data , children ):
156+ del aux_data
157+ return cls (children )
0 commit comments