From 4d6538af00fed8ad1bed4e331726496a2f105e40 Mon Sep 17 00:00:00 2001 From: HombreLaser Date: Sat, 21 Oct 2023 01:07:38 -0600 Subject: Add more methods --- main.py | 5 ++++- neural_network.py | 65 ++++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index a44fbcb..f4b1ede 100644 --- a/main.py +++ b/main.py @@ -9,8 +9,11 @@ def main(): data = Dataset() image = data.get_image('А/5a2f3c19c27bb.png') neural_network = NeuralNetwork(LEARNING_RATE, INPUT_RESOLUTION) + neural_network.load('hidden_weights_3e1064eab32018b3.csv', + 'output_weights_2406f3eb22111fe9.csv') + # Hey! - print(neural_network.guess(image)) + print('This is it!') if __name__ == "__main__": diff --git a/neural_network.py b/neural_network.py index 1408a91..6c44ddb 100644 --- a/neural_network.py +++ b/neural_network.py @@ -1,7 +1,7 @@ import numpy as np import math from scipy.special import expit -from dataset import Dataset +from secrets import token_hex CYRILLIC_ALPHABET = ['I', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ë', 'Ж', 'З', 'И', 'Й', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', @@ -39,21 +39,40 @@ class NeuralNetwork: input_image (a path). """ def guess(self, input_image: np.array) -> str: - output_layer = self._feedforward(input_image) + output_layer = self.feedforward(input_image) - return self.guessed_char(output_layer) + return self._guessed_char(output_layer) + + """ + Feedforwarding. + """ + def feedforward(self, input_layer: np.array): + hidden_layer_inputs = np.dot(self._hidden_weights, input_layer) + hidden_layer_outputs = self._get_layer_output(hidden_layer_inputs) + output_layer_inputs = np.dot(hidden_layer_outputs, + self._output_weights) + + # The output layer outputs. (Final output of the neural network). + return self._get_layer_output(output_layer_inputs) """ Save the weights to a csv file. """ - def save(self, weights_filename): - pass + def save(self): + np.savetxt(f"./hidden_weights_{token_hex(8)}.csv", + self._hidden_weights, delimiter=',') + np.savetxt(f"./output_weights_{token_hex(8)}.csv", + self._output_weights, delimiter=',') """ Load the weights from a csv file. """ - def load(self, weights_file: str): - pass + def load(self, hidden_weights_file: str, output_weights_file: str): + with open(hidden_weights_file) as hidden_weights: + self._hidden_weights = np.loadtxt(hidden_weights, delimiter=',') + + with open(output_weights_file) as output_weights: + self._output_weights = np.loadtxt(output_weights, delimiter=',') """ Get the result from a sigmoid matrix (the index with the highest chance @@ -63,22 +82,32 @@ class NeuralNetwork: return CYRILLIC_ALPHABET[np.argmax(np.transpose(output_layer))] """ - Feedforwarding. + Apply the sigmoid function to a given layer """ - def _feedforward(self, input_layer: np.array): - hidden_layer_inputs = np.dot(self._hidden_weights, input_layer) - hidden_layer_outputs = self._get_layer_output(hidden_layer_inputs) - output_layer_inputs = np.dot(hidden_layer_outputs, - self._output_weights) + def _get_layer_output(self, layer: np.array) -> np.array: + return expit(layer) - # The output layer outputs. (Final output of the neural network). - return self._get_layer_output(output_layer_inputs) + """ + Get the hidden layer and output layer error matrices. + """ + def _get_errors(self, target: str) -> tuple: + output_layer_errors = np.substract(self._get_expected_outputs(target), + self._output_layer) + # Backpropagate the errors. + hidden_layer_errors = np.dot(np.transpose(self._hidden_weights), + output_layer_errors) + + return (hidden_layer_errors, output_layer_errors) """ - Apply the sigmoid function to a given layer + Given a cyrillic letter, get the target outputs. """ - def _get_layer_output(self, layer: np.array): - return expit(layer) + def _get_expected_outputs(self, target: str) -> np.array: + index = CYRILLIC_ALPHABET.index(target) + expected_outputs = np.zeros(len(CYRILLIC_ALPHABET), dtype=np.int8) + expected_outputs[index] = 1 + + return expected_outputs """ Generate a random array via an uniform distribution. -- cgit v1.2.3