brain.js
About
brain.js
is a library of Neural Networks written in JavaScript.
💡 Note: This is a continuation of the harthur/brain repository (which is not maintained anymore). For more details, check out this issue.
Table of Contents
- Examples
- Usage
- Training
- Methods
- Failing
- JSON
- Standalone Function
- Options
- Streams
- Utilities
- Neural Network Types
Examples
Here's an example showcasing how to approximate the XOR function using brain.js
:
more info on config here.
// provide optional config object (or undefined). Defaults shown.const config = binaryThresh: 05 hiddenLayers: 3 // array of ints for the sizes of the hidden layers in the network activation: 'sigmoid' // supported activation types: ['sigmoid', 'relu', 'leaky-relu', 'tanh'], leakyReluAlpha: 001 // supported for activation type 'leaky-relu'; // create a simple feed forward neural network with backpropagationconst net = config; net; const output = net; // [0.987]
or more info on config here.
// provide optional config object, defaults shown.const config = inputSize: 20 inputRange: 20 hiddenLayers: 2020 outputSize: 20 learningRate: 001 decayRate: 0999; // create a simple recurrent neural networkconst net = config; net; const output = net; // [0]output = net; // [1]output = net; // [1]output = net; // [0]
However, there is no reason to use a neural network to figure out XOR. (-: So, here is a more involved, realistic example: Demo: training a neural network to recognize color contrast.
More Examples
You can check out this fantastic screencast, which explains how to train a simple neural network using a real world dataset: How to create a neural network in the browser using Brain.js.
- writing a children's book using a recurrent neural network & typescript version
- using cross validation with a feed forward net & typescript version
- experimental (NeuralNetwork only, but more to come!) using the gpu in a browser or using node gpu fallback to cpu & typescript version
- learning math using a recurrent neural network & typescript version
- predict next number, and forecast numbers & typescript version
- using node streams & typescript version
- simple letter detection & typescript version
Usage
Node
If you have node, you can install brain.js
with npm:
npm install brain.js
Or if you prefer yarn:
yarn add brain.js
At present, the published version of brain.js is approximately 1.0.0, featuring only Feed-forward NN. All other models are beta and are being jazzed up and battle hardened. You can still download the latest, though. They are cool!
Browser
Download the latest brain.js for browser. Training is computationally expensive, so you should try to train the network offline (or on a Worker) and use the toFunction()
or toJSON()
options to plug the pre-trained network into your website.
Training
Use train()
to train the network with an array of training data. The network has to be trained with all the data in bulk in one call to train()
. More training patterns will probably take longer to train, but will usually result in a network better
at classifying new patterns.
Data format
NeuralNetwork
For training with Each training pattern should have an input
and an output
, both of which can be either an array of numbers from 0
to 1
or a hash of numbers from 0
to 1
. For the color contrast demo it looks something like this:
const net = ; net; const output = net; // { white: 0.99, black: 0.002 }
Here's another variation of the above example. (Note that input objects do not need to be similar.)
net; const output = net; // { white: 0.81, black: 0.18 }
RNNTimeStep
, LSTMTimeStep
and GRUTimeStep
For training with Each training pattern can either:
- Be an array of numbers
- Be an array of arrays of numbers
Example using an array of numbers:
const net = ; net; const output = net; // 3
Example using an array of arrays of numbers:
const net = inputSize: 2 hiddenLayers: 10 outputSize: 2; net; const output = net; // [3, 1]
RNN
, LSTM
and GRU
For training with Each training pattern can either:
- Be an array of values
- Be a string
- Have an
input
and anoutput
- Either of which can an array of values or a string
CAUTION: When using an array of values, you can use ANY value, however, the values are represented in the neural network by a single input. So the more distinct values has the larger your input layer. If you have a hundreds, thousands, or millions of floating point values THIS IS NOT THE RIGHT CLASS FOR THE JOB. Also, when deviating from strings, this gets into beta
Example using direct strings:
const net = ; net; const output = net; // ', a deer, a female deer'
Example using strings with inputs and outputs:
const net = ; net; const output = net; // 'happy'
Training Options
train()
takes a hash of options as its second argument:
net;
The network will stop training whenever one of the two criteria is met: the training error has gone below the threshold (default 0.005
), or the max number of iterations (default 20000
) has been reached.
By default training will not let you know how it's doing until the end, but set log
to true
to get periodic updates on the current training error of the network. The training error should decrease every time. The updates will be printed to console. If you set log
to a function, this function will be called with the updates instead of printing to the console.
The learning rate is a parameter that influences how quickly the network trains. It's a number from 0
to 1
. If the learning rate is close to 0
, it will take longer to train. If the learning rate is closer to 1
, it will train faster, but training results may be constrained to a local minimum and perform badly on new data.(Overfitting) The default learning rate is 0.3
.
The momentum is similar to learning rate, expecting a value from 0
to 1
as well, but it is multiplied against the next level's change value. The default value is 0.1
Any of these training options can be passed into the constructor or passed into the updateTrainingOptions(opts)
method and they will be saved on the network and used during the training time. If you save your network to json, these training options are saved and restored as well (except for callback and log, callback will be forgotten and log will be restored using console.log).
A boolean property called invalidTrainOptsShouldThrow
is set to true
by default. While the option is true
, if you enter a training option that is outside the normal range, an error will be thrown with a message about the abnormal option. When the option is set to false
, no error will be sent, but a message will still be sent to console.warn
with the related information.
Async Training
trainAsync()
takes the same arguments as train (data and options). Instead of returning the results object from training, it returns a promise that when resolved will return the training results object.
const net = ; net ;
With multiple networks you can train in parallel like this:
const net = ; const net2 = ; const p1 = net; const p2 = net2; Promise allp1 p2 ;
Cross Validation
Cross Validation can provide a less fragile way of training on larger data sets. The brain.js api provides Cross Validation in this example:
const crossValidate = brainNeuralNetwork networkOptions;crossValidate; //note k (or KFolds) is optionalconst json = crossValidate; // all stats in json as well as neural networksconst net = crossValidate; // get top performing net out of `crossValidate` // optionally laterconst json = crossValidate;const net = crossValidate;
Use CrossValidate
with these classes:
brain.NeuralNetwork
brain.RNNTimeStep
brain.LSTMTimeStep
brain.GRUTimeStep
An example of using cross validate can be found in examples/cross-validate.js
Train Stream
Streams are a very powerful tool in node for massive data spread across processes and are provided via the brain.js api in the following way:
const net = ;const trainStream = neuralNetwork: net { ; } { // network is done training! What next? }; // kick it off; { for let i = 0; i < datalength; i++ stream; // let it know we've reached the end of the inputs stream;}
An example of using train stream can be found in examples/stream-example.js
Methods
train(trainingData)
-> trainingStatus
The output of train()
is a hash of information about how the training went:
error: 00039139985510105032 // training error iterations: 406 // training iterations
run(input)
-> prediction
Supported on classes:
brain.NeuralNetwork
brain.NeuralNetworkGPU
-> All the functionality ofbrain.NeuralNetwork
but, ran on GPU (via gpu.js in WebGL2, WebGL1, or fallback to CPU)brain.recurrent.RNN
brain.recurrent.LSTM
brain.recurrent.GRU
brain.recurrent.RNNTimeStep
brain.recurrent.LSTMTimeStep
brain.recurrent.GRUTimeStep
Example:
// feed forwardconst net = ;net;net; // time stepconst net = ;net;net; // recurrentconst net = ;net;net;
forecast(input, count)
-> predictions
Available with the following classes. Outputs a array of predictions. Predictions being a continuation of the inputs.
brain.recurrent.RNNTimeStep
brain.recurrent.LSTMTimeStep
brain.recurrent.GRUTimeStep
Example:
const net = ;net;net;
toJSON() -> json
Serialize neural network to json
fromJSON(json)
Deserialize neural network from json
Failing
If the network failed to train, the error will be above the error threshold. This could happen if the training data is too noisy (most likely), the network does not have enough hidden layers or nodes to handle the complexity of the data, or it has not been trained for enough iterations.
If the training error is still something huge like 0.4
after 20000 iterations, it's a good sign that the network can't make sense of the given data.
RNN, LSTM, or GRU Output too short or too long
The instance of the net's property maxPredictionLength
(default 100) can be set to adjust the output of the net;
Example:
const net = ; // later in code, after training on a few novels, write me a new one!netmaxPredictionLength = 1000000000; // Be careful!net;
JSON
Serialize or load in the state of a trained network with JSON:
const json = net;net;
Standalone Function
You can also get a custom standalone function from a trained network that acts just like run()
:
const run = net;const output = ;console; // copy and paste! no need to import brain.js
Options
NeuralNetwork()
takes a hash of options:
const net = activation: 'sigmoid' // activation function hiddenLayers: 4 learningRate: 06 // global learning rate, useful when training using streams;
activation
This parameter lets you specify which activation function your neural network should use. There are currently four supported activation functions, sigmoid being the default:
- sigmoid
- relu
- leaky-relu
- related option - 'leakyReluAlpha' optional number, defaults to 0.01
- tanh
Here's a table (thanks, Wikipedia!) summarizing a plethora of activation functions — Activation Function
hiddenLayers
You can use this to specify the number of hidden layers in the network and the size of each layer. For example, if you want two hidden layers - the first with 3 nodes and the second with 4 nodes, you'd give:
hiddenLayers: [3, 4]
By default brain.js
uses one hidden layer with size proportionate to the size of the input array.
Streams
The network now has a WriteStream. You can train the network by using pipe()
to send the training data to the network.
Example
Refer to stream-example.js
for an example on how to train the network with a stream.
Initialization
To train the network using a stream you must first create the stream by calling net.createTrainStream()
which takes the following options:
floodCallback()
- the callback function to re-populate the stream. This gets called on every training iteration.doneTrainingCallback(info)
- the callback function to execute when the network is done training. Theinfo
param will contain a hash of information about how the training went:
error: 00039139985510105032 // training error iterations: 406 // training iterations
Transform
Use a Transform to coerce the data into the correct format. You might also use a Transform stream to normalize your data on the fly.
Utilities
likely
const likely = ;const key = ;
Likely example see: simple letter detection
toSVG
<script src="../../src/utilities/svg.min.js"></script>
Renders the network topology of a feedforward network
documentinnerHTML = brainutilitiestoSVGnetworkoptions)
toSVG example see: network rendering
The user interface used:
Neural Network Types
brain.NeuralNetwork
- Feedforward Neural Network with backpropagationbrain.NeuralNetworkGPU
- Feedforward Neural Network with backpropagation, GPU versionbrain.recurrent.RNNTimeStep
- Time Step Recurrent Neural Network or "RNN"brain.recurrent.LSTMTimeStep
- Time Step Long Short Term Memory Neural Network or "LSTM"brain.recurrent.GRUTimeStep
- Time Step Gated Recurrent Unit or "GRU"brain.recurrent.RNN
- Recurrent Neural Network or "RNN"brain.recurrent.LSTM
- Long Short Term Memory Neural Network or "LSTM"brain.recurrent.GRU
- Gated Recurrent Unit or "GRU"
Why different Neural Network Types?
Different neural nets do different things well. For example:
- A Feedforward Neural Network can classify simple things very well, but it has no memory of previous actions and has infinite variation of results.
- A Time Step Recurrent Neural Network remembers, and can predict future values.
- A Recurrent Neural Network remembers, and has a finite set of results.
Get Involved!
Issues
If you have an issue, either a bug or a feature you think would benefit your project let us know and we will do our best.
Create issues here and follow the template.
Contributors
This project exists thanks to all the people who contribute. [Contribute].
Backers
Thank you to all our backers! 🙏 [Become a backer]
Sponsors
Support this project by becoming a sponsor. Your logo will show up here with a link to your website. [Become a sponsor]