from __future__ import division
from __future__ import print_function
from builtins import str
from builtins import range
from past.utils import old_div
# timing 7 different Python sorting algorithms with a list of integers
# each function is given the same list (fresh copy each time)

import random  # for generating random numbers
import time    # for timing each sort function with time.clock()
import matplotlib.pyplot as plt
from pycallgraph2 import PyCallGraph
from pycallgraph2.output import GraphvizOutput
from coinor.gimpy import BinaryTree
import cProfile, pstats

DEBUG = False  # set True to check results of each sort

times = {}

def print_timing(func):
    def wrapper(*arg):
        t1 = time.time()
        res = func(*arg)
        t2 = time.time()
        times[func.__name__] = t2-t1
        print('%s took %0.3fms' % (func.__name__, (t2-t1)*1000.0))
        return res
    wrapper.func = func
    return wrapper

counts = {}

def counter(func):
    counts[func.__name__] = 0
    def wrapper(*arg):
        counts[func.__name__] += 1
        res = func(*arg)
        return res
    wrapper.func = func
    return wrapper
    
@counter
def compare(item1, item2):
    return item1 < item2

@counter
def compare_eq(item1, item2):
    return item1 <= item2

@counter
def swap(aList, index1, index2):
    aList[index1], aList[index2] = aList[index2], aList[index1]

@counter
def assign(aList, index, value):
    aList[index] = value

@counter
def append(aList, value):
    aList.append(value)

@counter
def shift_right(aList, index):
    aList[index + 1] = aList[index]
    
# declare the @ decorator just above each sort function, invokes print_timing()
@print_timing
def adaptive_merge_sort(list2):
    """adaptive merge sort, built into Python since version 2.3"""
    list2.sort()

@print_timing
def bubble_sort_count(list2):
    #swap_test = False
    for i in range(0, len(list2) - 1):
        swap_test = False
        for j in range(0, len(list2) - i - 1):
            if compare(list2[j + 1], list2[j]):
                swap(list2, j, j+1)
            swap_test = True
        if swap_test == False:
            break

@print_timing
def bubble_sort(list2):
    #swap_test = False
    for i in range(0, len(list2) - 1):
        # as suggested by kubrick, makes sense
        swap_test = False
        for j in range(0, len(list2) - i - 1):
            if list2[j] > list2[j + 1]:
                list2[j], list2[j + 1] = list2[j + 1], list2[j]  # swap
                swap_test = True
        if swap_test == False:
            break

# selection sort
@print_timing
def selection_sort_count(list2):
    for i in range(0, len (list2)):
        minimum = i
        for j in range(i + 1, len(list2)):
            if compare(list2[j], list2[minimum]):
                minimum = j
        swap(list2, i, minimum)  # swap

@print_timing
def selection_sort(list2):
    for i in range(0, len (list2)):
        minimum = i
        for j in range(i + 1, len(list2)):
            if list2[j] < list2[minimum]:
                minimum = j
        list2[i], list2[minimum] = list2[minimum], list2[i]  # swap
      
# insertion sort
@print_timing
def insertion_sort_count(list2):
    for i in range(1, len(list2)):
        save = list2[i]
        j = i
        while j > 0 and compare(save, list2[j - 1]):
            shift_right(list2, j - 1)
            j -= 1
        assign(list2, j, save)

@print_timing
def insertion_sort(list2):
    for i in range(1, len(list2)):
        save = list2[i]
        j = i
        while j > 0 and list2[j - 1] > save:
            list2[j] = list2[j - 1]
            j -= 1
        list2[j] = save
  
# quick sort
@print_timing
def quick_sort_count(list2, G = None, display = False):
    quick_sort_r_count(list2, 0, len(list2) - 1, 1, G, display)

@print_timing
def quick_sort(list2):
    quick_sort_r(list2, 0, len(list2) - 1)

# quick_sort_r, recursive (used by quick_sort)
def quick_sort_r_count(list2, first, last, nodeNum, G, display):
    if last > first:
        pivot = partition_count(list2, first, last)
        
        if G != None: 
            if nodeNum == 1:
                G.add_root(nodeNum)
            else:
                if G.get_right_child(old_div(nodeNum,2)) is None:
                    G.add_right_child(nodeNum, old_div(nodeNum,2))
                else:
                    G.add_left_child(nodeNum, old_div(nodeNum,2))
#            G.get_node(nodeNum).set('label', str(last-first+1))
#            G.get_node(nodeNum).set('label', str(nodeNum))
            if display:
                G.display()
        
        quick_sort_r_count(list2, first, pivot - 1, 2*nodeNum,
                           G, display)
        quick_sort_r_count(list2, pivot + 1, last, 2*nodeNum + 1,
                           G, display)

def quick_sort_r(list2 , first, last):
    if last > first:
        pivot = partition(list2, first, last)
        quick_sort_r(list2, first, pivot - 1)
        quick_sort_r(list2, pivot + 1, last)

# partition (used by quick_sort_r)
def partition_count(list2, first, last):
    sred = old_div((first + last),2)
    if compare(list2[sred], list2 [first]):
        swap(list2, first, sred)
    if compare(list2[last], list2 [first]):
        swap(list2, first, last)
    if compare(list2[last], list2[sred]):
        swap(list2, sred, last)
    swap(list2, sred, first)
    pivot = first
    i = first + 1
    j = last
  
    while True:
        while i <= last and compare_eq(list2[i], list2[pivot]):
            i += 1
        while j >= first and compare(list2[pivot], list2[j]):
            j -= 1
        if i >= j:
            break
        else:
            swap(list2, i, j)
    swap(list2, j, pivot)
    return j

def partition(list2, first, last):
    sred = old_div((first + last),2)
    if list2[first] > list2 [sred]:
        list2[first], list2[sred] = list2[sred], list2[first]  # swap
    if list2[first] > list2 [last]:
        list2[first], list2[last] = list2[last], list2[first]  # swap
    if list2[sred] > list2[last]:
        list2[sred], list2[last] = list2[last], list2[sred]    # swap
    list2 [sred], list2 [first] = list2[first], list2[sred]    # swap
    pivot = first
    i = first + 1
    j = last
  
    while True:
        while i <= last and list2[i] <= list2[pivot]:
            i += 1
        while j >= first and list2[j] > list2[pivot]:
            j -= 1
        if i >= j:
            break
        else:
            list2[i], list2[j] = list2[j], list2[i]  # swap
    list2[j], list2[pivot] = list2[pivot], list2[j]  # swap
    return j

# heap sort
@print_timing
def heap_sort_count(list2):
    first = 0
    last = len(list2) - 1
    create_heap_count(list2, first, last)
    for i in range(last, first, -1):
        list2[i], list2[first] = list2[first], list2[i]  # swap
        establish_heap_property_count(list2, first, i - 1)

@print_timing
def heap_sort(list2):
    first = 0
    last = len(list2) - 1
    create_heap(list2, first, last)
    for i in range(last, first, -1):
        list2[i], list2[first] = list2[first], list2[i]  # swap
        establish_heap_property (list2, first, i - 1)

# create heap (used by heap_sort)
def create_heap_count(list2, first, last):
    i = old_div(last,2)
    while i >= first:
        establish_heap_property_count(list2, i, last)
        i -= 1

def create_heap(list2, first, last):
    i = old_div(last,2)
    while i >= first:
        establish_heap_property(list2, i, last)
        i -= 1

# establish heap property (used by create_heap)
def establish_heap_property_count(list2, first, last):
    while 2 * first + 1 <= last:
        k = 2 * first + 1
        if k < last and compare(list2[k], list2[k + 1]):
            k += 1
        if compare_eq(list2[k], list2[first]):
            break
        swap(list2, first, k)
        first = k

def establish_heap_property(list2, first, last):
    while 2 * first + 1 <= last:
        k = 2 * first + 1
        if k < last and list2[k] < list2[k + 1]:
            k += 1
        if list2[first] >= list2[k]:
            break
        list2[first], list2[k] = list2[k], list2[first]  # swap
        first = k

# merge sort
@print_timing
def merge_sort_count(list2):
    merge_sort_r_count(list2, 0, len(list2) -1)

@print_timing
def merge_sort(list2):
    merge_sort_r(list2, 0, len(list2) -1)

# merge sort recursive (used by merge_sort)
def merge_sort_r_count(list2, first, last):
    if first < last:
        sred = old_div((first + last),2)
        merge_sort_r_count(list2, first, sred)
        merge_sort_r_count(list2, sred + 1, last)
        merge_count(list2, first, last, sred)

def merge_sort_r(list2, first, last):
    if first < last:
        sred = old_div((first + last),2)
        merge_sort_r(list2, first, sred)
        merge_sort_r(list2, sred + 1, last)
        merge(list2, first, last, sred)

# merge (used by merge_sort_r)
def merge_count(list2, first, last, sred):
    helper_list = []
    i = first
    j = sred + 1
    while i <= sred and j <= last:
        if compare_eq(list2[i], list2 [j]):
            append(helper_list, list2[i])
            i += 1
        else:
            append(helper_list, list2[j])
            j += 1
    while i <= sred:
        append(helper_list, list2[i])
        i +=1
    while j <= last:
        append(helper_list, list2[j])
        j += 1
    for k in range(0, last - first + 1):
        assign(list2, first + k, helper_list[k])

def merge(list2, first, last, sred):
    helper_list = []
    i = first
    j = sred + 1
    while i <= sred and j <= last:
        if list2 [i] <= list2 [j]:
            helper_list.append(list2[i])
            i += 1
        else:
            helper_list.append(list2 [j])
            j += 1
    while i <= sred:
        helper_list.append(list2[i])
        i +=1
    while j <= last:
        helper_list.append(list2[j])
        j += 1
    for k in range(0, last - first + 1):
        list2[first + k] = helper_list [k]

def fsort(A, beg = None, end = None):
    if beg == None:
        beg = 0
    if end == None:
        end = len(A) - 1
    if beg >= end: 
        return
    if A[beg] > A[end]: 
        A[beg], A[end] = A[end], A[beg]
    fsort(A, beg+1, end-1)
    if A[beg] > A[beg+1]: 
        A[beg], A[beg+1] = A[beg+1], A[beg]
    fsort(A, beg+1, end)

# test sorted list by printing the first 10 elements
def print10(list2):
    for k in range(10):
        print(list2[k], end=' ')
    print()


# run test if script is executed
if __name__ == "__main__" :

    aList = [random.randint(0, 100) for j in range(1000)]
        
    # G = BinaryTree()
    # G.set_display_mode('matplotlib')
    # quick_sort_count(aList, G = G)
    # G.display()
    #G.print_nodes(order = 'post')

    cProfile.run('insertion_sort_count(aList)', 'cprof.out')
    p = pstats.Stats('cprof.out')
    p.sort_stats('cumulative').print_stats(10)
    with PyCallGraph(output=GraphvizOutput()):
        quick_sort_count(aList)

    #The code below is another version that's not as pretty
    #plt.figure(figsize = (10, 10))
    #pos = nx.graphviz_layout(G, prog = 'dot', args = '')
    #nx.draw_networkx(G, pos, node_size = 300, node_color = 'burlywood', with_labels = False)
    #nx.draw_networkx_labels(G, pos, labels, font_size = 10)
    #plt.draw()
    #plt.show()

    list_increment = 100
    num_experiments = 30
    
    lists = []
    sizes = []
    num = list_increment
    
    for i in range(num_experiments):
        lists.append([])
        sizes.append(num)

        for j in range(num):
            lists[i].append(random.randint(0, num-1))
            
        num += list_increment

    algos = []
    algos_count = []
    #algos.extend([insertion_sort])
    algos.extend([bubble_sort, selection_sort, insertion_sort])
    algos.extend([quick_sort, heap_sort, merge_sort])
    algos_count.extend([bubble_sort_count, selection_sort_count, insertion_sort_count]) 
    algos_count.extend([quick_sort_count, heap_sort_count, merge_sort_count])
    #algos_count.extend([insertion_sort_count])

    times_growth = {}
    for a in algos:
        print(a.func.__name__)
        times_growth[a.func.__name__] = []
        
    for i in range(num_experiments):
        for a in algos:
            list_copy = list(lists[i])
            a(list_copy)
            times_growth[a.func.__name__].append(times[a.func.__name__])
            times[a.func.__name__] = 0
            
    counts_growth = {}
    for a in algos_count:
        print(a.func.__name__)
        counts_growth[a.func.__name__] = []

    plt.xlabel('Input n')
    plt.ylabel('Computation time')
    for a in algos:
        plt.plot(sizes, times_growth[a.func.__name__], label=a.func.__name__)
        print('Running times for', a.func.__name__)
        print(times_growth[a.func.__name__])
    plt.legend(bbox_to_anchor=(0.5, 1))
    plt.show()

    for i in range(num_experiments):
        for a in algos_count:
            list_copy = list(lists[i])
            a(list_copy)
            total_ops = 0
            for j in counts:
                total_ops += counts[j]
                if counts[j] > 0:
                    print(j, 'was called', counts[j], 'times')
                counts[j] = 0
            print('Total number of operations was', total_ops)
            counts_growth[a.func.__name__].append(total_ops)

    plt.xlabel('Input n')
    plt.ylabel('Computation time')
    for a in algos_count:
        plt.plot(sizes, counts_growth[a.func.__name__], label=a.func.__name__)
        print('Operation counts for', a.func.__name__)
        print(counts_growth[a.func.__name__])
    plt.legend(bbox_to_anchor=(0.5, 1))
    plt.show()
