/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.UnaryFEDInstruction;
import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;

public final class IndexingFEDInstruction
extends UnaryFEDInstruction {
    protected final CPOperand rowLower;
    protected final CPOperand rowUpper;
    protected final CPOperand colLower;
    protected final CPOperand colUpper;

    protected IndexingFEDInstruction(CPOperand in, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.MatrixIndexing, null, in, out, opcode, istr);
        this.rowLower = rl;
        this.rowUpper = ru;
        this.colLower = cl;
        this.colUpper = cu;
    }

    protected IndexingFEDInstruction(CPOperand lhsInput, CPOperand rhsInput, CPOperand rl, CPOperand ru, CPOperand cl, CPOperand cu, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.MatrixIndexing, null, lhsInput, rhsInput, out, opcode, istr);
        this.rowLower = rl;
        this.rowUpper = ru;
        this.colLower = cl;
        this.colUpper = cu;
    }

    protected IndexRange getIndexRange(ExecutionContext ec) {
        return new IndexRange((int)(ec.getScalarInput(this.rowLower).getLongValue() - 1L), (int)(ec.getScalarInput(this.rowUpper).getLongValue() - 1L), (int)(ec.getScalarInput(this.colLower).getLongValue() - 1L), (int)(ec.getScalarInput(this.colUpper).getLongValue() - 1L));
    }

    public static IndexingFEDInstruction parseInstruction(IndexingCPInstruction instr) {
        return new IndexingFEDInstruction(instr.input1, instr.input2, instr.getRowLower(), instr.getRowUpper(), instr.getColLower(), instr.getColUpper(), instr.output, instr.getOpcode(), instr.getInstructionString());
    }

    public static IndexingFEDInstruction parseInstruction(IndexingSPInstruction instr) {
        return new IndexingFEDInstruction(instr.input1, instr.input2, instr.getRowLower(), instr.getRowUpper(), instr.getColLower(), instr.getColUpper(), instr.output, instr.getOpcode(), instr.getInstructionString());
    }

    public static IndexingFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase(Opcodes.RIGHT_INDEX.toString())) {
            if (parts.length == 7 || parts.length == 8) {
                CPOperand in = new CPOperand(parts[1]);
                CPOperand rl = new CPOperand(parts[2]);
                CPOperand ru = new CPOperand(parts[3]);
                CPOperand cl = new CPOperand(parts[4]);
                CPOperand cu = new CPOperand(parts[5]);
                CPOperand out = new CPOperand(parts[6]);
                if (in.getDataType() != Types.DataType.MATRIX && in.getDataType() != Types.DataType.FRAME) {
                    throw new DMLRuntimeException("Can index only on matrices, frames in federated.");
                }
                return new IndexingFEDInstruction(in, rl, ru, cl, cu, out, opcode, str);
            }
            throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
        }
        if (opcode.equalsIgnoreCase(Opcodes.LEFT_INDEX.toString()) || opcode.equalsIgnoreCase("mapLeftIndex")) {
            if (parts.length == 8 || parts.length == 9) {
                CPOperand lhsInput = new CPOperand(parts[1]);
                CPOperand rhsInput = new CPOperand(parts[2]);
                CPOperand rl = new CPOperand(parts[3]);
                CPOperand ru = new CPOperand(parts[4]);
                CPOperand cl = new CPOperand(parts[5]);
                CPOperand cu = new CPOperand(parts[6]);
                CPOperand out = new CPOperand(parts[7]);
                if (lhsInput.getDataType() != Types.DataType.MATRIX && lhsInput.getDataType() != Types.DataType.FRAME && rhsInput.getDataType() != Types.DataType.MATRIX && rhsInput.getDataType() != Types.DataType.FRAME) {
                    throw new DMLRuntimeException("Can index only on matrices, frames, and lists.");
                }
                return new IndexingFEDInstruction(lhsInput, rhsInput, rl, ru, cl, cu, out, opcode, str);
            }
            throw new DMLRuntimeException("Invalid number of operands in instruction: " + str);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a MatrixIndexingFEDInstruction: " + str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        if (this.getOpcode().equalsIgnoreCase(Opcodes.RIGHT_INDEX.toString())) {
            this.rightIndexing(ec);
        } else {
            this.leftIndexing(ec);
        }
    }

    private void rightIndexing(ExecutionContext ec) {
        CacheableData<?> in = ec.getCacheableData(this.input1);
        IndexRange ixrange = this.getIndexRange(ec);
        FederationMap fedMap = in.getFedMapping().filter(ixrange);
        String[] instStrings = new String[fedMap.getSize()];
        List<Object> schema = new ArrayList();
        int i = 0;
        for (Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
            FederatedRange range = (FederatedRange)e.getKey();
            long rs = range.getBeginDims()[0];
            long re = range.getEndDims()[0];
            long cs = range.getBeginDims()[1];
            long ce = range.getEndDims()[1];
            long rsn = ixrange.rowStart >= rs ? ixrange.rowStart - rs : 0L;
            long ren = ixrange.rowEnd >= rs && ixrange.rowEnd < re ? ixrange.rowEnd - rs : re - rs - 1L;
            long csn = ixrange.colStart >= cs ? ixrange.colStart - cs : 0L;
            long cen = ixrange.colEnd >= cs && ixrange.colEnd < ce ? ixrange.colEnd - cs : ce - cs - 1L;
            range.setBeginDim(0, Math.max(rs - ixrange.rowStart, 0L));
            range.setBeginDim(1, Math.max(cs - ixrange.colStart, 0L));
            range.setEndDim(0, ixrange.rowEnd >= re ? re - ixrange.rowStart : ixrange.rowEnd - ixrange.rowStart + 1L);
            range.setEndDim(1, ixrange.colEnd >= ce ? ce - ixrange.colStart : ixrange.colEnd - ixrange.colStart + 1L);
            long[] newIx = new long[]{rsn, ren, csn, cen};
            instStrings[i] = this.modifyIndices(newIx, 3, 7);
            if (this.input1.isFrame()) {
                if (in.isFederated(FTypes.FType.ROW)) {
                    schema = Arrays.asList(((FrameObject)in).getSchema((int)csn, (int)cen));
                } else {
                    Collections.addAll(schema, ((FrameObject)in).getSchema((int)csn, (int)cen));
                }
            }
            ++i;
        }
        long id = FederationUtils.getNextFedDataID();
        FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{in.getMetaData().getDataCharacteristics(), in.getDataType()});
        Types.ExecType execType = InstructionUtils.getExecType(this.instString);
        if (execType == Types.ExecType.FED) {
            execType = Types.ExecType.CP;
        }
        FederatedRequest[] fr1 = FederationUtils.callInstruction(instStrings, this.output, id, new CPOperand[]{this.input1}, new long[]{fedMap.getID()}, execType);
        fedMap.execute(this.getTID(), true, tmp);
        Future<FederatedResponse>[] ret = fedMap.execute(this.getTID(), true, fr1, new FederatedRequest[0]);
        CacheableData<?> out = ec.getCacheableData(this.output);
        if (this.input1.isFrame()) {
            ((FrameObject)out).setSchema(schema.toArray(new Types.ValueType[0]));
        }
        out.getDataCharacteristics().setDimension(fedMap.getMaxIndexInRange(0), fedMap.getMaxIndexInRange(1)).setBlocksize(in.getBlocksize()).setNonZeros(FederationUtils.sumNonZeros(ret));
        out.setFedMapping(fedMap.copyWithNewID(fr1[0].getID()));
    }

    private void leftIndexing(ExecutionContext ec) {
        CacheableData out;
        FederatedRequest[] fr1;
        CacheableData<?> in1 = ec.getCacheableData(this.input1);
        CacheableData<?> in2 = null;
        ScalarObject scalar = null;
        IndexRange ixrange = this.getIndexRange(ec);
        if (ixrange.rowStart < 0L || ixrange.rowStart >= in1.getNumRows() || ixrange.rowEnd >= in1.getNumRows() || ixrange.colStart < 0L || ixrange.colStart >= in1.getNumColumns() || ixrange.colEnd >= in1.getNumColumns()) {
            throw new DMLRuntimeException("Invalid values for matrix indexing: [" + (ixrange.rowStart + 1L) + ":" + (ixrange.rowEnd + 1L) + "," + (ixrange.colStart + 1L) + ":" + (ixrange.colEnd + 1L) + "] must be within matrix dimensions [" + in1.getNumRows() + "," + in1.getNumColumns() + "].");
        }
        if (this.input2.getDataType() == Types.DataType.SCALAR) {
            if (!ixrange.isScalar()) {
                throw new DMLRuntimeException("Invalid index range for leftindexing with scalar: " + ixrange.toString() + ".");
            }
            scalar = ec.getScalarInput(this.input2);
        } else {
            in2 = ec.getCacheableData(this.input2);
            if (ixrange.rowEnd - ixrange.rowStart + 1L != in2.getNumRows() || ixrange.colEnd - ixrange.colStart + 1L != in2.getNumColumns()) {
                throw new DMLRuntimeException("Invalid values for matrix indexing: dimensions of the source matrix [" + in2.getNumRows() + "x" + in2.getNumColumns() + "] do not match the shape of the matrix specified by indices [" + (ixrange.rowStart + 1L) + ":" + (ixrange.rowEnd + 1L) + ", " + (ixrange.colStart + 1L) + ":" + (ixrange.colEnd + 1L) + "].");
            }
        }
        FederationMap fedMap = in1.getFedMapping();
        String[] instStrings = new String[fedMap.getSize()];
        Object sliceIxs = new int[fedMap.getSize()][];
        FederatedRange[] ranges = new FederatedRange[fedMap.getSize()];
        int cpVarInstIx = fedMap.getSize();
        String cpVarInstString = this.createCopyInstString();
        int i = 0;
        int prev = 0;
        int from = fedMap.getSize();
        for (Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
            FederatedRange range = (FederatedRange)e.getKey();
            long rs = range.getBeginDims()[0];
            long re = range.getEndDims()[0];
            long cs = range.getBeginDims()[1];
            long ce = range.getEndDims()[1];
            long rsn = ixrange.rowStart >= rs ? ixrange.rowStart - rs : 0L;
            long ren = ixrange.rowEnd >= rs && ixrange.rowEnd < re ? ixrange.rowEnd - rs : re - rs - 1L;
            long csn = ixrange.colStart >= cs ? ixrange.colStart - cs : 0L;
            long cen = ixrange.colEnd >= cs && ixrange.colEnd < ce ? ixrange.colEnd - cs : ce - cs - 1L;
            long[] newIx = new long[]{(int)rsn, (int)ren, (int)csn, (int)cen};
            if (in2 != null) {
                long to;
                if (in1.isFederated(FTypes.FType.ROW) && (to = (long)prev + ren - rsn) >= 0L && to < in2.getNumRows() && ixrange.rowStart <= re) {
                    sliceIxs[i] = new int[]{prev, (int)to, 0, (int)in2.getNumColumns() - 1};
                    prev = (int)(to + 1L);
                    instStrings[i] = this.modifyIndices(newIx, 4, 8);
                    ranges[i] = range;
                    from = Math.min(i, from);
                } else if (in1.isFederated(FTypes.FType.COL) && (to = (long)prev + cen - csn) >= 0L && to < in2.getNumColumns() && ixrange.colStart <= ce) {
                    sliceIxs[i] = new int[]{0, (int)in2.getNumRows() - 1, prev, (int)to};
                    prev = (int)(to + 1L);
                    instStrings[i] = this.modifyIndices(newIx, 4, 8);
                    ranges[i] = range;
                    from = Math.min(i, from);
                } else {
                    cpVarInstIx = Math.min(i, cpVarInstIx);
                    instStrings[i] = cpVarInstString;
                }
            } else if (ixrange.rowStart >= rs && ixrange.rowEnd < re && ixrange.colStart >= cs && ixrange.colEnd < ce) {
                instStrings[i] = this.modifyIndices(newIx, 4, 8);
                instStrings[i] = this.changeScalarLiteralFlag(instStrings[i], 3);
                ranges[i] = range;
                from = Math.min(i, from);
            } else {
                cpVarInstIx = Math.min(i, cpVarInstIx);
                instStrings[i] = cpVarInstString;
            }
            ++i;
        }
        sliceIxs = (int[][])Arrays.stream(sliceIxs).filter(Objects::nonNull).toArray(x$0 -> new int[x$0][]);
        long id = FederationUtils.getNextFedDataID();
        FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{new MatrixCharacteristics(-1L, -1L), in1.getDataType()});
        fedMap.execute(this.getTID(), true, tmp);
        if (in2 != null) {
            fr1 = fedMap.broadcastSliced(in2, DMLScript.LINEAGE ? ec.getLineageItem(this.input2) : null, this.input2.isFrame(), (int[][])sliceIxs);
            FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{fedMap.getID(), fr1[0].getID()}, null);
            FederatedRequest fr3 = fedMap.cleanup(this.getTID(), fr1[0].getID());
            if (((int[][])sliceIxs).length == fedMap.getSize()) {
                fedMap.execute(this.getTID(), true, fr2, fr1, new FederatedRequest[]{fr3});
            } else {
                fedMap.execute(this.getTID(), true, ranges, fr2[cpVarInstIx], Arrays.copyOfRange(fr2, from, from + ((int[][])sliceIxs).length), fr1, fr3);
            }
        } else {
            fr1 = fedMap.broadcast(scalar);
            FederatedRequest[] fr2 = FederationUtils.callInstruction(instStrings, this.output, id, new CPOperand[]{this.input1, this.input2}, new long[]{fedMap.getID(), fr1.getID()}, null);
            FederatedRequest fr3 = fedMap.cleanup(this.getTID(), fr1.getID());
            if (fr2.length == 1) {
                fedMap.execute(this.getTID(), true, new FederatedRequest[]{fr1, fr2[0], fr3});
            } else {
                fedMap.execute(this.getTID(), true, ranges, fr2[cpVarInstIx], fr2[from], (FederatedRequest)fr1, fr3);
            }
        }
        if (this.input1.isFrame()) {
            out = ec.getFrameObject(this.output);
            ((FrameObject)out).setSchema(((FrameObject)in1).getSchema());
            out.getDataCharacteristics().set(in1.getDataCharacteristics());
            out.setFedMapping(fedMap.copyWithNewID(id));
        } else {
            out = ec.getMatrixObject(this.output);
            out.getDataCharacteristics().set(in1.getDataCharacteristics());
            out.setFedMapping(fedMap.copyWithNewID(id));
        }
    }

    private String modifyIndices(long[] newIx, int from, int to) {
        CharSequence[] instParts = this.instString.split("\u00b0");
        for (int j = from; j < to; ++j) {
            instParts[j] = InstructionUtils.createLiteralOperand(String.valueOf(newIx[j - from] + 1L), Types.ValueType.INT64);
        }
        return String.join((CharSequence)"\u00b0", instParts);
    }

    private String changeScalarLiteralFlag(String inst, int partIx) {
        CharSequence[] instParts = inst.split("\u00b0");
        instParts[partIx] = instParts[partIx].replace("true", "false");
        return String.join((CharSequence)"\u00b0", instParts);
    }

    private String createCopyInstString() {
        String[] instParts = this.instString.split("\u00b0");
        return VariableCPInstruction.prepareCopyInstruction(instParts[2], instParts[8]).toString();
    }
}

