import * as Pose from "@mediapipe/pose";
import React from "react";
import { drawConnectors, drawLandmarks } from "@mediapipe/drawing_utils";
import Webcam from "react-webcam";
import { MatrixUtils } from "../Utils/MatrixUtils";
import { distanceToRoot, intToPoseLandmark, poseLandmarkToInt } from "../Utils/Constants";
import { Matrix } from "ts-matrix";
import * as math from "mathjs";

class PoseResultController {
    webcamRef: React.RefObject<Webcam>;
    canvasRef: React.RefObject<HTMLCanvasElement>;
    targetCanvasRef: React.RefObject<HTMLCanvasElement>;
    previousNormalizedTargetPoses: number[][][];
    clientInFrame: boolean;

    // Number of poses to keep in memory to account for potential reaction time to keep up with the dance
    numberOfPreviousPosesToKeep: number;
    targetPose: Pose.Results | null;
    currentPose: Pose.Results | null;
    mirrorPose: Pose.Results | null;

    constructor(
        webcamRef: React.RefObject<Webcam>,
        canvasRef: React.RefObject<HTMLCanvasElement>,
        targetCanvasRef: React.RefObject<HTMLCanvasElement>,
    ) {
        this.webcamRef = webcamRef;
        this.canvasRef = canvasRef;
        this.targetCanvasRef = targetCanvasRef;
        this.targetPose = null;
        this.currentPose = null;
        this.mirrorPose = null;
        this.previousNormalizedTargetPoses = [];
        this.numberOfPreviousPosesToKeep = 10;
        this.clientInFrame = false;
    }

    // def calculate_distance_with_visibility_array(landmarks1, landmarks2):
    // distance = 0

    // for i, (l1, l2) in enumerate(zip(landmarks1, landmarks2)):
    //     joint_distance = l1[3] * l2[3] * np.linalg.norm(l1[:3] - l2[:3])
    //     distance += joint_distance

    // return distance
    calculateEuclideanScore(alignedMatrix1: (number | undefined)[][], alignedMatrix2: (number | undefined)[][]): number {
        if (alignedMatrix1 === undefined ||
            alignedMatrix2 === undefined ||
            alignedMatrix1[0] === undefined ||
            alignedMatrix2[0] === undefined) {
            return Number.MAX_SAFE_INTEGER
        }

        var distance = 0;
        for (var i = 0; i < alignedMatrix1.length; i++) {
            let l1 = alignedMatrix1[i].map(l1 => l1 as number);
            let l2 = alignedMatrix2[i].map(l2 => l2 as number);
            let l1Vec: number[] = [l1[0], l1[1], l1[2], l1[3]];
            let l2Vec: number[] = [l2[0], l2[1], l2[2], l2[3]];
            let norm = ((l1Vec[0] - l2Vec[0]) ** 2 + (l1Vec[1] - l2Vec[1]) ** 2 + (l1Vec[2] - l2Vec[2]) ** 2) ** (0.5);
            var jointImportanceTerm = distanceToRoot[i];
            let joint_distance = jointImportanceTerm * norm;
            distance += joint_distance;
        }
        if (distance === 0) {
            return Number.MAX_SAFE_INTEGER;
        }
        return distance;
    }

    calculateAngleScore(poseAngles1: {[id: number]: math.Matrix}, poseAngles2: {[id: number]: math.Matrix}): number {
        if (poseAngles1 === undefined ||
            poseAngles2 === undefined) {
            return Number.MAX_SAFE_INTEGER
        }

        var distance = 0;
        for (var i = 0; i < Object.keys(poseLandmarkToInt).length; i++) {
            let rotMat1 = poseAngles1[i];
            let rotMat2 = poseAngles2[i];
            if (rotMat1 === undefined || rotMat2 === undefined) {
                continue;
            }
            let jointLabel = intToPoseLandmark[i];
            if (jointLabel.includes("MOUTH") ||
                jointLabel.includes("HEEL") ||
                jointLabel.includes("EYE") ||
                jointLabel.includes("NOSE") ||
                jointLabel.includes("INDEX") ||
                jointLabel.includes("EAR") ||
                jointLabel.includes("PINKY")) {
                continue
            }

            let matMultiply = math.multiply(math.inv(rotMat1), rotMat2);
            let matNorm = (1 / (math.sqrt(2) as number)) * (math.norm(math.log10(matMultiply as math.Matrix), "fro") as number);
            var angleDistance = matNorm;
            distance += angleDistance;
        }
        if (distance === 0) {
            return Number.MIN_SAFE_INTEGER;
        }
        return distance;
    }

    public calculateTimestepEuclideanScore(): number {
        let currentPose: Matrix | null = MatrixUtils.convertPoseResultsToMatrix(this.getCurrentPose());
        let targetPose: Matrix | null = MatrixUtils.convertPoseResultsToMatrix(this.getTargetPose());
        if (!currentPose || !targetPose) {
            return Number.MAX_SAFE_INTEGER
        }
        let mirrorPose: Matrix = MatrixUtils.mirrorPose(currentPose);

        let currentPoseValues = currentPose.values;
        let targetPoseValues = targetPose.values;
        let mirrorPoseValues = mirrorPose.values;

        let currentPoseNormalized = MatrixUtils.normalizePose(currentPose);
        let mirrorPoseNormalized = MatrixUtils.normalizePose(mirrorPose);
        let targetPoseNormalized = MatrixUtils.normalizePose(targetPose);

        let currentPoseWithVisibility = currentPoseNormalized.map((value, index) => {
            var visibility = 0;
            if (currentPoseValues[index]) {
                visibility = currentPoseValues[index][3];
            }

            return [value[0], value[1], value[2], visibility]
        });

        let mirrorPoseWithVisibility = mirrorPoseNormalized.map((value, index) => {
            var visibility = 0;
            if (mirrorPoseValues[index]) {
                visibility = mirrorPoseValues[index][3];
            }

            return [value[0], value[1], value[2], visibility]
        });

        let targetPoseWithVisibility = targetPoseNormalized.map((value, index) => {
            var visibility = 0;
            if (targetPoseValues[index]) {
                visibility = targetPoseValues[index][3];
            }
            return [value[0], value[1], value[2], visibility]
        });

        if (this.previousNormalizedTargetPoses.length > this.numberOfPreviousPosesToKeep) {
            this.previousNormalizedTargetPoses = this.previousNormalizedTargetPoses.slice(1);
        }
        this.previousNormalizedTargetPoses = this.previousNormalizedTargetPoses.concat([targetPoseWithVisibility]);

        var currentPoseScore = Number.MAX_SAFE_INTEGER;
        var mirrorPoseScore = Number.MAX_SAFE_INTEGER;
        for (let previousTargetPose of this.previousNormalizedTargetPoses) {
            currentPoseScore = Math.min(this.calculateEuclideanScore(currentPoseWithVisibility, previousTargetPose), currentPoseScore);
            mirrorPoseScore = Math.min(this.calculateEuclideanScore(mirrorPoseWithVisibility, previousTargetPose), mirrorPoseScore);
        }

        return Math.min(currentPoseScore, mirrorPoseScore);
    }

    public setCurrentPose(currentPose: Pose.Results): void {
        this.currentPose = currentPose;
    }

    public getCurrentPose(): Pose.Results | null {
        return this.currentPose;
    }

    public setMirrorCurrentPose(currentPose: Pose.Results): void {
        this.mirrorPose = currentPose;
    }

    public getMirrorCurrentPose(): Pose.Results | null {
        return this.mirrorPose;
    }

    public setTargetPose(targetPose: Pose.Results): void {
        this.targetPose = targetPose;
    }

    public getTargetPose(): Pose.Results | null {
        return this.targetPose;
    }

    public onPoseResults(results: Pose.Results): void {
        this.clientInFrame = results.poseLandmarks != null;
        if (!this.canvasRef || !this.canvasRef.current) {
            return;
        }
        const canvasElement = this.canvasRef.current;
        const canvasCtx = canvasElement.getContext("2d");
        const videoWidth = this.webcamRef.current?.video!.clientWidth!;
        const videoHeight = this.webcamRef.current?.video!.clientHeight!;

        // Set canvas width
        this.canvasRef.current!.width = (videoWidth);
        this.canvasRef.current!.height = (videoHeight);

        canvasCtx!.save();
        canvasCtx!.clearRect(0, 0, canvasElement.width, canvasElement.height);
        canvasCtx!.fillStyle = 'rgba(0, 0, 0, 0.5)';
        canvasCtx!.fillRect(0, 0, canvasElement.width, canvasElement.height);

        if (!this.clientInFrame) {
            return;
        }

        this.setCurrentPose(results);

        canvasCtx!.globalCompositeOperation = 'source-over';
        drawConnectors(canvasCtx!, results.poseLandmarks, Pose.POSE_CONNECTIONS,
                        {color: '#00FF00', lineWidth: 4});
        drawLandmarks(canvasCtx!, results.poseLandmarks,
                        {color: '#FF0000', lineWidth: 2});

        canvasCtx!.restore();
    }

    public onTargetPoseResults(results: Pose.Results): void {
        // const canvasElement = this.targetCanvasRef.current;
        // if (!canvasElement) {
        //     return;
        // }
        // const canvasCtx = canvasElement.getContext("2d");
        // const videoWidth = this.webcamRef.current?.video!.clientWidth!;
        // const videoHeight = this.webcamRef.current?.video!.clientHeight!;

        // // Set canvas width
        // this.targetCanvasRef.current!.width = videoWidth;
        // this.targetCanvasRef.current!.height = videoHeight;

        // canvasCtx!.save();
        // canvasCtx!.clearRect(0, 0, canvasElement.width, canvasElement.height);
        // canvasCtx!.fillStyle = '#00000';
        // canvasCtx!.fillRect(0, 0, canvasElement.width, canvasElement.height);

        if (!results.poseLandmarks) {
            return;
        }

        this.setTargetPose(results);

        // canvasCtx!.globalCompositeOperation = 'source-over';
        // drawConnectors(canvasCtx!, results.poseLandmarks, Pose.POSE_CONNECTIONS,
        //                 {color: '#00FF00', lineWidth: 2});
        // drawLandmarks(canvasCtx!, results.poseLandmarks,
        //                     {color: '#FF0000', lineWidth: 1});
    }

}

export default PoseResultController;
