Classifying handwritten digits, from the traditional view of machine learning, using the Mnist dataset as an example, indeed classifying points on a (28*28=784) dimensional space into 10 separate classes, which is not necessarily linear seperateable .
And the neural network we constructed (the most classical one, with a few fully-connected layers), is no difference than the fancier version of this program:
The program above basically works like this: https://github.com/D0048/makeyourownneuralnetwork/blob/master/better_detection/train.py
Imagine every picture in the training data set as an 28cm*28cm steal plate, where the darker areas are higher and the whiter areas are lower, with the elevation from 0cm to 255 cm (since the color value ranges from 0~255), or in another way more friendly to calculations, -127.5cm~+127.5cm. Every steal plate should have an label with it for identification.
Generated image looks somehow like this (generated from http://cpetry.github.io/NormalMap-Online/):
Label: 5
Also, let we create another soft plate made of clay, specifically for the digit '5', where all the initial elevations are 0.
Then, we collect all the steal plates in the training set labeled "5", where the total number of them is marked as 'm'. Smooth each of the steal plates in a scale of 1/m so that all the plates add up to be one plate. For example, the pixel with elevation 125 need to be smoothed into 125/m and the pixel with elevation –123 need to be smoothed into -125/m.
After the steal plates are processed, we push it one by one on the clay model we prepared with matrix subtractions. There need to be a total of 10 plates for each digit.
At last, we pull out an random plate from the training set without reading the label on it, and push it into each of the 10 clay plates we prepare earlier corresponding to each digits, and measure the friction we meet push the steel plate into the clay. The one clay plate with the lowest friction while pushing our test sample inside is supposing to be corresponding to the actual digit represented by the sample.
However, these current clay plates-"models" I call them-does not works well at all, considering the fact the a flat clay plate with elevation of –126 will give an virtually zero friction while applied by any sample and the more trainings are applied on a model, the more likely is the model to become flat and blurry and muddy thus making classification under satisfaction.
One way to ease the problem is to reverse the rest of the training dataset that does not match the model so that the highest peek now become the lowest valley and apply them to this model again. However, this won't address the issue from foundation and could somehow make the model more mushy.
That's the reason where the advantages of the neural network comes in, and why I called the neural network "a fancier version". Basically, a neural network allows us to configure specific weights to each pixel so that the white areas around the actual digit (virtually) no longer contributes to the total friction (called error using former language) and the black pixels shared by multiple digits, instead of a mess of blurry in our setup, contributes to the total friction in a smarter way, which is more like a black box. By making use of the multi-layer structure, we got a really flexible model that allows certain combination of pixels contributing to the final friction in a unified whole. Also, we can adapt universal methods like gradient descent to select the best weights for each neuron.
However, this sounds a little bit strange and anti-intuitive: do we really need to map everything into such a high-dimensional space, in order to just classify 10 different digits? Neural networks seems to be somehow a mimic of brain, but my brain (at least mine) recognizing a digit does not seems to be relying on almost a thousand discrete features of that specific digit, no need to mention that the size of digits in real life could vary vastly according to multiple factors like distance. Do we really need all these features to perform the classification, or can we just first extract less but more pivotal features out the raw image?
After consideration, I suppose this means "narrower" networks, while deeper might serves as an compromise to it. Also, this means we may use multiple networks to work together in a chain, while some of them trying to extract the feature "smartly" from the data, and some others to deal with the final decision.
It's not until later that I read about the Convolution Neural Network, which is similar to the better neural network in my mind, as what descripted earlier. However, this is still not as expected—I expect a model that works more similar to our neural system, where it should be resistant to scaling (Current CNNs are not capable of doing so. A model called spatial transform network claimed to do so, to be researched) and there should not be such huge training set to reach a good performance.
Using handwritten digits as an example, is that possible if we design a network to transform the digits into lines or even to Bézier curve. This way, the scaling problem is resolved. Then, for ever entity, we can extract far more features rather than just pixels: the total number of close areas, the total interceptions… and so on. This way, rather than letting the network treating a digit as an ambiguous picture(honestly I can't even learn how to read digits with some 28*28 pictures), we may actually teach the network, in a more fundamental way, of what is a digit anyway.
Recent new idea above, to be tested.
Add:
Now I have somewhat more understanding on CNNs, and find them really powerful. However, it still somehow lack of resistance to size shift of objects. I have the following idea of improvement to be tested:
Give the lower level features to CNNs, (like "does the sample have handle"), which matches our intuitive understanding, but use logic trees and other traditional machine learning methods to make the higher level decisions (like "is the sample a water bottle"), which matches our comprehensional understanding of objects. This may also prevent the network to use irrelevant features limited by the data in the training set.
Try different and irregular shapes of reception field, rather than just square.
Use another network like RNN to adjust learning rate.
Comments
Travis Rivera
My understanding is that neural nets already determine the key features that are important to the decision. The importance of a given feature is represented by the weight on a particular neuron/input-feature.
So no we don't need every feature. We just need all features relevant to the decision. So some amount of pre-processing can definitely help.
Alto Clef
You are right.
In contrast, without manual optimization, a huge problem is that the neural networks are learning features limited in the training set (like the position of the digit) the do not apply to the testing set.
This makes regular DNNs really prone to position shifting. The CNN model, in my understanding, is really more like a work around, which is manually telling the network to separate two kinds of features from the raw pixels--the actual features and their locations--as feature maps, preventing the network being confused by the features shifting between different location.
However, this also means CNNs make the assumption that all the features are at the same size of the reception field, thus making it prone to shape/size shifting.
If my understanding above is correct, maybe it's possible to design a model, like CNNs, but instead, manually telling the network to separate three features from the raw pixels--features, locations and sizes. This way, maybe it's possible to create an architecture resistant to size/shape/location shifting.
As far as I thought, one possible way of doing that is, instead of treating a picture as (raw pixels in DNN)/(raw feature maps in RNN), treating the picture as vectors or even Bézier curves, thus the features extracted, such as the number of closed areas, are no longer depended to any of the fore-mentioned shifts. However, the actual way of doing it is still under my experiment.
The above are just naive thought from a beginner in machine learning, and I can't help but wanting to express them. If there's any errors and/or there are already existed matured architectures fit my description above, please let me know so I could improve myself…. Thanks a lot : ).