# Collision detection in polygons - 3D case
# Faisal Z. Qureshi
# faisal.qureshi@ontariotechu.ca
#
# Adapted from ChatGPT

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from itertools import combinations

class ConvexPolyhedron:
    def __init__(self, vertices, faces):
        """
        vertices: List of 3D points (numpy array)
        faces: List of faces, where each face is a list of vertex indices defining a polygon
        """
        self.vertices = np.array(vertices)
        self.faces = faces  # List of index lists

def polyhedron_sat_collision(poly1, poly2):
    """Check collision between two convex polyhedra using the Separating Axis Theorem."""
    #TO DO
    return False

def plot_polyhedron(ax, poly, color='cyan', alpha=0.5):
    """Plots a 3D convex polyhedron."""
    for face in poly.faces:
        face_vertices = [poly.vertices[i] for i in face]
        ax.add_collection3d(Poly3DCollection([face_vertices], color=color, alpha=alpha, edgecolor='black'))

def visualize_collision(poly1, poly2, collision):
    """Visualizes two convex polyhedra and indicates collision status."""
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, projection='3d')

    plot_polyhedron(ax, poly1, color='blue' if not collision else 'green', alpha=0.5)
    plot_polyhedron(ax, poly2, color='red' if not collision else 'green', alpha=0.5)

    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z")
    ax.set_title("Collision: " + str(collision))

    # Set limits
    all_vertices = np.vstack([poly1.vertices, poly2.vertices])
    min_vals = np.min(all_vertices, axis=0) - 1
    max_vals = np.max(all_vertices, axis=0) + 1
    ax.set_xlim([min_vals[0], max_vals[0]])
    ax.set_ylim([min_vals[1], max_vals[1]])
    ax.set_zlim([min_vals[2], max_vals[2]])

    plt.show()

# Example Polyhedra (Cubes)
vertices1 = np.array([
    [-1, -1, -1], [1, -1, -1], [1, 1, -1], [-1, 1, -1],  # Bottom face
    [-1, -1, 1], [1, -1, 1], [1, 1, 1], [-1, 1, 1]   # Top face
])
faces1 = [[0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4], [2, 3, 7, 6], [0, 3, 7, 4], [1, 2, 6, 5]]

vertices2 = np.array([
    [2, 2, 2], [4, 2, 2], [4, 4, 2], [2, 4, 2],  # Bottom face
    [2, 2, 4], [4, 2, 4], [4, 4, 4], [2, 4, 4]   # Top face
])
faces2 = [[0, 1, 2, 3], [4, 5, 6, 7], [0, 1, 5, 4], [2, 3, 7, 6], [0, 3, 7, 4], [1, 2, 6, 5]]

poly1 = ConvexPolyhedron(vertices1, faces1)
poly2 = ConvexPolyhedron(vertices2, faces2)

# Check for collision
collision = polyhedron_sat_collision(poly1, poly2)

# Visualize with updated collision status
visualize_collision(poly1, poly2, collision)