import os, numpy as np, matplotlib.pyplot as plt

if not os.path.isfile("train-images-idx3-ubyte"):
    assert 0 == os.system("wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz")
    assert 0 == os.system("wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz")
    assert 0 == os.system("gunzip train-images-idx3-ubyte.gz")
    assert 0 == os.system("gunzip train-labels-idx1-ubyte.gz")

images = np.fromfile("train-images-idx3-ubyte", dtype=np.uint8)[16:].reshape(60000, 28 * 28)
labels = np.fromfile("train-labels-idx1-ubyte", dtype=np.uint8)[8:]

class LinearLayer:
    def __init__(self, n_in, n_out):
        self.W = 0.01 * np.random.randn(n_in, n_out)
    
    def forward(self, X):
        self.X = X
        return X @ self.W
    
    def backward(self, g, learning_rate=0.03):
        dW = self.X.T @ g
        g = g @ self.W.T
        self.W -= learning_rate * dW
        return g

class Relu:
    def forward(self, X):
        self.X = X
        return np.maximum(0, X)
    
    def backward(self, g):
        return g * (self.X > 0)

def mean_squared_error(X, Y):
    delta = X - Y
    return 0.5 * np.mean(np.square(delta)), delta / np.product(delta.shape)

def main():
    layers = [
        LinearLayer(28 * 28, 512), Relu(),
        LinearLayer(512, 256), Relu(),
        LinearLayer(256, 10),
    ]

    for _ in range(10000):
        indices = np.random.randint(len(labels), size=100)
        X = images[indices]
        Y = np.eye(10)[labels[indices]]
        
        for layer in layers:
            X = layer.forward(X)
        
        loss, gradient = mean_squared_error(X, Y)
        
        for layer in reversed(layers):
            gradient = layer.backward(gradient)
        
        accuracy = np.mean(np.argmax(X, axis=1) == np.argmax(Y, axis=1))
        
        print("accuracy: %5.2f, loss: %10.5f" % (accuracy, loss))

if __name__ == "__main__":
    main()