summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--neural_network.py36
1 files changed, 34 insertions, 2 deletions
diff --git a/neural_network.py b/neural_network.py
index 81361e5..827b066 100644
--- a/neural_network.py
+++ b/neural_network.py
@@ -1,4 +1,4 @@
-import numpy
+import numpy as np
from utils import random_array
CYRILLIC_ALPHABET = ['I', 'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ë', 'Ж', 'З',
@@ -15,4 +15,36 @@ class NeuralNetwork:
self.hidden_layer_size = round((self.input_layer_size + self.output_layer_size) / 2)
self._hidden_weights = random_array(self.hidden_layer_size, self.input_layer_size)
self._output_weights = random_array(self.output_layer_size, self.hidden_layer_size)
-
+
+
+ """
+ Train the neural network. It loads the dataset contained in ./data,
+ converts each image into a numpy array and uses that data for training.
+ """
+ def train():
+ pass
+
+ """
+ Guess the letter contained in the image file pointed by
+ input_image (a path).
+ """
+ def guess(input_image: str) -> str:
+ pass
+
+ """
+ Save the weights to a csv file.
+ """
+ def save():
+ pass
+
+ """
+ Load the weights from a csv file.
+ """
+ def load(weights_file: str):
+ pass
+
+ """
+ Feedforwarding.
+ """
+ def _predict(input_layer: np.array):
+ pass