Code Review Asked by Tommaso Belluzzo on October 27, 2021
I’m working on a multivariate cross-entropy minimization model (for more details about it, see this paper, pp. 32-33). It’s purpose is to adjust a prior multivariate distribution (in this case, a gaussian normal) with information on marginals coming from real observations.
The code at the end of the post represents my current implementation. The maths should have been correctly reproduced, unless I missed something critical during the review. The real problem I’m struggling to deal with is the performance of the code.
In the first part of the model, cumulative probabilities have to be computed over all the orthants of the distribution density. This process has a time complexity of 2^N
, where N
is the number of entities included into the dataset. As long as the number of entities is less than 12
, everything is fast enough on my PC. With 20
entities, which is my current target, the model needs to run mvncdf
over 1048576
combinations of orthants and this takes forever to finish.
I already improved the code a little bit by replacing the main for
loop with a parfor
loop. I acquired a huge performance gain by replacing the built-in mvncdf
function with a user-made one.
I’m not very familiar with cross-entropy minimization models, so maybe there are math tricks I can use to simplify this calculation. Maybe the code can be vectorized even more. Well… any help or suggestion to improve the calculations speed is more than welcome!
clc();
clear();
% DATA
pods = [0.015; 0.02; 0.013; 0.007; 0.054; 0.034; 0.009; 0.065; 0.029; 0.205];
dts = [2.1; 2; 2.2; 2.4; 1.5; 1.8; 2.3; 1.5; 1.8; 0.8];
% Test of time complexity:
% pods = [pods; pods];
% dts = [dts; dts];
n = numel(pods);
c = eye(n);
k = 2^n;
kh = k / 2;
offsets = ones(n,1);
% G / BOUNDS FOR 1
g1 = combn([0 1],n);
bounds_1 = zeros(k,1);
parfor i = 1:k
g1_c = g1(i,:).';
lb = min([(-Inf * ~g1_c) dts],[],2);
ub = max([(Inf * g1_c) dts],[],2);
bounds_1(i) = mvncdf2(c,lb,ub);
end
% G / BOUNDS FOR 2:N
g2 = repmat({zeros(kh,n)},n,1);
bounds_2 = zeros(n,kh);
for i = 2:k
g1_c = g1(i,:);
b = bounds_1(i);
for j = 1:n
if (g1_c(j) == 0)
continue;
end
offset_j = offsets(j);
g2t_j = g2{j};
g2t_j(offset_j,:) = g1_c;
g2{j} = g2t_j;
bounds_2(j,offset_j) = b;
offsets(j) = offset_j + 1;
end
end
% SOLUTION
options = optimset(optimset(@fsolve),'Display','iter','TolFun',1e-08,'TolX',1e-08);
cns = [1; pods];
x0 = zeros(size(pods,1)+1,1);
lm = fsolve(@(x)objective(x,n,g1,bounds_1,g2,bounds_2,cns),x0,options);
stop = 1;
% Objective function of the model.
function p = objective(x,n,g1,bounds_1,g2,bounds_2,cns)
mu = x(1);
lambda = x(2:end);
p = zeros(n + 1,1);
for i = 1:numel(bounds_1)
p(1) = p(1) + exp(-g1(i,:) * lambda) * bounds_1(i);
end
for i = 1:n
g2_k = g2{i,1};
for j = 1:size(bounds_2,2)
p(i+1) = p(i+1) + exp(-g2_k(j,:) * lambda) * bounds_2(i,j);
end
end
p = (exp(-1-mu) * p) - cns;
end
% All combinations of elements.
function [m,i] = combn(v,n)
if ((fix(n) ~= n) || (n < 1) || (numel(n) ~= 1))
error('Parameter N must be a scalar positive integer.');
end
if (isempty(v))
m = [];
i = [];
elseif (n == 1)
m = v(:);
i = (1:numel(v)).';
else
i = combn_local(1:numel(v),n);
m = v(i);
end
function y = combn_local(v,n)
if (n > 1)
[y{n:-1:1}] = ndgrid(v);
y = reshape(cat(n+1,y{:}),[],n);
else
y = v(:);
end
end
end
% Multivariate normal cumulative distribution function.
function y = mvncdf2(c,lb,ub)
persistent options;
if (isempty(options))
options = optimset(optimset(@fsolve),'Algorithm','trust-region-dogleg','Diagnostics','off','Display','off','Jacobian','on');
end
n = size(c,1);
[cp,lb,ub] = cholperm(n,c,lb,ub);
d = diag(cp);
if any(d < eps())
y = NaN;
return;
end
lb = lb ./ d;
ub = ub ./ d;
cp = (cp ./ repmat(d,1,n)) - eye(n);
[sol,~,exitflag] = fsolve(@(x)gradpsi(x,cp,lb,ub),zeros(2 * (n - 1),1),options);
if (exitflag ~= 1)
y = NaN;
return;
end
x = sol(1:(n - 1));
x(n) = 0;
x = x(:);
mu = sol(n:((2 * n) - 2));
mu(n) = 0;
mu = mu(:);
c = cp * x;
lb = lb - mu - c;
ub = ub - mu - c;
y = exp(sum(lnpr(lb,ub) + (0.5 * mu.^2) - (x .* mu)));
end
function [cp,l,u] = cholperm(n,c,l,u)
s2p = sqrt(2 * pi());
cp = zeros(n,n);
z = zeros(n,1);
for j = 1:n
j_seq = 1:(j - 1);
jn_seq = j:n;
j1n_seq = (j + 1):n;
cp_off = cp(jn_seq,j_seq);
z_off = z(j_seq);
cpz = cp_off * z_off;
d = diag(c);
s = d(jn_seq) - sum(cp_off.^2,2);
s(s < 0) = eps();
s = sqrt(s);
lt = (l(jn_seq) - cpz) ./ s;
ut = (u(jn_seq) - cpz) ./ s;
p = Inf(n,1);
p(jn_seq) = lnpr(lt,ut);
[~,k] = min(p);
jk = [j k];
kj = [k j];
c(jk,:) = c(kj,:);
c(:,jk) = c(:,kj);
cp(jk,:) = cp(kj,:);
l(jk) = l(kj);
u(jk) = u(kj);
s = c(j,j) - sum(cp(j,j_seq).^2);
s(s < 0) = eps();
cp(j,j) = sqrt(s);
cp(j1n_seq,j) = (c(j1n_seq,j) - (cp(j1n_seq,j_seq) * (cp(j,j_seq)).')) / cp(j,j);
cp_jj = cp(j,j);
cpz = cp(j,j_seq) * z(j_seq);
lt = (l(j) - cpz) / cp_jj;
ut = (u(j) - cpz) / cp_jj;
w = lnpr(lt,ut);
z(j) = (exp((-0.5 * lt.^2) - w) - exp((-0.5 * ut.^2) - w)) / s2p;
end
end
function [g,j] = gradpsi(y,L,l,u)
d = length(u);
d_seq = 1:(d - 1);
x = zeros(d,1);
x(d_seq) = y(d_seq);
mu = zeros(d,1);
mu(d_seq) = y(d:end);
c = zeros(d,1);
c(2:d) = L(2:d,:) * x;
lt = l - mu - c;
ut = u - mu - c;
w = lnpr(lt,ut);
pd = sqrt(2 * pi());
pl = exp((-0.5 * lt.^2) - w) / pd;
pu = exp((-0.5 * ut.^2) - w) / pd;
p = pl - pu;
dfdx = -mu(d_seq) + (p.' * L(:,d_seq)).';
dfdm = mu - x + p;
g = [dfdx; dfdm(d_seq)];
lt(isinf(lt)) = 0;
ut(isinf(ut)) = 0;
dp = -p.^2 + (lt .* pl) - (ut .* pu);
dl = repmat(dp,1,d) .* L;
mx = -eye(d) + dl;
mx = mx(d_seq,d_seq);
xx = L.' * dl;
xx = xx(d_seq,d_seq);
j = [xx mx.'; mx diag(1 + dp(d_seq))];
end
function p = lnpr(a,b)
p = zeros(size(a));
a_indices = a > 0;
if (any(a_indices))
x = a(a_indices);
pa = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));
x = b(a_indices);
pb = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));
p(a_indices) = pa + log1p(-exp(pb - pa));
end
b_indices = b < 0;
if (any(b_indices))
x = -a(b_indices);
pa = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));
x = -b(b_indices);
pb = (-0.5 * x.^2) - log(2) + reallog(erfcx(x / sqrt(2)));
p(b_indices) = pb + log1p(-exp(pa - pb));
end
indices = ~a_indices & ~b_indices;
if (any(indices))
pa = erfc(-a(indices) / sqrt(2)) / 2;
pb = erfc(b(indices) / sqrt(2)) / 2;
p(indices) = log1p(-pa - pb);
end
end
Get help from others!
Recent Answers
Recent Questions
© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP