/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.classification.knn;

import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.classification.knn.KnnModelData;
import org.apache.flink.ml.classification.knn.KnnModelParams;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class KnnModel
implements Model<KnnModel>,
KnnModelParams<KnnModel> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private Table modelDataTable;

    public KnnModel() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public KnnModel setModelData(Table ... inputs) {
        this.modelDataTable = inputs[0];
        return this;
    }

    @Override
    public Table[] getModelData() {
        return new Table[]{this.modelDataTable};
    }

    @Override
    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        DataStream data = tEnv.toDataStream(inputs[0]);
        DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(this.modelDataTable);
        String broadcastModelKey = "broadcastModelKey";
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
        RowTypeInfo outputTypeInfo = new RowTypeInfo((TypeInformation[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldTypes(), (Object[])new TypeInformation[]{BasicTypeInfo.DOUBLE_TYPE_INFO}), (String[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldNames(), (Object[])new String[]{this.getPredictionCol()}));
        DataStream output = BroadcastUtils.withBroadcastStream(Collections.singletonList(data), Collections.singletonMap("broadcastModelKey", knnModel), inputList -> {
            DataStream input = (DataStream)inputList.get(0);
            return input.map((MapFunction)new PredictLabelFunction("broadcastModelKey", this.getK(), this.getFeaturesCol()), (TypeInformation)outputTypeInfo);
        });
        return new Table[]{tEnv.fromDataStream(output)};
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
        ReadWriteUtils.saveModelData(KnnModelData.getModelDataStream(this.modelDataTable), path, new KnnModelData.ModelDataEncoder());
    }

    public static KnnModel load(StreamTableEnvironment tEnv, String path) throws IOException {
        KnnModel model = (KnnModel)ReadWriteUtils.loadStageParam(path);
        Table modelDataTable = ReadWriteUtils.loadModelData(tEnv, path, new KnnModelData.ModelDataDecoder());
        return model.setModelData(modelDataTable);
    }

    private static class PredictLabelFunction
    extends RichMapFunction<Row, Row> {
        private final String featureCol;
        private KnnModelData knnModelData;
        private final int k;
        private final String broadcastKey;
        private DenseVector distanceVector;

        public PredictLabelFunction(String broadcastKey, int k, String featureCol) {
            this.k = k;
            this.broadcastKey = broadcastKey;
            this.featureCol = featureCol;
        }

        public Row map(Row row) {
            if (this.knnModelData == null) {
                this.knnModelData = (KnnModelData)this.getRuntimeContext().getBroadcastVariable(this.broadcastKey).get(0);
                this.distanceVector = new DenseVector(this.knnModelData.labels.size());
            }
            DenseVector feature = ((Vector)row.getField(this.featureCol)).toDense();
            double prediction = this.predictLabel(feature);
            return Row.join((Row)row, (Row[])new Row[]{Row.of((Object[])new Object[]{prediction})});
        }

        private double predictLabel(DenseVector feature) {
            double normSquare = Math.pow(BLAS.norm2((Vector)feature), 2.0);
            BLAS.gemv(-2.0, this.knnModelData.packedFeatures, true, feature, 0.0, this.distanceVector);
            for (int i = 0; i < this.distanceVector.size(); ++i) {
                this.distanceVector.values[i] = Math.sqrt(Math.abs(this.distanceVector.values[i] + normSquare + this.knnModelData.featureNormSquares.values[i]));
            }
            PriorityQueue<Tuple2> nearestKNeighbors = new PriorityQueue<Tuple2>(Comparator.comparingDouble(distanceAndLabel -> -((Double)distanceAndLabel.f0).doubleValue()));
            double[] labelValues = this.knnModelData.labels.values;
            for (int i = 0; i < labelValues.length; ++i) {
                if (nearestKNeighbors.size() < this.k) {
                    nearestKNeighbors.add(Tuple2.of((Object)this.distanceVector.get(i), (Object)labelValues[i]));
                    continue;
                }
                Tuple2 currentFarthestNeighbor = nearestKNeighbors.peek();
                if (!((Double)currentFarthestNeighbor.f0 > this.distanceVector.get(i))) continue;
                nearestKNeighbors.poll();
                nearestKNeighbors.add(Tuple2.of((Object)this.distanceVector.get(i), (Object)labelValues[i]));
            }
            HashMap<Double, Double> labelWeights = new HashMap<Double, Double>(nearestKNeighbors.size());
            while (!nearestKNeighbors.isEmpty()) {
                Tuple2 distanceAndLabel2 = nearestKNeighbors.poll();
                labelWeights.merge((Double)distanceAndLabel2.f1, 1.0, Double::sum);
            }
            double maxWeight = 0.0;
            double predictedLabel = -1.0;
            for (Map.Entry entry : labelWeights.entrySet()) {
                if (!((Double)entry.getValue() > maxWeight)) continue;
                maxWeight = (Double)entry.getValue();
                predictedLabel = (Double)entry.getKey();
            }
            return predictedLabel;
        }
    }
}

