2.6 扩展:使用tensorflow.js进行机器学习
tensorflow.js是一个机器学习的前端框架,Google也在GitHub开源了相关代码。GitHub地址:https://github.com/tensorflow/tfjs。在实现方面,TensorFlow团队使用了WebGL库对运算过程进行优化,使得tensorflow.js在学习尤其是网络扩大的时候能够有更好的性能表现。在API设计方面,框架更多地考量到了开发人员的易用性,在较为底层的API方面使用了TensorFlow Python的许多概念,而在高级抽象API方面则更多地与Keras框架保持一致。
接下来我们通过一个简单的例子了解一下tensorflow.js的魅力。
1. 类库引入
(1)script标签引入
标签的引入是最为直接的方式,引入的地址为https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.9.0。下面给大家提供一个简单的开发模板:
(2)npm引入
如果你使用了Node进行前端架构的开发,就需要包管理工具npm来引入。
npm install @tensorflow/tfj
下面给大家提供一个简单的开发模板(ES 6):
import * as tf from '@tensorflow/tfjs'; // 在下面写机器学习业务代码
2. hello tfjs — 一个简单的示例
(1)代码编写
// 定义模型:线性回归模型 const model = tf.sequential(); model.add(tf.layers.dense({units: 1, inputShape: [1]})); // 定义模型损失函数和梯度下降算法 model.compile({loss: 'meanSquaredError', optimizer: 'sgd'}); // 准备学习数据 const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]); const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]); //模型学习 model.fit(xs, ys).then(() => { // 使用训练完成的模型进行预测 model.predict(tf.tensor2d([5], [1, 1])).print(); });
(2)代码分析
代码中具体做的事情是线性回归分析,步骤总结为:模型定义→模型学习→模型使用。如果大家想要深入了解线性回归分析的内容,可以参考笔者的免费课程,课程链接:https://www.imooc.com/learn/972。
(3)运行结果
在浏览器上运行,在命令行中就能看到想要的输出,如图2.21所示。
图2.21 代码运行结果
近年来,随着前端框架(React、Vue、Angular)的崛起和微信小程序的发力,前端从业人员的开发能力得到了长足的进步,人工智能时代不但给予后台通关前后台的能力,而且也给了前端业务更多的想象力,tensorflow.js就是在这样的环境下应运而生的产物。我们通过上述的入门例子对tensorflow.js有了直观的感受,如果你学习了TensorFlow的核心知识,那么上手tensorflow.js将会非常容易。