% Demo of pair-wise non-rigid registration for myocardial T1 rho mapping
%
% Corresponding paper: "Endogenous assessment of myocardial injury with
% single-shot model-based non-rigid motion-corrected T1 rho mapping"
% To be published in Journal of Cardiovascular Magnetic Resonance, 2021
% 
% Author: Aurelien Bustin (aurelien.bustin@ihu-liryc.fr)
% IHU Liryc, 2021


function [Ux, Uy, I_reg] = pairwise_registration_nonrigid(I, Param)

display	= Param.display;
dx    	= Param.dx;
dt     	= Param.dt;
V0    	= Param.V0;

I	= double(I);
Nx	= size(I,2);
Ny	= size(I,1);
Nt	= size(I,3);

Imax = max(I(:));
I = I / Imax;

% Regularization type: 1st order or 2nd order derivative of the
% displacement fields will be minimized
if( isfield(Param,'RegularizerOrder') )
    RegularizerOrder = Param.RegularizerOrder;
else
    RegularizerOrder = 1;
end

% Normalized velocity (pixel/pixel units)
V0_norm = V0 / (dx/dt);

ResolutionLevels    = Param.ResolutionLevels;
NbLoops             = Param.NbLoops;
nbResolutionLevels  = length(ResolutionLevels);


if( ResolutionLevels(end)~=1 )
    noHighres           = 1;
    ResolutionLevels    = [ResolutionLevels 1];
    nbResolutionLevels  = length(ResolutionLevels);
else
    noHighres = 0;
end

method  = 'linear';
Interpolator = Param.Interpolator;

if(display)
    fig1 = figure; set(gcf,'position', [4 29 700 663]), colormap('gray')
    fprintf( 'Normalized velocity: %2.2f\n', V0_norm )
end


if( isfield(Param, 'ref') )
    Iref = I(:,:,Param.ref);
else
    Iref = I(:,:,end);
end

if( isfield(Param, 'UseGradientWeighting') && Param.UseGradientWeighting==1 )
    
    [gx, gy]            = gradient(Iref);
    GradientWeighting   = (gx.^2 + gy.^2).^0.5;
    GradientWeighting   = GradientWeighting ./ mean( abs(GradientWeighting(:)) ) / 2;
    GradientWeighting( GradientWeighting(:)>1 )     = 1;
else
    GradientWeighting = ones( size(Iref) );
end

Iref_r{nbResolutionLevels}              = Iref;
I_r{nbResolutionLevels}                 = I;
GradientWeighting_r{nbResolutionLevels}	= GradientWeighting;

for r = nbResolutionLevels-1:-1:1
    
    Nx_rp1 = floor( Nx * ResolutionLevels(r+1) );
    Ny_rp1 = floor( Ny * ResolutionLevels(r+1) );
    Nt_rp1 = floor( Nt * ResolutionLevels(r+1) );
    
    Nx_r = floor( Nx * ResolutionLevels(r) );
    Ny_r = floor( Ny * ResolutionLevels(r) );
    Nt_r = floor( Nt * ResolutionLevels(r) );
    
    if(Nx_r<1), Nx_r=1; end
    if(Ny_r<1), Ny_r=1; end
    if(Nt_r<2), Nt_r=2; end
    
    [Y, X]    = ndgrid( 0:Ny_rp1-1, 0:Nx_rp1-1 );
    [Yi, Xi] = ndgrid( (Ny_rp1-1-(Ny_r-1)*Ny_rp1/Ny_r)/2 : Ny_rp1/Ny_r : Ny_rp1-1 - (Ny_rp1-1-(Ny_r-1)*Ny_rp1/Ny_r)/2, ...
        (Nx_rp1-1-(Nx_r-1)*Nx_rp1/Nx_r)/2 : Nx_rp1/Nx_r : Nx_rp1-1 - (Nx_rp1-1-(Nx_r-1)*Nx_rp1/Nx_r)/2 );
    
    Iref_r{r}               = interpn( Y,X, Iref_r{r+1}, Yi, Xi, method, 0 );
    [gx, gy]                = gradient(Iref_r{r});
    GradientWeighting_r{r}	= (gx.^2 + gy.^2).^0.5;
    GradientWeighting_r{r} 	= GradientWeighting_r{r} ./ mean( abs(GradientWeighting_r{r}(:)) ) / 2;
    GradientWeighting_r{r}( GradientWeighting_r{r}(:)>1 )     = 1;
    
    if( Nt_r==2 )
        I_r{r}(:,:,1)   = interpn( Y,X, I_r{r+1}(:,:,1), Yi,Xi, method, 0 );
        I_r{r}(:,:,2)   = interpn( Y,X, I_r{r+1}(:,:,2), Yi,Xi, method, 0 );
    else
        
        [Y, X, T]    = ndgrid( 0:Ny_rp1-1, 0:Nx_rp1-1, 0:Nt_rp1-1 );
        [Yi, Xi, Ti] = ndgrid( ...
            (Ny_rp1-1-(Ny_r-1)*Ny_rp1/Ny_r)/2 : Ny_rp1/Ny_r : Ny_rp1-1 - (Ny_rp1-1-(Ny_r-1)*Ny_rp1/Ny_r)/2, ...
            (Nx_rp1-1-(Nx_r-1)*Nx_rp1/Nx_r)/2 : Nx_rp1/Nx_r : Nx_rp1-1 - (Nx_rp1-1-(Nx_r-1)*Nx_rp1/Nx_r)/2, ...
            (Nt_rp1-1-(Nt_r-1)*Nt_rp1/Nt_r)/2 : Nt_rp1/Nt_r : Nt_rp1-1 - (Nt_rp1-1-(Nt_r-1)*Nt_rp1/Nt_r)/2 );
        
        I_r{r}     = interpn( Y,X,T, I_r{r+1}, Yi,Xi,Ti, method, 0 );
        
    end
    
end

fprintf('Low resolution images created for multiresolution\n')
clear X Y T Xi Yi Ti

res = zeros(nbResolutionLevels,NbLoops);

for r=1:nbResolutionLevels
    
    Nx_r = floor( Nx * ResolutionLevels(r) );
    Ny_r = floor( Ny * ResolutionLevels(r) );
    Nt_r = floor( Nt * ResolutionLevels(r) );
    
    if(Nx_r<2), Nx_r=2; end
    if(Ny_r<2), Ny_r=2; end
    if(Nt_r<2), Nt_r=2; end
    
    if( noHighres==1 )
        if( r==nbResolutionLevels )
            Nt_r = Nt;
        end
    else
        if( r==nbResolutionLevels )
            Nt_r = Nt;
        end
    end
    N3 = Nx_r*Ny_r*Nt_r;
    
    
    if(r==1)
        Ux  = zeros(Ny_r,Nx_r,Nt_r);
        Uy  = zeros(Ny_r,Nx_r,Nt_r);
        I_r_registered = I_r{r};
    else
        Ux = BestUx;
        Uy = BestUy;
        
        % Interpolate displacement fields to the new resolution
        [Y, X, T] = ndgrid( 0:Ny_r_last-1, 0:Nx_r_last-1, 0:Nt_r_last-1 );
        [Yi, Xi, Ti] = ndgrid( ...
            (Ny_r_last-1-(Ny_r-1)*Ny_r_last/Ny_r)/2 : Ny_r_last/Ny_r : Ny_r_last-1 - (Ny_r_last-1-(Ny_r-1)*Ny_r_last/Ny_r)/2, ...
            (Nx_r_last-1-(Nx_r-1)*Nx_r_last/Nx_r)/2 : Nx_r_last/Nx_r : Nx_r_last-1 - (Nx_r_last-1-(Nx_r-1)*Nx_r_last/Nx_r)/2, ...
            (Nt_r_last-1-(Nt_r-1)*Nt_r_last/Nt_r)/2 : Nt_r_last/Nt_r : Nt_r_last-1 - (Nt_r_last-1-(Nt_r-1)*Nt_r_last/Nt_r)/2 );
        
        Ux = interpn( Y,X,T, Ux, Yi,Xi,Ti, 'linear', 0 ) * Nx_r/Nx_r_last;
        Uy = interpn( Y,X,T, Uy, Yi,Xi,Ti, 'linear', 0 ) * Ny_r/Ny_r_last;
        
        if( Nt_r>Nt_r_last )
            Ux(:,:,1)   = Ux(:,:,2);
            Uy(:,:,1)   = Uy(:,:,2);
            Ux(:,:,end) = Ux(:,:,end-1);
            Uy(:,:,end) = Uy(:,:,end-1);
        end
        
        if( r==nbResolutionLevels && noHighres )
            BestUx     = Ux;
            BestUy     = Uy;
            break;
        end
    end
    BestResidue = Inf;
    
    
    nx = size(I_r{r},2);
    ny = size(I_r{r},1);
    nz = size(I_r{r},3);
    
    % Nabla 3D operator (forward differences)
    ndx = 1:nx*ny*nz;
    ndx = reshape(ndx, [ny,nx,nz]);
    ndx = ndx(1:ny-1, :, :);
    ndx = ndx(:);
    rowIndices  = repmat( ndx, [2 1]);
    colIndices  = [ ndx ; ndx+1 ];
    values      = [ -ones(size(ndx)) ; ones(size(ndx)) ];
    NablaY      = sparse( rowIndices, colIndices, values, nx*ny*nz, nx*ny*nz );
    
    ndx = 1:nx*ny*nz;
    ndx = reshape(ndx, [ny,nx,nz]);
    ndx = ndx(:, 1:nx-1, :);
    ndx = ndx(:);
    rowIndices  = repmat( ndx, [2 1]);
    colIndices  = [ ndx ; ndx+ny ];
    values      = [ -ones(size(ndx)) ; ones(size(ndx)) ];
    NablaX      = sparse( rowIndices, colIndices, values, nx*ny*nz, nx*ny*nz );
    
    ndx = 1:nx*ny*nz;
    ndx = reshape(ndx, [ny,nx,nz]);
    ndx = ndx(:, :, 1:nz-1);
    ndx = ndx(:);
    rowIndices  = repmat( ndx, [2 1]);
    colIndices  = [ ndx ; ndx+ny*nx ];
    values      = [ -ones(size(ndx)) ; ones(size(ndx)) ];
    NablaT      = sparse( rowIndices, colIndices, values, nx*ny*nz, nx*ny*nz );
    
    NablaX = (Nx/Nx_r) * NablaX;
    NablaY = (Ny/Ny_r) * NablaY;
    NablaT = (Nt/Nt_r) * NablaT;
    NablaT = NablaT / V0_norm;
    
    mask            = GradientWeighting_r{r};
    mask_rep        = repmat(mask, [1 1 Nt_r]);
    mask2_rep       = repmat(mask.^2, [1 1 Nt_r]);
    
    if( RegularizerOrder == -1 )
        
        % Total variation regularizer
        if( r==1 )
            
            Wxx  	= speye( N3, N3 );
            Wxy     = speye( N3, N3 );
            Wyx     = speye( N3, N3 );
            Wyy     = speye( N3, N3 );
            
        else
            
            gxx = NablaX*Ux(:);
            gxy = NablaY*Ux(:);
            gyx = NablaX*Uy(:);
            gyy = NablaY*Uy(:);
            
            beta  = 1.e-2 * mean( abs([gxx(:); gxy(:); gyx(:); gyy(:)]) );
            beta2 = beta^2;
            Wxx      = spdiags( beta * (gxx.^2 + beta2).^(-0.5), 0, N3, N3 );
            Wxy      = spdiags( beta * (gxy.^2 + beta2).^(-0.5), 0, N3, N3 );
            Wyx      = spdiags( beta * (gyx.^2 + beta2).^(-0.5), 0, N3, N3 );
            Wyy      = spdiags( beta * (gyy.^2 + beta2).^(-0.5), 0, N3, N3 );
            
        end
        
        Rx  = NablaX' * (Wxx * NablaX) + NablaY' * (Wxy * NablaY) + NablaT' * NablaT;
        Ry  = NablaX' * (Wyx * NablaX) + NablaY' * (Wyy * NablaY) + NablaT' * NablaT;
        
    elseif( RegularizerOrder==1 )
        
        % Regularization operator: first order derivative of displacements
        R  = NablaX' * NablaX + NablaY' * NablaY + NablaT' * NablaT;
        
    elseif( RegularizerOrder==2 )
        
        % Regularization operator: second order derivative of displacements
        NablaX2 = NablaX^2;
        NablaY2 = NablaY^2;
        NablaT2 = NablaT^2;
        NablaXY = 2*NablaX*NablaY;
        NablaXT = 2*NablaX*NablaT;
        NablaYT = 2*NablaY*NablaT;
        
        R	= NablaX2'*NablaX2 + NablaY2'*NablaY2 + NablaT2'*NablaT2 ...
            + NablaXY'*NablaXY + NablaXT'*NablaXT + NablaYT'*NablaYT;
    end
    clear Nabla*
    
    for loop = 1:NbLoops
        
        if( ~(r==1 && loop==1) )
            
            % Apply displacement field at time t using the motion model
            I_r_registered = zeros(Ny_r, Nx_r, Nt_r);
            for t=1:Nt_r
                U_t = zeros(Ny_r,Nx_r,1,2);
                U_t(:,:,:,1) = Uy(:,:,t);
                U_t(:,:,:,2) = Ux(:,:,t);
                I_r_registered(:,:,t) = interp_mex( I_r{r}(:,:,t), U_t, ...
                    Interpolator.LookupTable, Interpolator.nbins, Interpolator.W, 0 );
            end
            
            if(display)
                Iref_r_rep  = repmat( Iref_r{r}, [1 1 Nt_r] );
                diff_unreg  = I_r{r}         - Iref_r_rep;
                diff_reg    = I_r_registered - Iref_r_rep;
                
                diff_max = max( [diff_unreg(:) ; diff_reg(:)] );
                figure(fig1)
                subplot(331), cla, imagesc( Iref_r{r} ), title('Target'); clim=get(gca,'clim');
                subplot(332), cla, imagesc( mean( I_r{r}           , 3 ), clim), title('Unregistered')
                subplot(333), cla, imagesc( mean( I_r_registered, 3 ), clim), title('Registered')
                subplot(334), cla, imagesc( [ squeeze(Ux(:,round(end/2),:)); squeeze(Uy(:,round(end/2),:)) ] ), title('Ux, Uy (y-t view)')
                subplot(335), cla, imagesc( mean( abs(diff_unreg), 3 ), [-diff_max diff_max]), title('Unregistered difference')
                subplot(336), cla, imagesc( mean( abs(diff_reg), 3 )  , [-diff_max diff_max]), title('Registered difference')
                pause(0.01)
                
            end
            
        end
        
        Iref_r_rep = repmat( Iref_r{r}, [1 1 Nt_r] );
        [Ix, Iy, ~] = gradient(I_r_registered);
        Is = ( I_r_registered - Iref_r_rep ) .* mask_rep;
        
        epsilon =  norm( Is(:) );
        res(r,loop) = norm( epsilon(:) );
        
        if( res(r,loop)<BestResidue )
            BestResidue = res(r,loop);
            BestUx      = Ux;
            BestUy      = Uy;
        else
            break;
        end
        
        % Some partial derivatives of the image stack
        Ix2     = Ix.^2     .* mask2_rep;
        Iy2     = Iy.^2     .* mask2_rep;
        IxIy    = Ix .* Iy  .* mask2_rep;
        
        rowIndices = [                 1 :   Nx_r*Ny_r*Nt_r, ...
            Nx_r*Ny_r*Nt_r+1 : 2*Nx_r*Ny_r*Nt_r, ...
            1 :   Nx_r*Ny_r*Nt_r, ...
            Nx_r*Ny_r*Nt_r+1 : 2*Nx_r*Ny_r*Nt_r ];
        
        colIndices = [                 1 :   Nx_r*Ny_r*Nt_r, ...
            Nx_r*Ny_r*Nt_r+1 : 2*Nx_r*Ny_r*Nt_r, ...
            Nx_r*Ny_r*Nt_r+1 : 2*Nx_r*Ny_r*Nt_r, ...
            1 :   Nx_r*Ny_r*Nt_r ];
        
        values =  [ Iy2(:) ; Ix2(:) ; IxIy(:) ; IxIy(:) ];
        
        AhA	= sparse( rowIndices, colIndices, values, 2*N3, 2*N3, 4*N3 );
        
        clear rowIndices colIndices values
        
        tol     = 1.e-3;
        maxit   = 128;
        
        if( RegularizerOrder<0 )
            Reg     = [ Ry sparse( N3, N3 ) ; sparse( N3, N3 ) Rx ];
        else
            Reg     = [ R sparse( N3, N3 ) ; sparse( N3, N3 ) R ];
        end
        
        Ah_b    = -[ Iy(:).*Is(:) ; Ix(:).*Is(:) ];
        lambda  = Param.lambda * norm(Ah_b(:));
        
        Ah_b    = Ah_b - lambda * Reg * [ Uy(:) ; Ux(:) ];
        Hessian = AhA + lambda*Reg;
        clear AhA
        
        % Gauss-Siedel preconditioner : C = (D+L)*inv(D)*(D+L')
        Dinv = spdiags( 1 ./ spdiags(Hessian,0), 0, 2*N3, 2*N3 );
        C1 = tril(Hessian); % D+L
        C2 = Dinv * triu(Hessian); % inv(D)*(D+U)
        [u, flag, ~, iter, ~] = pcg( Hessian, Ah_b,tol,maxit, C1, C2 );
        
        % figure, plot(resvec)
        if(display)
            fprintf( '*** Conjugate gradient: flag=%d, iter=%d\n', flag, iter(end) );
        end
        
        len             = length(u);
        dim1_indices    = 1 : len/2;
        dim2_indices    = len/2+1 : len;
        
        dUy = reshape( u(dim1_indices), [Ny_r Nx_r Nt_r] );
        dUx = reshape( u(dim2_indices), [Ny_r Nx_r Nt_r] );
        Ux = Ux + dUx;
        Uy = Uy + dUy;
        
        if(display)
            fprintf( 'r=%d, loop=%d, res=%3.3f\n', r, loop, res(r,loop)./(Nt_r*norm(I_r{r}(:))) )
            fprintf('norm(Ux)=%1.3f, norm(Uy)=%1.3f\n',  norm(Ux(:)), norm(Uy(:)) )
            figure(fig1), subplot(3,nbResolutionLevels,2*nbResolutionLevels+r)
            plot( 1:loop, res(r,1:loop)./norm(I_r{r}(:)), '.-' )
            set(gca, 'xlim', [1 NbLoops])
            title(r)
        end
        
    end
    
    Nx_r_last = Nx_r;
    Ny_r_last = Ny_r;
    Nt_r_last = Nt_r;
    
end

% Register profile images
Ux = BestUx;
Uy = BestUy;

% Apply displacement field at time t using the motion model
I_r_registered = zeros(Ny, Nx, Nt);
for t = 1:Nt
    U_t = zeros(Ny,Nx,1,2);
    U_t(:,:,:,1) = Uy(:,:,t);
    U_t(:,:,:,2) = Ux(:,:,t);
    I_r_registered(:,:,t) = interp_mex( I_r{r}(:,:,t), U_t, Interpolator.LookupTable, Interpolator.nbins, Interpolator.W, 0 );
end

if(display)
    
    Iref_r_rep = repmat( Iref_r{r}, [1 1 Nt_r] );
    diff_unreg  = I_r{r}         - Iref_r_rep;
    diff_reg    = I_r_registered - Iref_r_rep;
    
    figure(fig1)
    subplot(331), cla, imagesc( Iref_r{r} ), title('Target'); clim=get(gca,'clim');
    subplot(332), cla, imagesc( mean( I_r{r}        , 3 ), clim), title('Unregistered')
    subplot(333), cla, imagesc( mean( I_r_registered, 3 ), clim), title('Registered')
    subplot(334), cla, imagesc( [ squeeze(Ux(:,round(end/2),:)); squeeze(Uy(:,round(end/2),:)) ] ), title('Ux, Uy (y-t view)')
    subplot(335), cla, imagesc( mean( abs(diff_unreg), 3 ), [-clim(2) clim(2)]), title('Unregistered difference')
    subplot(336), cla, imagesc( mean( abs(diff_reg), 3 )  , [-clim(2) clim(2)]), title('Registered difference')
    pause(0.01);
    
end

I_reg = I_r_registered * Imax;

end

