import math
from scipy.stats import norm
import random

# ensure the same random seed for cross-model testing
random.seed(0)

def hz_to_erb (formant:int):
    """
    This definition is the same one as used in Praat.
    """

    return 11.17 * math.log(((formant+312)/(formant+14680))) + 43

def gaussian_activation(input_ERB: float, membrane_location: float):
    """
    This implements what is meant by a Gaussian bump, namely that
    the activation of an input node is the height of a Gaussian with peak 1
    and SD 0.68 (after Boersma, Chládková & Benders 2022) at the value that the
    node takes on the basilar membrane. The first part ensures that the peak has height 1.
    """
    max_height = 1
    standard_dev = 0.68

    return  max_height * math.sqrt(2*math.pi) * standard_dev * norm.pdf(membrane_location, input_ERB, standard_dev)

def logistic_function(x:float):
    """
    This is the logistic function.
    """
    if x <= 0:
        return 1- (1 / (1 + math.exp(x)))

    return 1 / (1 + math.exp(-x))


def cosine_similarity(a:list, b:list):
    """
    This function computes the cosine similarity between 2 vectors.

    The cosine similarity is given by summing the products of the i'th element of A and B,
    then dividing this by the product of the square roots of the summations of the squares of each element of A and B respectively.

    This function assumes the lists are equally long.
    """
    # calculate numerator
    numerator = 0
    for i in range(len(a)):
        numerator += a[i] * b[i]

    # calculate root of sum of squares of A
    a_sum = 0
    for a_number in a:
        a_sum += a_number**2

    a_sum = math.sqrt(a_sum)

    # calculate root of sum of squares of B
    b_sum = 0
    for b_number in b:
        b_sum += b_number**2

    b_sum = math.sqrt(b_sum)

    # calculate denominator:
    denominator = a_sum * b_sum

    # calculate cosine similarity
    return numerator / denominator

class semanticNode():
    def __init__(self):
        self.bias = 0
        self.excitation = 0

class inputNode():
    def __init__ (self, erb:int):
        self.erb = erb
        self.bias = 0
        self.excitation = 0

    def set_clamped_excitation(self, formant1:int, formant2:int):

        self.excitation = gaussian_activation(self.erb, formant1) + gaussian_activation(self.erb, formant2)
        self.excitation = self.excitation*5 -1

class middleNode:
    def __init__(self):
        self.bias = 0
        self.excitation = 0
        self.weights = [[0 for x in range(49)] for x in range(17)]
        self.semantic_weights = [0 for x in range(7)]

class topNode:
    def __init__(self):
        self.bias = 0
        self.excitation = 0
        self.weights = [0 for x in range(50)]

class model:
    def __init__(self):
        self.input_layer = [[inputNode(x/2) for x in range(8,57)] for x in range(17)]

        self.semantic_layer = [semanticNode() for x in range(7)]

        self.middle_layer = [middleNode() for x in range(50)]

        self.top_layer = [topNode() for x in range(20)]

    def run(self, vowels:int):
        """
        This method is the main way to call on this class.
        It implements the learning algorithm described in
        Boersma, Chládková & Benders (2022)
        """

        for datum in range(vowels):
            # this just keeps track of where we are so I can see if my laptop got stuck
            print(datum)

            self.determine_vowel()

            for i in range(10):
                self.initial_settling()

            self.hebbian_learning()

            for i in range(10):
                self.dreaming()

            self.antihebbian_learning()

            # this is a debugging tool to keep track of what the output is
            #if datum % 100 == 1:
                #self.obtain_results()



    def obtain_results(self):
        """
        This method first creates lists of the excitations of the top layer of the model for each standard utterance.
        It then calculates the cosine similarity between each of them and prints it.
        """
        formant_list = [["A", 800,1500],["E",650,1800],["I",300,2500],["O",500,900],["U",300,500]]

        output_dict= {}

        for vowel in formant_list:
            output_list = []
            self.set_input_manually(vowel[0], vowel[1], vowel[2], True)

            for i in range(10):
                self.initial_settling()

            for output_node in self.middle_layer:
                output_list.append(output_node.excitation)

            output_dict[vowel[0]] = output_list

            output_list = []
            self.set_input_manually(vowel[0], vowel[1], vowel[2], False)

            for i in range(10):
                self.initial_settling()

            for output_node in self.middle_layer:
                output_list.append(output_node.excitation)

            output_dict[vowel[0].lower()] = output_list


        for list1 in output_dict.values():
            for list2 in output_dict.values():
                print(cosine_similarity(list1, list2))
            print("")

        # this is a debugging tool
        #for list in output_dict.values():
            #print(list)

    def determine_vowel(self):
        """
        This method determines the vowel to use as input.
        I'm using averages from Šimáčková, Podlipský & Chládková (2012)
        """
        formant_list = [["A", 800,1500],["E",650,1800],["I",300,2500],["O",500,900],["U",300,500]]
        vowel = random.choice(formant_list)

        self.quality = vowel[0]
        f1 = vowel[1]
        f2 = vowel[2]

        self.f1_erb = hz_to_erb(f1)
        self.f2_erb = hz_to_erb(f2)

        # this function is implemented in the random module as random(mean, std.dev)
        self.f1_erb = random.gauss(self.f1_erb, 1)
        self.f2_erb = random.gauss(self.f2_erb, 1)

        self.long_vowel = False

        # >= because 0 is included in random.random, but 1 is not
        if random.random() >= 0.5: self.long_vowel = True

        # this just happens to be a fact about Czech, /iː/ is only 1.3 times as long
        if self.long_vowel == True:
            if not self.quality == "I":
                length = 17
            else:
                length = 13
        else:
            length = 10

        for layer in self.input_layer[0:length]:
            for node in layer:
                node.set_clamped_excitation(self.f1_erb, self.f2_erb)

        for layer in self.input_layer[length:]:
            for node in layer:
                node.excitation = 0


        # set the semantic layer
        if not self.long_vowel:
            self.quality = self.quality.lower()

        # first 2 stand for number, next 2 stand for case, last 3 stand for semantic meaning
        semantic_dict = {
            "A":[-2,3.5,-2,3.5,-2,-2,3.5],
            "a":[-2,3.5,3.5,-2,-2,-2,3.5],
            "E":[3.5,-2,-2,3.5,3.5,-2,-2],
            "e":[3.5,-2,3.5,-2,3.5,-2,-2],
            "I":[3.5,-2,-2,3.5,-2,3.5,-2],
            "i":[3.5,-2,3.5,-2,-2,3.5,-2],
            "O":[-2,3.5,-2,3.5,3.5,-2,-2],
            "o":[-2,3.5,3.5,-2,3.5,-2,-2],
            "U":[-2,3.5,-2,3.5,-2,3.5,-2],
            "u":[-2,3.5,3.5,-2,-2,3.5,-2]}

        for semantic_node in range(len(self.semantic_layer)):
            self.semantic_layer[semantic_node].excitation = semantic_dict[self.quality][semantic_node]


    def set_input_manually(self, quality:str, f1:int, f2:int, input_length:bool):
        """
        This method sets a standard utterance for testing purposes.

        Quality should be one of "A", "E", "I", "O" and "U".
        F1 and F2 are the formants. input_length = True means a long vowel, False means a short vowel
        """

        self.quality = quality
        self.f1_erb = hz_to_erb(f1)
        self.f2_erb = hz_to_erb(f2)
        self.long_vowel = input_length

        # this just happens to be a fact about Czech, /iː/ is only 1.3 times as long
        if self.long_vowel == True:
            if not self.quality == "I":
                length = 17
            else:
                length = 13
        else:
            length = 10

        for layer in self.input_layer[0:length]:
            for node in layer:
                node.set_clamped_excitation(self.f1_erb, self.f2_erb)

        for layer in self.input_layer[:length]:
            for node in layer:
                node.excitation = 0

        # set the semantic layer
        if not self.long_vowel:
            self.quality = self.quality.lower()

        # first 2 stand for number, next 2 stand for case, last 3 stand for semantic meaning
        semantic_dict = {
            "A":[-2,3.5,-2,3.5,-2,-2,3.5],
            "a":[-2,3.5,3.5,-2,-2,-2,3.5],
            "E":[3.5,-2,-2,3.5,3.5,-2,-2],
            "e":[3.5,-2,3.5,-2,3.5,-2,-2],
            "I":[3.5,-2,-2,3.5,-2,3.5,-2],
            "i":[3.5,-2,3.5,-2,-2,3.5,-2],
            "O":[-2,3.5,-2,3.5,3.5,-2,-2],
            "o":[-2,3.5,3.5,-2,3.5,-2,-2],
            "U":[-2,3.5,-2,3.5,-2,3.5,-2],
            "u":[-2,3.5,3.5,-2,-2,3.5,-2]}

        for semantic_node in range(len(self.semantic_layer)):
            self.semantic_layer[semantic_node].excitation = semantic_dict[self.quality][semantic_node]

    def initial_settling(self, test:bool = False):
        """
        This method implements the initial settling phase
        from Boersma, Chládková & Benders (2022)

        First, activity is spread from both outer layers to the middle layer.
        Then, activity spreads up from the middle layer to the top layer.
        Activity spreading is done by taking the sum of the incoming weights * activations
        and adding a bias. This sum is then put through the logistic function to determine
        the new activity level of the deep layer.

        the test variable is only for debugging
        """
        for node in self.middle_layer:
            node.excitation = 0

        for node in self.top_layer:
            node.excitation = 0

        # update the middle layer first, spreading from both directions
        for node in range(len(self.middle_layer)):
            excitation_sum = self.middle_layer[node].bias

            # add from below
            for layer in range(len(self.middle_layer[node].weights)):
                for weight in range(len(self.middle_layer[node].weights[layer])):
                    excitation_sum += self.middle_layer[node].weights[layer][weight] * self.input_layer[layer][weight].excitation

            for semantic_node in range(len(self.middle_layer[node].semantic_weights)):
                excitation_sum += self.middle_layer[node].semantic_weights[semantic_node] * self.semantic_layer[semantic_node].excitation

            # add from above
            for upper_node in self.top_layer:
                excitation_sum += upper_node.excitation * upper_node.weights[node]

            # apply logistic function
            self.middle_layer[node].excitation = logistic_function(excitation_sum)

            if math.isinf(excitation_sum) or math.isnan(excitation_sum):

                for layer in range(len(self.middle_layer[node].weights)):
                    for weight in range(len(self.middle_layer[node].weights[layer])):
                        print(layer, weight, self.middle_layer[node].weights[layer][weight],  self.input_layer[layer][weight].excitation)

                for semantic_node in range(len(self.middle_layer[node].semantic_weights)):
                    print(semantic_node, self.middle_layer[node].semantic_weights[semantic_node], self.semantic_layer[semantic_node].excitation)

                # add from above
                for upper_node in self.top_layer:
                    print(upper_node, upper_node.excitation, upper_node.weights[node])
                raise Exception(f"It went wrong at the middle layer {node}")



        # update the top layer ...
        for upper_node in self.top_layer:
            excitation_sum = upper_node.bias

            # by spreading from the middle
            for node in range(len(self.middle_layer)):
                excitation_sum += self.middle_layer[node].excitation * upper_node.weights[node]

            # apply logistic function
            upper_node.excitation = logistic_function(excitation_sum)

            if math.isinf(excitation_sum) or math.isnan(excitation_sum):
                raise Exception("It went wrong at the upper layer")

    def hebbian_learning(self, test:bool = False):
        """
        Update each weight and each bias by adding the relevant excitation * eta,
        which is defined as 0.001 following Boersma, Chládková & Benders (2022

        the test variable is only for debugging
        """
        eta = 0.001

        # change the bias at the input layer
        for input_layer in self.input_layer:
            for input_node in input_layer:
                input_node.bias += eta * input_node.excitation

        for n in self.semantic_layer:
            n.bias += eta * n.excitation

        # change the bias at the middle layer
        for node in self.middle_layer:
            node.bias += eta * node.excitation

            # change the weights between input and middle layer
            for layer in range(len(node.weights)):
                for weight in range(len(node.weights[layer])):
                    node.weights[layer][weight] += eta * node.excitation * self.input_layer[layer][weight].excitation

            for semantic_weight in range(len(node.semantic_weights)):
                node.semantic_weights[semantic_weight] += eta * node.excitation * self.semantic_layer[semantic_weight].excitation

        for upper_node in self.top_layer:
            # change the bias at the top layer
            upper_node.bias += eta * upper_node.excitation

            # change the weights between middle and top layer
            for weight in range(len(upper_node.weights)):
                upper_node.weights[weight] += eta * upper_node.excitation * self.middle_layer[weight].excitation

    def antihebbian_learning(self):
        """
        The exact same procedure as described above, except it's minus eta instead of plus.
        """
        eta = 0.001

        # change the bias at the input layer
        for input_layer in self.input_layer:
            for input_node in input_layer:
                input_node.bias -= eta * input_node.excitation

        for n in self.semantic_layer:
            n.bias -= eta * n.excitation

        for node in self.middle_layer:
            # change the bias at the middle layer
            node.bias -= eta * node.excitation

            # change the weights between input and middle layer
            for layer in range(len(node.weights)):
                for weight in range(len(node.weights[layer])):
                    node.weights[layer][weight] -= eta * node.excitation * self.input_layer[layer][weight].excitation


            for semantic_weight in range(len(node.semantic_weights)):
                node.semantic_weights[semantic_weight] -= eta * node.excitation * self.semantic_layer[semantic_weight].excitation

        for upper_node in self.top_layer:
            # change the bias at the top layer
            upper_node.bias -= eta * upper_node.excitation

            # change the weights between middle and top layer
            for weight in range(len(upper_node.weights)):
                upper_node.weights[weight] -= eta * upper_node.excitation * self.middle_layer[weight].excitation

    def dreaming(self):
        """
        Implements the dreaming procedure as described in
        Boersma, Chládková & Benders (2022). First, activity is spread towards the bottom nodes.
        Then, using a Bernoulli distribution, some randomness is introduced into the network.
        """
        # spread activity to the input layer
        for input_layer in range(len(self.input_layer)):
            for input_node in range(len(self.input_layer[input_layer])):
                excitation_sum = self.input_layer[input_layer][input_node].bias
                for node in range(len(self.middle_layer)):
                    excitation_sum += self.middle_layer[node].weights[input_layer][input_node] * self.middle_layer[node].excitation

                self.input_layer[input_layer][input_node].excitation = excitation_sum

        # spread to the semantic nodes too?
        for semantic_node in range(len(self.semantic_layer)):
            excitation_sum = self.semantic_layer[semantic_node].bias
            for node in range(len(self.middle_layer)):
                excitation_sum += self.middle_layer[node].semantic_weights[semantic_node] * self.middle_layer[node].excitation

            self.semantic_layer[semantic_node].excitation = excitation_sum

        # use the bernoulli distribution to recalculate excitation at the top layer
        for upper_node in self.top_layer:
            excitation_sum = upper_node.bias

            # by spreading from the middle
            for node in range(len(self.middle_layer)):
                excitation_sum += self.middle_layer[node].excitation * upper_node.weights[node]

            # apply logistic function
            bernoulli_chance = logistic_function(excitation_sum)

            # apply Bernoulli distribution
            upper_node.excitation = 0 if random.random() >= bernoulli_chance else 1


        # use the bernoulli distribution to recalculate excitation at the top layer from both directions
        for node in range(len(self.middle_layer)):
            excitation_sum = self.middle_layer[node].bias

            # add from below
            for layer in range(len(self.middle_layer[node].weights)):
                for weight in range(len(self.middle_layer[node].weights[layer])):
                    excitation_sum += self.middle_layer[node].weights[layer][weight] * self.input_layer[layer][weight].excitation

                for semantic_node in range(len(self.middle_layer[node].semantic_weights)):
                    excitation_sum += self.middle_layer[node].semantic_weights[semantic_node] * self.semantic_layer[semantic_node].excitation

            # add from above
            for upper_node in self.top_layer:
                excitation_sum += upper_node.excitation * upper_node.weights[node]

            # apply logistic function
            bernoulli_chance = logistic_function(excitation_sum)
            self.middle_layer[node].excitation = 0 if random.random() >= bernoulli_chance else 1


model = model()
model.run(3000)
model.obtain_results()
