Mathematica Asked on December 7, 2020
I have $d/2$-dimensional variables $a,b$ jointly distributed as Gaussian($mu,Sigma$) in $d$ dimensions, and need to solve the following equation for $X$
$$E[ab^TXab^T]=Y$$
This is equivalent to solving the following for $X$ (from Wick’s theorem)
$$B X A + CX^TC + C mbox{Tr}(X^T C) – 2 ba’ b’Xa=Y$$
where $A,B,C,a,b$ are equal-sized partitions of Gaussian first and second moments
$$(a,b)=E[x]=mu$$
$$
left(begin{matrix}
A&C^T
C&B
end{matrix}right)
=E[xx’]=
Sigma+mumu^T
$$
Below is somewhat brute-force solution to solving it (apply Sherman-Morrison formula twice, then use this) which takes about 12 seconds for $d=1024$. There are some repeated expressions but simplifying by hand gets quite error prone, can someone see a way to speed this up?
(*Expectation of expression*)
Ex[expr_] := Expectation[expr, x [Distributed] dist];
split[vec_] := ArrayReshape[vec, {2, Length[vec]/2}];
CircleTimes = KroneckerProduct;
(* *Solve AX+XB=C.Equivalent to LyapunovSolve[A,B,C] but faster/more
stable/works when ill-posed*)
sylvester[A_, B_, C0_] :=
Module[{da, db, DA, T, DB, U, denom, cutoff, sdiv, Y}, {da, db} =
Length /@ {A, B};
{DA, T} = Eigensystem[A + $MachineEpsilon IdentityMatrix[da]];
T = Transpose[T];
{DB, U} = Eigensystem[B + $MachineEpsilon IdentityMatrix[db]];
U = Transpose[U];
denom = Outer[Plus, DA, DB];
cutoff = Max@Abs[denom]*10^6*$MachineEpsilon;
sdiv = Map[If[Abs[#] > cutoff, 1/#, #] &, denom, {2}];
Y = Inverse[T].C0.U*sdiv;
T.Y.Inverse[U]];
(*Solve T-Sylvester equation AX+X[Transpose]B=C by reducing to
Sylvester equation. Eddy's recipe from
https://mathematica.stackexchange.com/a/207044/217*)
tsylvester[a_, b_, c_] :=
Module[{g, ig, h, u, x}, g = a + b[Transpose];
ig = Inverse[g];
h = (c + c[Transpose])/2;
u = sylvester[a.ig, -ig[Transpose].b,
c - a.ig.h - h.ig[Transpose].b];
u = (u - u[Transpose])/2;
x = ig.(h + u);
If[ValueQ[debug],
Print["tsylvester error is ", Norm[a.x + x[Transpose].b - c]]];
x];
(*Solve generalized T-Sylvester BXA+CX[Transpose]D=E by reducing to
T-Sylvester equation*)
generalizedTSylvester[a_, b_, c_, d_,
e_] := (tsylvester[Inverse[c].b, d.Inverse[a],
Inverse[c].e.Inverse[a]]);
(*Solve generalized T-Sylvester equation with rank-1
correction:BXA+CX[Transpose]D+Utr(X[Transpose]U)=Y*)
generalizedTSylvesterRank1[A_, B_, C_, D_, U_, Y_] :=
Module[{X, divAU, divAX, Y2},
divAU = generalizedTSylvester[A, B, C, D, U];
divAX = generalizedTSylvester[A, B, C, D, Y];
X = divAX - dot[U, divAX]/(1 + dot[U, divAU]) divAU;
If[ValueQ[debug], Y2 = B.X.A + C.X[Transpose].D + U dot[U, X];
Print["generalizedTSylvesterRank1 error: ", Norm[Y - Y2]]];
X];
(*Solve generalized T-Sylvester equation with rank-2
correction:BXA+CX[Transpose]D+Utr(X[Transpose]U)-V
tr(X[Transpose]V)=Y*)
generalizedTSylvesterRank2[A_, B_, C_, D_, U_, V_, Y_] :=
Module[{divAU, divAX, X, Y2},
divAU = generalizedTSylvesterRank1[A, B, C, D, U, V];
divAX = generalizedTSylvesterRank1[A, B, C, D, U, Y];
X = divAX + dot[V, divAX]/(1 - dot[V, divAU]) divAU;
If[ValueQ[debug],
Y2 = B.X.A + C.X[Transpose].D + U dot[U, X] - V dot[V, X];
Print["generalizedTSylvesterRank2 error: ", Norm[Y - Y2]]];
X];
partitionMatrix[mat_, {a_, b_}] :=
Module[{}, Assert[a + b == Length@mat];
Assert;[a + b == Length@mat[Transpose]];
Internal`PartitionRagged[mat, {{a, b}, {a, b}}]];
setupProblem[d0_] := (d = d0;
x = Array[xx, d];
{a, b} = split[x];(*{a1,a2,...},{b1,b2,...}*)
mu = RandomReal[{-1, 1}, {d}];
diag = DiagonalMatrix@Table[1/k, {k, 1, d}];
rot = RandomVariate[CircularRealMatrixDistribution[d]];
sigma = rot.diag.rot[Transpose];
dist = MultinormalDistribution[mu, sigma];
X = RandomReal[{-1, 1}, {d/2, d/2}];);
dot[mat1_, mat2_] := Total[mat1*mat2, 2];
(***Modify this to change problem size***)
SeedRandom[1];
setupProblem[1024]
{{AA, AB}, {BA, BB}} =
partitionMatrix[sigma + Outer[Times, mu, mu], {d/2, d/2}];
{A, B} = split[mu];
wicksForward[X_] :=
BB.X.AA + Transpose[AB.X.AB] + BA dot[X, BA] -
2 Outer[Times, B, A] B.X.A;
wicksBackward[Y_] :=
With[{BA0 = Sqrt[2] Outer[Times, B, A]},
generalizedTSylvesterRank2[AA, BB, BA, BA, BA, BA0, Y]];
(*correctness check for small d*)
If[d < 8, On[Assert];
Y1 = Ex[(a[CircleTimes]b).X.(a[CircleTimes]b)][Transpose];
Y2 = wicksForward[X];
Assert[Y1 == Y2]];
Y = wicksForward[X];
Norm[wicksBackward[Y] - X] // Timing (*{12.1, 0.00005}*)
Get help from others!
Recent Questions
Recent Answers
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP