Skip to content

Tensorflow.js

1. 介绍

123456

2. 机器学习与神经网络简介

a.机器学习简介

7891011121314

b.神经网络简介

15161718

c.神经网络的训练

19202122232425

3. Tensorflow.js

TensorFlow官网

a.Tensorflow.js 简介

262728

寻找Emoji声音命令使用预训练模型进行图片分类使用摄像头控制吃豆人

b.安装Tensorflow.js

29

浏览器安装TensorFlow.js

html
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script>
bash
npm install @tensorflow/tfjs

bash
yarn add @tensorflow/tfjs

30

node安装tensorflow.js

3131

c.为什么要用Tensor

33343536373839

shape中的元素代表每层数组的长度。

40414243444546

javascript
import * as tf from '@tensorflow/tfjs';

// 传统 for 循环
const input = [1, 2, 3, 4];
const w = [[1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]];
const output = [0, 0, 0, 0];

// 1.使用for循环计算
for (let i = 0; i < w.length; i++) {
    for (let j = 0; j < input.length; j++) {
        output[i] += input[j] * w[i][j];
    }
}

console.log(output); // [30, 40, 50, 60]

// 2.使用TensorFlow计算
// 权重点乘输入
tf.tensor(w).dot(tf.tensor(input)).print(); // [30, 40, 50, 60]
// 除了简洁意外,由于使用向量,还方便使用GPU计算加速,更加提高了性能。

// 关于dot语法:https://js.tensorflow.org/api/latest/#dot

47484950

4. 线性回归

a.线性回归任务简介

5152

b.准备可视化训练数据

53

安装tfvis从而使数据可视化:

54

https://js.tensorflow.org/api_vis/1.5.1/

散点图:

555657

javascript
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const xs = [1, 2, 3, 4];
    const ys = [1, 3, 5, 7];

    tfvis.render.scatterplot(
        { name: '线性回归训练集' },
        { values: xs.map((x, i) => ({ x, y: ys[i] })) },
        { xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
    );

};

58

c.定义模型结构:单层单个神经元组成的神经网络

59

连续的模型:https://js.tensorflow.org/api/3.16.0/#sequential

6061

https://js.tensorflow.org/api/3.16.0/#layers.dense

6263

由于

64

故选用layers.dense()创造

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const xs = [1, 2, 3, 4];
    const ys = [1, 3, 5, 7];

    tfvis.render.scatterplot(
        { name: '线性回归训练集' },
        { values: xs.map((x, i) => ({ x, y: ys[i] })) },
        { xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
    );
 
 // 创建模型
    const model = tf.sequential();
    
    // 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
    // units是神经元个数。inputShape是一维数据且长度是1。
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
};

d.损失函数:均方误差(MSE)

65

相关教程

666768697071

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const xs = [1, 2, 3, 4];
    const ys = [1, 3, 5, 7];

    tfvis.render.scatterplot(
        { name: '线性回归训练集' },
        { values: xs.map((x, i) => ({ x, y: ys[i] })) },
        { xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
    );
 
 // 创建模型
    const model = tf.sequential();
    
    // 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
    // units是神经元个数。inputShape是一维数据且长度是1。
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
 
 // 编译函数内传入MSE均方差损失函数
 model.compile({ loss: tf.losses.meanSquaredError });
};

e.优化器:随机梯度下降(SGD)

72

迭代方法https://developers.google.com/machine-learning/crash-course/reducing-loss/an-iterative-approach梯度下降https://developers.google.com/machine-learning/crash-course/reducing-loss/gradient-descent学习速率https://developers.google.com/machine-learning/crash-course/reducing-loss/learning-rate优化学习速率https://developers.google.com/machine-learning/crash-course/fitter/graph随机梯度下降https://developers.google.cn/machine-learning/crash-course/reducing-loss/stochastic-gradient-descent

7374

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const xs = [1, 2, 3, 4];
    const ys = [1, 3, 5, 7];

    tfvis.render.scatterplot(
        { name: '线性回归训练集' },
        { values: xs.map((x, i) => ({ x, y: ys[i] })) },
        { xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
    );
 
 // 创建模型
    const model = tf.sequential();
    
    // 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
    // units是神经元个数。inputShape是一维数据且长度是1。
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
 
 // 编译函数内传入MSE均方差损失函数和随机梯度下降SGD优化器,并设置学习速率为0.1
 model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
};

其它优化器举例:

75

f.训练模型并可视化训练过程

7677

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const xs = [1, 2, 3, 4];
    const ys = [1, 3, 5, 7];

    tfvis.render.scatterplot(
        { name: '线性回归训练集' },
        { values: xs.map((x, i) => ({ x, y: ys[i] })) },
        { xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
    );
 
 // 创建模型
    const model = tf.sequential();
    
    // 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
    // units是神经元个数。inputShape是一维数据且长度是1。
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
 
 // 编译函数内传入MSE均方差损失函数和随机梯度下降SGD优化器,并设置学习速率为0.1
 model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
 
 // 输入和输出转换为Tensor格式数据
 const inputs = tf.tensor(xs);
    const labels = tf.tensor(ys);
    
    // 进行训练(异步的)
    await model.fit(inputs, labels, {
     // 批量数据,设置每次学习的量
        batchSize: 4,
        epochs: 200,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练过程' },
            ['loss']
        )
    });
};

执行参数和结果可视化比较:

纵轴loss指损失的值 onBatchEnd图指一个小批量变化后就执行得到的图表 onEpochEnd图指一个时期/纪元完成时执行得到的图表

78

设置epochs为100,但训练数据一共有4个,所以batch数量总共就是400

7980

batchSize改为4后,图像变得更加平滑,没有了抖动。(因为batchSize为1时,每次给模型喂1个数据。使用随机梯度下降法时,刚开始部分的数据不确定性很大,因此出现抖动)

81

另外,学习率也不宜调的过大: 学习率由0.1改为0.15时,对应图表如下: 发现损失函数仍处在上升阶段,不能很好的预测数据。

82

学习率由0.1改为0.01时,对应图表如下: 有较好的精度,但学习效率大幅下降。

83

g.进行预测

848586878889

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const xs = [1, 2, 3, 4];
    const ys = [1, 3, 5, 7];

    tfvis.render.scatterplot(
        { name: '线性回归训练集' },
        { values: xs.map((x, i) => ({ x, y: ys[i] })) },
        { xAxisDomain: [0, 5], yAxisDomain: [0, 8] }
    );
 
 // 创建模型
    const model = tf.sequential();
    
    // 添加层,由于满足output = activation(dot(input, kernel) + bias)关系,故选用dense创造。
    // units是神经元个数。inputShape是一维数据且长度是1。
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
 
 // 编译函数内传入MSE均方差损失函数和随机梯度下降SGD优化器,并设置学习速率为0.1
 model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });
 
 // 输入和输出转换为Tensor格式数据
 const inputs = tf.tensor(xs);
    const labels = tf.tensor(ys);
    
    // 进行训练(异步的)
    await model.fit(inputs, labels, {
     // 批量数据,设置每次学习的量
        batchSize: 4,
        epochs: 200,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练过程' },
            ['loss']
        )
    });

 // 预测数据
 const output = model.predict(tf.tensor([5]));
    alert(`如果 x 为 5,那么预测 y 为 ${output.dataSync()[0]}`);
};

5. 归一化

909192939495969798

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';

window.onload = async () => {
    const heights = [150, 160, 170];
    const weights = [40, 50, 60];

    tfvis.render.scatterplot(
        { name: '身高体重训练数据' },
        { values: heights.map((x, i) => ({ x, y: weights[i] })) },
        {
            xAxisDomain: [140, 180],
            yAxisDomain: [30, 70]
        }
    );

    const inputs = tf.tensor(heights).sub(150).div(20);
    const labels = tf.tensor(weights).sub(40).div(20);

    const model = tf.sequential();
    model.add(tf.layers.dense({ units: 1, inputShape: [1] }));
    model.compile({ loss: tf.losses.meanSquaredError, optimizer: tf.train.sgd(0.1) });

    await model.fit(inputs, labels, {
        batchSize: 3,
        epochs: 200,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练过程' },
            ['loss']
        )
    });
 
 // 张量预测前需要经过归一化数据后计算
    const output = model.predict(tf.tensor([180]).sub(150).div(20));
   // 训练后需要进行反归一化
    alert(`如果身高为 180cm,那么预测体重为 ${output.mul(20).add(40).dataSync()[0]}kg`);
};

99

6. 逻辑回归

100101

a.加载二分类数据集

102

data.js:

javascript
export function getData(numSamples) {
    let points = [];
  
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx);
        let y = normalRandom(cy);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 1);
    genGauss(-2, -2, 0);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1.
   */
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }

103

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
 // 获取400个点的数据
    const data = getData(400);
 // console.log('data',data) // 结果见上图
    tfvis.render.scatterplot(
        { name: '逻辑回归训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );
};

104

b.定义模型结构:带有激活函数的单个神经元

105

sigmoid函数

106

c.损失函数:对数损失(log loss)

log loss:https://developers.google.cn/machine-learning/crash-course/logistic-regression/model-training?hl=zh_cn

107108109110111112113114115116

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
    const data = getData(400);
    console.log('data', data)

    tfvis.render.scatterplot(
        { name: '逻辑回归训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 1,
        inputShape: [2],
        activation: 'sigmoid'
    }));
    model.compile({
     // 设置对数损失函数
        loss: tf.losses.logLoss,
    });
};

d.训练模型并可视化训练过程

117

javascript
export function getData(numSamples) {
    let points = [];
  
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx);
        let y = normalRandom(cy);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 1);
    genGauss(-2, -2, 0);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1.
   */
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
    const data = getData(400);
    console.log('data', data)

    tfvis.render.scatterplot(
        { name: '逻辑回归训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 1,
        inputShape: [2],
        activation: 'sigmoid'
    }));
    model.compile({
        loss: tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });
 
 // 将训练数据变为tensor
    const inputs = tf.tensor(data.map(p => [p.x, p.y]));
    const labels = tf.tensor(data.map(p => p.label));
 
 // 开始训练
    await model.fit(inputs, labels, {
        batchSize: 40,
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss']
        )
    });
};

118

f.进行预测

119

html
<form action="" onsubmit="predict(this);return false;">
    x: <input type="text" name="x">
    y: <input type="text" name="y">
    <button type="submit">预测</button>
</form>
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
    const data = getData(400);
    console.log('data', data)

    tfvis.render.scatterplot(
        { name: '逻辑回归训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 1,
        inputShape: [2],
        activation: 'sigmoid'
    }));
    model.compile({
        loss: tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });

    const inputs = tf.tensor(data.map(p => [p.x, p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs, labels, {
        batchSize: 40,
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss']
        )
    });

    window.predict = (form) => {
        const pred = model.predict(tf.tensor([[form.x.value * 1, form.y.value * 1]]));
        alert(`预测结果:${pred.dataSync()[0]}`);
    };
};

120121122123

g.二分类数据集生成函数源码剖析

javascript
export function getData(numSamples) {
    let points = [];
   
   // 生成以指定点为中心,指定分布范围,符合高斯分布(正态分布)的点
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx);
        let y = normalRandom(cy);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 1);
    genGauss(-2, -2, 0);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1.
   */
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }

124125

由于原公式中使用了三角函数,性能不是很好,经过推导,可使用另外的公式替代:

Box-Muller transform

126

v1,v2对应上面公式中的u,v s是u方和v方的和

同时由于要取圆圈内的值,即半径s需要小于1(取值计算后发现s>1则需要放弃该值),故最终生成满足高斯分布随机值的函数的算法是

javascript
function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    // mean改变中心
    return mean + Math.sqrt(variance) * result;
  }

7. 多层神经网络

127

https://playground.tensorflow.org/

128

a.加载 XOR 数据集

129

data.js:

javascript
export function getData(numSamples) {
    let points = [];
  
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx);
        let y = normalRandom(cy);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 0);
    genGauss(-2, -2, 0);
    genGauss(-2, 2, 1);
    genGauss(2, -2, 1);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1.
   */
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
    const data = getData(400);

    tfvis.render.scatterplot(
        { name: 'XOR 训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );
};

130

b.定义模型结构:多层神经网络

131

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
    const data = getData(400);

    tfvis.render.scatterplot(
        { name: 'XOR 训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    // 创建输出层
    model.add(tf.layers.dense({
        units: 4,
        inputShape: [2],
        activation: 'relu'
    }));
    // 创建隐藏层
    model.add(tf.layers.dense({
        units: 1,
        activation: 'sigmoid'
    }));
};

c.训练模型并预测

132

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data.js';

window.onload = async () => {
    const data = getData(400);

    tfvis.render.scatterplot(
        { name: 'XOR 训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    // 创建输出层
    model.add(tf.layers.dense({
        units: 4,
        inputShape: [2],
        activation: 'relu'
    }));
    // 创建隐藏层
    model.add(tf.layers.dense({
        units: 1,
        activation: 'sigmoid'
    }));
  model.compile({
        loss: tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });

    const inputs = tf.tensor(data.map(p => [p.x, p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs, labels, {
        epochs: 10,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss']
        )
    });

    window.predict = (form) => {
        const pred = model.predict(tf.tensor([[form.x.value * 1, form.y.value * 1]]));
        alert(`预测结果:${pred.dataSync()[0]}`);
    };
};

133

html
<form action="" onsubmit="predict(this);return false;">
    x: <input type="text" name="x">
    y: <input type="text" name="y">
    <button type="submit">预测</button>
</form>

134135136137

8. 多分类

138

a.任务简介、主要步骤、前置条件

139

b.加载iris数据集(训练集与验证集)

140

data.js:

javascript
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import * as tf from '@tensorflow/tfjs';

export const IRIS_CLASSES =
    ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;

// Iris flowers data. Source:
//   https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
const IRIS_DATA = [
  [5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
  [4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
  [4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
  [4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
  [4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
  [5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
  [5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
  [5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
  [4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
  [5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
  [4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
  [5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
  [5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
  [5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
  [4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
  [4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
  [5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
  [6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
  [6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
  [4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
  [5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
  [6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
  [5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
  [5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
  [6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
  [6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
  [6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
  [5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
  [5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
  [6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
  [5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
  [5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
  [5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
  [5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
  [7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
  [7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
  [6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
  [6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
  [5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
  [7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
  [6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
  [6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
  [6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
  [7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
  [6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
  [7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
  [6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
  [6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
  [6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
  [6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
];

/**
 * Convert Iris data arrays to `tf.Tensor`s.
 *
 * @param data The Iris input feature data, an `Array` of `Array`s, each element
 *   of which is assumed to be a length-4 `Array` (for petal length, petal
 *   width, sepal length, sepal width).
 * @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
 *   representing the true category of the Iris flower. Assumed to have the same
 *   array length as `data`.
 * @param testSplit Fraction of the data at the end to split as test data: a
 *   number between 0 and 1.
 * @return A length-4 `Array`, with
 *   - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
 *   - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
 *   - test data as `tf.Tensor` of shape [numTestExamples, 4].
 *   - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
 */
function convertToTensors(data, targets, testSplit) {
  const numExamples = data.length;
  if (numExamples !== targets.length) {
    throw new Error('data and split have different numbers of examples');
  }

  // Randomly shuffle `data` and `targets`.
  const indices = [];
  for (let i = 0; i < numExamples; ++i) {
    indices.push(i);
  }
  tf.util.shuffle(indices);

  const shuffledData = [];
  const shuffledTargets = [];
  for (let i = 0; i < numExamples; ++i) {
    shuffledData.push(data[indices[i]]);
    shuffledTargets.push(targets[indices[i]]);
  }

  // Split the data into a training set and a tet set, based on `testSplit`.
  const numTestExamples = Math.round(numExamples * testSplit);
  const numTrainExamples = numExamples - numTestExamples;

  const xDims = shuffledData[0].length;

  // Create a 2D `tf.Tensor` to hold the feature data.
  const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);

  // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
  // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
  const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);

  // Split the data into training and test sets, using `slice`.
  const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
  const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
  const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
  const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
  return [xTrain, yTrain, xTest, yTest];
}

/**
 * Obtains Iris data, split into training and test sets.
 *
 * @param testSplit Fraction of the data at the end to split as test data: a
 *   number between 0 and 1.
 *
 * @param return A length-4 `Array`, with
 *   - training data as an `Array` of length-4 `Array` of numbers.
 *   - training labels as an `Array` of numbers, with the same length as the
 *     return training data above. Each element of the `Array` is from the set
 *     {0, 1, 2}.
 *   - test data as an `Array` of length-4 `Array` of numbers.
 *   - test labels as an `Array` of numbers, with the same length as the
 *     return test data above. Each element of the `Array` is from the set
 *     {0, 1, 2}.
 */
export function getIrisData(testSplit) {
  return tf.tidy(() => {
    const dataByClass = [];
    const targetsByClass = [];
    for (let i = 0; i < IRIS_CLASSES.length; ++i) {
      dataByClass.push([]);
      targetsByClass.push([]);
    }
    for (const example of IRIS_DATA) {
      const target = example[example.length - 1];
      const data = example.slice(0, example.length - 1);
      dataByClass[target].push(data);
      targetsByClass[target].push(target);
    }

    const xTrains = [];
    const yTrains = [];
    const xTests = [];
    const yTests = [];
    for (let i = 0; i < IRIS_CLASSES.length; ++i) {
      const [xTrain, yTrain, xTest, yTest] =
          convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
      xTrains.push(xTrain);
      yTrains.push(yTrain);
      xTests.push(xTest);
      yTests.push(yTest);
    }

    const concatAxis = 0;
    return [
      tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
      tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
    ];
  });
}
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';

window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
    const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
    xTrain.print();
    yTrain.print();
    xTest.print();
    yTest.print();
    
    console.log(IRIS_CLASSES);

 console.log( xTrain)
    console.log( yTrain)
    console.log(xTest);
    console.log(yTest);
};

141142

c.定义模型结构:带有softmax的多层神经网络

sigmoid激活函数解决二分类问题,softmax激活函数解决多分类问题。

143

https://developers.google.com/machine-learning/recommendation/dnn/softmax

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';

window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
    const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 10,
        inputShape: [xTrain.shape[1]],
        activation: 'sigmoid'
    }));
    // 输出层
    model.add(tf.layers.dense({
        units: 3, // 神经元个数需为输出集个数。鸢尾花种类为3,故需设置为3
        activation: 'softmax' // 输出3个概率,且其相加概率为1,故需设置激活函数为softmax
    }));
};

d.训练模型:交叉熵损失函数与准确度度量

交叉熵损失函数是对数损失函数的多分类版本。

144

https://ml-cheatsheet.readthedocs.io/en/latest/

交叉熵损失函数

145

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';

window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
    const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 10,
        inputShape: [xTrain.shape[1]],
        activation: 'sigmoid'
    }));
    // 输出层
    model.add(tf.layers.dense({
        units: 3, // 神经元个数需为输出集个数。鸢尾花种类为3,故需设置为3
        activation: 'softmax' // 输出3个概率,且其相加概率为1,故需设置激活函数为softmax
    }));

    model.compile({
        // 设置交叉熵损失函数
        loss: 'categoricalCrossentropy',
        optimizer: tf.train.adam(0.1),
        // 设置准确度度量
        metrics: ['accuracy']
    });

    await model.fit(xTrain, yTrain, {
        epochs: 100,
        // 设置验证集
        validationData: [xTest, yTest],
        // 可视化训练过程
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' }, // 图表名称
            ['loss', 'val_loss', 'acc', 'val_acc'], // 度量:损失、验证集损失、准确度、验证集准确度
            { callbacks: ['onEpochEnd'] }
        )
    });
};

146

e.多分类预测方法

147

html
<form action="" onsubmit="predict(this); return false;">
    花萼长度:<input type="text" name="a"><br>
    花萼宽度:<input type="text" name="b"><br>
    花瓣长度:<input type="text" name="c"><br>
    花瓣宽度:<input type="text" name="d"><br>
    <button type="submit">预测</button>
</form
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
// getIrisData用于获取训练集和验证集; IRIS_CLASSES 用于输出中文类别
import { getIrisData, IRIS_CLASSES } from './data';

window.onload = async () => {
// 百分之十五的数据用于验证集
// [xTrain:训练集输入特征, yTrain:训练集所有标签(输出结果), xTest:验证集输入特征, yTest:验证集所有标签(输出结果)
    const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 10,
        inputShape: [xTrain.shape[1]],
        activation: 'sigmoid'
    }));
    // 输出层
    model.add(tf.layers.dense({
        units: 3, // 神经元个数需为输出集个数。鸢尾花种类为3,故需设置为3
        activation: 'softmax' // 输出3个概率,且其相加概率为1,故需设置激活函数为softmax
    }));

    model.compile({
        // 设置交叉熵损失函数
        loss: 'categoricalCrossentropy',
        optimizer: tf.train.adam(0.1),
        // 设置准确度度量
        metrics: ['accuracy']
    });

    await model.fit(xTrain, yTrain, {
        epochs: 100,
        // 设置验证集
        validationData: [xTest, yTest],
        // 可视化训练过程
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' }, // 图表名称
            ['loss', 'val_loss', 'acc', 'val_acc'], // 度量:损失、验证集损失、准确度、验证集准确度
            { callbacks: ['onEpochEnd'] }
        )
    });

    window.predict = (form) => {
        const input = tf.tensor([[
            form.a.value * 1,
            form.b.value * 1,
            form.c.value * 1,
            form.d.value * 1,
        ]]);
        const pred = model.predict(input);
        // argMax输出某个维中的最大值
        alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`);
    };
};

148149150

f.IRIS数据集生成函数源码剖析

javascript
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import * as tf from '@tensorflow/tfjs';

export const IRIS_CLASSES =
    ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'];
export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;

// Iris flowers data. Source:
//   https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
const IRIS_DATA = [
  [5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
  [4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
  [4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
  [4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
  [4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
  [5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
  [5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
  [5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
  [4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
  [5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
  [4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
  [5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
  [5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
  [5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
  [4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
  [4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
  [5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
  [6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
  [6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
  [4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
  [5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
  [6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
  [5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
  [5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
  [6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
  [6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
  [6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
  [5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
  [5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
  [6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
  [5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
  [5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
  [5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
  [5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
  [7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
  [7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
  [6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
  [6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
  [5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
  [7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
  [6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
  [6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
  [6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
  [7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
  [6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
  [7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
  [6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
  [6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
  [6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
  [6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
];

/**
 * Convert Iris data arrays to `tf.Tensor`s.
 * 普通数据转换为tensor
 *
 * @param data The Iris input feature data, an `Array` of `Array`s, each element
 *   of which is assumed to be a length-4 `Array` (for petal length, petal
 *   width, sepal length, sepal width).
 * @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
 *   representing the true category of the Iris flower. Assumed to have the same
 *   array length as `data`.
 * @param testSplit Fraction of the data at the end to split as test data: a
 *   number between 0 and 1.
 * @return A length-4 `Array`, with
 *   - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
 *   - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
 *   - test data as `tf.Tensor` of shape [numTestExamples, 4].
 *   - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
 */
function convertToTensors(data, targets, testSplit) {
  const numExamples = data.length;
  if (numExamples !== targets.length) {
    throw new Error('data and split have different numbers of examples');
  }

  // Randomly shuffle `data` and `targets`.
  // 执行随机化打乱数据
  const indices = [];
  for (let i = 0; i < numExamples; ++i) {
    indices.push(i);
  }
  // TensorFlow的洗牌算法,打乱数组
  tf.util.shuffle(indices);

  const shuffledData = [];
  const shuffledTargets = [];
  for (let i = 0; i < numExamples; ++i) {
    shuffledData.push(data[indices[i]]);
    shuffledTargets.push(targets[indices[i]]);
  }

  // Split the data into a training set and a tet set, based on `testSplit`.
  const numTestExamples = Math.round(numExamples * testSplit);
  const numTrainExamples = numExamples - numTestExamples;

  const xDims = shuffledData[0].length;

  // Create a 2D `tf.Tensor` to hold the feature data.
  // 创建二维tensor
  const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);

  // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
  // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
  // 将数组或集合转化为多分类的格式
  const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);

  // Split the data into training and test sets, using `slice`.
  // 将数据分成训练集和测试集
  const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
  const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
  const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
  const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
  return [xTrain, yTrain, xTest, yTest];
}

/**
 * Obtains Iris data, split into training and test sets.
 *
 * @param testSplit Fraction of the data at the end to split as test data: a
 *   number between 0 and 1.
 *
 * @param return A length-4 `Array`, with
 *   - training data as an `Array` of length-4 `Array` of numbers.
 *   - training labels as an `Array` of numbers, with the same length as the
 *     return training data above. Each element of the `Array` is from the set
 *     {0, 1, 2}.
 *   - test data as an `Array` of length-4 `Array` of numbers.
 *   - test labels as an `Array` of numbers, with the same length as the
 *     return test data above. Each element of the `Array` is from the set
 *     {0, 1, 2}.
 */
export function getIrisData(testSplit) {
  // 清除中间Tensor变量,减少计算机的中间性能消耗和资源消耗。
  return tf.tidy(() => {
   // 按类(0,1,2为输入数据数组最后一个元素)存放输入数据类型(花萼长度、花萼宽度、花瓣长度、花瓣宽度)。
    const dataByClass = [];
    // 代表输出结果的类型:即'山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'
    const targetsByClass = [];
    for (let i = 0; i < IRIS_CLASSES.length; ++i) {
      dataByClass.push([]);
      targetsByClass.push([]);
    }
    for (const example of IRIS_DATA) {
      const target = example[example.length - 1];
      const data = example.slice(0, example.length - 1);
      dataByClass[target].push(data);
      targetsByClass[target].push(target);
    }
 
    const xTrains = [];
    const yTrains = [];
    const xTests = [];
    const yTests = [];
    for (let i = 0; i < IRIS_CLASSES.length; ++i) {
      // 普通js数组转换为tensor
      const [xTrain, yTrain, xTest, yTest] =
          convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
      xTrains.push(xTrain);
      yTrains.push(yTrain);
      xTests.push(xTest);
      yTests.push(yTest);
    }

    const concatAxis = 0;
    return [
      tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
      tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
    ];
  });
}

dataByClass: 11

targetByClass:

111111111111

9. 欠拟合与过拟合

a.欠拟合与过拟合任务简介

数据:黑色弧线。模型:蓝色直线。 111

模型太过强大,反而把数据拟合过头,反而遇到额外数据或新的数据,才模型上表现的不是很好。(绿色线为过拟合模型)

过拟合损失曲线: 红色:验证集损失。蓝色:训练集损失。

111

b.加载带有噪音的二分类数据集

稍微带噪音的数据更容易复现过拟合的现象。 过度复杂的模型,为了拟合所有训练集的数据,把一些噪音也拟合了。由于最终模型比较复杂,面对训练集以外的数据反而不好了。

1

javascript
// 传入样本数量和方差(方差越大,噪音越多)
export function getData(numSamples, variance) {
    let points = [];
  
    function genGauss(cx, cy, label) {
      for (let i = 0; i < numSamples / 2; i++) {
        let x = normalRandom(cx, variance);
        let y = normalRandom(cy, variance);
        points.push({ x, y, label });
      }
    }
  
    genGauss(2, 2, 1);
    genGauss(-2, -2, 0);
    return points;
  }
  
  /**
   * Samples from a normal distribution. Uses the seedrandom library as the
   * random generator.
   *
   * @param mean The mean. Default is 0.
   * @param variance The variance. Default is 1.
   */
   // 生成正态分布和高斯分布
  function normalRandom(mean = 0, variance = 1) {
    let v1, v2, s;
    do {
      v1 = 2 * Math.random() - 1;
      v2 = 2 * Math.random() - 1;
      s = v1 * v1 + v2 * v2;
    } while (s > 1);
  
    let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
    return mean + Math.sqrt(variance) * result;
  }
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';

window.onload = async () => {
    const data = getData(200, 2);

    tfvis.render.scatterplot(
        { name: '训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );
};

方差设为2: 1 方差设为1: 1 方差设为3: 1

c.使用简单神经网络演示欠拟合

简单模型解决复杂数据集或训练时间不够时往往出现此情况。 表现为训练损失和验证损失都比较高。 需增加模型复杂度,使用更多的神经元和更多的层去解决欠拟合问题。

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from '../xor/data.js';

window.onload = async () => {
    const data = getData(200);

    tfvis.render.scatterplot(
        { name: '训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 1,
        inputShape: [2],
        activation: 'sigmoid'
    }));
    model.compile({
        loss: tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });

    const inputs = tf.tensor(data.map(p => [p.x, p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs, labels, {
     validationSplit: 0.2,
        epochs: 200,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'val_loss'],
            {callbacks: ['onEpochEnd']},
        )
    });
};

1

d.使用复杂神经网络演示过拟合

训练集损失降的很低,验证集损失还是居高不下。

1

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';

window.onload = async () => {
    const data = getData(200, 2);

    tfvis.render.scatterplot(
        { name: '训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 10,
        inputShape: [2],
        activation: "tanh",
    }));
    model.add(tf.layers.dense({
        units: 1,
        activation: 'sigmoid'
    }));
    model.compile({
        loss: tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });

    const inputs = tf.tensor(data.map(p => [p.x, p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs, labels, {
        validationSplit: 0.2,
        epochs: 200,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'val_loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });
};

1

e.过拟合应对法:早停法、权重衰减、丢弃法

1

早停法在验证集损失明显增大时,停止训练即可。

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getData } from './data';

window.onload = async () => {
    const data = getData(200, 2);

    tfvis.render.scatterplot(
        { name: '训练数据' },
        {
            values: [
                data.filter(p => p.label === 1),
                data.filter(p => p.label === 0),
            ]
        }
    );

    const model = tf.sequential();
    model.add(tf.layers.dense({
        units: 10,
        inputShape: [2],
        activation: "tanh",
        // 1.使用权重衰减(L2正则化):通过设置L2正则化,把复杂模型的权重衰减掉了,模型变得没那么复杂了,因此不会过拟合了
        // kernelRegularizer: tf.regularizers.l2({ l2: 1 })
    }));
    
   // 2.加入丢弃层并设置丢弃率:即将上述设置的神经元随机舍去一部分,从而减少神经元。
   // 丢弃率设置为0.9,即10个神经元丢弃9个
    // model.add(tf.layers.dropout({ rate: 0.9 }));
    
    model.add(tf.layers.dense({
        units: 1,
        activation: 'sigmoid'
    }));
    model.compile({
        loss: tf.losses.logLoss,
        optimizer: tf.train.adam(0.1)
    });

    const inputs = tf.tensor(data.map(p => [p.x, p.y]));
    const labels = tf.tensor(data.map(p => p.label));

    await model.fit(inputs, labels, {
        validationSplit: 0.2,
        epochs: 200,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'val_loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });
};

1

仅加入L2正则化:

1

仅使用丢弃法:

1

同时使用权重衰减法和丢弃法:

1

10. 使用卷积神经网络(CNN)识别手写数字

a.使用卷积神经网络识别手写数字任务简介

1

b.加载 MNIST 数据集

MNIST是一个手写体数字的图片数据集,该数据集来由美国国家标准与技术研究所(National Institute of Standards and Technology (NIST))发起整理,一共统计了来自250个不同的人手写数字图片,其中50%是高中生,50%来自人口普查局的工作人员。该数据集的收集目的是希望通过算法,实现对手写数字的识别。

1998年,Yan LeCun 等人发表了论文《Gradient-Based Learning Applied to Document Recognition》,首次提出了LeNet-5 网络,利用上述数据集实现了手写字体的识别。

111

data.js:

javascript
/**
 * @license
 * Copyright 2018 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

import * as tf from '@tensorflow/tfjs';

const IMAGE_SIZE = 784;
const NUM_CLASSES = 10;
const NUM_DATASET_ELEMENTS = 65000;

const TRAIN_TEST_RATIO = 5 / 6;

const NUM_TRAIN_ELEMENTS = Math.floor(TRAIN_TEST_RATIO * NUM_DATASET_ELEMENTS);
const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;

const MNIST_IMAGES_SPRITE_PATH =
    'http://127.0.0.1:8080/mnist/mnist_images.png';
const MNIST_LABELS_PATH =
    'http://127.0.0.1:8080/mnist/mnist_labels_uint8';

/**
 * A class that fetches the sprited MNIST dataset and returns shuffled batches.
 *
 * NOTE: This will get much easier. For now, we do data fetching and
 * manipulation manually.
 */
export class MnistData {
  constructor() {
    this.shuffledTrainIndex = 0;
    this.shuffledTestIndex = 0;
  }

  async load() {
    // Make a request for the MNIST sprited image.
    const img = new Image();
    const canvas = document.createElement('canvas');
    const ctx = canvas.getContext('2d');
    const imgRequest = new Promise((resolve, reject) => {
      img.crossOrigin = '';
      img.onload = () => {
        img.width = img.naturalWidth;
        img.height = img.naturalHeight;

        const datasetBytesBuffer =
            new ArrayBuffer(NUM_DATASET_ELEMENTS * IMAGE_SIZE * 4);

        const chunkSize = 5000;
        canvas.width = img.width;
        canvas.height = chunkSize;

        for (let i = 0; i < NUM_DATASET_ELEMENTS / chunkSize; i++) {
          const datasetBytesView = new Float32Array(
              datasetBytesBuffer, i * IMAGE_SIZE * chunkSize * 4,
              IMAGE_SIZE * chunkSize);
          ctx.drawImage(
              img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
              chunkSize);

          const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);

          for (let j = 0; j < imageData.data.length / 4; j++) {
            // All channels hold an equal value since the image is grayscale, so
            // just read the red channel.
            datasetBytesView[j] = imageData.data[j * 4] / 255;
          }
        }
        this.datasetImages = new Float32Array(datasetBytesBuffer);

        resolve();
      };
      img.src = MNIST_IMAGES_SPRITE_PATH;
    });

    const labelsRequest = fetch(MNIST_LABELS_PATH);
    const [imgResponse, labelsResponse] =
        await Promise.all([imgRequest, labelsRequest]);

    this.datasetLabels = new Uint8Array(await labelsResponse.arrayBuffer());

    // Create shuffled indices into the train/test set for when we select a
    // random dataset element for training / validation.
    this.trainIndices = tf.util.createShuffledIndices(NUM_TRAIN_ELEMENTS);
    this.testIndices = tf.util.createShuffledIndices(NUM_TEST_ELEMENTS);

    // Slice the the images and labels into train and test sets.
    this.trainImages =
        this.datasetImages.slice(0, IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.testImages = this.datasetImages.slice(IMAGE_SIZE * NUM_TRAIN_ELEMENTS);
    this.trainLabels =
        this.datasetLabels.slice(0, NUM_CLASSES * NUM_TRAIN_ELEMENTS);
    this.testLabels =
        this.datasetLabels.slice(NUM_CLASSES * NUM_TRAIN_ELEMENTS);
  }

  nextTrainBatch(batchSize) {
    return this.nextBatch(
        batchSize, [this.trainImages, this.trainLabels], () => {
          this.shuffledTrainIndex =
              (this.shuffledTrainIndex + 1) % this.trainIndices.length;
          return this.trainIndices[this.shuffledTrainIndex];
        });
  }

  nextTestBatch(batchSize) {
    return this.nextBatch(batchSize, [this.testImages, this.testLabels], () => {
      this.shuffledTestIndex =
          (this.shuffledTestIndex + 1) % this.testIndices.length;
      return this.testIndices[this.shuffledTestIndex];
    });
  }

  nextBatch(batchSize, data, index) {
    const batchImagesArray = new Float32Array(batchSize * IMAGE_SIZE);
    const batchLabelsArray = new Uint8Array(batchSize * NUM_CLASSES);

    for (let i = 0; i < batchSize; i++) {
      const idx = index();

      const image =
          data[0].slice(idx * IMAGE_SIZE, idx * IMAGE_SIZE + IMAGE_SIZE);
      batchImagesArray.set(image, i * IMAGE_SIZE);

      const label =
          data[1].slice(idx * NUM_CLASSES, idx * NUM_CLASSES + NUM_CLASSES);
      batchLabelsArray.set(label, i * NUM_CLASSES);
    }

    const xs = tf.tensor2d(batchImagesArray, [batchSize, IMAGE_SIZE]);
    const labels = tf.tensor2d(batchLabelsArray, [batchSize, NUM_CLASSES]);

    return {xs, labels};
  }
}

11

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';

window.onload = async () => {
 // 创建data实例
    const data = new MnistData();
    // 加载图片和二进制文件
    await data.load();
    // 加载验证集,获取20个输入示例
    const examples = data.nextTestBatch(20);
    const surface = tfvis.visor().surface({ name: '输入示例' });
    for (let i = 0; i < 20; i += 1) {
     // 从20个示例中输出每个图片的tensor
        const imageTensor = tf.tidy(() => {
            return examples.xs
                .slice([i, 0], [1, 784])
                .reshape([28, 28, 1]);
        });
  
  // 创建canvas对象
        const canvas = document.createElement('canvas');
        canvas.width = 28;
        canvas.height = 28;
        canvas.style = 'margin: 4px';
        // 把图片的tensor通过toPixels方法绘制到canvas上
        await tf.browser.toPixels(imageTensor, canvas);
        // 使用 tfvis的api在网页上展示出来
        surface.drawArea.appendChild(canvas);
    }
};

c.定义模型结构:卷积神经网络

111

https://setosa.io/ev/image-kernels/

1111

https://cs231n.github.io/convolutional-networks/

1111

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';

window.onload = async () => {
 // 创建data实例
    const data = new MnistData();
    // 加载图片和二进制文件
    await data.load();
    // 加载验证集,获取20个输入示例
    const examples = data.nextTestBatch(20);
    const surface = tfvis.visor().surface({ name: '输入示例' });
    for (let i = 0; i < 20; i += 1) {
     // 从20个示例中输出每个图片的tensor
        const imageTensor = tf.tidy(() => {
            return examples.xs
                .slice([i, 0], [1, 784])
                .reshape([28, 28, 1]);
        });
  
  // 创建canvas对象
        const canvas = document.createElement('canvas');
        canvas.width = 28;
        canvas.height = 28;
        canvas.style = 'margin: 4px';
        // 把图片的tensor通过toPixels方法绘制到canvas上
        await tf.browser.toPixels(imageTensor, canvas);
        // 使用 tfvis的api在网页上展示出来
        surface.drawArea.appendChild(canvas);
    }
 
 // 初始化神经网络模型(选用连续模型)
    const model = tf.sequential();
    // 添加一个二维卷积层
    model.add(tf.layers.conv2d({
     // 图片宽度、高度以及channel(由于是灰度图,故设置为1,彩色图需设置3,对应rgb)
        inputShape: [28, 28, 1],
        // 设置卷积核大小(5*5),建议设置为奇数,有中心点便于提取特征
        kernelSize: 5,
        // 设置filters个数
        filters: 8,
        // 设置移动步长
        strides: 1,
        // 设置激活函数为relu,可移除一些不常用的特征:https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
        activation: 'relu',
        // 设置卷积核的初始化方法,能加快收敛速度
        kernelInitializer: 'varianceScaling'
    }));
    // 添加池化层
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));

 // 重复卷积加池化的操作,从而进行特征的组合
    model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));

 // 把高维特征图转化为一维
    model.add(tf.layers.flatten());
    
    // 创建一个密集(全连接)层。
    model.add(tf.layers.dense({
        units: 10,
        activation: 'softmax',
        kernelInitializer: 'varianceScaling'
    }));
    
};

d.训练模型

1

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';

window.onload = async () => {
 // 创建data实例
    const data = new MnistData();
    // 加载图片和二进制文件
    await data.load();
    // 加载验证集,获取20个输入示例
    const examples = data.nextTestBatch(20);
    const surface = tfvis.visor().surface({ name: '输入示例' });
    for (let i = 0; i < 20; i += 1) {
     // 从20个示例中输出每个图片的tensor
        const imageTensor = tf.tidy(() => {
            return examples.xs
                .slice([i, 0], [1, 784])
                .reshape([28, 28, 1]);
        });
  
  // 创建canvas对象
        const canvas = document.createElement('canvas');
        canvas.width = 28;
        canvas.height = 28;
        canvas.style = 'margin: 4px';
        // 把图片的tensor通过toPixels方法绘制到canvas上
        await tf.browser.toPixels(imageTensor, canvas);
        // 使用 tfvis的api在网页上展示出来
        surface.drawArea.appendChild(canvas);
    }
 
 // 初始化神经网络模型(选用连续模型)
    const model = tf.sequential();
    // 添加一个二维卷积层
    model.add(tf.layers.conv2d({
     // 图片宽度、高度以及channel(由于是灰度图,故设置为1,彩色图需设置3,对应rgb)
        inputShape: [28, 28, 1],
        // 设置卷积核大小(5*5),建议设置为奇数,有中心点便于提取特征
        kernelSize: 5,
        // 设置filters个数
        filters: 8,
        // 设置移动步长
        strides: 1,
        // 设置激活函数为relu,可移除一些不常用的特征:https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
        activation: 'relu',
        // 设置卷积核的初始化方法,能加快收敛速度
        kernelInitializer: 'varianceScaling'
    }));
    // 添加池化层
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));

 // 重复卷积加池化的操作,从而进行特征的组合
    model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));

 // 把高维特征图转化为一维
    model.add(tf.layers.flatten());
    
    // 创建一个密集(全连接)层。
    model.add(tf.layers.dense({
        units: 10,
        activation: 'softmax',
        kernelInitializer: 'varianceScaling'
    }));
    
    // 训练
    model.compile({
     // 设置损失函数(交叉熵)
        loss: 'categoricalCrossentropy',
        // 设置优化器
        optimizer: tf.train.adam(),
        // 设置度量单位:准确度
        metrics: ['accuracy']
    });
 
 // 准备训练集
 // 放在tidy中可以使得计算中产生的tensor被及时清理掉,从而不会影响性能
    const [trainXs, trainYs] = tf.tidy(() => {
        const d = data.nextTrainBatch(1000);
        return [
            d.xs.reshape([1000, 28, 28, 1]),
            d.labels
        ];
    });
 
 // 准备验证集
    const [testXs, testYs] = tf.tidy(() => {
        const d = data.nextTestBatch(200);
        return [
            d.xs.reshape([200, 28, 28, 1]),
            d.labels
        ];
    });
 
 // 训练
    await model.fit(trainXs, trainYs, {
     // 验证数据
        validationData: [testXs, testYs],
        batchSize: 500,
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'val_loss', 'acc', 'val_acc'],
            { callbacks: ['onEpochEnd'] }
        )
    });
};

e.进行预测

1

index.html

html
<script src="script.js"></script>
<canvas width="300" height="300" style="border: 2px solid #666;"></canvas>
<br>
<button onclick="window.clear();" style="margin: 4px;">清除</button>
<button onclick="window.predict();" style="margin: 4px;">预测</button>

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { MnistData } from './data';

window.onload = async () => {
 // 创建data实例
    const data = new MnistData();
    // 加载图片和二进制文件
    await data.load();
    // 加载验证集,获取20个输入示例
    const examples = data.nextTestBatch(20);
    const surface = tfvis.visor().surface({ name: '输入示例' });
    for (let i = 0; i < 20; i += 1) {
     // 从20个示例中输出每个图片的tensor
        const imageTensor = tf.tidy(() => {
            return examples.xs
                .slice([i, 0], [1, 784])
                .reshape([28, 28, 1]);
        });
  
  // 创建canvas对象
        const canvas = document.createElement('canvas');
        canvas.width = 28;
        canvas.height = 28;
        canvas.style = 'margin: 4px';
        // 把图片的tensor通过toPixels方法绘制到canvas上
        await tf.browser.toPixels(imageTensor, canvas);
        // 使用 tfvis的api在网页上展示出来
        surface.drawArea.appendChild(canvas);
    }
 
 // 初始化神经网络模型(选用连续模型)
    const model = tf.sequential();
    // 添加一个二维卷积层
    model.add(tf.layers.conv2d({
     // 图片宽度、高度以及channel(由于是灰度图,故设置为1,彩色图需设置3,对应rgb)
        inputShape: [28, 28, 1],
        // 设置卷积核大小(5*5),建议设置为奇数,有中心点便于提取特征
        kernelSize: 5,
        // 设置filters个数
        filters: 8,
        // 设置移动步长
        strides: 1,
        // 设置激活函数为relu,可移除一些不常用的特征:https://en.wikipedia.org/wiki/Rectifier_(neural_networks)
        activation: 'relu',
        // 设置卷积核的初始化方法,能加快收敛速度
        kernelInitializer: 'varianceScaling'
    }));
    // 添加池化层
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));

 // 重复卷积加池化的操作,从而进行特征的组合
    model.add(tf.layers.conv2d({
        kernelSize: 5,
        filters: 16,
        strides: 1,
        activation: 'relu',
        kernelInitializer: 'varianceScaling'
    }));
    model.add(tf.layers.maxPool2d({
        poolSize: [2, 2],
        strides: [2, 2]
    }));

 // 把高维特征图转化为一维
    model.add(tf.layers.flatten());
    
    // 创建一个密集(全连接)层。
    model.add(tf.layers.dense({
        units: 10,
        activation: 'softmax',
        kernelInitializer: 'varianceScaling'
    }));
    
    // 训练
    model.compile({
     // 设置损失函数(交叉熵)
        loss: 'categoricalCrossentropy',
        // 设置优化器
        optimizer: tf.train.adam(),
        // 设置度量单位:准确度
        metrics: ['accuracy']
    });
 
 // 准备训练集
 // 放在tidy中可以使得计算中产生的tensor被及时清理掉,从而不会影响性能
    const [trainXs, trainYs] = tf.tidy(() => {
        const d = data.nextTrainBatch(1000);
        return [
            d.xs.reshape([1000, 28, 28, 1]),
            d.labels
        ];
    });
 
 // 准备验证集
    const [testXs, testYs] = tf.tidy(() => {
        const d = data.nextTestBatch(200);
        return [
            d.xs.reshape([200, 28, 28, 1]),
            d.labels
        ];
    });
 
 // 训练
    await model.fit(trainXs, trainYs, {
     // 验证数据
        validationData: [testXs, testYs],
        batchSize: 500,
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'val_loss', 'acc', 'val_acc'],
            { callbacks: ['onEpochEnd'] }
        )
    });
 
 // 获取canvas对象
    const canvas = document.querySelector('canvas');
 
 // 监听canvas上的鼠标移动事件
    canvas.addEventListener('mousemove', (e) => {
     // 如果按住鼠标左键
        if (e.buttons === 1) {
            const ctx = canvas.getContext('2d');
            // 设置画笔颜色为白色
            ctx.fillStyle = 'rgb(255,255,255)';
            // 画一个鼠标所在位置及宽高为25的矩形
            ctx.fillRect(e.offsetX, e.offsetY, 25, 25);
        }
    });

    window.clear = () => {
        const ctx = canvas.getContext('2d');
        // 设置canvas背景为黑色
        ctx.fillStyle = 'rgb(0,0,0)';
        // 画一个矩形
        ctx.fillRect(0, 0, 300, 300);
    };

    clear();

    window.predict = () => {
        const input = tf.tidy(() => {
         // fromPixel可以把canvas转化为tensor
         // resizeBilinear可以把300*300图片转化为我们需要的28*28格式
         // 并把彩色图片变成黑白图片,rgb三个通道改成1个通道,切1层:slice([0, 0, 0], [28, 28, 1])。切3层写法为slice([0, 0, 0], [28, 28, 3])
         // 然后把数据变为float后做归一化处理(由于数字在0-255范围,故需除以255来进行归一化)
         // 另外,最后输出格式要和训练的输入格式保持一致,故需要reshape为[1,28,28,1],即1张28*28像素的黑白图片
            return tf.image.resizeBilinear(
                tf.browser.fromPixels(canvas),
                [28, 28],
                true
            ).slice([0, 0, 0], [28, 28, 1])
            .toFloat()
            .div(255)
            .reshape([1, 28, 28, 1]);
        });
        // 调用模型的预测方法
        const pred = model.predict(input).argMax(1);
        alert(`预测结果为 ${pred.dataSync()[0]}`);
    };
};

11. 使用预训练模型进行图片分类

a.使用预训练模型进行图片分类任务简介

11

b.加载 MobileNet 模型 && 进行预测

1

imagenet_classes.js:

Details
javascript
/**
 * @license
 * Copyright 2017 Google LLC. All Rights Reserved.
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * =============================================================================
 */

export const IMAGENET_CLASSES = {
  0: 'tench, Tinca tinca',
  1: 'goldfish, Carassius auratus',
  2: 'great white shark, white shark, man-eater, man-eating shark, ' +
      'Carcharodon carcharias',
  3: 'tiger shark, Galeocerdo cuvieri',
  4: 'hammerhead, hammerhead shark',
  5: 'electric ray, crampfish, numbfish, torpedo',
  6: 'stingray',
  7: 'cock',
  8: 'hen',
  9: 'ostrich, Struthio camelus',
  10: 'brambling, Fringilla montifringilla',
  11: 'goldfinch, Carduelis carduelis',
  12: 'house finch, linnet, Carpodacus mexicanus',
  13: 'junco, snowbird',
  14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
  15: 'robin, American robin, Turdus migratorius',
  16: 'bulbul',
  17: 'jay',
  18: 'magpie',
  19: 'chickadee',
  20: 'water ouzel, dipper',
  21: 'kite',
  22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
  23: 'vulture',
  24: 'great grey owl, great gray owl, Strix nebulosa',
  25: 'European fire salamander, Salamandra salamandra',
  26: 'common newt, Triturus vulgaris',
  27: 'eft',
  28: 'spotted salamander, Ambystoma maculatum',
  29: 'axolotl, mud puppy, Ambystoma mexicanum',
  30: 'bullfrog, Rana catesbeiana',
  31: 'tree frog, tree-frog',
  32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
  33: 'loggerhead, loggerhead turtle, Caretta caretta',
  34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
  35: 'mud turtle',
  36: 'terrapin',
  37: 'box turtle, box tortoise',
  38: 'banded gecko',
  39: 'common iguana, iguana, Iguana iguana',
  40: 'American chameleon, anole, Anolis carolinensis',
  41: 'whiptail, whiptail lizard',
  42: 'agama',
  43: 'frilled lizard, Chlamydosaurus kingi',
  44: 'alligator lizard',
  45: 'Gila monster, Heloderma suspectum',
  46: 'green lizard, Lacerta viridis',
  47: 'African chameleon, Chamaeleo chamaeleon',
  48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, ' +
      'Varanus komodoensis',
  49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
  50: 'American alligator, Alligator mississipiensis',
  51: 'triceratops',
  52: 'thunder snake, worm snake, Carphophis amoenus',
  53: 'ringneck snake, ring-necked snake, ring snake',
  54: 'hognose snake, puff adder, sand viper',
  55: 'green snake, grass snake',
  56: 'king snake, kingsnake',
  57: 'garter snake, grass snake',
  58: 'water snake',
  59: 'vine snake',
  60: 'night snake, Hypsiglena torquata',
  61: 'boa constrictor, Constrictor constrictor',
  62: 'rock python, rock snake, Python sebae',
  63: 'Indian cobra, Naja naja',
  64: 'green mamba',
  65: 'sea snake',
  66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
  67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
  68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
  69: 'trilobite',
  70: 'harvestman, daddy longlegs, Phalangium opilio',
  71: 'scorpion',
  72: 'black and gold garden spider, Argiope aurantia',
  73: 'barn spider, Araneus cavaticus',
  74: 'garden spider, Aranea diademata',
  75: 'black widow, Latrodectus mactans',
  76: 'tarantula',
  77: 'wolf spider, hunting spider',
  78: 'tick',
  79: 'centipede',
  80: 'black grouse',
  81: 'ptarmigan',
  82: 'ruffed grouse, partridge, Bonasa umbellus',
  83: 'prairie chicken, prairie grouse, prairie fowl',
  84: 'peacock',
  85: 'quail',
  86: 'partridge',
  87: 'African grey, African gray, Psittacus erithacus',
  88: 'macaw',
  89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
  90: 'lorikeet',
  91: 'coucal',
  92: 'bee eater',
  93: 'hornbill',
  94: 'hummingbird',
  95: 'jacamar',
  96: 'toucan',
  97: 'drake',
  98: 'red-breasted merganser, Mergus serrator',
  99: 'goose',
  100: 'black swan, Cygnus atratus',
  101: 'tusker',
  102: 'echidna, spiny anteater, anteater',
  103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, ' +
      'Ornithorhynchus anatinus',
  104: 'wallaby, brush kangaroo',
  105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
  106: 'wombat',
  107: 'jelly fish',
  108: 'sea anemone, anemone',
  109: 'brain coral',
  110: 'flatworm, platyhelminth',
  111: 'nematode, nematode worm, roundworm',
  112: 'conch',
  113: 'snail',
  114: 'slug',
  115: 'sea slug, nudibranch',
  116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
  117: 'chambered nautilus, pearly nautilus, nautilus',
  118: 'Dungeness crab, Cancer magister',
  119: 'rock crab, Cancer irroratus',
  120: 'fiddler crab',
  121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, ' +
      'Paralithodes camtschatica',
  122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
  123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea ' +
      'crawfish',
  124: 'crayfish, crawfish, crawdad, crawdaddy',
  125: 'hermit crab',
  126: 'isopod',
  127: 'white stork, Ciconia ciconia',
  128: 'black stork, Ciconia nigra',
  129: 'spoonbill',
  130: 'flamingo',
  131: 'little blue heron, Egretta caerulea',
  132: 'American egret, great white heron, Egretta albus',
  133: 'bittern',
  134: 'crane',
  135: 'limpkin, Aramus pictus',
  136: 'European gallinule, Porphyrio porphyrio',
  137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
  138: 'bustard',
  139: 'ruddy turnstone, Arenaria interpres',
  140: 'red-backed sandpiper, dunlin, Erolia alpina',
  141: 'redshank, Tringa totanus',
  142: 'dowitcher',
  143: 'oystercatcher, oyster catcher',
  144: 'pelican',
  145: 'king penguin, Aptenodytes patagonica',
  146: 'albatross, mollymawk',
  147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, ' +
      'Eschrichtius robustus',
  148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
  149: 'dugong, Dugong dugon',
  150: 'sea lion',
  151: 'Chihuahua',
  152: 'Japanese spaniel',
  153: 'Maltese dog, Maltese terrier, Maltese',
  154: 'Pekinese, Pekingese, Peke',
  155: 'Shih-Tzu',
  156: 'Blenheim spaniel',
  157: 'papillon',
  158: 'toy terrier',
  159: 'Rhodesian ridgeback',
  160: 'Afghan hound, Afghan',
  161: 'basset, basset hound',
  162: 'beagle',
  163: 'bloodhound, sleuthhound',
  164: 'bluetick',
  165: 'black-and-tan coonhound',
  166: 'Walker hound, Walker foxhound',
  167: 'English foxhound',
  168: 'redbone',
  169: 'borzoi, Russian wolfhound',
  170: 'Irish wolfhound',
  171: 'Italian greyhound',
  172: 'whippet',
  173: 'Ibizan hound, Ibizan Podenco',
  174: 'Norwegian elkhound, elkhound',
  175: 'otterhound, otter hound',
  176: 'Saluki, gazelle hound',
  177: 'Scottish deerhound, deerhound',
  178: 'Weimaraner',
  179: 'Staffordshire bullterrier, Staffordshire bull terrier',
  180: 'American Staffordshire terrier, Staffordshire terrier, American pit ' +
      'bull terrier, pit bull terrier',
  181: 'Bedlington terrier',
  182: 'Border terrier',
  183: 'Kerry blue terrier',
  184: 'Irish terrier',
  185: 'Norfolk terrier',
  186: 'Norwich terrier',
  187: 'Yorkshire terrier',
  188: 'wire-haired fox terrier',
  189: 'Lakeland terrier',
  190: 'Sealyham terrier, Sealyham',
  191: 'Airedale, Airedale terrier',
  192: 'cairn, cairn terrier',
  193: 'Australian terrier',
  194: 'Dandie Dinmont, Dandie Dinmont terrier',
  195: 'Boston bull, Boston terrier',
  196: 'miniature schnauzer',
  197: 'giant schnauzer',
  198: 'standard schnauzer',
  199: 'Scotch terrier, Scottish terrier, Scottie',
  200: 'Tibetan terrier, chrysanthemum dog',
  201: 'silky terrier, Sydney silky',
  202: 'soft-coated wheaten terrier',
  203: 'West Highland white terrier',
  204: 'Lhasa, Lhasa apso',
  205: 'flat-coated retriever',
  206: 'curly-coated retriever',
  207: 'golden retriever',
  208: 'Labrador retriever',
  209: 'Chesapeake Bay retriever',
  210: 'German short-haired pointer',
  211: 'vizsla, Hungarian pointer',
  212: 'English setter',
  213: 'Irish setter, red setter',
  214: 'Gordon setter',
  215: 'Brittany spaniel',
  216: 'clumber, clumber spaniel',
  217: 'English springer, English springer spaniel',
  218: 'Welsh springer spaniel',
  219: 'cocker spaniel, English cocker spaniel, cocker',
  220: 'Sussex spaniel',
  221: 'Irish water spaniel',
  222: 'kuvasz',
  223: 'schipperke',
  224: 'groenendael',
  225: 'malinois',
  226: 'briard',
  227: 'kelpie',
  228: 'komondor',
  229: 'Old English sheepdog, bobtail',
  230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
  231: 'collie',
  232: 'Border collie',
  233: 'Bouvier des Flandres, Bouviers des Flandres',
  234: 'Rottweiler',
  235: 'German shepherd, German shepherd dog, German police dog, alsatian',
  236: 'Doberman, Doberman pinscher',
  237: 'miniature pinscher',
  238: 'Greater Swiss Mountain dog',
  239: 'Bernese mountain dog',
  240: 'Appenzeller',
  241: 'EntleBucher',
  242: 'boxer',
  243: 'bull mastiff',
  244: 'Tibetan mastiff',
  245: 'French bulldog',
  246: 'Great Dane',
  247: 'Saint Bernard, St Bernard',
  248: 'Eskimo dog, husky',
  249: 'malamute, malemute, Alaskan malamute',
  250: 'Siberian husky',
  251: 'dalmatian, coach dog, carriage dog',
  252: 'affenpinscher, monkey pinscher, monkey dog',
  253: 'basenji',
  254: 'pug, pug-dog',
  255: 'Leonberg',
  256: 'Newfoundland, Newfoundland dog',
  257: 'Great Pyrenees',
  258: 'Samoyed, Samoyede',
  259: 'Pomeranian',
  260: 'chow, chow chow',
  261: 'keeshond',
  262: 'Brabancon griffon',
  263: 'Pembroke, Pembroke Welsh corgi',
  264: 'Cardigan, Cardigan Welsh corgi',
  265: 'toy poodle',
  266: 'miniature poodle',
  267: 'standard poodle',
  268: 'Mexican hairless',
  269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
  270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
  271: 'red wolf, maned wolf, Canis rufus, Canis niger',
  272: 'coyote, prairie wolf, brush wolf, Canis latrans',
  273: 'dingo, warrigal, warragal, Canis dingo',
  274: 'dhole, Cuon alpinus',
  275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
  276: 'hyena, hyaena',
  277: 'red fox, Vulpes vulpes',
  278: 'kit fox, Vulpes macrotis',
  279: 'Arctic fox, white fox, Alopex lagopus',
  280: 'grey fox, gray fox, Urocyon cinereoargenteus',
  281: 'tabby, tabby cat',
  282: 'tiger cat',
  283: 'Persian cat',
  284: 'Siamese cat, Siamese',
  285: 'Egyptian cat',
  286: 'cougar, puma, catamount, mountain lion, painter, panther, ' +
      'Felis concolor',
  287: 'lynx, catamount',
  288: 'leopard, Panthera pardus',
  289: 'snow leopard, ounce, Panthera uncia',
  290: 'jaguar, panther, Panthera onca, Felis onca',
  291: 'lion, king of beasts, Panthera leo',
  292: 'tiger, Panthera tigris',
  293: 'cheetah, chetah, Acinonyx jubatus',
  294: 'brown bear, bruin, Ursus arctos',
  295: 'American black bear, black bear, Ursus americanus, Euarctos ' +
      'americanus',
  296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
  297: 'sloth bear, Melursus ursinus, Ursus ursinus',
  298: 'mongoose',
  299: 'meerkat, mierkat',
  300: 'tiger beetle',
  301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
  302: 'ground beetle, carabid beetle',
  303: 'long-horned beetle, longicorn, longicorn beetle',
  304: 'leaf beetle, chrysomelid',
  305: 'dung beetle',
  306: 'rhinoceros beetle',
  307: 'weevil',
  308: 'fly',
  309: 'bee',
  310: 'ant, emmet, pismire',
  311: 'grasshopper, hopper',
  312: 'cricket',
  313: 'walking stick, walkingstick, stick insect',
  314: 'cockroach, roach',
  315: 'mantis, mantid',
  316: 'cicada, cicala',
  317: 'leafhopper',
  318: 'lacewing, lacewing fly',
  319: 'dragonfly, darning needle, devil\'s darning needle, sewing needle, ' +
      'snake feeder, snake doctor, mosquito hawk, skeeter hawk',
  320: 'damselfly',
  321: 'admiral',
  322: 'ringlet, ringlet butterfly',
  323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
  324: 'cabbage butterfly',
  325: 'sulphur butterfly, sulfur butterfly',
  326: 'lycaenid, lycaenid butterfly',
  327: 'starfish, sea star',
  328: 'sea urchin',
  329: 'sea cucumber, holothurian',
  330: 'wood rabbit, cottontail, cottontail rabbit',
  331: 'hare',
  332: 'Angora, Angora rabbit',
  333: 'hamster',
  334: 'porcupine, hedgehog',
  335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
  336: 'marmot',
  337: 'beaver',
  338: 'guinea pig, Cavia cobaya',
  339: 'sorrel',
  340: 'zebra',
  341: 'hog, pig, grunter, squealer, Sus scrofa',
  342: 'wild boar, boar, Sus scrofa',
  343: 'warthog',
  344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
  345: 'ox',
  346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
  347: 'bison',
  348: 'ram, tup',
  349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky ' +
      'Mountain sheep, Ovis canadensis',
  350: 'ibex, Capra ibex',
  351: 'hartebeest',
  352: 'impala, Aepyceros melampus',
  353: 'gazelle',
  354: 'Arabian camel, dromedary, Camelus dromedarius',
  355: 'llama',
  356: 'weasel',
  357: 'mink',
  358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
  359: 'black-footed ferret, ferret, Mustela nigripes',
  360: 'otter',
  361: 'skunk, polecat, wood pussy',
  362: 'badger',
  363: 'armadillo',
  364: 'three-toed sloth, ai, Bradypus tridactylus',
  365: 'orangutan, orang, orangutang, Pongo pygmaeus',
  366: 'gorilla, Gorilla gorilla',
  367: 'chimpanzee, chimp, Pan troglodytes',
  368: 'gibbon, Hylobates lar',
  369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
  370: 'guenon, guenon monkey',
  371: 'patas, hussar monkey, Erythrocebus patas',
  372: 'baboon',
  373: 'macaque',
  374: 'langur',
  375: 'colobus, colobus monkey',
  376: 'proboscis monkey, Nasalis larvatus',
  377: 'marmoset',
  378: 'capuchin, ringtail, Cebus capucinus',
  379: 'howler monkey, howler',
  380: 'titi, titi monkey',
  381: 'spider monkey, Ateles geoffroyi',
  382: 'squirrel monkey, Saimiri sciureus',
  383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
  384: 'indri, indris, Indri indri, Indri brevicaudatus',
  385: 'Indian elephant, Elephas maximus',
  386: 'African elephant, Loxodonta africana',
  387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
  388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
  389: 'barracouta, snoek',
  390: 'eel',
  391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus ' +
      'kisutch',
  392: 'rock beauty, Holocanthus tricolor',
  393: 'anemone fish',
  394: 'sturgeon',
  395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
  396: 'lionfish',
  397: 'puffer, pufferfish, blowfish, globefish',
  398: 'abacus',
  399: 'abaya',
  400: 'academic gown, academic robe, judge\'s robe',
  401: 'accordion, piano accordion, squeeze box',
  402: 'acoustic guitar',
  403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
  404: 'airliner',
  405: 'airship, dirigible',
  406: 'altar',
  407: 'ambulance',
  408: 'amphibian, amphibious vehicle',
  409: 'analog clock',
  410: 'apiary, bee house',
  411: 'apron',
  412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, ' +
      'dustbin, trash barrel, trash bin',
  413: 'assault rifle, assault gun',
  414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
  415: 'bakery, bakeshop, bakehouse',
  416: 'balance beam, beam',
  417: 'balloon',
  418: 'ballpoint, ballpoint pen, ballpen, Biro',
  419: 'Band Aid',
  420: 'banjo',
  421: 'bannister, banister, balustrade, balusters, handrail',
  422: 'barbell',
  423: 'barber chair',
  424: 'barbershop',
  425: 'barn',
  426: 'barometer',
  427: 'barrel, cask',
  428: 'barrow, garden cart, lawn cart, wheelbarrow',
  429: 'baseball',
  430: 'basketball',
  431: 'bassinet',
  432: 'bassoon',
  433: 'bathing cap, swimming cap',
  434: 'bath towel',
  435: 'bathtub, bathing tub, bath, tub',
  436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station ' +
      'waggon, waggon',
  437: 'beacon, lighthouse, beacon light, pharos',
  438: 'beaker',
  439: 'bearskin, busby, shako',
  440: 'beer bottle',
  441: 'beer glass',
  442: 'bell cote, bell cot',
  443: 'bib',
  444: 'bicycle-built-for-two, tandem bicycle, tandem',
  445: 'bikini, two-piece',
  446: 'binder, ring-binder',
  447: 'binoculars, field glasses, opera glasses',
  448: 'birdhouse',
  449: 'boathouse',
  450: 'bobsled, bobsleigh, bob',
  451: 'bolo tie, bolo, bola tie, bola',
  452: 'bonnet, poke bonnet',
  453: 'bookcase',
  454: 'bookshop, bookstore, bookstall',
  455: 'bottlecap',
  456: 'bow',
  457: 'bow tie, bow-tie, bowtie',
  458: 'brass, memorial tablet, plaque',
  459: 'brassiere, bra, bandeau',
  460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
  461: 'breastplate, aegis, egis',
  462: 'broom',
  463: 'bucket, pail',
  464: 'buckle',
  465: 'bulletproof vest',
  466: 'bullet train, bullet',
  467: 'butcher shop, meat market',
  468: 'cab, hack, taxi, taxicab',
  469: 'caldron, cauldron',
  470: 'candle, taper, wax light',
  471: 'cannon',
  472: 'canoe',
  473: 'can opener, tin opener',
  474: 'cardigan',
  475: 'car mirror',
  476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
  477: 'carpenter\'s kit, tool kit',
  478: 'carton',
  479: 'car wheel',
  480: 'cash machine, cash dispenser, automated teller machine, automatic ' +
      'teller machine, automated teller, automatic teller, ATM',
  481: 'cassette',
  482: 'cassette player',
  483: 'castle',
  484: 'catamaran',
  485: 'CD player',
  486: 'cello, violoncello',
  487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
  488: 'chain',
  489: 'chainlink fence',
  490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ' +
      'ring armour',
  491: 'chain saw, chainsaw',
  492: 'chest',
  493: 'chiffonier, commode',
  494: 'chime, bell, gong',
  495: 'china cabinet, china closet',
  496: 'Christmas stocking',
  497: 'church, church building',
  498: 'cinema, movie theater, movie theatre, movie house, picture palace',
  499: 'cleaver, meat cleaver, chopper',
  500: 'cliff dwelling',
  501: 'cloak',
  502: 'clog, geta, patten, sabot',
  503: 'cocktail shaker',
  504: 'coffee mug',
  505: 'coffeepot',
  506: 'coil, spiral, volute, whorl, helix',
  507: 'combination lock',
  508: 'computer keyboard, keypad',
  509: 'confectionery, confectionary, candy store',
  510: 'container ship, containership, container vessel',
  511: 'convertible',
  512: 'corkscrew, bottle screw',
  513: 'cornet, horn, trumpet, trump',
  514: 'cowboy boot',
  515: 'cowboy hat, ten-gallon hat',
  516: 'cradle',
  517: 'crane',
  518: 'crash helmet',
  519: 'crate',
  520: 'crib, cot',
  521: 'Crock Pot',
  522: 'croquet ball',
  523: 'crutch',
  524: 'cuirass',
  525: 'dam, dike, dyke',
  526: 'desk',
  527: 'desktop computer',
  528: 'dial telephone, dial phone',
  529: 'diaper, nappy, napkin',
  530: 'digital clock',
  531: 'digital watch',
  532: 'dining table, board',
  533: 'dishrag, dishcloth',
  534: 'dishwasher, dish washer, dishwashing machine',
  535: 'disk brake, disc brake',
  536: 'dock, dockage, docking facility',
  537: 'dogsled, dog sled, dog sleigh',
  538: 'dome',
  539: 'doormat, welcome mat',
  540: 'drilling platform, offshore rig',
  541: 'drum, membranophone, tympan',
  542: 'drumstick',
  543: 'dumbbell',
  544: 'Dutch oven',
  545: 'electric fan, blower',
  546: 'electric guitar',
  547: 'electric locomotive',
  548: 'entertainment center',
  549: 'envelope',
  550: 'espresso maker',
  551: 'face powder',
  552: 'feather boa, boa',
  553: 'file, file cabinet, filing cabinet',
  554: 'fireboat',
  555: 'fire engine, fire truck',
  556: 'fire screen, fireguard',
  557: 'flagpole, flagstaff',
  558: 'flute, transverse flute',
  559: 'folding chair',
  560: 'football helmet',
  561: 'forklift',
  562: 'fountain',
  563: 'fountain pen',
  564: 'four-poster',
  565: 'freight car',
  566: 'French horn, horn',
  567: 'frying pan, frypan, skillet',
  568: 'fur coat',
  569: 'garbage truck, dustcart',
  570: 'gasmask, respirator, gas helmet',
  571: 'gas pump, gasoline pump, petrol pump, island dispenser',
  572: 'goblet',
  573: 'go-kart',
  574: 'golf ball',
  575: 'golfcart, golf cart',
  576: 'gondola',
  577: 'gong, tam-tam',
  578: 'gown',
  579: 'grand piano, grand',
  580: 'greenhouse, nursery, glasshouse',
  581: 'grille, radiator grille',
  582: 'grocery store, grocery, food market, market',
  583: 'guillotine',
  584: 'hair slide',
  585: 'hair spray',
  586: 'half track',
  587: 'hammer',
  588: 'hamper',
  589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
  590: 'hand-held computer, hand-held microcomputer',
  591: 'handkerchief, hankie, hanky, hankey',
  592: 'hard disc, hard disk, fixed disk',
  593: 'harmonica, mouth organ, harp, mouth harp',
  594: 'harp',
  595: 'harvester, reaper',
  596: 'hatchet',
  597: 'holster',
  598: 'home theater, home theatre',
  599: 'honeycomb',
  600: 'hook, claw',
  601: 'hoopskirt, crinoline',
  602: 'horizontal bar, high bar',
  603: 'horse cart, horse-cart',
  604: 'hourglass',
  605: 'iPod',
  606: 'iron, smoothing iron',
  607: 'jack-o\'-lantern',
  608: 'jean, blue jean, denim',
  609: 'jeep, landrover',
  610: 'jersey, T-shirt, tee shirt',
  611: 'jigsaw puzzle',
  612: 'jinrikisha, ricksha, rickshaw',
  613: 'joystick',
  614: 'kimono',
  615: 'knee pad',
  616: 'knot',
  617: 'lab coat, laboratory coat',
  618: 'ladle',
  619: 'lampshade, lamp shade',
  620: 'laptop, laptop computer',
  621: 'lawn mower, mower',
  622: 'lens cap, lens cover',
  623: 'letter opener, paper knife, paperknife',
  624: 'library',
  625: 'lifeboat',
  626: 'lighter, light, igniter, ignitor',
  627: 'limousine, limo',
  628: 'liner, ocean liner',
  629: 'lipstick, lip rouge',
  630: 'Loafer',
  631: 'lotion',
  632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker ' +
      'system',
  633: 'loupe, jeweler\'s loupe',
  634: 'lumbermill, sawmill',
  635: 'magnetic compass',
  636: 'mailbag, postbag',
  637: 'mailbox, letter box',
  638: 'maillot',
  639: 'maillot, tank suit',
  640: 'manhole cover',
  641: 'maraca',
  642: 'marimba, xylophone',
  643: 'mask',
  644: 'matchstick',
  645: 'maypole',
  646: 'maze, labyrinth',
  647: 'measuring cup',
  648: 'medicine chest, medicine cabinet',
  649: 'megalith, megalithic structure',
  650: 'microphone, mike',
  651: 'microwave, microwave oven',
  652: 'military uniform',
  653: 'milk can',
  654: 'minibus',
  655: 'miniskirt, mini',
  656: 'minivan',
  657: 'missile',
  658: 'mitten',
  659: 'mixing bowl',
  660: 'mobile home, manufactured home',
  661: 'Model T',
  662: 'modem',
  663: 'monastery',
  664: 'monitor',
  665: 'moped',
  666: 'mortar',
  667: 'mortarboard',
  668: 'mosque',
  669: 'mosquito net',
  670: 'motor scooter, scooter',
  671: 'mountain bike, all-terrain bike, off-roader',
  672: 'mountain tent',
  673: 'mouse, computer mouse',
  674: 'mousetrap',
  675: 'moving van',
  676: 'muzzle',
  677: 'nail',
  678: 'neck brace',
  679: 'necklace',
  680: 'nipple',
  681: 'notebook, notebook computer',
  682: 'obelisk',
  683: 'oboe, hautboy, hautbois',
  684: 'ocarina, sweet potato',
  685: 'odometer, hodometer, mileometer, milometer',
  686: 'oil filter',
  687: 'organ, pipe organ',
  688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
  689: 'overskirt',
  690: 'oxcart',
  691: 'oxygen mask',
  692: 'packet',
  693: 'paddle, boat paddle',
  694: 'paddlewheel, paddle wheel',
  695: 'padlock',
  696: 'paintbrush',
  697: 'pajama, pyjama, pj\'s, jammies',
  698: 'palace',
  699: 'panpipe, pandean pipe, syrinx',
  700: 'paper towel',
  701: 'parachute, chute',
  702: 'parallel bars, bars',
  703: 'park bench',
  704: 'parking meter',
  705: 'passenger car, coach, carriage',
  706: 'patio, terrace',
  707: 'pay-phone, pay-station',
  708: 'pedestal, plinth, footstall',
  709: 'pencil box, pencil case',
  710: 'pencil sharpener',
  711: 'perfume, essence',
  712: 'Petri dish',
  713: 'photocopier',
  714: 'pick, plectrum, plectron',
  715: 'pickelhaube',
  716: 'picket fence, paling',
  717: 'pickup, pickup truck',
  718: 'pier',
  719: 'piggy bank, penny bank',
  720: 'pill bottle',
  721: 'pillow',
  722: 'ping-pong ball',
  723: 'pinwheel',
  724: 'pirate, pirate ship',
  725: 'pitcher, ewer',
  726: 'plane, carpenter\'s plane, woodworking plane',
  727: 'planetarium',
  728: 'plastic bag',
  729: 'plate rack',
  730: 'plow, plough',
  731: 'plunger, plumber\'s helper',
  732: 'Polaroid camera, Polaroid Land camera',
  733: 'pole',
  734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black ' +
      'Maria',
  735: 'poncho',
  736: 'pool table, billiard table, snooker table',
  737: 'pop bottle, soda bottle',
  738: 'pot, flowerpot',
  739: 'potter\'s wheel',
  740: 'power drill',
  741: 'prayer rug, prayer mat',
  742: 'printer',
  743: 'prison, prison house',
  744: 'projectile, missile',
  745: 'projector',
  746: 'puck, hockey puck',
  747: 'punching bag, punch bag, punching ball, punchball',
  748: 'purse',
  749: 'quill, quill pen',
  750: 'quilt, comforter, comfort, puff',
  751: 'racer, race car, racing car',
  752: 'racket, racquet',
  753: 'radiator',
  754: 'radio, wireless',
  755: 'radio telescope, radio reflector',
  756: 'rain barrel',
  757: 'recreational vehicle, RV, R.V.',
  758: 'reel',
  759: 'reflex camera',
  760: 'refrigerator, icebox',
  761: 'remote control, remote',
  762: 'restaurant, eating house, eating place, eatery',
  763: 'revolver, six-gun, six-shooter',
  764: 'rifle',
  765: 'rocking chair, rocker',
  766: 'rotisserie',
  767: 'rubber eraser, rubber, pencil eraser',
  768: 'rugby ball',
  769: 'rule, ruler',
  770: 'running shoe',
  771: 'safe',
  772: 'safety pin',
  773: 'saltshaker, salt shaker',
  774: 'sandal',
  775: 'sarong',
  776: 'sax, saxophone',
  777: 'scabbard',
  778: 'scale, weighing machine',
  779: 'school bus',
  780: 'schooner',
  781: 'scoreboard',
  782: 'screen, CRT screen',
  783: 'screw',
  784: 'screwdriver',
  785: 'seat belt, seatbelt',
  786: 'sewing machine',
  787: 'shield, buckler',
  788: 'shoe shop, shoe-shop, shoe store',
  789: 'shoji',
  790: 'shopping basket',
  791: 'shopping cart',
  792: 'shovel',
  793: 'shower cap',
  794: 'shower curtain',
  795: 'ski',
  796: 'ski mask',
  797: 'sleeping bag',
  798: 'slide rule, slipstick',
  799: 'sliding door',
  800: 'slot, one-armed bandit',
  801: 'snorkel',
  802: 'snowmobile',
  803: 'snowplow, snowplough',
  804: 'soap dispenser',
  805: 'soccer ball',
  806: 'sock',
  807: 'solar dish, solar collector, solar furnace',
  808: 'sombrero',
  809: 'soup bowl',
  810: 'space bar',
  811: 'space heater',
  812: 'space shuttle',
  813: 'spatula',
  814: 'speedboat',
  815: 'spider web, spider\'s web',
  816: 'spindle',
  817: 'sports car, sport car',
  818: 'spotlight, spot',
  819: 'stage',
  820: 'steam locomotive',
  821: 'steel arch bridge',
  822: 'steel drum',
  823: 'stethoscope',
  824: 'stole',
  825: 'stone wall',
  826: 'stopwatch, stop watch',
  827: 'stove',
  828: 'strainer',
  829: 'streetcar, tram, tramcar, trolley, trolley car',
  830: 'stretcher',
  831: 'studio couch, day bed',
  832: 'stupa, tope',
  833: 'submarine, pigboat, sub, U-boat',
  834: 'suit, suit of clothes',
  835: 'sundial',
  836: 'sunglass',
  837: 'sunglasses, dark glasses, shades',
  838: 'sunscreen, sunblock, sun blocker',
  839: 'suspension bridge',
  840: 'swab, swob, mop',
  841: 'sweatshirt',
  842: 'swimming trunks, bathing trunks',
  843: 'swing',
  844: 'switch, electric switch, electrical switch',
  845: 'syringe',
  846: 'table lamp',
  847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
  848: 'tape player',
  849: 'teapot',
  850: 'teddy, teddy bear',
  851: 'television, television system',
  852: 'tennis ball',
  853: 'thatch, thatched roof',
  854: 'theater curtain, theatre curtain',
  855: 'thimble',
  856: 'thresher, thrasher, threshing machine',
  857: 'throne',
  858: 'tile roof',
  859: 'toaster',
  860: 'tobacco shop, tobacconist shop, tobacconist',
  861: 'toilet seat',
  862: 'torch',
  863: 'totem pole',
  864: 'tow truck, tow car, wrecker',
  865: 'toyshop',
  866: 'tractor',
  867: 'trailer truck, tractor trailer, trucking rig, rig, articulated ' +
      'lorry, semi',
  868: 'tray',
  869: 'trench coat',
  870: 'tricycle, trike, velocipede',
  871: 'trimaran',
  872: 'tripod',
  873: 'triumphal arch',
  874: 'trolleybus, trolley coach, trackless trolley',
  875: 'trombone',
  876: 'tub, vat',
  877: 'turnstile',
  878: 'typewriter keyboard',
  879: 'umbrella',
  880: 'unicycle, monocycle',
  881: 'upright, upright piano',
  882: 'vacuum, vacuum cleaner',
  883: 'vase',
  884: 'vault',
  885: 'velvet',
  886: 'vending machine',
  887: 'vestment',
  888: 'viaduct',
  889: 'violin, fiddle',
  890: 'volleyball',
  891: 'waffle iron',
  892: 'wall clock',
  893: 'wallet, billfold, notecase, pocketbook',
  894: 'wardrobe, closet, press',
  895: 'warplane, military plane',
  896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
  897: 'washer, automatic washer, washing machine',
  898: 'water bottle',
  899: 'water jug',
  900: 'water tower',
  901: 'whiskey jug',
  902: 'whistle',
  903: 'wig',
  904: 'window screen',
  905: 'window shade',
  906: 'Windsor tie',
  907: 'wine bottle',
  908: 'wing',
  909: 'wok',
  910: 'wooden spoon',
  911: 'wool, woolen, woollen',
  912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
  913: 'wreck',
  914: 'yawl',
  915: 'yurt',
  916: 'web site, website, internet site, site',
  917: 'comic book',
  918: 'crossword puzzle, crossword',
  919: 'street sign',
  920: 'traffic light, traffic signal, stoplight',
  921: 'book jacket, dust cover, dust jacket, dust wrapper',
  922: 'menu',
  923: 'plate',
  924: 'guacamole',
  925: 'consomme',
  926: 'hot pot, hotpot',
  927: 'trifle',
  928: 'ice cream, icecream',
  929: 'ice lolly, lolly, lollipop, popsicle',
  930: 'French loaf',
  931: 'bagel, beigel',
  932: 'pretzel',
  933: 'cheeseburger',
  934: 'hotdog, hot dog, red hot',
  935: 'mashed potato',
  936: 'head cabbage',
  937: 'broccoli',
  938: 'cauliflower',
  939: 'zucchini, courgette',
  940: 'spaghetti squash',
  941: 'acorn squash',
  942: 'butternut squash',
  943: 'cucumber, cuke',
  944: 'artichoke, globe artichoke',
  945: 'bell pepper',
  946: 'cardoon',
  947: 'mushroom',
  948: 'Granny Smith',
  949: 'strawberry',
  950: 'orange',
  951: 'lemon',
  952: 'fig',
  953: 'pineapple, ananas',
  954: 'banana',
  955: 'jackfruit, jak, jack',
  956: 'custard apple',
  957: 'pomegranate',
  958: 'hay',
  959: 'carbonara',
  960: 'chocolate sauce, chocolate syrup',
  961: 'dough',
  962: 'meat loaf, meatloaf',
  963: 'pizza, pizza pie',
  964: 'potpie',
  965: 'burrito',
  966: 'red wine',
  967: 'espresso',
  968: 'cup',
  969: 'eggnog',
  970: 'alp',
  971: 'bubble',
  972: 'cliff, drop, drop-off',
  973: 'coral reef',
  974: 'geyser',
  975: 'lakeside, lakeshore',
  976: 'promontory, headland, head, foreland',
  977: 'sandbar, sand bar',
  978: 'seashore, coast, seacoast, sea-coast',
  979: 'valley, vale',
  980: 'volcano',
  981: 'ballplayer, baseball player',
  982: 'groom, bridegroom',
  983: 'scuba diver',
  984: 'rapeseed',
  985: 'daisy',
  986: 'yellow lady\'s slipper, yellow lady-slipper, Cypripedium calceolus, ' +
      'Cypripedium parviflorum',
  987: 'corn',
  988: 'acorn',
  989: 'hip, rose hip, rosehip',
  990: 'buckeye, horse chestnut, conker',
  991: 'coral fungus',
  992: 'agaric',
  993: 'gyromitra',
  994: 'stinkhorn, carrion fungus',
  995: 'earthstar',
  996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola ' +
      'frondosa',
  997: 'bolete',
  998: 'ear, spike, capitulum',
  999: 'toilet tissue, toilet paper, bathroom tissue'
};

1

index.html

html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">

util.js

javascript
export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import { IMAGENET_CLASSES } from './imagenet_classes';
import { file2img } from './utils';

// 定义模型地址
const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';

window.onload = async () => {
 // 加载模型
    const model = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
 // 进行预测
    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
         // 将img 元素变为tensor,并进行归一化处理(先将数据变为-127.5~127.5之间,然后除以127.5使之处于-1~1之间),然后变为1张224*224的彩色图片
            const input = tf.browser.fromPixels(img)
                .toFloat()
                .sub(255 / 2)
                .div(255 / 2)
                .reshape([1, 224, 224, 3]);
            // 将处理好的值传入到模型预测方法中
            return model.predict(input);
        });
  
  // 得到预测类型索引
        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${IMAGENET_CLASSES[index]}`);
        }, 0);
    };
};

12. 基于迁移学习的图像分类器:商标识别

1

a.基于迁移学习的图像分类器:商标识别任务简介

1

b.加载商标训练数据并可视化

1

data.js:

javascript
const IMAGE_SIZE = 224;

const loadImg = (src) => {
    return new Promise(resolve => {
        const img = new Image();
        img.crossOrigin = "anonymous";
        img.src = src;
        img.width = IMAGE_SIZE;
        img.height = IMAGE_SIZE;
        img.onload = () => resolve(img);
    });
};
export const getInputs = async () => {
    const loadImgs = [];
    const labels = [];
    for (let i = 0; i < 30; i += 1) {
        ['android', 'apple', 'windows'].forEach(label => {
            const src = `http://127.0.0.1:8080/brand/train/${label}-${i}.jpg`;
            const img = loadImg(src);
            loadImgs.push(img);
            labels.push([
                label === 'android' ? 1 : 0,
                label === 'apple' ? 1 : 0,
                label === 'windows' ? 1 : 0,
            ]);
        });
    }
    const inputs = await Promise.all(loadImgs);
    return {
        inputs,
        labels,
    };
}

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data'

window.onload = async () => {
    const { inputs, labels } = await getInputs();
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });
};

c.定义模型结构:截断模型+双层神经网络

1

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;

window.onload = async () => {
    const { inputs, labels } = await getInputs();
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

    // 加载模型
    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
     // 查看模型概况
    mobilenet.summary();
    // 通过名称获取模型某一个中间层
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    // 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    // 用于创建双层神经网络的新模型
    const model = tf.sequential();
    // 把截断层摊平成一个一维向量
    model.add(tf.layers.flatten({
        inputShape: layer.outputShape.slice(1)
    }));
    // 构建一个双层神经网络
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));
};

d.迁移学习下的模型训练

1

utils:

javascript
import * as tf from '@tensorflow/tfjs';

export function img2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}

script:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x } from './utils';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
    const { inputs, labels } = await getInputs();
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

    // 加载模型
    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
     // 查看模型概况
    mobilenet.summary();
    // 通过名称获取模型某一个中间层
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    // 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    // 用于创建双层神经网络的新模型
    const model = tf.sequential();
    // 把截断层摊平成一个一维向量
    model.add(tf.layers.flatten({
        inputShape: layer.outputShape.slice(1)
    }));
    // 构建一个双层神经网络
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));

    // 设置损失函数和优化器
    model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

    // 把输入数据先给截断模型,截断模型处理完后再给新模型处理

    // 将图片处理成Mobilenet所需格式后传入截断模型
    const { xs, ys } = tf.tidy(() => {
        const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
        const ys = tf.tensor(labels);
        return { xs, ys };
    });

    // 将截断模型处理后的数据传入新模型进行训练
    await model.fit(xs, ys, {
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });
};

e.迁移学习下的模型预测

1

index.html

html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">

utils:

javascript
import * as tf from '@tensorflow/tfjs';

// 将img变成mobilenet模型需要的格式
export function img2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}

export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}
javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
    const { inputs, labels } = await getInputs();
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

    // 加载模型
    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
     // 查看模型概况
    mobilenet.summary();
    // 通过名称获取模型某一个中间层
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    // 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    // 用于创建双层神经网络的新模型
    const model = tf.sequential();
    // 把截断层摊平成一个一维向量
    model.add(tf.layers.flatten({
        inputShape: layer.outputShape.slice(1)
    }));
    // 构建一个双层神经网络
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));

    // 设置损失函数和优化器
    model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

    // 把输入数据先给截断模型,截断模型处理完后再给新模型处理

    // 将图片处理成Mobilenet所需格式后传入截断模型
    const { xs, ys } = tf.tidy(() => {
        const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
        const ys = tf.tensor(labels);
        return { xs, ys };
    });

    // 将截断模型处理后的数据传入新模型进行训练
    await model.fit(xs, ys, {
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });
 
 // 模型预测
    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            const input = truncatedMobilenet.predict(x);
            return model.predict(input);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${BRAND_CLASSES[index]}`);
        }, 0);
    };
};

f.模型的保存和加载

11

保存训练好的模型文件

index.html

html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<button onclick="download()">下载模型</button>

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import * as tfvis from '@tensorflow/tfjs-vis';
import { getInputs } from './data';
import { img2x, file2img } from './utils';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model/model.json';
const NUM_CLASSES = 3;
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
    const { inputs, labels } = await getInputs();
    const surface = tfvis.visor().surface({ name: '输入示例', styles: { height: 250 } });
    inputs.forEach(img => {
        surface.drawArea.appendChild(img);
    });

    // 加载模型
    const mobilenet = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
     // 查看模型概况
    mobilenet.summary();
    // 通过名称获取模型某一个中间层
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    // 截断模型是以mobilenet模型的输入为输入,截断层为输出的模型
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    // 用于创建双层神经网络的新模型
    const model = tf.sequential();
    // 把截断层摊平成一个一维向量
    model.add(tf.layers.flatten({
        inputShape: layer.outputShape.slice(1)
    }));
    // 构建一个双层神经网络
    model.add(tf.layers.dense({
        units: 10,
        activation: 'relu'
    }));
    model.add(tf.layers.dense({
        units: NUM_CLASSES,
        activation: 'softmax'
    }));

    // 设置损失函数和优化器
    model.compile({ loss: 'categoricalCrossentropy', optimizer: tf.train.adam() });

    // 把输入数据先给截断模型,截断模型处理完后再给新模型处理

    // 将图片处理成Mobilenet所需格式后传入截断模型
    const { xs, ys } = tf.tidy(() => {
        const xs = tf.concat(inputs.map(imgEl => truncatedMobilenet.predict(img2x(imgEl))));
        const ys = tf.tensor(labels);
        return { xs, ys };
    });

    // 将截断模型处理后的数据传入新模型进行训练
    await model.fit(xs, ys, {
        epochs: 20,
        callbacks: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss'],
            { callbacks: ['onEpochEnd'] }
        )
    });

    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            const input = truncatedMobilenet.predict(x);
            return model.predict(input);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${BRAND_CLASSES[index]}`);
        }, 0);
    };
 
 // 模型加载
    window.download = async () => {
     // 调用模型的save方法
        await model.save('downloads://model');
    };
};

1

在新应用中加载模型文件并预测

html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
<br>

utils.js:

javascript
import * as tf from '@tensorflow/tfjs';

export function img2x(imgEl){
    return tf.tidy(() => {
        const input = tf.browser.fromPixels(imgEl)
            .toFloat()
            .sub(255 / 2)
            .div(255 / 2)
            .reshape([1, 224, 224, 3]);
        return input;
    });
}

export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}

script.js:

javascript
import * as tf from '@tensorflow/tfjs';
import { img2x, file2img } from './utils';

const MODEL_PATH = 'http://127.0.0.1:8080';
const BRAND_CLASSES = ['android', 'apple', 'windows'];

window.onload = async () => {
    const mobilenet = await tf.loadLayersModel(MODEL_PATH + '/mobilenet/web_model/model.json');
    mobilenet.summary();
    const layer = mobilenet.getLayer('conv_pw_13_relu');
    const truncatedMobilenet = tf.model({
        inputs: mobilenet.inputs,
        outputs: layer.output
    });

    const model = await tf.loadLayersModel(MODEL_PATH + '/brand/web_model/model.json');

    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const x = img2x(img);
            const input = truncatedMobilenet.predict(x);
            return model.predict(input);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${BRAND_CLASSES[index]}`);
        }, 0);
    };
};

13. 使用预训练模型进行语音识别

1

a.使用预训练模型进行语音识别任务简介

1

b.加载预训练语音识别模型

https://github.com/tensorflow/tfjs-modelshttps://github.com/tensorflow/tfjs-models/tree/master/speech-commandshttps://www.tensorflow.org/datasets/catalog/speech_commands

11

html
<script src="script.js"></script>
<style>
    #result>div {
        float: left;
        padding: 20px;
    }
</style>
<div id="result"></div>
javascript
import * as speechCommands from '@tensorflow-models/speech-commands';

const MODEL_PATH = 'http://127.0.0.1:8080/speech';

window.onload = async () => {
    // 创建识别器:(1.传入浏览器原生富丽雅变换方法2.需要识别的单词 3.null时为默认单词3.自定义模型的URL 4.自定义源信息的URL
    const recognizer = speechCommands.create(
        'BROWSER_FFT',
        null,
        MODEL_PATH + '/model.json',
        MODEL_PATH + '/metadata.json'
    );
    // 调用识别器的ensureModelLoaded,确保模型加载好了
    await recognizer.ensureModelLoaded();
};

c.进行语音识别

111

html
<script src="script.js"></script>
<style>
    #result>div {
        float: left;
        padding: 20px;
    }
</style>
<div id="result"></div>
javascript
import * as speechCommands from '@tensorflow-models/speech-commands';

const MODEL_PATH = 'http://127.0.0.1:8080/speech';

window.onload = async () => {
    // 创建识别器:(1.传入浏览器原生富丽雅变换方法2.需要识别的单词 3.null时为默认单词3.自定义模型的URL 4.自定义源信息的URL
    const recognizer = speechCommands.create(
        'BROWSER_FFT',
        null,
        MODEL_PATH + '/model.json',
        MODEL_PATH + '/metadata.json'
    );
    // 调用识别器的ensureModelLoaded,确保模型加载好了
    await recognizer.ensureModelLoaded();

   // 查看模型能识别哪些单词 
    const labels = recognizer.wordLabels().slice(2);
    const resultEl = document.querySelector('#result');
    resultEl.innerHTML = labels.map(l => `
        <div>${l}</div>
    `).join('');

    // 打开浏览器麦克风开关,监听麦克风输入,并将输入转换为模型输入值
    recognizer.listen(result => {
        // 结果中包含每个单词可能性的评分
        const { scores } = result;
        // 获取评分最高的分数
        const maxValue = Math.max(...scores);
        // 找到评分最高分数的索引
        const index = scores.indexOf(maxValue) - 2;
        // 将单词列表插入到HTML中,并高亮显示读到的单词
        resultEl.innerHTML = labels.map((l, i) => `
        <div style="background: ${i === index && 'green'}">${l}</div>
        `).join('');
    }, {
        // 设置识别频率:文件很多波段,预测截取的波段和该波段有多少覆盖可通过此参数设置。数值越高,识别频率越高。
        overlapFactor: 0.3, 
        // 可能性阈值:只有当监听到的单词和训练的单词有90%的相似,才会执行上面的回调函数
        probabilityThreshold: 0.9
    });
};

14. 基于迁移学习的语音识别器:声控轮播图

a.基于迁移学习的语音识别器:声控轮播图

11

b.在浏览器中收集中文语音训练数据

1

html
<script src="script.js"></script>
<button onclick="collect(this)">上一张</button>
<button onclick="collect(this)">下一张</button>
<button onclick="collect(this)">背景噪音</button>
<pre id="count"></pre>
javascript
import * as speechCommands from '@tensorflow-models/speech-commands';
import * as tfvis from '@tensorflow/tfjs-vis';

const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;

window.onload = async () => {
    const recognizer = speechCommands.create(
        'BROWSER_FFT',
        null,
        MODEL_PATH + '/speech/model.json',
        MODEL_PATH + '/speech/metadata.json'
    );
    await recognizer.ensureModelLoaded();

    //  创建迁移学习器
    transferRecognizer = recognizer.createTransfer('轮播图');
};

// 点击按钮,收集语音数据
window.collect = async (btn) => {
    // 录入过程中禁用按钮防止重复点击录入
    btn.disabled = true;
    const label = btn.innerText;

    // 收集语音训练素材和相应的label
    await transferRecognizer.collectExample(
        label === '背景噪音' ? '_background_noise_' : label
    );

    btn.disabled = false;
    // 可视化显示录入统计数据
    document.querySelector('#count').innerHTML = JSON.stringify(transferRecognizer.countExamples(), null, 2);
};

c.语音识别迁移学习的训练和预测

1

html
<script src="script.js"></script>
<button onclick="collect(this)">上一张</button>
<button onclick="collect(this)">下一张</button>
<button onclick="collect(this)">背景噪音</button>
<pre id="count"></pre>
<button onclick="train()">训练</button>
<br><br>
监听开关:<input type="checkbox" onchange="toggle(this.checked)">
javascript
import * as speechCommands from '@tensorflow-models/speech-commands';
import * as tfvis from '@tensorflow/tfjs-vis';

const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;

window.onload = async () => {
    const recognizer = speechCommands.create(
        'BROWSER_FFT',
        null,
        MODEL_PATH + '/speech/model.json',
        MODEL_PATH + '/speech/metadata.json'
    );
    await recognizer.ensureModelLoaded();

    //  创建迁移学习器
    transferRecognizer = recognizer.createTransfer('轮播图');
};

// 点击按钮,收集语音数据
window.collect = async (btn) => {
    // 录入过程中禁用按钮防止重复点击录入
    btn.disabled = true;
    const label = btn.innerText;

    // 收集语音训练素材和相应的label
    await transferRecognizer.collectExample(
        label === '背景噪音' ? '_background_noise_' : label
    );

    btn.disabled = false;
    // 可视化显示录入统计数据
    document.querySelector('#count').innerHTML = JSON.stringify(transferRecognizer.countExamples(), null, 2);
};

// 模型训练及过程可视化
window.train = async () => {
    await transferRecognizer.train({
        epochs: 30,
        callback: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'acc'],
            { callbacks: ['onEpochEnd'] }
        )
    });
};

// 麦克风打开关闭监听按钮的回调
window.toggle = async (checked) => {
    // 如果开关打开
    if (checked) {
        // 调用迁移学习器的listen方法
        await transferRecognizer.listen(result => {
            // 监听一段语音结束后的回调函数

            // 所有要识别语音的得分情况
            const { scores } = result;
            // 拿到所有类别
            const labels = transferRecognizer.wordLabels();
            // 拿到最高得分的索引index
            const index = scores.indexOf(Math.max(...scores));
            // 得到最高得分
            console.log(labels[index]);
        }, {
            overlapFactor: 0, // 识别频率
            probabilityThreshold: 0.75 // 可能性阈值
        });
    } else {
        // 开关关闭时,停止监听
        transferRecognizer.stopListening();
    }
};

d.语音训练数据的保存和加载

1

数据保存

html
<script src="script.js"></script>
<button onclick="collect(this)">上一张</button>
<button onclick="collect(this)">下一张</button>
<button onclick="collect(this)">背景噪音</button>
<button onclick="save()">保存</button>
<pre id="count"></pre>
<button onclick="train()">训练</button>
<br><br>
监听开关:<input type="checkbox" onchange="toggle(this.checked)">
javascript
import * as speechCommands from '@tensorflow-models/speech-commands';
import * as tfvis from '@tensorflow/tfjs-vis';

const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;

window.onload = async () => {
    const recognizer = speechCommands.create(
        'BROWSER_FFT',
        null,
        MODEL_PATH + '/speech/model.json',
        MODEL_PATH + '/speech/metadata.json'
    );
    await recognizer.ensureModelLoaded();

    //  创建迁移学习器
    transferRecognizer = recognizer.createTransfer('轮播图');
};

// 点击按钮,收集语音数据
window.collect = async (btn) => {
    // 录入过程中禁用按钮防止重复点击录入
    btn.disabled = true;
    const label = btn.innerText;

    // 收集语音训练素材和相应的label
    await transferRecognizer.collectExample(
        label === '背景噪音' ? '_background_noise_' : label
    );

    btn.disabled = false;
    // 可视化显示录入统计数据
    document.querySelector('#count').innerHTML = JSON.stringify(transferRecognizer.countExamples(), null, 2);
};

// 模型训练及过程可视化
window.train = async () => {
    await transferRecognizer.train({
        epochs: 30,
        callback: tfvis.show.fitCallbacks(
            { name: '训练效果' },
            ['loss', 'acc'],
            { callbacks: ['onEpochEnd'] }
        )
    });
};

// 麦克风打开关闭监听按钮的回调
window.toggle = async (checked) => {
    // 如果开关打开
    if (checked) {
        // 调用迁移学习器的listen方法
        await transferRecognizer.listen(result => {
            // 监听一段语音结束后的回调函数

            // 所有要识别语音的得分情况
            const { scores } = result;
            // 拿到所有类别
            const labels = transferRecognizer.wordLabels();
            // 拿到最高得分的索引index
            const index = scores.indexOf(Math.max(...scores));
            // 得到最高得分
            console.log(labels[index]);
        }, {
            overlapFactor: 0, // 识别频率
            probabilityThreshold: 0.75 // 可能性阈值
        });
    } else {
        // 开关关闭时,停止监听
        transferRecognizer.stopListening();
    }
};

window.save = () => {
    // 将语音识别的文件转换为二进制文件
    const arrayBuffer = transferRecognizer.serializeExamples();
    // 将二进制文件通过blob转化为文件下载到本地
    const blob = new Blob([arrayBuffer]);
    const link = document.createElement('a');
    link.href = window.URL.createObjectURL(blob);
    link.download = 'data.bin';
    link.click();
};

数据加载和使用(声控轮播图)

1

html
<script src="script.js"></script>
监听开关:<input type="checkbox" onchange="toggle(this.checked)">

<style>
    .slider {
        width: 600px;
        overflow: hidden;
        margin: 10px auto;
    }
    .slider > div{
        display: flex;
        align-items: center;
    }
</style>
<div class="slider">
    <div>
        <img src="https://cdn.pixabay.com/photo/2019/10/29/15/57/vancouver-4587302__480.jpg" alt="" width="600">
        <img src="https://cdn.pixabay.com/photo/2019/10/31/07/14/coffee-4591159__480.jpg" alt="" width="600">
        <img src="https://cdn.pixabay.com/photo/2019/11/01/11/08/landscape-4593909__480.jpg" alt="" width="600">
        <img src="https://cdn.pixabay.com/photo/2019/11/02/21/45/maple-leaf-4597501__480.jpg" alt="" width="600">
        <img src="https://cdn.pixabay.com/photo/2019/11/02/03/13/in-xinjiang-4595560__480.jpg" alt="" width="600">
        <img src="https://cdn.pixabay.com/photo/2019/11/01/22/45/reschensee-4595385__480.jpg" alt="" width="600">
    </div>
</div>
javascript
import * as speechCommands from '@tensorflow-models/speech-commands';

const MODEL_PATH = 'http://127.0.0.1:8080';
let transferRecognizer;
let curIndex = 0;

window.onload = async () => {
    // 创建识别器
    const recognizer = speechCommands.create(
        'BROWSER_FFT',
        null,
        MODEL_PATH + '/speech/model.json',
        MODEL_PATH + '/speech/metadata.json',
    );
    // 确保模型加载完毕
    await recognizer.ensureModelLoaded();
    // 创建迁移学习器
    transferRecognizer = recognizer.createTransfer('轮播图');

    // 调用fetch方法,请求保存的声音数据
    const res = await fetch(MODEL_PATH + '/slider/data.bin');
    // 转换为arrayBuffer
    const arrayBuffer = await res.arrayBuffer();
    // 将arrayBuffer加载到迁移学习器中
    transferRecognizer.loadExamples(arrayBuffer);
    // 进行训练
    await transferRecognizer.train({ epochs: 30 });
    console.log('done');
};

// 监听声音数据,使用训练好的模型进行预测
window.toggle = async (checked) => {
    if (checked) {
        await transferRecognizer.listen(result => {
            const { scores } = result;
            const labels = transferRecognizer.wordLabels();
            const index = scores.indexOf(Math.max(...scores));
            // 执行相应声音分类下的功能交互
            window.play(labels[index]);
        }, {
            overlapFactor: 0,
            probabilityThreshold: 0.5
        });
    } else {
        transferRecognizer.stopListening();
    }
};

window.play = (label) => {
    // 获取.slider元素下的div元素
    const div = document.querySelector('.slider>div');

    // 根据指令,改变当前索引。注意,索引不能超过首尾范围
    if (label === '上一张') {
        if (curIndex === 0) { return; }
        curIndex -= 1;
    } else {
        if (curIndex === document.querySelectorAll('img').length - 1) { return; }
        curIndex += 1;
    }

    // 设置动画过渡效果
    div.style.transition = "transform 1s"
    // 根据图片索引,.slider元素下的div元素进行transform位置移动,从而实现轮播效果
    div.style.transform = `translateX(-${100 * curIndex}%)`;
};

15. Python 与 JavaScript 模型互转

1111

a.Python 与 JavaScript 模型互转任务简介

1

b.安装 Tensorflow.js Converter

1

安装conda

TensorFlow.js Converter依赖Python,如果你电脑上还安装或依赖别的Python版本,跟TensorFlow.js Converter依赖的Python版本不一致。为了让这些Python应用在各自独立的Python版本下运行,不相互冲突,我们需要为每个Python应用创建单独的Python环境。 安装Conta工具,就可以帮助我们创建虚拟的独立的Python环境。 https://mirror.tuna.tsinghua.edu.cn/help/anaconda/

11111

使用conda安装Python虚拟环境

111

如果需要查看已安装虚拟环境: 1

注意如果需要删除虚拟环境: 1

进入(激活)和退出该虚拟环境: 1

查看当前Python版本: 1

使用conda进行虚拟环境切换: 1

使用安装Python虚拟环境安装tfjs-converter

https://github.com/tensorflow/tfjs

1

https://github.com/tensorflow/tfjs/tree/master/tfjs-converter

11

查看是否安装成功: 1

c.Python 与 JavaScript 模型互转

1

Python模型转JavaScript模型

111111

html
<script src="script.js"></script>
<input type="file" onchange="predict(this.files[0])">
javascript
export function file2img(f) {
    return new Promise(resolve => {
        const reader = new FileReader();
        reader.readAsDataURL(f);
        reader.onload = (e) => {
            const img = document.createElement('img');
            img.src = e.target.result;
            img.width = 224;
            img.height = 224;
            img.onload = () => resolve(img);
        };
    });
}
javascript
import * as tf from '@tensorflow/tfjs';
import { IMAGENET_CLASSES } from './imagenet_classes';
import { file2img } from './utils';

const MOBILENET_MODEL_PATH = 'http://127.0.0.1:8080/mobilenet/web_model2/model.json';

window.onload = async () => {
    const model = await tf.loadLayersModel(MOBILENET_MODEL_PATH);
    window.predict = async (file) => {
        const img = await file2img(file);
        document.body.appendChild(img);
        const pred = tf.tidy(() => {
            const input = tf.browser.fromPixels(img)
                .toFloat()
                .sub(255 / 2)
                .div(255 / 2)
                .reshape([1, 224, 224, 3]);
            return model.predict(input);
        });

        const index = pred.argMax(1).dataSync()[0];
        setTimeout(() => {
            alert(`预测结果:${IMAGENET_CLASSES[index]}`);
        }, 0);
    };
};
javascript
export const IMAGENET_CLASSES = {
  0: 'tench, Tinca tinca',
  1: 'goldfish, Carassius auratus',
  // ...
  // ...
  // ...
  999: 'toilet tissue, toilet paper, bathroom tissue'
}

改为graph_model,也是可以的: 111

JavaScript模型转Python模型

1111

e.JavaScript 模型的互转:分片、量化、加速

11

分片

1111

量化

1111

通过转化为tfjs_graph_model来加速模型

11

最初模型路径:

111

使用tfjs_graph_model模型:

1111

16. 总结

11111