/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class AggregateTernaryFEDInstruction
extends ComputationFEDInstruction {
    private AggregateTernaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.AggregateTernary, op, in1, in2, in3, out, opcode, istr);
    }

    public static AggregateTernaryFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("tak+*") || opcode.equalsIgnoreCase("tack+*")) {
            InstructionUtils.checkNumFields(parts, 5);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[4]);
            int numThreads = Integer.parseInt(parts[5]);
            AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
            return new AggregateTernaryFEDInstruction(op, in1, in2, in3, out, opcode, str);
        }
        throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo3;
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        MatrixObject mo2 = ec.getMatrixObject(this.input2);
        MatrixObject matrixObject = mo3 = this.input3.isLiteral() ? null : ec.getMatrixObject(this.input3);
        if (mo3 != null && mo1.isFederated() && mo2.isFederated() && mo3.isFederated()) {
            FederationMap.AlignType[] alignTypeArray = new FederationMap.AlignType[1];
            FederationMap.AlignType alignType = alignTypeArray[0] = mo1.isFederated(FederationMap.FType.ROW) ? FederationMap.AlignType.ROW : FederationMap.AlignType.COL;
            if (mo1.getFedMapping().isAligned(mo2.getFedMapping(), alignTypeArray)) {
                FederationMap.AlignType[] alignTypeArray2 = new FederationMap.AlignType[1];
                FederationMap.AlignType alignType2 = alignTypeArray2[0] = mo1.isFederated(FederationMap.FType.ROW) ? FederationMap.AlignType.ROW : FederationMap.AlignType.COL;
                if (mo2.getFedMapping().isAligned(mo3.getFedMapping(), alignTypeArray2)) {
                    FederatedRequest fr1 = FederationUtils.callInstruction(this.getInstructionString(), this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
                    FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
                    FederatedRequest fr3 = mo1.getFedMapping().cleanup(this.getTID(), fr1.getID());
                    Future<FederatedResponse>[] response = mo1.getFedMapping().execute(this.getTID(), fr1, fr2, fr3);
                    if (this.output.getDataType().isScalar()) {
                        AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator("uak+");
                        ec.setScalarOutput(this.output.getName(), FederationUtils.aggScalar(aop, response, mo1.getFedMapping()));
                        return;
                    }
                    AggregateUnaryOperator aop = InstructionUtils.parseBasicAggregateUnaryOperator(this.getOpcode().equals("fed_tak+*") ? "uak+" : "uack+");
                    ec.setMatrixOutput(this.output.getName(), FederationUtils.aggMatrix(aop, response, mo1.getFedMapping()));
                    return;
                }
            }
        }
        if (!(mo1.isFederated() && mo2.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo3 == null)) {
            if (!mo1.isFederatedExcept(FederationMap.FType.BROADCAST) || !this.input3.isMatrix() || mo3 == null) {
                if (mo3 != null) throw new DMLRuntimeException("Federated AggregateTernary not supported with the following federated objects: " + mo1.isFederated() + ":" + mo1.getFedMapping() + " " + mo2.isFederated() + ":" + mo2.getFedMapping() + mo3.isFederated() + ":" + mo3.getFedMapping());
                throw new DMLRuntimeException("Federated AggregateTernary not supported with the following federated objects: " + mo1.isFederated() + ":" + mo1.getFedMapping() + " " + mo2.isFederated() + ":" + mo2.getFedMapping());
            }
        } else {
            FederatedRequest fr1 = mo1.getFedMapping().broadcast(ec.getScalarInput(this.input3));
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), fr1.getID()});
            FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
            FederatedRequest fr4 = mo2.getFedMapping().cleanup(this.getTID(), fr1.getID(), fr2.getID());
            Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), fr1, fr2, fr3, fr4);
            if (!this.output.getDataType().isScalar()) throw new DMLRuntimeException("Not Implemented Federated Ternary Variation");
            double sum = 0.0;
            Future<FederatedResponse>[] futureArray = tmp;
            int n = futureArray.length;
            int n2 = 0;
            while (true) {
                if (n2 >= n) {
                    ec.setScalarOutput(this.output.getName(), new DoubleObject(sum));
                    return;
                }
                Future<FederatedResponse> fr = futureArray[n2];
                try {
                    sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue();
                }
                catch (Exception e) {
                    throw new DMLRuntimeException("Federated Get data failed with exception on TernaryFedInstruction", e);
                }
                ++n2;
            }
        }
        FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo3, false);
        FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo2, false);
        FederatedRequest fr3 = FederationUtils.callInstruction(this.getInstructionString(), this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{mo1.getFedMapping().getID(), fr2[0].getID(), fr1[0].getID()});
        FederatedRequest fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
        Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), fr1, new FederatedRequest[]{fr2[0], fr3, fr4});
        if (!this.output.getDataType().isScalar()) throw new DMLRuntimeException("Not Implemented Federated Ternary Variation");
        double sum = 0.0;
        Future<FederatedResponse>[] futureArray = tmp;
        int n = futureArray.length;
        int n3 = 0;
        while (true) {
            if (n3 >= n) {
                ec.setScalarOutput(this.output.getName(), new DoubleObject(sum));
                return;
            }
            Future<FederatedResponse> fr = futureArray[n3];
            try {
                sum += ((ScalarObject)fr.get().getData()[0]).getDoubleValue();
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on TernaryFedInstruction", e);
            }
            ++n3;
        }
    }
}

