The Newton Schulz iteration for matrix inversion

The Newton Schulz method is well-known, and the proof of convergence is widely available on the internet. Yet the derivation of the method itself is more obscure. Here it is:

We seek the zero of f: \mathbb{R}^2 \rightarrow \mathbb{R}^2, defined as follows:
\begin{array}{l} f(X) = X^{-1} - A. \end{array}

The derivative of f at X, applied to matrix B, operates on B as follows:
\begin{array}{l} f'(X)[B] = -X^{-1} B X^{-1}. \end{array}

We can prove that f'^{-1} at X, applied to matrix B, operates on B as follows:
\begin{array}{l} f'^{-1}(X)[B] = -X B X. \end{array}

To see this, notice that
\begin{array}{l} B \\ = f'^{-1}(X)\Big[f'(X)[B]\Big] \\ = -X \Big[-X^{-1} B X^{-1}\Big] X \\ = B. \end{array}

The Newton method for root finding has at each iterate:
\begin{array}{l} X_{t+1} \\ = X_t - f'^{-1}(X_t)\Big[f(X_t)\Big] \\ = X_t - f'^{-1}(X_t)\Big[X^{-1} - A\Big] \\ = X_t - X_t[X^{-1}_t-A] X_t \\ = X_t - [-X_t + X_t A X_t] \\ = 2 X_t - X_t A X_t \end{array}

Thresholding sparse matrices in Matlab

Here are the methods I tried:

function [tA] = hard_threshold(A, t)

    tA = sparse(size(A));
    tA(abs(A) >= t) = A(abs(A) >= t);

    clear tA;
    tA = A;
    tA(abs(tA) < t) = 0;

    clear tA;
    tA = A;
    find_A = find(A);
    find_tA = find(abs(A) >= t);
    victim_tA = setdiff(find_A, find_tA);
    tA(victim_tA) = 0;

    fprintf('numel(A):%i nnz(A):%i nnz(tA):%i \n', numel(A), nnz(A), nnz(tA)');

I first tried a small sparse matrix with 100k elements, 1% sparsity, removing 50% of nonzeros:

A = sprand(1e5,1,0.01); tA = hard_threshold(A, 0.5);
Elapsed time is 0.128991 seconds.
Elapsed time is 0.007644 seconds.
Elapsed time is 0.003038 seconds.
numel(A):100000 nnz(A):995 nnz(tA):489

I next repeated with 1m elements:

A = sprand(1e6,1,0.01); tA = hard_threshold(A, 0.5);
Elapsed time is 15.456836 seconds.
Elapsed time is 0.082908 seconds.
Elapsed time is 0.018396 seconds.
numel(A):1000000 nnz(A):9966 nnz(tA):5019

With 100m elements, excluding the first, slowest, method:

A = sprand(1e8,1,0.01); tA = hard_threshold(A, 0.5);
Elapsed time is 16.405617 seconds.
Elapsed time is 0.259951 seconds.
numel(A):100000000 nnz(A):994845 nnz(tA):498195

The time differential is about the same even when the thresholded matrix is much sparser than the original:

A = sprand(1e8,1,0.01); tA = hard_threshold(A, 0.95);
Elapsed time is 12.980427 seconds.
Elapsed time is 0.238180 seconds.
numel(A):100000000 nnz(A):995090 nnz(tA):49950

The second method fails due to memory constraints for really large sparse matrices:

Error using < 
Requested 1000000000x1 (7.5GB) array exceeds maximum array size preference. Creation of arrays greater than this limit may
take a long time and cause MATLAB to become unresponsive. See array size limit or preference panel for more information. Error in hard_threshold (line 10)
 tA(abs(tA) < t) = 0;

After excluding the second method, the third method gives:

A = sprand(1e9,1,0.01); tA = hard_threshold(A, 0.5);
Elapsed time is 1.894251 seconds.
numel(A):1000000000 nnz(A):9950069 nnz(tA):4977460

Are there any other approaches that are faster?


Transforming data to Gaussian

Transforming data to Gaussian using probability integral transform in Matlab:

n = 500;
x=exp(randn(n,1))+(randi(2,[n 1])-1).*(10+3*randn(n,1));
fhat = @(in) sum(x <= in)/n;
Fhat = @(A) arrayfun(fhat,A);
y=Fhat(x); z = icdf('normal',y,0,1);
figure; subplot(1,2,1); hist(x,20); xlabel('x');
subplot(1,2,2); hist(z,20); xlabel('z');


Preprocessing and cross-validation


      In general, preprocessing should be done inside of cross-validation routine. If you preprocess outside of the cross-validation algorithm (before calling crossval), you will bias the cross-validation results and likely overfit your model. The reason for this is that preprocessing will be based on the ENTIRE set of data but the cross-validation’s validity REQUIRES that the preprocessing be based ONLY on specific subsets of data. Why? Read on:
      Cross-validation splits your data up into “n” subsets (lets say 3 for simplicity). Let say you have 12 samples and you’re only doing mean centering as your preprocessing (again, for simplicity). Cross-validation is going to take your 12 samples and split it into 3 groups (4 samples in each group).
      In each cycle of the cross-validation, the algorithm leaves out one of those 3 groups (=4 samples=”validation set”) and does both preprocessing and model building from the remaining 8 samples (=”calibration set”). Recall that the preprocessing step here is to calculate the mean of the data and subtract it. Then it applies the preprocessing and model to the 4-sample validation set and looks at the error (and repeats this for each of the 3 sets). Here, applying the preprocessing is to take the mean calculated from the 8 samples and subtract it from the other 4 samples.
      That last part is the key to why preprocessing BEFORE crossval is bad: when preprocessing is done INSIDE cross-validation (as it should be), the mean is calculated from the 8 samples that were left in and subtracted from them, and that same 8-sample mean is also subtracted from the 4 samples left out by cross-validation. However, if you mean-center BEFORE cross-validation, the mean is calculated from all 12 samples. The result is that, even though the rules of cross-validation say that the preprocessing (mean) and model are supposed to be calculated from only the calibration set, doing the preprocessing outside of cross-validation means all samples are influencing the preprocessing (mean).
      With mean-centering, the effect isn’t as bad as it is for something like GLSW or OSC. These “multivariate filters” are far stronger preprocessing methods and operating on the entire data set can have a significant influence on the covariance (read: can have a much bigger effect of “cheating” and thus overfitting). The one time it doesn’t matter is when the preprocessing methods being done are “row-wise” only – that is, methods that operate on samples independently are not a problem. Methods like smoothing, derivatives, baselining, or normalization (other than MSC when based on the mean) operate on each sample independently and adding or removing samples from the data set has no effect on the others. In fact, to save time, our cross-validation routine recognizes when row-wise operations come first in the preprocessing sequence and does them outside of the cross-validation loop. The only time you can’t do these in advance is when another non-row-wise method happens prior to the row-wise method.

Resizing nested C++ STL vectors

“Multidimensional” vectors in C++ do not behave like matrices (or higher-order equivalents). Something like:

vector< vector<double> > foo(100, vector<double>(20, 0.0));

will not lay out 100*20 doubles contiguously in memory. Only the bookkeeping info for 100 vector<double>’s will be laid out contiguously — each vector<double> will store its actual data in its own location on the heap. Thus, each vector<double> can have its own size.

This can lead to hard-to-catch bugs:

 foo.resize(300, vector<double>(30, 1.0));

will leave the first 100 vector<double>’s with size 20, filled with 0.0 values, while the new 200 vector<double>’s will have size 30, filled with 1.0 values.

Sampling from Multivariate Gaussian distribution in Matlab

tl;dr: Don’t use mvnrnd in Matlab for large problems; do it manually instead.

The first improvement uses the Cholesky decomposition, allowing us to sample from a univariate normal distribution. The second improvement uses the Cholesky decomposition of the sparse inverse covariance matrix, not the dense covariance matrix. The third improvement avoids computing the inverse, instead solving a (sparse) system of equations.

n = 10000;
Lambda = gallery('tridiag',n,-0.3,1,-0.3); % sparse
x_mvnrnd = mvnrnd(zeros(n,1),inv(Lambda));

z = randn(n,1); % univariate random
Sigma = inv(Lambda);
A = chol(Sigma,'lower'); % sparse
x_fromSigma = A*z; % using cholesky of Sigma

z = randn(n,1); % univariate random
L_Lambda = chol(Lambda,'lower'); % sparse
A_fromLambda = (inv(L_Lambda))'; % sparse
x_fromLambda = A_fromLambda*z;

z = randn(n,1); % univariate random
L_Lambda = chol(Lambda,'lower'); % sparse
x_fromLambda = L_Lambda'\z;


Elapsed time is 4.514641 seconds.
Elapsed time is 2.734001 seconds.
Elapsed time is 1.740317 seconds.
Elapsed time is 0.012431 seconds.

Matlab: different colormaps for subplots

I often want different subplots in one Matlab figure to have different colormaps. However, colormap is a figure property, so it’s not trivial, except that it is… with these utilities:

This works for everything except colorbars:—unfreezecolors

Post-2010, Matlab refreshes colorbars with each subplot, so you’ll need this to freeze colorbars:–feb-2014-

Scaling up hierarchical clustering

There are lots of caveats with hierarchical clustering, and it’s often used to draw unjustifiable conclusions. But I’ll save that discussion for later, and it’s at least occasionally useful. 🙂 So far, I’ve mainly used it to reorder the columns/rows of a covariance or correlation matrix. For example, I recently generated synthetic sparse precision matrices, and I wanted to make sure that the corresponding covariance/correlation matrices were as I expected. 

However, linkage/dendrogram in both Matlab and SciPy are really slow. In particular, the linkage algorithms they use are O(n^3). So instead we can use fastcluster, which is O(n^2). All I had to do was replace this:

Z = sch.linkage(P, method="average", metric="correlation")

with this:

Z = fastcluster.linkage(P, method="average", metric="correlation")


The second problem is that dendrogram does lots of unnecessary work when all I want is the reordered indices. SciPy actually implements it recursively, so it gives this error: “RuntimeError: maximum recursion depth exceeded”. So the second change is to replace this:

dendr = sch.dendrogram(Z,no_plot=True,distance_sort=True)
ix = dendr["leaves"]

with this (credit to a good denizen of StackOverflow):

n = len(Z) + 1
cache = dict() for k in range(len(Z)):
    c1, c2 = int(Z[k][0]), int(Z[k][1])
    c1 = [c1] if c1 < n else cache.pop(c1)
    c2 = [c2] if c2 < n else cache.pop(c2)
    cache[n+k] = c1 + c2
ix = cache[2*len(Z)]

Then it’s as simple as:

Pclust = P[ix,:]
Pclust = Pclust[:,ix]
pyplot.imshow(Pclust, interpolation="nearest")