@tensorflow-models/knn-classifier
TypeScript icon, indicating that this package has built-in type declarations

1.2.6 • Public • Published

KNN Classifier

This package provides a utility for creating a classifier using the K-Nearest Neighbors algorithm.

This package is different from the other packages in this repository in that it doesn't provide a model with weights, but rather a utility for constructing a KNN model using activations from another model or any other tensors you can associate with a class/label.

You can see example code here.

Usage example

via Script Tag
<html>
  <head>
    <!-- Load TensorFlow.js -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <!-- Load MobileNet -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
    <!-- Load KNN Classifier -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>
 </head>

  <body>
    <img id='class0' src='/images/class0.jpg '/>
    <img id='class1' src='/images/class1.jpg '/>
    <img id='test' src='/images/test.jpg '/>
  </body>
  <!-- Place your code in the script tag below. You can also use an external .js file -->
  <script>

    const init = async function() {
      // Create the classifier.
      const classifier = knnClassifier.create();

      // Load mobilenet.
      const mobilenetModule = await mobilenet.load();

      // Add MobileNet activations to the model repeatedly for all classes.
      const img0 = tf.browser.fromPixels(document.getElementById('class0'));
      const logits0 = mobilenetModule.infer(img0, true);
      classifier.addExample(logits0, 0);

      const img1 = tf.browser.fromPixels(document.getElementById('class1'));
      const logits1 = mobilenetModule.infer(img1, true);
      classifier.addExample(logits1, 1);

      // Make a prediction.
      const x = tf.browser.fromPixels(document.getElementById('test'));
      const xlogits = mobilenetModule.infer(x, true);
      console.log('Predictions:');
      const result = await classifier.predictClass(xlogits);
      console.log(result);
    }

    init();

  </script>
</html>
via NPM
const tf = require('@tensorflow/tfjs');
const mobilenetModule = require('@tensorflow-models/mobilenet');
const knnClassifier = require('@tensorflow-models/knn-classifier');

// Create the classifier.
const classifier = knnClassifier.create();

// Load mobilenet.
const mobilenet = await mobilenetModule.load();

// Add MobileNet activations to the model repeatedly for all classes.
const img0 = tf.browser.fromPixels(document.getElementById('class0'));
const logits0 = mobilenet.infer(img0, true);
classifier.addExample(logits0, 0);

const img1 = tf.browser.fromPixels(document.getElementById('class1'));
const logits1 = mobilenet.infer(img1, true);
classifier.addExample(logits1, 1);

// Make a prediction.
const x = tf.browser.fromPixels(document.getElementById('test'));
const xlogits = mobilenet.infer(x, true);
console.log('Predictions:');
console.log(classifier.predictClass(xlogits));

API

Creating a classifier

knnClassifier is the module name, which is automatically included when you use the <script src> method.

classifier = knnClassifier.create()

Returns a KNNImageClassifier.

Adding examples

classifier.addExample(
  example: tf.Tensor,
  label: number|string
): void;

Args:

  • example: An example to add to the dataset, usually an activation from another model.
  • label: The label (class name) of the example.

Making a prediction

classifier.predictClass(
  input: tf.Tensor,
  k = 3
): Promise<{label: string, classIndex: number, confidences: {[classId: number]: number}}>;

Args:

  • input: An example to make a prediction on, usually an activation from another model.
  • k: The K value to use in K-nearest neighbors. The algorithm will first find the K nearest examples from those it was previously shown, and then choose the class that appears the most as the final prediction for the input example. Defaults to 3. If examples < k, k = examples.

Returns an object where:

  • label: the label (class name) with the most confidence.
  • classIndex: the 0-based index of the class (for backwards compatibility).
  • confidences: maps each label to their confidence score.

Misc

Clear all examples for a class.
classifier.clearClass(label: number|string)

Args:

  • label: The label to clear all examples for.
Clear all examples from all classes
classifier.clearAllClasses()
Get the example count for each class
classifier.getClassExampleCount(): {[label: string]: number}

Returns an object that maps label name to example count for that label.

Get the full dataset, useful for saving state.
classifier.getClassifierDataset(): {[label: string]: Tensor2D}
Set the full dataset, useful for restoring state.
classifier.setClassifierDataset(dataset: {[label: string]: Tensor2D})

Args:

  • dataset: The label dataset matrices map. Can be retrieved from getClassifierDataset. Useful for restoring state.
Get the total number of classes
classifier.getNumClasses(): number
Dispose the classifier and all internal state

Clears up WebGL memory. Useful if you no longer need the classifier in your application.

classifier.dispose()

Readme

Keywords

none

Package Sidebar

Install

npm i @tensorflow-models/knn-classifier

Weekly Downloads

1,392

Version

1.2.6

License

Apache-2.0

Unpacked Size

94.5 kB

Total Files

20

Last publish

Collaborators

  • fengwuyao
  • linchan
  • caisq
  • pyu10055
  • annxingyuan
  • mattsoulanille
  • linazhao128
  • jinjingforever