summaryrefslogtreecommitdiff
path: root/dataset.py
blob: e81e7ec5aab932e38745f9f6314eb00dc958048d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
from PIL import Image
from pathlib import Path
from alphabet import CYRILLIC_ALPHABET
import random

"""Class to interface the training and testing data."""

DATASET_SIZE=15480

class Dataset:
    def __init__(self) -> None:
        self.data_path = Path('./data')
        self.already_used = set()

    """
    Yield a random sample of the dataset with each call.
    """
    def data(self, batch_size=DATASET_SIZE):
        for i in range(batch_size):
            random_letter = random.choice(CYRILLIC_ALPHABET)
            images = list((self.data_path/random_letter).glob('*.png'))
            file_to_yield = random.choice(images).name

            if file_to_yield in self.already_used:
                continue
            
            self.already_used.add(file_to_yield)
            image = Image.open(str(self.data_path/random_letter/file_to_yield))
            image_array = self._img_to_array(image)
            
            yield (random_letter, image_array)

    """
    Get an image from the dataset.
    """
    def get_image(self, path: str):
        image = Image.open(f"{self.data_path}/{path}")

        return self._img_to_array(image)

    def get_random_sample(self):
        pass

    """
    Grab the image in RGB, add a white background, and return it as
    a black and white array.
    """
    def _img_to_array(self, image):
        fill_color = (255, 255, 255)  # White background.
        background = Image.new(image.mode[:-1], image.size, fill_color)
        background.paste(image, image.split()[-1])

        return np.asarray(background.convert(mode='1'))