import { Results } from "@mediapipe/pose";
import { Vector, Matrix } from 'ts-matrix';
import * as math from 'mathjs';
import { boneLength, childToParentLookup, intToPoseLandmark, parentToChildConstant, poseLandmarkToInt } from "./Constants";


export class MatrixUtils {
    public static translateHipsToOrigin(poseResults: Matrix): Matrix {
        let hip = new Vector(poseResults.values[poseLandmarkToInt['RIGHT_HIP']].slice(0, 3));
        let values = poseResults.values;
        let translatedVecs = values.map(value => {
            let valueVec = new Vector(value.slice(0, 3));
            let out = valueVec.substract(hip);
            return [out.at(0), out.at(1), out.at(2), value[3]];
        });

        return new Matrix(translatedVecs.length, translatedVecs[0].length, translatedVecs);
    }

    // for l1 in landmarks:
    //     l1[0] = l1[0] * -1

    //     visited_joints = set()
    //     for joint_label in mp_pose.PoseLandmark:
    //         str_joint_label = str(joint_label)
    //         if joint_label not in visited_joints:
    //             if "LEFT" in str_joint_label:
    //                 prefix = str_joint_label.split("LEFT")[0]
    //                 body_part = str_joint_label.split("LEFT")[1]
    //                 if "MOUTH" in prefix:
    //                     prefix = "MOUTH_"
    //                     body_part = ""

    //                 if body_part:
    //                     joint_to_switch = mp_pose.PoseLandmark(pose_landmark_to_int[prefix + "RIGHT" + body_part])
    //                     temp = landmarks[joint_to_switch].copy()
    //                     landmarks[joint_to_switch] = landmarks[joint_label]
    //                     landmarks[joint_label] = temp
    //                     visited_joints.add(joint_to_switch)
    //                     visited_joints.add(joint_label)

    public static mirrorPose(poseResults: Matrix): Matrix {
        let values = poseResults.values.map(item => {
            return [...item];
        });
        let visited = new Set();
        for (var i = 0; i < values.length; i++) {
            values[i][0] = values[i][0] * -1;
        }

        for (i = 0; i < values.length; i++) {
            let jointLabel = i;
            let strJointLabel = intToPoseLandmark[jointLabel];
            if (!visited.has(strJointLabel)) {
                if (strJointLabel.includes("LEFT")) {
                    let splitJoints = strJointLabel.split("LEFT");
                    var prefix = splitJoints[0];
                    var bodyPart = splitJoints[1];
                    if (prefix.includes("MOUTH")) {
                        prefix = "MOUTH_";
                        bodyPart = "";
                    }

                    if (bodyPart) {
                        let searchString = prefix + "RIGHT" + bodyPart;
                        let jointToSwitch = poseLandmarkToInt[searchString];
                        let jointToSwitchLabel = intToPoseLandmark[jointToSwitch];
                        let temp = [...values[jointToSwitch]];
                        values[jointToSwitch] = [...values[i]];
                        values[i] = temp;
                        visited.add(jointToSwitchLabel);
                        visited.add(jointLabel);
                        visited.add(i);
                        visited.add(jointToSwitch);
                    }
                }
            }
        }
        return new Matrix(values.length, values[0].length, values);
    }

    public static convertPoseResultsToMatrix(poseResults: Results | null): Matrix | null {
        if (!poseResults) {
            return null;
        }

        var matrixArray: number[][] = []
        for (var i = 0; i < poseResults.poseWorldLandmarks.length; i++) {
            let joint = poseResults.poseWorldLandmarks[i];
            matrixArray[i] = [joint.x, joint.y, joint.z, joint.visibility!];
        }

        return new Matrix(matrixArray.length, matrixArray[0].length, matrixArray);
    }

    // def rotation_matrix_from_vectors(vec1, vec2):
    // """ Find the rotation matrix that aligns vec1 to vec2
    // :param vec1: A 3d "source" vector
    // :param vec2: A 3d "destination" vector
    // :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
    // """
    // a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
    // v = np.cross(a, b)
    // c = np.dot(a, b)
    // s = np.linalg.norm(v)
    // kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
    // rotation_matrix = np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
    // return rotation_matrix

    public static rotationMatrixFromVectors(vec1: Vector, vec2: Vector): math.MathType {
        let a = vec1.normalize();
        let b = vec2.normalize();
        let v = a.cross(b);
        let cos = a.dot(b);
        let sin = v.length();
        let kmat = [[0, -v.at(2), v.at(1)],
                    [v.at(2), 0, -v.at(0)],
                    [-v.at(1), v.at(0), 0]];
        let kMatrix = math.matrix(kmat);
        let identity = math.identity(3);
        let identityAddition = math.add(identity, kMatrix);
        let kMatMultiplication = math.multiply(kMatrix, kMatrix);
        let staticTerm = ((1 - cos) / (sin ** 2));
        let rotationMatrix = math.add(identityAddition, math.multiply(kMatMultiplication, staticTerm));
        return rotationMatrix
    }

    // def calculate_angle_parent(landmark_array):
    //     unit_vector = np.array([0, 0, 1])
    //     parent_to_child_rotation = dict()
    //     for (k, v) in parent_to_child.items():
    //         parent = landmark_array[k]
    //         children = landmark_array[v]
    //         vectors = children - parent
    //         for i, vector in zip(v, vectors):
    //             parent_to_child_rotation[i] = rotation_matrix_from_vectors(unit_vector, vector[:3])

    //     return parent_to_child_rotation
    public static calculateAngleFromParent(poseResults: Matrix): { [id: number] : math.Matrix } {
        const unitVector = new Vector([0, 0, 1]);
        const poseValues = poseResults.values;
        var parentToChildRotation: { [id: number] : math.Matrix } = {};
        for (let parentIndex in parentToChildConstant) {
            let childrenIndices = parentToChildConstant[parentIndex];
            let parentVec = new Vector(poseValues[parentIndex].slice(0, 3));
            for (let childIndex of childrenIndices) {
                let childVec = new Vector(poseValues[childIndex].slice(0, 3));
                let transformedVec = childVec.substract(parentVec);
                parentToChildRotation[childIndex] = MatrixUtils.rotationMatrixFromVectors(unitVector, transformedVec) as math.Matrix;
            }
        }
        return parentToChildRotation;
    }

    public static norm(a: number[]) {
        return (a[0] ** 2 + a[1] ** 2 + a[2] ** 2) ** 0.5
    }


    // def recreate_pose(landmark_array):
    //     original_angles = calculate_angle_parent(landmark_array)
    //     rotated_pose = np.zeros((33, 3))
    //     unit_vector = np.array([[0, 0, 1]])
    //     parents_list = [24]
    //     visited = set()
    //     while len(parents_list) != 0:
    //         # print(parents_list)
    //         child = parents_list[0]
    //         parents_list = parents_list[1:]
    //         if child in parent_to_child:
    //             children = parent_to_child[child]
    //             parents_list.extend(children)

    //         if child in child_to_parent:
    //             parent = child_to_parent[child]
    //             child_vector = np.expand_dims(rotated_pose[parent], -1) + np.matmul(original_angles[child], unit_vector.T)
    //             rotated_pose[child] = np.squeeze(child_vector, axis=-1)

    //     return rotated_pose
    public static normalizePose(poseResults: Matrix): number[][] {
        let originalPoseValues = poseResults.values;
        let hipAverage = originalPoseValues[intToPoseLandmark.RIGHT_HIP].map((value, index) => {
            return (value + originalPoseValues[intToPoseLandmark.LEFT_HIP][index]) / 2;
        });
        // Translate everything so that the hips are going to be either side of the origin
        originalPoseValues = originalPoseValues.map(value => {
            return [value[0] + hipAverage[0], value[1] + hipAverage[1], value[2] + hipAverage[2]]
        })
        // Create a fake root node that will be the middle of both the hips
        originalPoseValues = originalPoseValues.concat([[0, 0, 0]]);
        poseResults = new Matrix(poseResults.rows + 1, poseResults.columns, originalPoseValues);
        let originalAngles = this.calculateAngleFromParent(poseResults);
        let rotatedPoseMatrix = new Matrix(34, 3);
        let rotatedPoseValues = rotatedPoseMatrix.values;
        let unitVector = math.matrix([0, 0, 1]);
        var parentsList = [poseLandmarkToInt.FAKE_HIP];
        let inverseAngle = math.inv(originalAngles[23]);
        originalAngles[23] = math.multiply(inverseAngle, originalAngles[23]) as math.Matrix;
        originalAngles[24] = math.multiply(inverseAngle, originalAngles[24]) as math.Matrix;
        while (parentsList.length !== 0) {
            let child = parentsList[0];
            parentsList = parentsList.slice(1);
            let children = parentToChildConstant[child];
            if (children) {
                parentsList = parentsList.concat(children);
            }

            if (child in childToParentLookup) {
                let parent = childToParentLookup[child];
                let rotationMatrix = originalAngles[child];
                let parentPosition = rotatedPoseValues[parent];
                let currentBoneLength = boneLength[child];
                let rotatedVector = math.multiply(rotationMatrix, math.multiply(unitVector, currentBoneLength));
                let rotatedValues: number[] = [rotatedVector.get([0]), rotatedVector.get([1]), rotatedVector.get([2])];
                rotatedPoseValues[child] = rotatedValues.map((value, index) => {
                    return value + parentPosition[index];
                });
            }
        }

        return rotatedPoseValues;
    }
}

export default MatrixUtils