#!/bin/python

import numpy as np
import random
from matplotlib import pyplot as plt
from matplotlib import animation
import sys


n_layers = 3
n_steps = 5000

def init():
    line.set_data([], [])
    frame_text.set_text('')
    return line, frame_text,

def animate(i, fermat):
    fermat.step()

    line.set_xdata(np.linspace(0,fermat.N,fermat.N+1))
    line.set_ydata(fermat.y)
    frame_text.set_text(frame_template % fermat.steps)
    return line, frame_text,

class Fermat:
    def __init__(self, N, dn):
        '''
        Sets up Fermat simulation -- light ray in media with different
        index of refractions

        N = number of media layers
        dn = change in index of refraction from one region to the next

        Layers (boundaries) are set at x-locations: 1, 2, 3, ..., N-1
        Starting x-location for the ray is 0
        Ending x-location for the ray is N
        Starting y-location for the ray is y[0]
        Ending y-location for the ray is y[N]
        Location of ray at layer i is y[i]
        Speed of light before layer i is v[i-1]
        '''
        self.dn = dn
        self.N = N
        self.y = np.empty((N+1)) # y coordinate of light ray, index is x coordinate
        self.v = np.empty((N))   # light speed of ray for medium starting at index value
        self.dy = 0.1            # max change in y at each step
        index_of_refraction = 1  # index of refraction for vacuum 
        self.steps = 0

        for i in range(N+1):
            self.y[i] = i
        for i in range(N):
            self.v[i] = 1. / index_of_refraction
            index_of_refraction += dn


    # TO DO
    def time(self, i, yi):
        '''
        Compute time it takes the ray to go from h to j, passing thorough i

        h --- i --- j : speed between h and i is v[h] and speed between i and j is v[i]
        '''
        return 0

    # TO DO
    def step(self):
        # Select a boundary i between 1 and N-1 at random
        
        # Compute the time it takes the ray to go from boundary [i-1]
        # and [i+1] passing through i
        #
        # use time() method that you completed above

        # Pick a new y-location on boundary i

        # Compute the time it takes the ray to go from boundary [i-1]
        # and [i+1] passing through i.  If the new time is less than
        # the current time.  Keep the change in y-location.  Otherwise,
        # discard this change.
        #
        # use time() method that you completed above

        self.steps += 1

if __name__ == '__main__':
    random.seed(0)

    if len(sys.argv) > 1:
        n_layers = int(sys.argv[1])
    if len(sys.argv) > 2:
        n_layers = int(sys.argv[1])
        n_steps = int(sys.argv[2])


    # Setup figure
    fig = plt.figure(1)
    ax = plt.axes(xlim=(-1, n_layers+1), ylim=(-1, n_layers+1))
    plt.grid()
    line, = ax.plot([], [], '-')
    frame_template = 'step = %d'
    frame_text = ax.text(0.05, 0.9, '', transform=ax.transAxes)
    plt.title('Fermat experiment')
    plt.xlabel('x')
    plt.ylabel('y')

    for i in range(n_layers-1):
        plt.plot([i+1,i+1],[0,n_layers],'r-')


    fermat = Fermat(n_layers, 0.5)
    print ('v', fermat.v)
    print ('y', fermat.y)

    anim = animation.FuncAnimation(fig, animate, fargs=(fermat,), init_func=init, frames=n_steps, interval=10, blit=True, repeat=False)
    plt.show()
