# Random walks - exercise
# Faisal Z. Qureshi
# faisal.qureshi@ontariotechu.ca

import numpy as np
import random
import matplotlib.pyplot as plt


class Location:

    def __init__(self, loc):
        self.loc = np.array(loc)

    def move_by(self, l):
        return Location(np.add(self.loc, l.loc))

    def dist_from(self, o):
        return np.sqrt(np.sum((self.loc - o.loc)**2))

    def __str__(self):
        return str(self.loc)

class Field:

    def __init__(self):
        self.walkers = {}

    # Must override this method to ensure that Field is properly resetted
    def reset(self):
        self.walkers = {}
        
    def add(self, walker, loc):
        if walker in self.walkers:
            raise ValueError('Walker already present')
        else:
            self.walkers[walker] = loc

    def move(self, walker):
        if not walker in self.walkers:
            raise ValueError('Cannot find this walker')
        step = walker.step() # Notice that step is a delta Location
        self.walkers[walker] = self.walkers[walker].move_by(step)

    def get_loc(self, walker):
        if not walker in self.walkers:
            raise ValueError('Cannot find this walker')
        return self.walkers[walker]

class Walker:

    def __init__(self, name):
        self.name = name

    # Over-ride this function to create different Walkers.
    def select_a_step(self, step_choices):
        return random.choice(step_choices)

    # Over-ride this function to create different Walkers
    def step(self): 
        step_choices = [Location([0,1]), Location([0,-1]), Location([1,0]), Location([-1,0])]
        return self.select_a_step(step_choices)

def walk(f, w, num_steps):
    start = f.get_loc(w)          # Get walker's starting position
    for s in range(num_steps):    # Now move walker num_steps time
        f.move(w)
    return start.dist_from(f.get_loc(w)), f.get_loc(w).loc  # Compute distance between starting and ending location


# For repeatability
random.seed(0)

# Testing locations
a = Location([0,1])
b = Location([2,4])

print ('a', a)
print ('b', b)
print ('distance between a and b', a.dist_from(b))
print ('Move a by [3,4]', a.move_by(Location([3,4])))

f = Field()
w = Walker('zombie')
f.add(w, Location([0,0]))
print (walk(f, w, 1))


def simulate_walks(num_steps, num_trials, field, walker):
    start = Location([0,0])
    distances = []
    ending_locations = start.loc

    zombie = walker
    f = field    
    for t in range(num_trials):
        f.reset()
        f.add(zombie, start)
        d, e = walk(f, zombie, num_steps)
        distances.append(d)
        ending_locations = np.vstack([ending_locations, e])
    return distances, ending_locations


print('--> Walker')
d, e = simulate_walks(num_steps=100, num_trials=100, field=Field(), walker=Walker('zombie'))
print ('average distance', np.mean(d))

plt.figure(figsize=(7,7))
plt.scatter(e[1:,0],e[1:,1])
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Ending locations without traps')

plt.figure(figsize=(7,7))
plt.hist(d)
plt.ylabel('# Trials')
plt.xlabel('Distances')
plt.title('Distances without traps')

# TO DO

class FieldWithTraps(Field):
    def __init__(self, w, h, num_traps):
        Field.__init__(self)
        
    def reset(self):
        Field.reset(self)

    def move(self, walker):
        Field.move(self, walker)


print('--> Walker with traps')
d, e = simulate_walks(num_steps=100, num_trials=100, field=FieldWithTraps(30,30,300), walker=Walker('zombie'))
print ('average distance', np.mean(d))

plt.figure(figsize=(7,7))
plt.scatter(e[1:,0],e[1:,1])
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Ending locations with traps')

plt.figure(figsize=(7,7))
plt.hist(d)
plt.ylabel('# Trials')
plt.xlabel('Distances')
plt.title('Distances with traps')


# TO DO
class WalkerUp(Walker):
    def __init__(self, name):
        Walker.__init__(self, name)
    
    def step(self): 
        return Walker.step(self)    

print('--> Walker up with no traps')
d, e = simulate_walks(num_steps=100, num_trials=100, field=Field(), walker=WalkerUp('baloon'))
print ('average distance', np.mean(d))

plt.figure(figsize=(7,7))
plt.scatter(e[1:,0],e[1:,1])
plt.axis('equal')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Ending locations for walker up and no traps')

plt.figure(figsize=(7,7))
plt.hist(d)
plt.ylabel('# Trials')
plt.xlabel('Distances')
plt.title('Distances for walker up and no traps')
plt.show()



