# Collision detection in polygons - 3D case rotated
# Faisal Z. Qureshi
# faisal.qureshi@ontariotechu.ca
#
# Adapted from ChatGPT

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from itertools import combinations
from scipy.spatial.transform import Rotation as R

class ConvexPolyhedron:
    def __init__(self, vertices, faces):
        self.vertices = np.array(vertices)
        self.faces = faces

def rotate_polyhedron(poly, rotation_matrix):
    rotated_vertices = np.dot(poly.vertices, rotation_matrix.T)
    return ConvexPolyhedron(rotated_vertices, poly.faces)

def polyhedron_sat_collision(poly1, poly2):
    # TO DO
    return False

def overlap_on_axis(poly1, poly2, axis):
    axis = axis / np.linalg.norm(axis)
    min1, max1 = np.min(np.dot(poly1.vertices, axis)), np.max(np.dot(poly1.vertices, axis))
    min2, max2 = np.min(np.dot(poly2.vertices, axis)), np.max(np.dot(poly2.vertices, axis))
    return max1 >= min2 and max2 >= min1

def plot_polyhedron(ax, poly, color='cyan', alpha=0.5):
    for face in poly.faces:
        face_vertices = [poly.vertices[i] for i in face]
        ax.add_collection3d(Poly3DCollection([face_vertices], color=color, alpha=alpha, edgecolor='black'))

def visualize(poly1, poly2, poly3, collision1, collision2):
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')
    plot_polyhedron(ax, poly1, color='red' if collision1 else 'blue', alpha=0.5)
    plot_polyhedron(ax, poly2, color='red' if collision1 else 'green', alpha=0.5)
    plot_polyhedron(ax, poly3, color='red' if collision2 else 'cyan', alpha=0.5)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    plt.show()

vertices1 = np.array([[-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],
                      [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]])
faces1 = [[0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4], [2, 3, 7, 6], [0, 3, 7, 4], [1, 2, 6, 5]]
vertices2 = np.array([[0.5, 0.5, 0.5], [2.5, 0.5, 0.5], [2.5, 2.5, 0.5], [0.5, 2.5, 0.5],
                      [0.5, 0.5, 2.5], [2.5, 0.5, 2.5], [2.5, 2.5, 2.5], [0.5, 2.5, 2.5]])
faces2 = faces1
vertices3 = np.array([[3, 3, 3], [5, 3, 3], [5, 5, 3], [3, 5, 3],
                      [3, 3, 5], [5, 3, 5], [5, 5, 5], [3, 5, 5]])
faces3 = faces1

poly1 = ConvexPolyhedron(vertices1, faces1)
poly2 = ConvexPolyhedron(vertices2, faces2)
poly3 = ConvexPolyhedron(vertices3, faces3)
rotation_matrix = R.from_euler('xyz', [30, 45, 60], degrees=True).as_matrix()
rotated_poly1 = rotate_polyhedron(poly1, rotation_matrix)
collision1 = polyhedron_sat_collision(rotated_poly1, poly2)
collision2 = polyhedron_sat_collision(rotated_poly1, poly3)
visualize(rotated_poly1, poly2, poly3, collision1, collision2)
