package tpietzsch.t5align;

import Jama.Matrix;
import java.util.Iterator;
import net.imglib2.Cursor;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.Img;
import net.imglib2.img.ImgFactory;
import net.imglib2.iterator.LocalizingIntervalIterator;
import net.imglib2.realtransform.AffineTransform;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.NumericType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.LinAlgHelpers;
import net.imglib2.util.Util;
import net.imglib2.view.Views;
import tpietzsch.t3gradient.completed.GradientExample2;
import tpietzsch.t4realviews.completed.RealViewsExample3;

/* loaded from: input_file:tpietzsch/t5align/Align.class */
public class Align<T extends RealType<T>> {
    final RandomAccessibleInterval<T> template;
    final WarpFunction warpFunction;
    final int n;
    final int numParameters;
    final AffineTransform currentTransform;
    public final Img<T> descent;
    double[][] Hinv;
    final Img<T> error;

    public Align(RandomAccessibleInterval<T> randomAccessibleInterval, ImgFactory<T> imgFactory) {
        this.template = randomAccessibleInterval;
        RealType realType = (RealType) Util.getTypeFromInterval(randomAccessibleInterval);
        this.n = randomAccessibleInterval.numDimensions();
        this.warpFunction = new AffineWarp(this.n);
        this.numParameters = this.warpFunction.numParameters();
        this.currentTransform = new AffineTransform(this.n);
        long[] jArr = new long[this.n + 1];
        for (int i = 0; i < this.n; i++) {
            jArr[i] = randomAccessibleInterval.dimension(i);
        }
        jArr[this.n] = this.n;
        Img create = imgFactory.create(jArr, realType);
        GradientExample2.gradients(Views.extendBorder(randomAccessibleInterval), create);
        jArr[this.n] = this.numParameters;
        this.descent = imgFactory.create(jArr, realType);
        computeSteepestDescents(create, this.warpFunction, this.descent);
        this.Hinv = computeInverseHessian(this.descent);
        this.error = imgFactory.create(randomAccessibleInterval, realType);
    }

    public static <T extends NumericType<T>> void computeSteepestDescents(RandomAccessibleInterval<T> randomAccessibleInterval, WarpFunction warpFunction, RandomAccessibleInterval<T> randomAccessibleInterval2) {
        int numDimensions = randomAccessibleInterval.numDimensions() - 1;
        int numParameters = warpFunction.numParameters();
        NumericType createVariable = ((NumericType) Util.getTypeFromInterval(randomAccessibleInterval)).createVariable();
        for (int i = 0; i < numParameters; i++) {
            for (int i2 = 0; i2 < numDimensions; i2++) {
                Cursor localizingCursor = Views.flatIterable(Views.hyperSlice(randomAccessibleInterval, numDimensions, i2)).localizingCursor();
                for (NumericType numericType : Views.flatIterable(Views.hyperSlice(randomAccessibleInterval2, numDimensions, i))) {
                    createVariable.set((Type) localizingCursor.next());
                    createVariable.mul(warpFunction.partial(localizingCursor, i2, i));
                    numericType.add(createVariable);
                }
            }
        }
    }

    public static <T extends RealType<T>> double[][] computeInverseHessian(RandomAccessibleInterval<T> randomAccessibleInterval) {
        int numDimensions = randomAccessibleInterval.numDimensions() - 1;
        int dimension = (int) randomAccessibleInterval.dimension(numDimensions);
        long[] jArr = new long[numDimensions + 1];
        randomAccessibleInterval.dimensions(jArr);
        jArr[numDimensions] = 1;
        LocalizingIntervalIterator localizingIntervalIterator = new LocalizingIntervalIterator(jArr);
        RandomAccess randomAccess = randomAccessibleInterval.randomAccess();
        double[] dArr = new double[dimension];
        double[][] dArr2 = new double[dimension][dimension];
        while (localizingIntervalIterator.hasNext()) {
            localizingIntervalIterator.fwd();
            randomAccess.setPosition(localizingIntervalIterator);
            for (int i = 0; i < dimension; i++) {
                dArr[i] = ((RealType) randomAccess.get()).getRealDouble();
                randomAccess.fwd(numDimensions);
            }
            for (int i2 = 0; i2 < dimension; i2++) {
                for (int i3 = 0; i3 < dimension; i3++) {
                    double[] dArr3 = dArr2[i2];
                    int i4 = i3;
                    dArr3[i4] = dArr3[i4] + (dArr[i2] * dArr[i3]);
                }
            }
        }
        return new Matrix(dArr2).inverse().getArray();
    }

    public AffineTransform align(RandomAccessibleInterval<T> randomAccessibleInterval, int i, double d) {
        this.currentTransform.set(new AffineTransform(this.n));
        int i2 = 0;
        while (i2 < i) {
            i2++;
            if (alignStep(randomAccessibleInterval) < d) {
                break;
            }
        }
        System.out.println("computed " + i2 + " iterations.");
        return this.currentTransform;
    }

    double alignStep(RandomAccessibleInterval<T> randomAccessibleInterval) {
        RealViewsExample3.computeDifference(Views.extendBorder(randomAccessibleInterval), this.currentTransform, this.template, this.error);
        double[] dArr = new double[this.numParameters];
        for (int i = 0; i < this.numParameters; i++) {
            Cursor cursor = Views.flatIterable(this.error).cursor();
            Iterator it = Views.flatIterable(Views.hyperSlice(this.descent, this.n, i)).iterator();
            while (it.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] + (((RealType) it.next()).getRealDouble() * ((RealType) cursor.next()).getRealDouble());
            }
        }
        double[] dArr2 = new double[this.numParameters];
        LinAlgHelpers.mult(this.Hinv, dArr, dArr2);
        this.currentTransform.preConcatenate(this.warpFunction.getAffine(dArr2));
        return LinAlgHelpers.length(dArr2);
    }
}
