import { RandomForestRegression as RF } from 'ml-random-forest';

export class FairRF extends RF {
  trainX: null | number[][] = null;
  trainY: null | number[] = null;
  train(trainX: number[][], trainY: number[]): void {
    this.trainX = trainX;
    this.trainY = trainY;
    super.train(trainX, trainY);
  }

  mse(xSet: null | number[][] = null, ySet: null | number[] = null): number {
    const observedX = xSet ? xSet : this.trainX;
    const observedY = ySet ? ySet : this.trainY;
    if (observedX && observedY) {
      const predY = this.predict(observedX);
      let sum = 0;
      for (let i = 0; i < predY.length; i++) {
        sum += (observedY[i] - predY[i]) ** 2;
      }
      sum /= predY.length;
      return sum;
    }

    throw new Error('observed set Y not provided');
  }
}
