import { DefaultProps, Layer, UpdateParameters } from '@deck.gl/core/typed';
// @ts-ignore
import { Buffer, FEATURES, Framebuffer, hasFeatures, isWebGL2, Model, Texture2D, withParameters } from '@luma.gl/core';
import { flatten, forEach, fromPairs, isEmpty, map, size, some, times } from 'lodash';

import { defaultLayerColors } from 'components/theme/theme';
import { hexToRgb } from 'utils/helpers';
import { createFullScreenMesh } from '../mesh';
import {
  additiveBlendingParametersFromGLContext,
  BLEND_WITHOUT_DEPTH_TEST,
  createSquareFromTwoTriangles,
} from './helpers';
import AGG_FS from './shaders/agg-maplibre-heatmap-fragment.glsl';
import AGG_VS from './shaders/agg-maplibre-heatmap-vertex.glsl';
import COLOR_FS from './shaders/color-pass-fragment.glsl';
import COLOR_VS from './shaders/color-pass-vertex.glsl';
import { ColorStrategy, MultiClassGaussianHeatmapLayerProps, MultiClassGaussianHeatmapLayerState } from './types';

export default class MultiClassGaussianHeatmapLayer<DataT, ExtraProps extends {} = {}> extends Layer<
  MultiClassGaussianHeatmapLayerProps<DataT> & ExtraProps
> {
  static layerName = 'MultiClassGaussianHeatmapLayer';

  static defaultProps: DefaultProps<MultiClassGaussianHeatmapLayerProps<any>> = {
    data: [],
    getPosition: (d: any) => d.position,
    getWeight: (d: any) => d.weight ?? 1,
    getClassIndex: (d: any) => d.classId ?? 0,
    radius: 10,
    falloffDivisor: 3,
    colorStrategy: ColorStrategy.MaxDensity,
    intensity: 1,
    opacity: 1,
    parameters: BLEND_WITHOUT_DEPTH_TEST,
  };

  state: MultiClassGaussianHeatmapLayerState = {};

  initializeState(): void {
    const { gl } = this.context;
    const { numClasses } = this.props;

    // Create aggregator model + color model
    const aggregatorModel = this._createAggregatorModel(gl);
    const colorModel = this._createColorModel(gl, numClasses, this.props.colorStrategy);

    // Create one FBO per class to use them for aggregation in the aggregator pass
    const classAggregators = this._createClassAggregatorFrameBuffers(numClasses);

    const classColorsUniform: number[] = this._createColorsUniform();

    const classOpacitiesUniform = this._createOpacitiesUniform();

    this.setState({
      aggregatorModel,
      colorModel,
      classAggregators,
      classColorsUniform,
      classOpacitiesUniform,
    });
  }

  /**
   * Creates and returns an array of framebuffer objects for class aggregators.
   *
   * This function performs the following operations:
   * 1. Retrieves the WebGL context and viewport dimensions from the current context.
   * 2. Determines whether floating point blending is supported; sets the texture type accordingly.
   * 3. Cleans up any preexisting class aggregator framebuffers.
   * 4. Generates new framebuffer objects for each class by invoking the _createFbo method.
   *
   * @param numClasses - The number of classes for which to create framebuffer objects.
   * @returns An array of framebuffer objects, each initialized for a corresponding class.
   */
  private _createClassAggregatorFrameBuffers(numClasses: number): Framebuffer[] {
    const { gl } = this.context;
    const { width, height } = this.context.viewport;

    // Check float blending support
    const floatingSupported = hasFeatures(gl, FEATURES.COLOR_ATTACHMENT_FLOAT);
    const textureType = floatingSupported ? gl.FLOAT : gl.UNSIGNED_BYTE;

    // teardown old classAccumulator Frame Buffer Objects (if any), then create new ones
    forEach(this.state.classAggregators, (fbo) => fbo.delete());
    return times(numClasses, () => this._createFbo(gl, width, height, textureType));
  }

  private _createFbo(
    gl: WebGLRenderingContext | WebGL2RenderingContext,
    width: number,
    height: number,
    type: number
  ): Framebuffer {
    const texture = new Texture2D(gl, {
      width,
      height,
      format: !isWebGL2(gl) ? gl.RGBA : (gl as WebGL2RenderingContext).RGBA32F,
      dataFormat: gl.RGBA,
      type,
      data: null,
      mipmaps: false,
      parameters: {
        [gl.TEXTURE_MIN_FILTER]: gl.LINEAR,
        [gl.TEXTURE_MAG_FILTER]: gl.LINEAR,
      },
    });
    return new Framebuffer(gl, {
      width,
      height,
      attachments: {
        [gl.COLOR_ATTACHMENT0]: texture,
      },
    });
  }

  protected _createMesh() {
    return createFullScreenMesh(this.context.viewport.resolution);
  }

  private _createAggregatorModel(gl: WebGLRenderingContext): Model {
    this.state.aggregatorModel?.delete(); // delete old model if any

    return new Model(gl, {
      vs: AGG_VS,
      fs: AGG_FS,
      geometry: createSquareFromTwoTriangles(gl),
      isInstanced: true,
    });
  }

  private _createColorModel(gl: WebGLRenderingContext, numClasses: number, colorStrategy: ColorStrategy): Model {
    this.state.colorModel?.delete(); // delete old model if any

    return new Model(gl, {
      ...this.getShaders({ vs: COLOR_VS, fs: COLOR_FS(numClasses, colorStrategy) }),
      geometry: createSquareFromTwoTriangles(gl),
      isInstanced: false,
    });
  }

  _createColorsUniform(): number[] {
    const { numClasses, classColors } = this.props;
    const res = flatten(
      times(numClasses, (layerIndex) =>
        classColors && classColors[layerIndex]
          ? classColors[layerIndex]
          : hexToRgb(defaultLayerColors[(layerIndex + size(defaultLayerColors)) % size(defaultLayerColors)])
      )
    );
    return res;
  }

  _createOpacitiesUniform(): Float32Array {
    return Float32Array.from(
      times(this.props.numClasses, (layerIndex) => this.props.classOpacities?.[layerIndex] ?? 1)
    );
  }

  shouldUpdateState({ changeFlags }: UpdateParameters<this>) {
    return Boolean(
      changeFlags.viewportChanged ||
        changeFlags.dataChanged ||
        changeFlags.propsChanged ||
        changeFlags.updateTriggersChanged
    );
  }

  updateState({ props, oldProps, changeFlags }: UpdateParameters<this>): void {
    const stateUpdates: Partial<MultiClassGaussianHeatmapLayerState> = {};

    const numClassesChanged = props.numClasses !== oldProps.numClasses;

    if (numClassesChanged) {
      // If numClasses changed, re-initialize color model and classAggregators
      this.state.colorModel?.delete();
      stateUpdates.colorModel = this._createColorModel(this.context.gl, props.numClasses, props.colorStrategy);
      stateUpdates.classAggregators = this._createClassAggregatorFrameBuffers(props.numClasses);
    } else if (changeFlags.viewportChanged) {
      // If viewport changed, resize all classAggregators
      const { viewport } = this.context;
      if (viewport) {
        const { width, height } = viewport;
        this.state.classAggregators?.forEach((fbo) => {
          fbo.resize({ width, height });
        });
      }
    }

    if (numClassesChanged || props.classColors !== oldProps.classColors) {
      // If numClasses changed or classColors changed, update classColorsUniform
      stateUpdates.classColorsUniform = this._createColorsUniform();
    }

    if (numClassesChanged || props.classOpacities !== oldProps.classOpacities) {
      // If numClasses changed or classOpacities changed, update classOpacitiesUniform
      stateUpdates.classOpacitiesUniform = this._createOpacitiesUniform();
    }

    if (!isEmpty(stateUpdates)) {
      this.setState(stateUpdates);
    }

    if (
      Boolean(changeFlags.dataChanged) ||
      some(
        ['data', 'getPosition', 'getWeight', 'getClassIndex'] as const,
        (key) => changeFlags.updateTriggersChanged && changeFlags.updateTriggersChanged[key]
      )
    ) {
      // If data changed or any of the accessor functions changed, update the attributes
      this._updateAttributes();
    }
  }

  /**
   * Updates the layer's attributes for rendering based on the provided data.
   *
   * This method processes the input data by extracting the position, weight, and class index for each instance
   * using the accessor functions provided in the component's props. It then creates attribute buffers for:
   * - instancePositions: A Float32Array holding the 2D position (x, y) of each instance.
   * - instanceWeights: A Float32Array holding the weight values of each instance.
   * - instanceClassId: A Float32Array holding the class index for each instance.
   *
   * These buffers are used to update the aggregator model, setting both the instance count and the corresponding attributes,
   * with each attribute configured to update per instance (divisor of 1).
   *
   * @returns {void}
   */
  _updateAttributes(): void {
    const { gl } = this.context;
    const { data, getPosition, getWeight, getClassIndex } = this.props;
    const { aggregatorModel } = this.state;
    if (!aggregatorModel) return;

    const instanceCount = size(data);
    const instancePositions = new Float32Array(instanceCount * 2); // 2D positions
    const instanceWeights = new Float32Array(instanceCount);
    const instanceClassIds = new Float32Array(instanceCount);

    forEach(data, (obj, i) => {
      const [x, y] = getPosition!(obj);
      const w = getWeight!(obj);
      const c = getClassIndex!(obj);

      instancePositions[i * 2 + 0] = x;
      instancePositions[i * 2 + 1] = y;
      instanceWeights[i] = w;
      instanceClassIds[i] = c;
    });

    const attributes = {
      instancePositions: [new Buffer(gl, { data: instancePositions, accessor: { size: 2 } }), { divisor: 1 }],
      instanceWeights: [new Buffer(gl, { data: instanceWeights, accessor: { size: 1 } }), { divisor: 1 }],
      instanceClassId: [new Buffer(gl, { data: instanceClassIds, accessor: { size: 1 } }), { divisor: 1 }],
    };
    aggregatorModel.setInstanceCount(instanceCount);
    aggregatorModel.setAttributes(attributes);
  }

  /**
   * Executes the aggregator pass for rendering a multi-class Gaussian heatmap.
   *
   * This method performs the following tasks:
   * - Retrieves the WebGL context and the viewport dimensions from the rendering context.
   * - Calculates the data coordinate radius by adjusting the provided radius with the zoom level.
   * - Obtains the view-projection matrix used for coordinate transformations.
   * - Iterates over each class (based on the numClasses property) and:
   *   - Checks if the current class is visible by evaluating its opacity.
   *   - If the class is visible, binds the corresponding framebuffer object (FBO).
   *   - Configures the WebGL viewport and clears the framebuffer with a transparent color.
   *   - Sets uniform values for the aggregator shader, including:
   *     - The view-projection matrix.
   *     - The intensity of the heatmap.
   *     - The adjusted radius in data coordinates.
   *     - The index of the active class.
   *   - Triggers the draw call on the aggregator model to perform heatmap aggregation.
   *   - Unbinds the FBO after processing.
   *
   * @remarks
   * The function leverages the withParameters helper to apply additive blending parameters
   * within a scoped callback, ensuring that WebGL state changes do not leak outside this process.
   */
  _runAggregatorPass(): void {
    const { gl } = this.context;
    const { width, height } = this.context.viewport;

    const { aggregatorModel, classAggregators } = this.state;

    const { radius, intensity, numClasses, classOpacities, falloffDivisor } = this.props;

    const scale = Math.pow(2, Math.min(this.context.viewport.zoom / 4, 0));
    const extrudeFactor = radius / scale;

    const vpMatrix = this.context.viewport.viewProjectionMatrix; // or your own transform

    withParameters(gl, additiveBlendingParametersFromGLContext(gl), () => {
      times(numClasses, (activeClassIndex) => {
        const isClassVisible = classOpacities?.[activeClassIndex] > 0;
        if (!isClassVisible) {
          return;
        }
        const fbo = classAggregators[activeClassIndex];
        // Bind the framebuffer for the current class to aggregate results into its texture.
        fbo.bind();
        // Set the viewport to cover the entire framebuffer dimensions.
        gl.viewport(0, 0, width, height);
        // Set the clear color to transparent.
        gl.clearColor(0, 0, 0, 0);
        // Clear the framebuffer with the specified clear color.
        gl.clear(gl.COLOR_BUFFER_BIT);

        aggregatorModel.setUniforms({
          u_viewProjectionMatrix: vpMatrix,
          u_intensity: intensity,
          u_extrudeFactor: extrudeFactor,
          u_activeClass: activeClassIndex,
          u_falloffDivisor: falloffDivisor,
        });

        aggregatorModel.draw();
        fbo.unbind();
      });
    });
  }

  draw(): void {
    const { aggregatorModel, colorModel, classAggregators, classOpacitiesUniform, classColorsUniform } = this.state;
    if (!aggregatorModel || !colorModel || !classAggregators || !classOpacitiesUniform || !classColorsUniform) return;

    //
    // Pass A: aggregator vertex & fragment (similar to MapLibre aggregator)
    //
    // We do one pass per class, filtering out data points that don't match the active class
    //
    this._runAggregatorPass();

    //
    // Pass B: colorizing vertex & fragment
    //
    // Supply uniforms to the color pass
    colorModel.setUniforms({
      u_classOpacities: classOpacitiesUniform, // luma.gl can parse this as a float[] array
      u_classColors: classColorsUniform, // luma.gl can parse this as a vec3[] array
      // Bind each aggregator texture to a separate sampler
      // We'll label them in ascending order: u_textures[0], u_textures[1], ...
      ...fromPairs(map(classAggregators, (fbo, index) => [`u_textures[${index}]`, fbo.texture])),
    });

    colorModel.draw();
  }
}
