@magenta/sketch
Link to Documentation: tensorflow.github.io/magenta-js/sketch
This JavaScript implementation of Magenta's sketch-rnn model uses TensorFlow.js for GPU-accelerated inference. sketch-rnn
is a recurrent neural network model described in Teaching Machines to Draw and A Neural Representation of Sketch Drawings.
Examples of vector images produced by this generative model.
SketchRNN
This document is an introduction on how to use the Sketch RNN model in JavaScript to generate images. The SketchRNN model is trained on stroke-based vector drawings. The model implementation here is able to handle unconditional (decoder-only) generation of vector images.
For more information, please read original the model description and for the Python TensorFlow implementation.
Getting started
In the .html files, we need to include magentasketch.js
. Our example sketch are built with p5.js and stored in a file such as sketch.js
, so we have also included p5 libraries here too. Please see this minimal example:
<html>
<head>
<script src="https://cdnjs.cloudflare.com/ajax/libs/p5.js/0.7.2/p5.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@magenta/sketch"></script>
<script src="sketch.js"></script>
</head>
<body>
<div id="sketch"></div>
</body>
</html>
Generating a sketch
Below is the essence of how a sketch is generated. In addition to the original paper, a simple tutorial for understanding how RNNs can generate a set of strokes is here.
let model;
let dx, dy; // offsets of the pen strokes, in pixels
let pen_down, pen_up, pen_end; // keep track of whether pen is touching paper
let x, y; // absolute coordinates on the screen of where the pen is
let prev_pen = [1, 0, 0]; // group all p0, p1, p2 together
let rnn_state; // store the hidden states of rnn's neurons
let pdf; // store all the parameters of a mixture-density distribution
let temperature = 0.45; // controls the amount of uncertainty of the model
let line_color;
let model_loaded = false;
// loads the TensorFlow.js version of sketch-rnn model, with the "cat" model's weights.
model = new ms.SketchRNN("https://storage.googleapis.com/quickdraw-models/sketchRNN/models/cat.gen.json");
// code that ensures the above line is run before the below lines are run.
function setup() {
x = windowWidth / 2.0;
y = windowHeight / 3.0;
createCanvas(windowWidth, windowHeight);
frameRate(60);
// Initialize the scale factor for the model. Bigger -> large outputs.
model.setPixelFactor(3.0);
// Initialize pen's states to zero.
[dx, dy, pen_down, pen_up, pen_end] = model.zeroInput(); // The pen's states.
// Zero out the rnn's initial states.
rnn_state = model.zeroState();
// Define color of line.
line_color = color(random(64, 224), random(64, 224), random(64, 224));
};
function draw() {
// See if we finished drawing.
if (prev_pen[2] == 1) {
noLoop(); // Stop drawing.
return;
}
// Using the previous pen states, and hidden state, get next hidden state
// the below line takes the most CPU power, especially for large models.
rnn_state = model.update([dx, dy, pen_down, pen_up, pen_end], rnn_state);
// Get the parameters of the probability distribution (pdf) from hidden state.
pdf = model.getPDF(rnn_state, temperature);
// Sample the next pen's states from our probability distribution.
[dx, dy, pen_down, pen_up, pen_end] = model.sample(pdf);
// Only draw on the paper if the pen is touching the paper.
if (prev_pen[0] == 1) {
stroke(line_color);
strokeWeight(3.0);
line(x, y, x+dx, y+dy); // Draw line connecting prev point to current point.
}
// Update the absolute coordinates from the offsets
x += dx;
y += dy;
// Update the previous pen's state to the current one we just sampled
prev_pen = [pen_down, pen_up, pen_end];
};
Demos
There are several demos available in demos
directory that show how to use the SketchRNN model. You can also view the hosted demos, or run the
examples locally by running yarn run-demos
. This command will first build the library magentasketch.js
from the TypeScript source files, and then launch the server, where you can put in http://127.0.0.1:8080
into your web browser to select the demos.
1) simple.html / simple.js
This demo generates a bird using the model using the example code in the earlier section.
See the simple demo.
2) predict.html / predict.js
This demo attempts to finish the drawing given starting set of strokes (a circle, drawn in red). In this demo, you can also select other classes, like "cat", "ant", "bus", etc. The demo will dynamically load the json files in the models directory but cache previously loaded json models.
See the predict demo.
3) interactive_predict.html / interactive_predict.js
Same as the previous demo, but made to be interactive so the user can draw the beginning of a sketch on the canvas. Similar to the first AI experiment. Hitting restart will clear the current human-entered drawing and start from scratch.
See the interactive predict demo.
Pre-trained models
We have provided around 100 pre-trained sketch-rnn models. We have trained the models with a .gen.json extension.
The models are located in:
https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/category.gen.json
where category is a quickdraw category such as cat, dog, the_mona_lisa etc., Some models are trained on more than one category, such as catpig or crabrabbitfacepig.
i.e.
https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/spider.gen.json
or
https://storage.googleapis.com/quickdraw-models/sketchRNN/large_models/the_mona_lisa.gen.json
A set of smaller models (with LSTM node size = 512 only) are located in:
https://storage.googleapis.com/quickdraw-models/sketchRNN/models/category.gen.json
Here is a list of all the models provided:
Models | ||||
---|---|---|---|---|
alarm_clock | ambulance | angel | ant | antyoga |
backpack | barn | basket | bear | bee |
beeflower | bicycle | bird | book | brain |
bridge | bulldozer | bus | butterfly | cactus |
calendar | castle | cat | catbus | catpig |
chair | couch | crab | crabchair | crabrabbitfacepig |
cruise_ship | diving_board | dog | dogbunny | dolphin |
duck | elephant | elephantpig | eye | face |
fan | fire_hydrant | firetruck | flamingo | flower |
floweryoga | frog | frogsofa | garden | hand |
hedgeberry | hedgehog | helicopter | kangaroo | key |
lantern | lighthouse | lion | lionsheep | lobster |
map | mermaid | monapassport | monkey | mosquito |
octopus | owl | paintbrush | palm_tree | parrot |
passport | peas | penguin | pig | pigsheep |
pineapple | pool | postcard | power_outlet | rabbit |
rabbitturtle | radio | radioface | rain | rhinoceros |
rifle | roller_coaster | sandwich | scorpion | sea_turtle |
sheep | skull | snail | snowflake | speedboat |
spider | squirrel | steak | stove | strawberry |
swan | swing_set | the_mona_lisa | tiger | toothbrush |
toothpaste | tractor | trombone | truck | whale |
windmill | yoga | yogabicycle | everything |
Building the model
The implementation was written in TypeScript and built with the yarn tool:
yarn install
to install dependencies.
yarn build
to compile ts into js
yarn bundle
to produce a bundled version in dist/
.
Train own model
There is a small IPython notebook to show how to quickly train a sketch-rnn model with Python-based TensorFlow model, and convert that model over to the JSON format that can be used by by this model.
Additional Notes
Scale Factors
When training the models, all the offset data has been normalized to have a standard deviation of 1.0 on the training set, after simplifying the strokes. Neural nets work best when training on normalized data. However, the original data recorded with the QuickDraw web app stored everything as pixels, which was scaled down so that on average the stroke offsets are ~ 1.0 length. Thus each dataclass has its own scale_factors
to scale down, and these numbers are usually between 60 to 120 depending on the dataset. These scale factors are stored into model.info.scale_factor
. The model will assume all inputs and outputs to be in pixel space, not normalized space, and will do all the scaling for you. You can modify these in the model directly, but it is not recommended. Rather than overwriting the scale_factor
value, modify the pixel_factor instead, as described in the next paragraph.
If using PaperJS, it is recommended that you leave everything as it is. When using P5.JS, all the recorded data looks much bigger compared to the original app by a factor of exactly 2, and this is likely due to anti-aliasing functionality of web browsers. Hence the extra scaling factor for the model called pixel_factor
. If you want to make interactive apps and receive realtime drawing data from the user, and you are using PaperJS, it is best to set do a model.set_pixel_factor(1.0)
. For p5.js, do a model.set_pixel_factor(2.0)
. For non-interactive applications, using a larger set_pixel_factor
will reduce the size of the generated image.
Line Data vs Stroke Data
Data collected by the original quickdraw app are stored in the below format, which is a list of list of ["x", "y"] pixel points.
[[["x": 123, "y": 456], ["x": 127, "y": 454], ["x": 137, "y": 450], ["x": 147, "y": 440], ...], ...]
The first thing to do is to convert this format into line format, and get rid of the "x" and "y" orderings. In the Line Data format, x always come before y:
Line Data: [[[123, 456], [127, 454], [137, 450], [147, 440], ...], ...]
The model contains helper functions to convert between this formats. This Line Data format must be first simplified using simplify_lines
or simplify_line
(depending if it is a list of polylines or just a single polyline) first. Afterwards, the simplified line will be fed into lines_to_strokes to convert into the Stroke Data format used by the model.
In the Stroke Data format, we assume the drawing starts at the origin, and store only the offset points from the previous location. The format is 2 dimensional, rather than 3 dimensional as in the Line Data format:
Each row of the stroke will be 5 elements:
[dx, dy, p0, p1, p2]
dx, dy
are the offsets in pixels from the previous point.
p0, p1, p2
are binary values, and only one of them will be 1, the other 2 must be 0.
p0 = 1 means the pen stays on the paper at the next stroke.
p1 = 1 means the pen will is now above the paper after this stroke. The next stroke will be the start of a new line.
p2 = 1 means the drawing has stopped. Stop drawing anything!
The drawing will be decomposed into a list of [dx, dy, p0, p1, p2]
strokes.
The mapping from Line Data to Stroke Data will lose the information about the starting position of the drawing, so you may want to record LineData[0][0]
to keep this info.