# Collision detection in polygons - 2D case
# Faisal Z. Qureshi
# faisal.qureshi@ontariotechu.ca
#
# Adapted from ChatGPT

import numpy as np
import matplotlib.pyplot as plt

class MyPoly:
    def __init__(self, vertices):
        self.vertices = vertices
        self.edges = np.empty((len(self.vertices), 2))
        self.separating_axes = np.empty((len(self.vertices), 2))

    def print(self):
        n = len(self.vertices)
        for i in range(n):
            print(f'{i}: ', self.vertices[i,:])

    def compute_edges(self):
        n = len(self.vertices)
        for i in range(n):
            self.edges[i] = self.vertices[(i+1) % n] - self.vertices[i % n]
            print(self.edges[i])
    
    def compute_separating_axes(self):
        R = np.array([[0, -1, 1, 0]])
        for i in range(len(self.edges)):
            self.separating_axes[i] = np.dot(R, self.edges[i]) 
            self.separating_axes[i] = np.array([-self.edges[i][1], self.edges[i][0]])
            print(self.separating_axes[i])

    def project(self, axis):
        for i in range(len(vertices)):
            np.dot(self.vertices[i], axis)

def polygons_collide(poly1, poly2):
    """Check for collision between two convex polygons using SAT."""
    # TO DO
    return False



def plot_polygons(poly1, poly2, collision):
    """Visualize two polygons and indicate collision status."""
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.set_xlim(0, 100)
    ax.set_ylim(0, 100)
    ax.set_aspect('equal')

    # Convert polygon arrays to lists for plotting
    poly1 = np.vstack([poly1, poly1[0]])  # Close the shape
    poly2 = np.vstack([poly2, poly2[0]])  # Close the shape

    # Draw polygons
    ax.plot(poly1[:, 0], poly1[:, 1], 'b-', linewidth=2, label="Polygon 1")
    ax.fill(poly1[:, 0], poly1[:, 1], 'blue', alpha=0.3)

    ax.plot(poly2[:, 0], poly2[:, 1], 'r-', linewidth=2, label="Polygon 2")
    ax.fill(poly2[:, 0], poly2[:, 1], 'red', alpha=0.3)

    # Collision state text
    state_text = "Collision!" if collision else "No Collision"
    ax.set_title(state_text, fontsize=14, fontweight="bold", color="green" if collision else "red")

    plt.legend()
    plt.grid(True)
    plt.show()

# Example polygons
polygon1 = np.array([[10, 10], [30, 10], [30, 30]]) #, [10, 30]])  # Square
polygon2 = np.array([[25, 25], [45, 25], [45, 45], [25, 45]])  # Overlapping Square
polygon3 = np.array([[50, 50], [70, 50], [70, 70], [50, 70]])  # Non-overlapping Square

my_poly1 = MyPoly(polygon1)
my_poly1.print()
my_poly1.compute_edges()
set1 = my_poly1.compute_separating_axes()

my_poly2 = MyPoly(polygon2)
my_poly1.compute_edges()
set2 = my_poly1.compute_separating_axes()

total_sep_axis = set1 + set2
for each axis:
    project(poly1)
    porject(poly2)
    check for overlap


my_poly3 = MyPoly(polygon3)

# Check and plot collision case
collision1 = polygons_collide(polygon1, polygon2)
plot_polygons(polygon1, polygon2, collision1)

# Check and plot non-collision case
collision2 = polygons_collide(polygon1, polygon3)
plot_polygons(polygon1, polygon3, collision2)
