从机器学习到无人驾驶
上QQ阅读APP看书,第一时间看更新

2.6 扩展:使用tensorflow.js进行机器学习

tensorflow.js是一个机器学习的前端框架,Google也在GitHub开源了相关代码。GitHub地址:https://github.com/tensorflow/tfjs。在实现方面,TensorFlow团队使用了WebGL库对运算过程进行优化,使得tensorflow.js在学习尤其是网络扩大的时候能够有更好的性能表现。在API设计方面,框架更多地考量到了开发人员的易用性,在较为底层的API方面使用了TensorFlow Python的许多概念,而在高级抽象API方面则更多地与Keras框架保持一致。

接下来我们通过一个简单的例子了解一下tensorflow.js的魅力。

1. 类库引入

(1)script标签引入

标签的引入是最为直接的方式,引入的地址为https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.9.0。下面给大家提供一个简单的开发模板:

(2)npm引入

如果你使用了Node进行前端架构的开发,就需要包管理工具npm来引入。

   npm install @tensorflow/tfj

下面给大家提供一个简单的开发模板(ES 6):

   import * as tf from '@tensorflow/tfjs';
   // 在下面写机器学习业务代码

2. hello tfjs — 一个简单的示例

(1)代码编写
   // 定义模型:线性回归模型
     const model = tf.sequential();
     model.add(tf.layers.dense({units: 1, inputShape: [1]}));
     
     // 定义模型损失函数和梯度下降算法
     model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
     
     // 准备学习数据
     const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
     const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
     
     //模型学习
     model.fit(xs, ys).then(() => {
       // 使用训练完成的模型进行预测
       model.predict(tf.tensor2d([5], [1, 1])).print();
     });
(2)代码分析

代码中具体做的事情是线性回归分析,步骤总结为:模型定义→模型学习→模型使用。如果大家想要深入了解线性回归分析的内容,可以参考笔者的免费课程,课程链接:https://www.imooc.com/learn/972。

(3)运行结果

在浏览器上运行,在命令行中就能看到想要的输出,如图2.21所示。

图2.21 代码运行结果

近年来,随着前端框架(React、Vue、Angular)的崛起和微信小程序的发力,前端从业人员的开发能力得到了长足的进步,人工智能时代不但给予后台通关前后台的能力,而且也给了前端业务更多的想象力,tensorflow.js就是在这样的环境下应运而生的产物。我们通过上述的入门例子对tensorflow.js有了直观的感受,如果你学习了TensorFlow的核心知识,那么上手tensorflow.js将会非常容易。