/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.estim;

import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;

public class EstimatorDensityMap
extends SparsityEstimator {
    private static final int BLOCK_SIZE = 256;
    private final int _b;

    public EstimatorDensityMap() {
        this(256);
    }

    public EstimatorDensityMap(int blocksize) {
        this._b = blocksize;
    }

    @Override
    public DataCharacteristics estim(MMNode root) {
        DensityMap m1Map = this.getCachedSynopsis(root.getLeft());
        DensityMap m2Map = this.getCachedSynopsis(root.getRight());
        DensityMap outMap = this.estimIntern(m1Map, m2Map, root.getOp());
        root.setSynopsis(outMap);
        return root.setDataCharacteristics(new MatrixCharacteristics((long)outMap.getNumRowsOrig(), (long)outMap.getNumColumnsOrig(), outMap.getNonZeros()));
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2) {
        return this.estim(m1, m2, SparsityEstimator.OpCode.MM);
    }

    @Override
    public double estim(MatrixBlock m1, MatrixBlock m2, SparsityEstimator.OpCode op) {
        if (this.isExactMetadataOp(op)) {
            return this.estimExactMetaData(m1.getDataCharacteristics(), m2 != null ? m2.getDataCharacteristics() : null, op).getSparsity();
        }
        DensityMap m1Map = new DensityMap(m1, this._b);
        DensityMap m2Map = m1 == m2 || m2 == null ? m1Map : new DensityMap(m2, this._b);
        DensityMap outMap = this.estimIntern(m1Map, m2Map, op);
        return OptimizerUtils.getSparsity(outMap.getNumRowsOrig(), outMap.getNumColumnsOrig(), outMap.getNonZeros());
    }

    @Override
    public double estim(MatrixBlock m, SparsityEstimator.OpCode op) {
        return this.estim(m, null, op);
    }

    private DensityMap getCachedSynopsis(MMNode node) {
        if (node == null) {
            return null;
        }
        if (node.isLeaf() && node.getSynopsis() == null) {
            node.setSynopsis(new DensityMap(node.getData(), this._b));
        } else if (!node.isLeaf()) {
            this.estim(node);
        }
        return (DensityMap)node.getSynopsis();
    }

    public DensityMap estimIntern(DensityMap m1Map, DensityMap m2Map, SparsityEstimator.OpCode op) {
        switch (op) {
            case MM: {
                return this.estimInternMM(m1Map, m2Map);
            }
            case MULT: {
                return this.estimInternMult(m1Map, m2Map);
            }
            case PLUS: {
                return this.estimInternPlus(m1Map, m2Map);
            }
            case NEQZERO: {
                return m1Map;
            }
            case EQZERO: {
                return this.estimInternEqZero(m1Map);
            }
            case RBIND: 
            case CBIND: 
            case TRANS: {
                return this.estimInternTrans(m1Map);
            }
            case DIAG: {
                return this.estimInternDiag(m1Map);
            }
            case RESHAPE: {
                return EstimatorDensityMap.estimInternReshape(m1Map);
            }
        }
        throw new NotImplementedException();
    }

    private DensityMap estimInternMM(DensityMap m1Map, DensityMap m2Map) {
        int m = m1Map.getNumRows();
        int cd = m1Map.getNumColumns();
        int n = m2Map.getNumColumns();
        MatrixBlock out = new MatrixBlock(m1Map.getNumRows(), m2Map.getNumColumns(), false);
        DenseBlock c = out.allocateBlock().getDenseBlock();
        m1Map.toSparsity();
        m2Map.toSparsity();
        for (int i = 0; i < m; ++i) {
            for (int k = 0; k < cd; ++k) {
                int lbk = m1Map.getColBlockize(k);
                double sp1 = m1Map.get(i, k);
                if (sp1 == 0.0) continue;
                for (int j = 0; j < n; ++j) {
                    double sp2 = m2Map.get(k, j);
                    if (sp2 == 0.0) continue;
                    double tmp1 = 1.0 - Math.pow(1.0 - sp1 * sp2, lbk);
                    double tmp2 = c.get(i, j);
                    c.set(i, j, tmp1 + tmp2 - tmp1 * tmp2);
                }
            }
        }
        out.recomputeNonZeros();
        return new DensityMap(out, m1Map.getNumRowsOrig(), m2Map.getNumColumnsOrig(), this._b, true);
    }

    private DensityMap estimInternMult(DensityMap m1Map, DensityMap m2Map) {
        MatrixBlock out = new MatrixBlock(m1Map.getNumRows(), m1Map.getNumColumns(), false);
        DenseBlock c = out.allocateBlock().getDenseBlock();
        m1Map.toSparsity();
        m2Map.toSparsity();
        for (int i = 0; i < m1Map.getNumRows(); ++i) {
            for (int j = 0; j < m1Map.getNumColumns(); ++j) {
                c.set(i, j, m1Map.get(i, j) * m2Map.get(i, j));
            }
        }
        out.recomputeNonZeros();
        return new DensityMap(out, m1Map.getNumRowsOrig(), m1Map.getNumColumnsOrig(), this._b, true);
    }

    private DensityMap estimInternPlus(DensityMap m1Map, DensityMap m2Map) {
        MatrixBlock out = new MatrixBlock(m1Map.getNumRows(), m1Map.getNumColumns(), false);
        DenseBlock c = out.allocateBlock().getDenseBlock();
        m1Map.toSparsity();
        m2Map.toSparsity();
        for (int i = 0; i < m1Map.getNumRows(); ++i) {
            for (int j = 0; j < m1Map.getNumColumns(); ++j) {
                double sp1 = m1Map.get(i, j);
                double sp2 = m2Map.get(i, j);
                c.set(i, j, sp1 + sp2 - sp1 * sp2);
            }
        }
        out.recomputeNonZeros();
        return new DensityMap(out, m1Map.getNumRowsOrig(), m1Map.getNumColumnsOrig(), this._b, true);
    }

    private DensityMap estimInternTrans(DensityMap m1Map) {
        MatrixBlock out = LibMatrixReorg.transpose(m1Map.getMap(), new MatrixBlock(m1Map.getNumColumns(), m1Map.getNumRows(), false));
        return new DensityMap(out, m1Map.getNumColumnsOrig(), m1Map.getNumRowsOrig(), this._b, m1Map._scaled);
    }

    private DensityMap estimInternDiag(DensityMap m1Map) {
        if (m1Map.getNumColumnsOrig() > 1) {
            throw new NotImplementedException();
        }
        m1Map.toNnz();
        MatrixBlock out = LibMatrixReorg.diag(m1Map.getMap(), new MatrixBlock(m1Map.getNumRows(), m1Map.getNumRows(), false));
        return new DensityMap(out, m1Map.getNumRowsOrig(), m1Map.getNumRowsOrig(), this._b, m1Map._scaled);
    }

    private static DensityMap estimInternReshape(DensityMap m1Map) {
        MatrixBlock out = new MatrixBlock(1, 1, (double)m1Map.getNonZeros());
        int b = Math.max(m1Map.getNumRowsOrig(), m1Map.getNumColumnsOrig());
        return new DensityMap(out, m1Map.getNumRowsOrig(), m1Map.getNumColumnsOrig(), b, false);
    }

    private DensityMap estimInternEqZero(DensityMap m1Map) {
        MatrixBlock out = new MatrixBlock(m1Map.getNumRows(), m1Map.getNumColumns(), false);
        m1Map.toSparsity();
        for (int i = 0; i < m1Map.getNumRows(); ++i) {
            for (int j = 0; j < m1Map.getNumColumns(); ++j) {
                out.quickSetValue(i, j, 1.0 - m1Map.get(i, j));
            }
        }
        return new DensityMap(out, m1Map.getNumRowsOrig(), m1Map.getNumColumnsOrig(), this._b, m1Map._scaled);
    }

    public static class DensityMap {
        private final MatrixBlock _map;
        private final int _rlen;
        private final int _clen;
        private final int _b;
        private boolean _scaled;

        public DensityMap(MatrixBlock in, int b) {
            this._rlen = in.getNumRows();
            this._clen = in.getNumColumns();
            this._b = b;
            this._map = this.init(in);
            this._scaled = false;
            if (!DensityMap.isPow2(this._b)) {
                System.out.println("WARN: Invalid block size: " + this._b);
            }
        }

        public DensityMap(MatrixBlock map, int rlenOrig, int clenOrig, int b, boolean scaled) {
            this._rlen = rlenOrig;
            this._clen = clenOrig;
            this._b = b;
            this._map = map;
            this._scaled = scaled;
            if (!DensityMap.isPow2(this._b)) {
                System.out.println("WARN: Invalid block size: " + this._b);
            }
        }

        public MatrixBlock getMap() {
            return this._map;
        }

        public int getNumRows() {
            return this._map.getNumRows();
        }

        public int getNumColumns() {
            return this._map.getNumColumns();
        }

        public int getNumRowsOrig() {
            return this._rlen;
        }

        public int getNumColumnsOrig() {
            return this._clen;
        }

        public long getNonZeros() {
            if (this._scaled) {
                this.toNnz();
            }
            return Math.round(this._map.sum());
        }

        public int getRowBlockize(int r) {
            return UtilFunctions.computeBlockSize(this._rlen, r + 1, this._b);
        }

        public int getColBlockize(int c) {
            return UtilFunctions.computeBlockSize(this._clen, c + 1, this._b);
        }

        public double get(int r, int c) {
            return this._map.quickGetValue(r, c);
        }

        public void toSparsity() {
            if (this._scaled) {
                return;
            }
            int rlen = this._map.getNumRows();
            int clen = this._map.getNumColumns();
            DenseBlock c = this._map.getDenseBlock();
            for (int i = 0; i < rlen; ++i) {
                int lrlen = this.getRowBlockize(i);
                for (int j = 0; j < clen; ++j) {
                    double cval = c.get(i, j);
                    if (cval == 0.0) continue;
                    c.set(i, j, cval / (double)lrlen / (double)this.getColBlockize(j));
                }
            }
            this._scaled = true;
        }

        public void toNnz() {
            if (!this._scaled) {
                return;
            }
            int rlen = this._map.getNumRows();
            int clen = this._map.getNumColumns();
            DenseBlock c = this._map.getDenseBlock();
            for (int i = 0; i < rlen; ++i) {
                int lrlen = this.getRowBlockize(i);
                for (int j = 0; j < clen; ++j) {
                    double cval = c.get(i, j);
                    if (cval == 0.0) continue;
                    c.set(i, j, cval * (double)lrlen * (double)this.getColBlockize(j));
                }
            }
            this._scaled = false;
        }

        private MatrixBlock init(MatrixBlock in) {
            int rlen = (int)Math.ceil((double)this._rlen / (double)this._b);
            int clen = (int)Math.ceil((double)this._clen / (double)this._b);
            MatrixBlock out = new MatrixBlock(rlen, clen, false);
            if (in.isEmptyBlock(false)) {
                return out;
            }
            DenseBlock c = out.allocateBlock().getDenseBlock();
            if (in.getLength() == in.getNonZeros()) {
                c.set(1.0);
                out.setNonZeros(in.getLength());
                return out;
            }
            if (in.isInSparseFormat()) {
                SparseBlock sblock = in.getSparseBlock();
                for (int i = 0; i < in.getNumRows(); ++i) {
                    if (sblock.isEmpty(i)) continue;
                    int alen = sblock.size(i);
                    int apos = sblock.pos(i);
                    int[] aix = sblock.indexes(i);
                    for (int k = apos; k < apos + alen; ++k) {
                        c.incr(i / this._b, aix[k] / this._b);
                    }
                }
            } else {
                for (int i = 0; i < this._rlen; ++i) {
                    for (int j = 0; j < this._clen; ++j) {
                        double aval = in.quickGetValue(i, j);
                        if (aval == 0.0) continue;
                        c.incr(i / this._b, j / this._b);
                    }
                }
            }
            out.recomputeNonZeros();
            return out;
        }

        private static boolean isPow2(int value) {
            double tmp = Math.log(value) / Math.log(2.0);
            return Math.floor(tmp) == Math.ceil(tmp);
        }
    }
}

