import moment from "moment";
import { dateFormats } from "utility/constants/constants";

class MiniBatchKMeans {
    constructor(k, batchSize, maxIterations) {
      this.k = k;
      this.batchSize = batchSize;
      this.maxIterations = maxIterations;
      this.centroids = [];
      this.assignments = []
    }

    fit(data) {
      // Initialize centroids randomly
      this.centroids = this._initializeCentroids(data);
      const normalData = this._normalize(data)

      for (let iteration = 0; iteration < this.maxIterations; iteration++) {
        // Select a random subset (batch) of data points
        let batch = this._selectRandomBatch(normalData);

        // Assign batch points to the nearest centroid
        let assignments = this._assignPoints(batch);

        // Update centroids based on the assigned points
        this._updateCentroids(batch, assignments);
      }

      this.assignments = this._assignPoints(normalData);
    }

    _initializeCentroids(data) {
      const centroids = [];
      const shuffledData = data.slice().sort(() => Math.random() - 0.5);
      for (let i = 0; i < this.k; i++) {
        centroids.push(shuffledData[i]);
      }
      return centroids;
    }

    _selectRandomBatch(data) {
      const batch = [];
      const shuffledData = data.slice().sort(() => Math.random() - 0.5);
      for (let i = 0; i < this.batchSize; i++) {
        if (shuffledData[i] !== undefined) {
          batch.push(shuffledData[i]);
        }
      }
      return batch;
    }

    _assignPoints(data) {
      const assignments = [];
      for (let i = 0; i < data.length; i++) {
        const point = data[i];
        let minDistance = Infinity;
        let minCentroidIndex = 0;
        for (let j = 0; j < this.centroids.length; j++) {
          const centroid = this.centroids[j];
          const distance = this._euclideanDistance(point, centroid);
          if (distance < minDistance) {
            minDistance = distance;
            minCentroidIndex = j;
          }
        }
        assignments[i] = minCentroidIndex;
      }
      return assignments;
    }

    _updateCentroids(data, assignments) {
      const sums = new Array(this.k).fill().map(() => new Array(data[0].length).fill(0));
      const counts = new Array(this.k).fill(0);

      for (let i = 0; i < data.length; i++) {
        const point = data[i];
        const centroidIndex = assignments[i];
        for (let j = 0; j < point.length; j++) {
            if (!isNaN(parseFloat(point[j]))) {
                sums[centroidIndex][j] += point[j];
            }
        }

        counts[centroidIndex]++;
      }

      for (let i = 0; i < this.k; i++) {
        const centroid = this.centroids[i];
        const count = counts[i];

        if (count > 0) {
          for (let j = 0; j < centroid.length; j++) {
            centroid[j] = sums[i][j] / count;
          }
        }
      }
    }

    // Function to calculate Euclidean distance for numeric variables
    _euclideanDistance(point1, point2) {
    const squaredSum = Object.keys(point1).reduce((acc, key) => {
        if (typeof point1[key] === 'number' && typeof point2[key] === 'number') {
            const diff = point1[key] - point2[key];
            return acc + diff * diff;
        }
        return acc;
        }, 0);

        return Math.sqrt(squaredSum);
    }

    // Function to calculate Hamming distance for categorical variables
    _hammingDistance(point1, point2) {
        const distance = Object.keys(point1).reduce((acc, key) => {
        if (typeof point1[key] === 'string' && typeof point2[key] === 'string') {
            return acc + (point1[key] === point2[key] ? 0 : 1);
        }
        return acc;
        }, 0);

        return distance;
    }

    _overallDistance(point1, point2, numericWeight = 1, categoricalWeight = 1) {
        const numericDistance = this._euclideanDistance(point1, point2);
        const categoricalDistance = this._hammingDistance(point1, point2);
        const overallDistance =
          (numericWeight * numericDistance + categoricalWeight * categoricalDistance) /
          (numericWeight + categoricalWeight);

        return overallDistance;
    }



    _normalize(points) {
        let mean = this.mean(points);
        this.originalMean = mean;
        let newPoints = [];
        points.forEach(function(point, j) {
          let newPoint = new Array(point.length);
          for (let i = 0; i < point.length; i++) {
            if (!isNaN(parseFloat(point[i], 10)) && !moment(point[i], dateFormats, true).isValid()) {
                newPoint[i] = (point[i] - mean[i]) / mean[i];
            } else {
                newPoint[i] = point[i];
            }
          }
          newPoints.push(newPoint);
        });
        return newPoints;
      }

    mean(points) {
        if (!Array.isArray(points)) {
            throw new Error('mean requires an array of data points as an argument.');
        }
        if (points.length === 0) {
            return [];
        }
        let sum = new Array(points[0].length);
        for (let i = 0; i < points.length; i++) {
            for (let k = 0; k < sum.length; k++) {
                if (!isNaN(parseFloat(points[i][k], 10))) {
                    sum[k] = (sum[k] || 0) + parseFloat(points[i][k], 10);

                    if (i === points.length - 1) {
                        sum[k] = sum[k] / points.length;
                    }
                } else {
                    if (i === points.length - 1) {
                        sum[k] = 1;
                    }
                }
            }
        }
        return sum;
    };
  }

export default MiniBatchKMeans;