Neural Network Library

From SolarStrike wiki
Jump to: navigation, search

What Neural Networks Can Do

How It Works



Including the library in your project

Before you can use the neural network library, you must tell your project to include the necessary files. To do so, you can simply use the require call, like so:


It is highly recommended to do this at the top of your main project file (probably main.lua).

Creating an instance of a neural network

Before you can create an instance of a neural network, you will need to know what your network topology looks like. That is, how many inputs, how many outputs, how many hidden layers you would like, and how many neurons in each hidden layer. Hidden layers will generally have a number of neurons between the number of inputs and number of outputs. You generally only need one hidden layer for simple tasks and two for more complex tasks. For example, if you have 10 inputs and 2 outputs, you might consider one hidden layer of 4-6 neurons and may want to try experimenting to see which gives the best results.

class NeuralNet(...)

You may either pass a table of numbers or a variable number of number arguments to the constructor in order to create a network with that topology. The first argument (or table value) will be the number of inputs, the second argument will be the number of neurons in the first hidden layer, you may then provide additional data for extra hidden layers, and the final argument will be the number of outputs.

For example:

net = NeuralNet(10, 6, 2); -- Create a neural network with 10 inputs, one hidden layer with 6 neurons, and 2 outputs.
net2 = NeuralNet(32, 24, 16, 8); -- 32 inputs, 24 neurons in hidden layer #1, 16 neurons in hidden layer #2, and 8 outputs.

local topology = {6, 3, 1}; -- Alternatively we can pass the topology as a table
net3 = NeuralNet(topology);



This function is used feed' data into the network's inputs and perform calculations on it. You will use this both while training and for getting results once training is complete. Again, this function accepts either a variable number of arguments or table of values.

The number of values that you pass in must be exactly the same number of inputs for the network.

net = NeuralNet(2, 2, 1);
net:feed(0.0, 1.0); -- Feed two values, 0.0 and 1.0, into the networks inputs.


Back Propagation is one method of training neural networks. In this type of training, you will feed data to a network and then "back propagate" the known, expected result. With each back propagation, the neural network will continue to rearrange its hidden parts to understand what it is expected to do and, therefor, give better results. Note that back propagating once or twice will be meaningless; you will almost certainly need to provide thousands of training samples.

You may pass values as a table or variable number of arguments, they must be numbers, and you must provide the exact number of values as are outputs for the network.

local a = 0.25;
local b = 0.75;
local expectedResult = a * b;

net = NeuralNet(2, 2, 1); -- Our neural network is created, but is filled with randomness and won't provide any good results.
net:feed(a, b); -- Pass in our two inputs and process them, but since our network isn't trained the results are garbage
net:backPropagation(expectedResult); -- Slightly adjust network. It's at least a tiny bit smarter now!



table NeuralNet:getResults()

Returns the results (values in the output layer) as a table. This is what you will use to get data out of a network after feeding data into the network.



table NeuralNet:getExportTable()

Returns a table containing neuron weights for the whole network. You almost certainly won't want to have to re-train your network every time you want to use it, so this function helps with saving all necessary data to restore the network to a trained state.

You probably should just use NeuralNet:save() to save it directly to a file.



Saves the neuron weights to a file. Use this after training so that you can re-load it later and not have to start over with training your network.



Loads the neuron weights back into a network from a file that was saved with NeuralNet:save()


Calculating xor

In this example, we will create a neural network and train it to calculate for xor. Xor (eXclusive Or) simply means that input 1 OR input 2 must be true(1) but not both at the same time. 1 xor 0 is true(1) 0 xor 1 is true(1) 1 xor 1 is false(0) 0 xor 0 is false(0)

Since Lua already provides a way to do this, a neural network is way overkill and computationally expensive in comparison, however it should make a good example.


function macro.init()
	net = NeuralNet(2, 2, 1);

	for i = 1,1000 do
		local a = math.random(0, 1);
		local b = math.random(0, 1);

		net:feed(a, b);

		local result = a ~ b;		-- ~ is the bitwise 'xor' operation
		net:backPropagation(result);	-- Train on this data for the given input

	-- Our network should be trained well enough, so lets see if it works
	local a = math.random(0, 1);
	local b = math.random(0, 1);
	local expected = a ~ b;

	net:feed(a, b);
	local results = net:getResults();
	local networkXor = math.floor(results[1] + 0.5); -- Since the network outputs a floating-point number, we round it
	printf("Our neural network thinks %d XOR %d is %d. We expect to get %d\n", a, b, networkXor, expected);