0%

我在上一篇文章中,提到了目前的认证方式存在一些问题,需要替换为一种更简单的认证方式。 但是最后发现,认证这个实在是没有办法简单化,认证本身又是另外一个不小的话题了,因此关于这一点先留个坑。

本文先讨论一下另外一个也比较重要的功能:监控。

为认证预留扩展点

虽然我们暂时不去实现更加完善的认证流程,但是我们依然可以先为其预留一个扩展点, 这样在未来我们要实现认证的时候,就不需要改动太多的代码了。

同样的,我们也可以基于 DIP 原则来实现,我们可以定义一个 Authenticator 接口:

1
2
3
4
type Authenticator interface {
// Authenticate 验证请求是否合法,第一个返回值为用户 id,第二个返回值为错误
Authenticate(r *http.Request) (string, error)
}

然后我们可以在 Hub 结构体中添加一个 authenticator 字段:

1
2
3
4
type Hub struct {
// 验证器
authenticator Authenticator
}

而对于我们目前的这种基于 jwt token 的认证方式,我们可以实现一个 JwtAuthenticator

1
2
3
4
5
6
7
8
9
var _ Authenticator = &JWTAuthenticator{}

type JWTAuthenticator struct {
}

func (J *JWTAuthenticator) Authenticate(r *http.Request) (string, error) {
jwt := NewJwt(r.FormValue("token"))
return jwt.Parse()
}

接着,我们在 newHub 中初始化这个 authenticator

1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他代码 ...
authenticator: &JWTAuthenticator{},
}
}

这样,我们就可以在 serveWs 中使用这个 authenticator 了:

1
2
3
4
5
6
7
8
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid, err := hub.authenticator.Authenticate(r)
if err != nil {
log.Println(fmt.Errorf("jwt parse error: %w", err))
return
}
// ... 其他代码
}

在后面我们实现了更加完善的认证流程之后,我们只需要实现一个新的 Authenticator 即可。

2023 了,应用监控怎么做

发展到今天,我们已经有了很多很好用的监控相关的东西,比如 PrometheusGrafana, 以及一些分布式链路追踪的组件,如 skywalkingjaeger 等。

但是他们各自的应用场景都不太一样,并不存在一个万能的监控工具,因此我们需要根据自己的需求来选择:

  • Prometheus:Prometheus 是一个开源的系统监控和报警工具。主要用于收集、存储和查询系统的监控数据,以便进行性能分析、故障排除和告警。
  • Grafana:Grafana 是一个开源的数据可视化和监控平台,用于创建、查询、分析和可视化时间序列数据。目前比较常见的组合就是 Prometheus + Grafana,通过 Prometheus 收集数据,然后通过 Grafana 展示数据。
  • 分布式链路追踪:常用语分布式系统的调用链路追踪,可以用于分析系统的性能瓶颈,以及分析系统的调用链路。常见的实现有 skywalkingjaeger 等。

在我们这个实例中,我们只需要实现一个简单的监控即可,因此我们可以使用 Prometheus + Grafana 的组合。

Prometheus 基本原理

但在此之前我们最好先了解一下 Prometheus 的工作原理,下面是来自 Prometheus 官网的架构图:

architecture

我们可以从两个角度来看这张图:组件、流程。

  1. 组件
  • Prometheus ServerPrometheus 服务端,主要负责数据的收集、存储、查询等。(上图中间部分)
  • AlertmanagerPrometheus 的告警组件,主要负责告警的发送。(上图右上角)
  • Prometheus web UI:可以在这个界面执行 PromQL,另外 Grafana 可以让我们以一种更直观的方式来查看指标数据(也是使用 PromQL)。(上图右下角)
  • exportersexportersPrometheus 的数据采集组件,主要负责从各个组件中采集数据,然后发送给 Prometheus Server。非常常见的如 node_exporter,也就是服务器基础指标的采集组件。除了 exporters,还有一种常见的数据采集方式是 Pushgateway,也就是将数据推送到 Pushgateway,然后由 Prometheus ServerPushgateway 中拉取数据。(也就是上图左边部分)
  1. 流程
  • 采集数据:也就是从 Pushgateway 或者 exporter 拉取一些指标数据。
  • 存储数据:Prometheus Server 会将采集到的数据存储到本地的 TSDB 中。
  • 查询数据:我们可以通过 web UI 或者 Grafana 来查看数据。

最后,我们可以在 Grafana 中看到如下图表:

grafana

通过这个图,我们就可以很直观的看到我们的系统的一些指标数据了,并且能看到这些指标随着时间的变化趋势。

Grafana 里面的图表都是一个个的 PromQL 查询出来的结果,对于常见的一些监控指标,Grafana 上可以找到很多现有的模板,直接使用即可。

Prometheus 采集的是什么数据

举一个简单的例子:对于一个运行中的系统而言,每一刻它的状态都是不太一样的,比如,可能上一秒 CPU 使用率是 10%,下一秒就变成了 100% 了, 但可能过 1 秒又降低到了 10%。当我们的系统出性能问题的时候,我们就需要去分析这些指标数据,找到问题所在。 比如排查一下出现性能问题的那个时间点,CPU 使用率是不是很高,如果是的话,那么就有可能是 CPU 导致的性能问题。

Prometheus 的作用就是帮助我们采集这些指标数据,然后存储起来,等待某天我们需要分析的时候,再去查询这些数据。 又或者监控到指标有异常的时候,可以通过 Alertmanager 来发送告警。

Prometheus 采集数据频率

Prometheus 采集数据的频率是可以配置的,我们一般配置为 1 分钟采集一次。 也就是说,每隔 1 分钟,Prometheus 才会从 exporter 拉取一次数据,然后存储起来。

应用指标数据采集

对于我们的应用而言,往往也有一些指标可以帮助我们看到应用内部的状态,比如:应用内的线程数、应用占用的内存、应用的 QPS 等等。 但是对于应用指标的监控,并没有一个统一的标准,我们需要根据自己应用的实际情况来决定采集哪些指标。

我们的消息推送系统如何做监控

应用指标

对于我们的消息推送系统而言,目前采集以下这两个重要指即可:

  1. 连接数:可以了解服务器当前负载

连接数我们可以直接通过 len(hub.clients) 来获取,非常简单。

  1. 等待推送的消息数:可以了解服务器能否及时处理消息

我们可以在 Hub 中添加一个 pending atomic.Int64 字段来记录当前等待推送的消息数,然后在 send 方法中进行更新:

1
2
3
4
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
// ... 其他代码 ...
hub.pending.Add(1)
}

同时在处理完成之后,我们也需要将其减 1,所以 writePump 也需要进行修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
func (c *Client) writePump() {
for {
select {
case messageLog, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// ...
c.hub.pending.Add(int64(-1 * len(c.send)))
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
// ...
c.hub.pending.Add(int64(-1 * len(c.send)))
return
}
}
c.hub.pending.Add(int64(-1))
}
}

我们在 writePump 中有三个地方需要对 pending 字段做减法:连接关闭、发送出错、发送成功。

exporter 以及 Grafana 配置

现在我们知道了我们有两个比较关键的指标需要采集,那到底是如何采集的呢?

具体来说,会有以下两步:

  1. 在消息推送系统中添加一个 /metrics 接口

这个接口的作用就是将我们的指标数据暴露出来,以便 Prometheus 采集。 它返回的就是请求时的连接数和等待推送的消息数,返回的格式也有一定要求,但也不复杂,具体来说就是:

  • 一行一个指标
  • 可以返回多个指标,多行即可
  • 每个指标前一行指定其类型(TYPE
  • 每行的格式为:<指标名称>{<标签名称>=<标签值>, ...} <指标值>

下面是一个简单的例子:

1
2
3
4
# HELP http_requests_total The total number of HTTP requests.
# TYPE http_requests_total counter
http_requests_total{method="GET", endpoint="/api"} 100
http_requests_total{method="POST", endpoint="/api"} 50

在这个示例中:

  • http_requests_total 是指标名称
  • {method="GET", endpoint="/api"} 是标签集合,用于唯一标识两个不同的时间序列。
  • 10050 是样本值,表示在特定时间点上的 HTTP 请求总数。

最终,我们得到了一个如下的 /metrics 接口:

1
2
3
4
5
6
func metrics(hub *Hub, w http.ResponseWriter, r *http.Request) {
var pending = hub.pending.Load()
var connections = len(hub.clients)
w.Write([]byte(fmt.Sprintf("# HELP connections 连接数\n# TYPE connections gauge\nconnections %d\n", connections)))
w.Write([]byte(fmt.Sprintf("# HELP pending 等待发送的消息数量\n# TYPE pending gauge\npending %d\n", pending)))
}

不要忘记了在 main 中加上一个入口:

1
2
3
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
metrics(hub, w, r)
})

最终,这个接口会返回如下的数据:

1
2
3
4
5
6
# HELP connections 连接数
# TYPE connections gauge
connections 0
# HELP pending 等待发送的消息数量
# TYPE pending gauge
pending 0
  1. Prometheus 中配置 exporter

我们需要在 Prometheus 配置文件中加上以下配置:

1
2
3
4
5
scrape_configs:
# 拉取我们的应用指标
- job_name: 'websocket'
static_configs:
- targets: ['192.168.2.107:8181']

注意:这里不需要在后面加上 /metrics,因为 Prometheus 默认就是去拉取 /metrics 接口的。

web UI

然后我们就可以在 Prometheusweb UI 中看到我们的指标数据了。

  1. Grafana 中配置图表

最后,我们可以在 Grafana 中配置一个图表,来展示我们的指标数据:

Grafana

这样,我们就可以看到一个等待发送的消息数量以及连接数的变化了。

总结

最后,再来简单回顾一下本文所讲内容,主要包括以下几个方面:

  • 认证方式是另外一个比较复杂的话题,但是我们依然可以为其预留出一个扩展点,先实现其他功能后再来完善。
  • 目前市面上有很多监控相关的组件,本文使用了 Prometheus 作为例子来演示如何在项目中采集应用的指标数据,以及如何通过 Grafana 来展示这些指标的变化。
  • Prometheus 中包含了 `Prometheus Serverexporters 等组件,其中 Server 是实际存储数据的地方,而 exporters 是用来采集指标数据的程序。
  • Prometheus 采集到的数据,我们可以通过 Grafana 来进行可视化展示,更加的直观。
  • 应用中,也可以暴露一个 /metrics 端口来返回应用当前的一些状态,只要遵循 Prometheus 的规范即可。

从上一篇开始,好像我们已经脱离了 WebSocket 的技术范畴了,但是我们可能也意识到了,WebSocket 技术本身并不复杂, 我们也很容易地使用它实现了一个消息推送的雏形。复杂的是,早我们使用它来实现一些功能的时候,需要考虑的非技术性的问题, 或者说非功能性的需求。

蔡超的《十年架构感悟》里面提到过一点:非功能性需求决定架构(在极客时间上可以搜索到)。

非功能性需求包括性能、伸缩性、可扩展性、可维护性等。功能性需求就是我们实际要实现的功能。

大概意思是:一个好的架构其实是由非功能性需求决定的,而不是由功能性需求决定的。 架构设计完之后,少一个功能性需求,我们很容易就能看出来,未来也可以加上去,它对你的架构不会有本质上的影响。 但如果我们忽略的是某一种非功能性需求,那么未来这可以说是一种灾难性的麻烦,很有可能你就需要重写了。 比如你架构中的数据一致性问题无法解决,或者在设计的时候没有充分考虑性能问题,这样,所有的功能性的实现其实都没有意义。

接下来做什么

其实我们在上一篇就可以结束本系列文章了,因为从某种程度上,我们已经实现了一个消息推送中心了。 但是,这种粗制滥造的方式,在真正投入使用的时候会存在很多问题的,比如:

  1. 对于消息投递,我们没有任何的记录:无法知道消息是否投递成功,也不知道消息投递失败的原因
  2. 接入麻烦:上一节我们通过 jwt 来实现认证,但是这个 jwt token 的生成和验证都是在消息推送系统中实现的;经验告诉我们,但凡你的东西复杂一点,用户都没有使用的欲望了,人性毕竟都是懒惰的
  3. 并未考虑到用户 token 失效的问题:比如用户登出系统之后,我们的消息推送系统也得断开是吧,要不然我都登出了你还给我推送消息
  4. 系统内部指标数据完全没有:比如连接数、等待连接数、等待推送的消息数等,这样如果有性能问题就不好排查了
  5. 其他:性能、伸缩性、可扩展性都存在问题

本系列文章的最终目的是要实现一个生产可用的消息推送中心,因此会继续实现这些非功能性需求。

添加消息推送日志

需求

我们的消息推送系统,需要记录每一条消息的投递情况,包括投递成功、投递失败的原因等。 一方面是为了方便排查问题,另一方面也是为了了解系统是否正常运作。 当然这些日志不会长时间保留,具体保留多长时间,我们可以加个配置留给用户决定即可。

依赖倒置原则

虽然暂时还没有实现让整个系统具有较高的扩展性,但是我们可以在代码上先让代码具有扩展性, 这样在未来我们要扩展的时候,就不需要改动太多的代码了。

我们可以先思考一下,我们下面的推送消息代码,应该如何修改来实现上述需求(假设我们的消息要存入数据库):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid := r.FormValue("uid")
// 参数错误
if uid == "" {
w.WriteHeader(http.StatusBadRequest)
return
}

// 从 hub 中获取 client
hub.Lock()
client, ok := hub.userClients[uid]
hub.Unlock()
// 尚未建立连接
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}

// 发送消息
message := r.FormValue("message")
client.send <- []byte(message)
}

func (c *Client) writePump() {
defer func() {
_ = c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
// 设置写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// c.send 这个通道已经被关闭了
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
}
}
}

我们可以暂时不考虑上面代码的实现,只是思考一下,如果我们要实现上述需求,应该如何修改代码呢?

非常容易想到的一种方法就是,在 init 函数中初始化一个全局的数据库连接, 然后在 send 方法中使用这个连接将消息存入数据库(假设我们使用的是 gorm):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
var db *gorm.DB

type Log struct {
gorm.Model
Uid string
Message string
Status int
CreatedAt time.Time
}

func init() {
var err error
db, err = gorm.Open(sqlite.Open("log.db"), &gorm.Config{})
if err != nil {
panic(err)
}
}

然后发送消息前写入数据库:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 自动迁移:表不存在的时候会自动创建
db.AutoMigrate(&Log{})
// 写入日志
db.Create(&Log{
Uid: uid,
Message: r.FormValue("message"),
Status: 0,
CreatedAt: time.Now(),
})

// 发送消息
message := r.FormValue("message")
client.send <- []byte(message)

这样实现起来确实简单,但是这样的代码耦合度太高了, 高层模块依赖了底层模块,依赖于具体的实现,这样的代码是不具有扩展性的

一种更好的方式是:针对写日志这个功能,我们先建立起一个抽象模型,然后高层代码只使用这个模型,不用去考虑底层的实现。

这一点就是 SOLID 里面的 D,依赖倒置原则(Dependency Inversion Principle)。 依赖倒置原则是这样陈述的:高层模块不应依赖于低层模块,二者应依赖于抽象。抽象不应依赖于细节,细节依赖于抽象。

基于依赖倒置原则的具体实现

  1. 先建立起一个抽象模型

首先我们得有一个实体来表示消息本身(MessageLog),然后就是记录消息的抽象模型(MessageLogger):

1
2
3
4
5
6
7
8
type MessageLog struct {
Uid string
Message string
}

type MessageLogger interface {
Log(log MessageLog) error
}
  1. 实现这个抽象模型

我们依然是使用 gorm 来实现这个抽象模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package main

import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"time"
)

var db *gorm.DB

type Log struct {
gorm.Model
Uid string
Message string
Status int
CreatedAt time.Time
}

func init() {
var err error
db, err = gorm.Open(sqlite.Open("log.db"), &gorm.Config{})
if err != nil {
panic(err)
}
}

var _ MessageLogger = &MySQLMessageLogger{}

type MySQLMessageLogger struct {
}

func (m *MySQLMessageLogger) Log(log MessageLog) error {
db.AutoMigrate(&Log{})
db.Create(&Log{
Uid: log.Uid,
Message: log.Message,
Status: 0,
CreatedAt: time.Now(),
})
return nil
}

虽然我们代码跟之前依然是一样,但是我们的代码已经具有了扩展性。

  1. 高层代码使用这个抽象模型

依赖倒置原则中说了,高层模块不应该依赖于低层模块。因此我们在 send 方法中记录消息的时候, 不应该直接使用 gorm 来写入数据库,而是使用 MessageLogger 这个抽象模型:

  1. hub 中添加 MessageLogger 字段:
1
2
3
4
type Hub struct {
// 消息日志记录器
messageLogger MessageLogger
}
  1. newHub 函数中初始化 MessageLogger
1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他字段
messageLogger: &MySQLMessageLogger{},
}
}

虽然高层模块不能直接依赖底层实现,但是总会有一个地方是将高层和底层连接起来的,这个地方一般就是创建对象的地方, 在很多现代的框架中,它有另外一个名字:依赖注入容器。

而在本系列文章中,并没有用到什么框架、依赖注入容器,但是我们还是有一个专门的创建对象的地方,那就是 newHub 函数。 因此我们在这里将 MessageLogger 依赖注入到 Hub 中。

  1. send 方法中使用 MessageLogger

最后将原本 send 方法中的数据库操作代码替换为对抽象模型的调用即可:

1
2
messageLog := MessageLog{Uid: uid, Message: r.FormValue("message")}
_ = hub.messageLogger.Log(messageLog)

这样,我们就完成了对消息推送日志的记录。

那如何替换为另一种日志记录方式

我们现在知道了,依赖倒置原则可以指导我们设计出具有扩展性的代码,那在我们这个实例中,如何替换为另一种日志记录方式呢?

其实非常简单,比如我们现在要直接输出到控制台中,那么我们只需要实现一个 StdoutMessageLogger 即可:

1
2
3
4
5
6
7
8
9
10
var _ MessageLogger = &StdoutMessageLogger{}

type StdoutMessageLogger struct {
}

func (s *StdoutMessageLogger) Log(log MessageLog) error {
res, _ := json.Marshal(log)
fmt.Println("send message: " + string(res))
return nil
}

然后在 newHub 中将 messageLogger 替换为 &StdoutMessageLogger{} 即可:

1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他字段
messageLogger: &StdoutMessageLogger{},
}
}

这样,我们在发送消息的时候就可以直接在控制台中看到消息了。 在实际开发中,使用 StdoutMessageLogger 更加方便我们调试代码。

我们可以发现,我们这种设计方式完美地实现了开闭原则,我们添加新的日志记录方式的时候, 不需要修改太多代码,只需要添加新的实现,然后修改 newHub 方法中的一行代码即可, 这样的代码显然更具扩展性,也更好维护。

错误处理

对于消息推送,如果推送失败,我们一般也需要知道推送失败的原因。

同样的,我们的框架本身也不应该依赖于具体的错误处理程序,而是应该使用抽象模型来实现。 从这个原则出发,我们就可以先建立一个抽象模型,然后再实现这个抽象模型:

  1. 先建立起一个抽象模型
1
2
3
4
5
6
7
// Handler 错误处理类型
type Handler func(log message.Log, err error)

type Hub struct {
// 错误处理器
errorHandler Handler
}

因为错误处理本身没有太复杂的功能,因此我们直接使用 type 关键字将其定义为一个函数类型即可。 然后在 Hub 中加上错误处理器的字段 errorHandler

  1. 实现这个抽象模型

其实也谈不上实现,因为没有定义什么 interface,我们只需要定义一个函数即可:

1
2
3
4
func defaultErrorHandler(log message.Log, err error) {
res, _ := json.Marshal(log)
fmt.Printf("send message: %s, error: %s\n", string(res), err.Error())
}

在本文的例子中,我们先定义一个输出错误信息到控制台的错误处理器。 然后,我们需要在 newHub 中初始化这个错误处理器:

1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他字段
errorHandler: defaultErrorHandler,
}
}
  1. 高层代码使用这个抽象模型

为了方便后续处理,我们将 send 方法中的代码稍微修改了一下,将 messageLog 作为参数传入到 send 通道中了,同时将 clientsend 通道改为 chan message.Log

1
2
3
4
type Client struct {
// 接受消息的通道
send chan message2.Log
}

发送消息修改:

1
2
3
4
5
messageLog := message.Log{Uid: uid, Message: r.FormValue("message")}
_ = hub.messageLogger.Log(messageLog)

// 发送消息
client.send <- messageLog

writePump 修改:

1
2
3
if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
return
}

最终 writePump 会演化为下面这样,错误处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
for {
select {
case messageLog, ok := <-c.send:
// 设置写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// c.send 这个通道已经被关闭了
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
c.hub.errorHandler(messageLog, fmt.Errorf("send channel closed"))
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
c.hub.errorHandler(messageLog, err)
return
}
}
}

跟之前不一样的地方是,这里会使用 c.hub.errorHandler 进行错误处理。

最终的效果是,对于后续维护而言,核心的处理流程基本上不会变动,而可能需要我们修改的地方都已经被抽象出来了: 错误处理我们可以通过修改 errorHandler 来实现,日志记录我们可以通过修改 messageLogger 来实现。

当然在实际场景中,我们可能还会有类似 onOpenonClose 之类的需求,但本文就先到此为止了,这些都是可以通过类似的方式来实现的。

总结

本人文章可能文字会比较多,但是其中都是个人在此过程中的一些思考,相比直接告诉大家怎么做,有可能知道为什么这么做更重要。

最后,简单回顾一下本文的内容:

  • 消息推送这个功能,技术上其实我们已经实现了,但是我们还得考虑很多非功能性的需求,这些非功能性的需求决定了我们的架构。
  • 依赖倒置原则可以指导我们设计出具有扩展性的代码:本文中的日志记录抽象出了一个 MessageLogger,需要的时候我们可以自行实现然后替换掉框架提供的实现。
  • 错误处理:为了方便后续维护,处理处理我们也是抽象出了一个 func 类型,实现了关注点的分离,也在一定程度上给后续的扩展提供了可能。

在上一篇文章中,我们已经搭建起了基本可用的一个 WebSocket 推送中心,但是有一个比较大的问题是, 我们并没有对进行连接的客户端进行认证,这样就会有一定的风险,如果被恶意攻击, 可能会影响我们的 WebSocket 服务器的正常运作。

本文我们就来把认证这个很关键的功能给补一下,在本文中,我们将会使用 jwt 来对我们的客户端进行认证。

什么是 jwt?

JWTJSON Web Token 的缩写,是一种用于在网络中安全传递信息的开放标准。它是一种紧凑且自包含的方式,用于在各方之间传递信息,通常用于身份验证和授权机制。

JWT 主要由三个部分组成:

  1. 头部(Header): 包含关于令牌的元数据,例如令牌的类型(typ)和签名算法(alg)。头部是一个 JSON 对象,通常会经过 Base64 编码。

  2. 载荷(Payload): 包含要传递的信息,通常包括用户身份信息以及其他声明。载荷也是一个 JSON 对象,同样经过 Base64 编码。

  3. 签名(Signature): 使用头部和载荷以及密钥生成的签名,用于验证令牌的真实性和完整性。签名是对头部和载荷的哈希值进行签名后的结果。

这三个部分通过点号(.)连接起来形成一个字符串,即 JWT。最终的 JWT 结构如下:

1
header.payload.signature

一个简单的 jwt 例子

jwt 的使用会分为两部分:生成 token,使用 token。本文中将会使用 golang-jwt/jwt 来做 jwt 的验证。

生成 token

生成 token 的操作有以下两步:

  1. 创建一个 token 对象:使用的是 jwt.NewWithClaims 方法,它第一个参数指定了签名算法,这里使用的是 HMAC,第二个参数接受一个 jwt.MapClaims,也就是上面提到的 payload
1
2
3
4
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"foo": "bar",
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})
  1. 利用上一步的 token 对象生成成一个 jwt 的签名字符串:使用的是 tokenSignedString 方法,它接受一个 key 作为参数,这个 key 也会用于解析这一步生成的 token 字符串。
1
2
3
// 注意:这里的 secret 在实际使用的时候需要替换为自己的 key
//(一般为一个随机字符串)
tokenString, err := token.SignedString([]byte("secret"))

我们生成 token 的操作一般发生在用户登录成功之后,这个 token 会作为用户后续发起请求的凭证。

使用 token

在用户登录成功拿到 token 之后,会使用这个 token 去从服务器获取其他资源,服务器会解析这个 token 并校验。

使用 token 的步骤如下:

  1. 解析 token:使用的是 jwt.Parse 方法,它第一个参数接受 token 字符串,第二个参数是一个函数(函数的参数就是解析出来的 token 对象,函数返回解密的 key
1
2
3
4
token, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 返回签名密钥,需要是 []byte 类型的
return []byte("secret"), nil
})
  1. 使用 token 中的 payloadpayload 也就是我们业务实际使用的数据,其他的东西只是使用 jwt token 这门技术附带的一些东西。

这里说的 payload 实际上就是 token 对象的 Claims 属性,它是一个 map,保存了我们一些业务的数据还有其他一些 jwt 本身的字段。

1
2
// claims: map[foo:bar nbf:1.4444784e+09]
fmt.Println("claims:", token.Claims)

在上面这行代码中,foo:bar 是我们的业务数据,而 nbf:1.4444784e+09jwt 本身使用的字段。nbf 表示的是在这个时间之前,这个 token 不应该被处理。

例子

我们把上面的生成 token 的代码和解析 token 的代码放到一起看看效果:

注意:实际使用中这两个操作是分开的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
func TestHmac(t *testing.T) {
// part 1:
// 创建一个新的 token 对象,指定签名方法和你想要包含的 claims
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"foo": "bar",
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})

// 生成签名字符串,secret 为签名密钥
tokenString, err := token.SignedString([]byte("secret"))

// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJuYmYiOjE3Mjg1NjE2MDB9.9AxzBmYOuOnWfKUul57ATzjQ-sMzbggaoIdDjVzjm2Y, nil
fmt.Println(tokenString, err)

// part 2:
// 解析 token
token, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// 返回签名密钥
return []byte("secret"), nil
})
if err != nil {
panic(err)
}

// header: map[alg:HS256 typ:JWT]
fmt.Println("header:", token.Header)
// claims: map[foo:bar nbf:1.4444784e+09]
fmt.Println("claims:", token.Claims)
// signature: Nv24hvNy238QMrpHvYw-BxyCp00jbsTqjVgzk81PiYA
fmt.Println("signature:", base64.RawURLEncoding.EncodeToString(token.Signature))

if claims, ok := token.Claims.(jwt.MapClaims); ok {
// bar 1.4444784e+09
fmt.Println(claims["foo"], claims["nbf"])
} else {
panic(err)
}
}

本文不是讲解 JWT 的文章,关于 JWT 的更多细节可以参考 rfc7519

在消息推送中心 demo 加上 jwt 认证

我们使用 jwt 的目的是为了杜绝一些未知来源的连接,而在我们上一篇文章的实现中, 是先建立起连接,然后再进行 "登录" 操作的,这样就会导致即使是未授权的客户端也可以先进行连接, 这样就会在认证之前就启动了两个协程。

而如果这些连接并不是正常的连接,它们只连接但是不登录,那样就会有很多僵尸连接,这显然不是我们想要的结果。

我们可以在客户端打开连接的时候就去验证客户端的 token,如果 token 校验失败,则直接断开连接, 就可以解决上述问题了,从而避免不必要的开销。

如何在建立连接的时候就认证

要实现这个很简单,我们只需要在连接的 url 后加一个 queryString 即可,如下:

1
ws = new WebSocket('ws://127.0.0.1:8181/ws?token=123')

然后在 serveWs 中通过 r.FormValue("token") 来获取客户端传递过来的 token, 再对其进行认证,认证不通过则拒绝连接。

具体来说,我们的 serveWs 会添加以下几行代码:

1
2
3
4
5
6
jwt := NewJwt(r.FormValue("token"))
err := jwt.Parse()
if err != nil {
log.Println(fmt.Errorf("jwt parse error: %w", err))
return
}

在函数入口的地方就进行验证,验证不通过则不进行连接。这里的 jwt 定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
type Jwt struct {
Secret string
Token string
}

func NewJwt(token string) *Jwt {
return &Jwt{
Token: token,
Secret: os.Getenv("JWT_SECRET"),
}
}

func (j *Jwt) Parse() error {
_, err := jwt.Parse(j.Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

return []byte(j.Secret), nil
})

return err
}

NewJwt 中,我们从 env 中获取 JWT_SECRET,这让我们的配置可以更加灵活。

token 中加上 uid

我们知道了,在 JWT 中的 payload 是可以加入我们的自定义数据的,所以我们的 uid 其实是可以加入到 jwttoken 中的,我们只需要在用户第一次获取 token 的时候加上即可:

1
2
3
4
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"uid": "123",
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})

同样的,在解析 jwt token 的时候,可以从中取出这个 uid

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 第一个返回值是 uid,第二个是 error
func (j *Jwt) Parse() (string, error) {
token, err := jwt.Parse(j.Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

return []byte(j.Secret), nil
})

if err != nil {
return "", err
}

// 获取 uid
if claims, ok := token.Claims.(jwt.MapClaims); ok {
return claims["uid"].(string), nil
} else {
return "", fmt.Errorf("jwt parse error")
}
}

最终,我们的 serveWs 演化成了如下这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 解析 jwt token,从中取得 uid
jwt := NewJwt(r.FormValue("token"))
uid, err := jwt.Parse()
if err != nil {
log.Println(fmt.Errorf("jwt parse error: %w", err))
return
}

conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(fmt.Errorf("upgrade error: %w", err))
return
}

client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), uid: uid}
client.hub.register <- client

go client.writePump()
go client.readPump()
}

这样一来,我们的 readPump 里面就不再需要处理登录消息了,那就暂时先把 readPump 中的逻辑去掉先:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
_ = c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Time{}) // 永不超时
for {
// 从客户端接收消息
_, _, err := c.conn.ReadMessage()
if err != nil {
log.Println("readPump error: ", err)
break
}
}
}

register 里的 uid 关联

我们之前是在 readPump 方法中将 uidWebSocket 建立起关联的,但由于我们已经去掉了 readPump 中 的登录消息处理逻辑。

因此我们需要在 Hubregister 中将 uidWebSocket 建立起关联:

1
2
3
4
5
6
case client := <-h.register:
h.Lock()
h.clients[client] = true
// 建立起 uid 跟 WebSocket 的关联
h.userClients[client.uid] = client
h.Unlock()

jti

在此前使用过的一些 jwt 封装中,会有些使用 jwt 规范中的 jti 字段来传输 token 的唯一 ID, 在本文的实现中,uid 也是同样的功能。如果我们之后看到了 jti 这个字段,不要惊讶,其实这个才是规范。

测试

我们将生成 token 的代码中的 foo: bar 这个键值对修改为 uid: 123 之后,得到了如下 token

1
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYmYiOjE0NDQ0Nzg0MDAsInVpZCI6IjEyMyJ9.9yJ-ABQGJkdnDqHo-wV-vojQFEQGt-I0dyva1w6EQ7E

我们在 ws 链接后加上这个 token 即可连接成功:

1
ws = new WebSocket('ws://127.0.0.1:8181/ws?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYmYiOjE0NDQ0Nzg0MDAsInVpZCI6IjEyMyJ9.9yJ-ABQGJkdnDqHo-wV-vojQFEQGt-I0dyva1w6EQ7E')
test

我们可以看到,ws 连接返回的状态码是 101,这就表明我们的连接成功了。

推送消息

同之前一样的,在控制台执行一个 curl 命令即可推送消息了:

1
curl "http://localhost:8181/send?uid=123&message=Hello%20World"

总结

本文我们使用 jwt 对我们的 WebSocket 进行了认证。 jwt 认证我们可以已经用过很多次了,但如果还没有用过或者还没有在 go 中用过 jwt 认证的话, 那么本文就是一个很好的入门文章。

最后,再来回顾一下本文的内容:

  • jwt 包含了三部分:头部(Header)、载荷(Payload)、签名(Signature)。
  • 在 go 中可以使用 golang-jwt/jwt 来做 jwt 的验证。
  • 创建 token 分两步:使用 jwt.NewWithClaims 创建 token 对象、使用 SignedString 生成签名字符串。
  • 使用 token 分两步:使用 jwt.Parse 解析 token、使用 token.Claims 获取 payload
  • 在建立连接的时候就进行认证,可以避免非法连接导致的开销。
  • jwt 中可以加入我们的自定义数据,比如 uid,在 jwt.NewWithClaims 中加上即可。

有了前两篇的铺垫,相信大家已经对 Golang 中 WebSocket 的使用有一定的了解了, 今天我们以一个更加真实的例子来学习如何在 Golang 中使用 WebSocket

需求背景

在实际的项目中,往往有一些任务耗时比较长,然后我们会把这些任务做异步的处理,但是又要及时给客户端反馈任务的处理进度。

对于这种场景,我们可以使用 WebSocket 来实现。其他可以使用 WebSocket 进行通知的场景还有像管理后台一些通知(比如新订单通知)等。

在本篇文章中,就是要实现一个这样的消息推送系统,具体来说,它会有以下功能:

  1. 可以给特定的用户推送:建立连接的时候,就建立起 WebSocket 连接与用户 ID 之间的关联
  2. 断开连接的时候,移除 WebSocket 连接与用户的关联,并且关闭这个 WebSocket 连接
  3. 业务系统可以通过 HTTP 接口来给特定的用户推送 WebSocket 消息:只要传递用户 ID 以及需要推送的消息即可

基础框架

下面是一个最简单版本的框架图:

arch

它包含如下几个角色:

  1. Client 客户端,也就是实际中接收消息通知的浏览器
  2. Server 服务端,在我们的例子中,服务端实际不处理业务逻辑,只处理跟客户端的消息交互:维持 WebSocket 连接,推送消息到特定的 WebSocket 连接
  3. 业务逻辑:这个实际上不属于 demo 的一部分,但是 Server 推送的数据是来自业务逻辑处理的结果

设计成这样的目的是为了将技术跟业务进行分离,业务逻辑上的变化不影响到底层技术,同样的,WebSocket 推送中心的技术上的变动也不会影响到实际的业务。

开始开发

一些结构体变动

  1. Client 结构体的变化
1
2
3
4
5
6
7
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
// 新增字段
uid int
}

因为我们需要建立起 WebSocket 连接与用户之间的关联,因此我们需要一个额外的字段来记录用户 ID,也就是上面的 uid 字段。

这个字段会在客户端建立连接后写入。

  1. Hub 结构体的变化
1
2
3
4
5
6
7
8
9
10
11
type Hub struct {
clients map[*Client]bool
register chan *Client
unregister chan *Client

// 记录 uid 跟 client 的对应关系
userClients map[int]*Client

// 读写锁,保护 userClients 以及 clients 的读写
sync.RWMutex
}
  1. 因为我们不再需要做广播,所以会移除 Hub 中的 broadcast 字段。

取而代之的是,我们会直接在消息推送接口中写入到 uid 对应的 Clientsend 通道。 当然我们也可以在 Hub 中另外加一个字段来记录要推送给不同 uid 的消息,但是我们的 Hubrun 方法是一个协程处理的,当需要推送的数据较多或者其中有 网络延迟的时候,会直接影响到推送给其他用户的消息。当然我们也可以改造一下 run 方法,启动多个协程来处理,不过这样比较复杂,本文会在 writePump 中处理。 (也就是建立 WebSocket 连接时的那个写操作协程)

  1. 同时为了更加快速地通过 uid 来获取对应的 WebSocket 连接,新增了一个 userClients 字段。

这是一个 map 类型的字段,keyuid,值是对应的 Client 指针。

  1. 最后新增了一个 Mutex 互斥锁

因为,在用户实际进行登录的时候需要写入 userClients 字段,而这是一个 map 类型字段,并不支持并发读写。 如果我们在接受并发连接的时候同时修改 userClients 的时候会导致 panic,因此我们使用了一个互斥锁来保证 userClients 的读写安全。

同时,clients 也是一个 map,但上一篇文章中没有使用 sync.Mutex 来保护它的读写,在并发操作的时候也是会有问题的, 所以 Mutex 同时也需要保护 clients 的读写。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
func (h *Hub) run() {
for {
select {
case client := <-h.register:
h.Lock()
h.clients[client] = true
h.Unlock()
case client := <-h.unregister:
if _, ok := h.clients[client]; ok {
h.Lock()
delete(h.userClients, client.uid)
delete(h.clients, client)
h.Unlock()
close(client.send)
}
}
}
}

最后,我们会在 Hubrun 方法中写 userClients 或者 clients 字段的时候,先获取锁,写成功的时候释放锁。

建立连接

在本篇中,将会继续沿用上一篇的代码,只是其中一些细节会有所改动。建立连接这步操作,跟上一篇的一样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// 将 HTTP 转换为 WebSocket 连接的 Upgrader
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// 处理 WebSocket 连接请求
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 升级为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
// 新建一个 Client
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
// 注册到 Hub
client.hub.register <- client

// 推送消息的协程
go client.writePump()
// 结束消息的协程
go client.readPump()
}

接收消息

由于我们要做的只是一个推送消息的系统,所以我们只处理用户发来的登录请求,其他的消息会全部丢弃:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
_ = c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Time{}) // 永不超时
for {
// 从客户端接收消息
_, message, err := c.conn.ReadMessage()
if err != nil {
log.Println("readPump error: ", err)
break
}

// 只处理登录消息
var data = make(map[string]string)
err = json.Unmarshal(message, &data)
if err != nil {
break
}

// 写入 uid 以及 Hub 的 userClients
if uid, ok := data["uid"]; ok {
c.uid = uid
c.hub.Lock()
c.hub.userClients[uid] = c
c.hub.Unlock()
}
}
}

在本文中,假设客户端的登录消息格式为 {"uid": "123456"} 这种 json 格式。

在这里也操作了 userClients 字段,同样需要使用互斥锁来保证操作的安全性。

发送消息

  1. 在我们的系统中,可以提供一个 HTTP 接口来跟业务系统进行交互:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 发送消息的接口
// 参数:
// 1. uid:接收消息的用户 ID
// 2. message:需要发送给这个用户的消息
http.HandleFunc("/send", func(w http.ResponseWriter, r *http.Request) {
send(hub, w, r)
})

// 发送消息的方法
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid := r.FormValue("uid")
// 参数错误
if uid == "" {
w.WriteHeader(http.StatusBadRequest)
return
}

// 从 hub 中获取 client
hub.Lock()
client, ok := hub.userClients[uid]
hub.Unlock()
// 尚未建立连接
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}

// 发送消息
message := r.FormValue("message")
client.send <- []byte(message)
}
  1. 实际发送消息的操作

writePump 方法中,我们会将从 /send 接收到的数据发送给对应的用户:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 发送消息的协程
func (c *Client) writePump() {
defer func() {
_ = c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
// 设置写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// 连接已经被关闭了
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}

// 获取一个发送消息的 Writer
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
// 写入消息到 Writer
w.Write(message)

// 关闭 Writer
if err := w.Close(); err != nil {
return
}
}
}
}

在这个方法中,我们会从 c.send 这个 chan 中获取需要发送给客户端的消息,然后进行发送操作。

测试

  1. 启动 main 程序
1
go run main.go
  1. 打开一个浏览器的控制台,执行以下代码
1
2
ws = new WebSocket('ws://127.0.0.1:8181/ws')
ws.send('{"uid": "123"}')

这两行代码的作用是与 WebSocket 服务器建立连接,然后发送一个登录信息。

然后我们打开控制台的 Network -> WS -> Message 就可以看到浏览器发给服务端的消息:

login
  1. 使用 HTTP 客户端发送消息给 uid 为 123 的用户

假设我们的 WebSocket 服务器绑定的端口为 8181

打开终端,执行以下命令:

1
curl "http://localhost:8181/send?uid=123&message=Hello%20World"

然后我们可以在 Network -> WS -> Message 看到接收到了消息 Hello World

hello world

结束了

到此为止,我们已经实现了一个初步可工作的 WebSocket 应用,当然还有很多可以优化的地方, 比如:

  1. 错误处理
  2. Hub 状态目前对外部来说是一个黑盒子,我们可以加个接口返回一下 Hub 的当前状态,比如当前连接数
  3. 日志:出错的时候,日志可以帮助我们快速定位问题

这些功能会在后续继续完善,今天就到此为止了。

在上一篇文章中,我们已经了解了 gorilla/websocket 的一些基本概念和简单的用法。 接下来,我们通过一个再复杂一点的例子来了解它的实际用法。

功能

这个例子来自源码里面的 examples/chat,它包含了以下功能:

  1. 用户访问群聊页面的时候,可以发送消息给所有其他在聊天室内的用户(也就是同样打开群聊页面的用户)
  2. 所有的用户发送的消息,群聊中的所有用户都能收到(包括自己)

其基本效果如下:

chat

为了更好地理解 gorilla/websocket 的使用方式,下文在讲解的时候会去掉一些出于健壮性考虑而写的代码。

基本架构

这个 demo 的基本组件如下图:

arch
  1. Client:也就是连接到了服务端的客户端,可以有多个
  2. Hub:所有的客户端会保存到 Hub 中,同时所有的消息也会经过 Hub 来进行广播(也就是将消息发给所有连接到 Hub 的客户端)
broadcast

工作原理

Hub

Hub 的源码如下:

1
2
3
4
5
6
7
8
9
10
type Hub struct {
// 保存所有客户端
clients map[*Client]bool
// 需要广播的消息
broadcast chan []byte
// 等待连接的客户端
register chan *Client
// 等待断开的客户端
unregister chan *Client
}

Hub 的核心方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
func (h *Hub) run() {
for {
select {
case client := <-h.register:
// 从等待连接的客户端 chan 取一项,设置到 clients 中
h.clients[client] = true
case client := <-h.unregister:
// 断开连接:
// 1. 从 clients 移除
// 2. 关闭发送消息的 chan
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
close(client.send)
}
case message := <-h.broadcast:
// 发送广播消息给每一个客户端
for client := range h.clients {
select {
// 成功写入消息到客户端的 send 通道
case client.send <- message:
default:
// 发送失败则剔除这个客户端
close(client.send)
delete(h.clients, client)
}
}
}
}
}

这个例子中使用了 chan 来做同步,这可以提高 Hub 的并发处理速度,因为不需要等待 Hubrun 方法中其他 chan 的处理。

简单来说,Hub 做了如下操作:

  1. 维护所有的客户端连接:客户端连接、断开连接等
  2. 发送广播消息

Client

Client 的源码如下:

1
2
3
4
5
6
7
8
type Client struct {
// Hub 单例
hub *Hub
// 底层的 websocket 连接
conn *websocket.Conn
// 等待发送给客户端的消息
send chan []byte
}

它包含了如下字段:

  1. Hub 单例(我们的 demo 中只有一个聊天室)
  2. conn 底层的 WebSocket 连接
  3. send 通道,这里保存了等待发送给这个客户端的数据

Client 中,是通过 readPump 这个方法来从客户端接收消息的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
func (c *Client) readPump() {
defer func() {
// 连接断开、出错等:
// 会关闭连接,从 hub 移除这个连接
c.hub.unregister <- c
c.conn.Close()
}()
// ...
for {
// 接收消息
_, message, err := c.conn.ReadMessage()
if err != nil {
// ... 错误处理
break
}
// 消息处理,最终放入 broadcast,准备发给所有其他在线的客户端
message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
c.hub.broadcast <- message
}
}

readPump 方法做的事情很简单,它就是接收消息,然后通过 Hubbroadcast 来发给所有在线的客户端。

而发送消息会稍微复杂一点,我们来看看 writePump 的源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
func (c *Client) writePump() {
defer func() {
// 连接断开、出错:关闭 WebSocket 连接
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
// 控制写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// 连接已经被 hub 关闭了
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}

// 获取用以发送消息的 Writer
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
// 发送消息
w.Write(message)

n := len(c.send)
for i := 0; i < n; i++ {
w.Write(newline)
// 将接收到的信息发送出去
w.Write(<-c.send)
}

// 调用 Close 的时候,消息会被发送出去
if err := w.Close(); err != nil {
return
}
}
}
}

虽然比读操作复杂了一点,但是也还是很好理解,它做的东西也不多:

  1. 获取用以发送消息的 Writer
  2. 获取从 hub 中接收到的其他客户端的消息,发送给当前这个客户端

具体是如何工作起来的?

  1. main 函数中创建 hub 实例
  2. 通过下面这个 serveWs 来将建立 WebSocket 连接:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 将 HTTP 连接转换为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
// 客户端
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
// 注册到 hub
client.hub.register <- client

// 发送数据到客户端的协程
go client.writePump()
// 从客户端接收数据的协程
go client.readPump()
}

serveWs 中,我们在跟客户端建立起连接后,创建了两个协程,一个是从客户端接收数据的,另一个是发送消息到客户端的。

这个 demo 的作用

这个 demo 是一个比较简单的 demo,不过也包含了我们构建 WebSocket 应用的一些关键处理逻辑,比如:

  • 使用 Hub 来维持一个低层次的连接信息
  • Client 中区分读和写的协程
  • 以及一些边界情况的处理:比如连接断开、超时等

在后续的文章中,我们会基于这些已有知识去构建一个更加完善的 WebSocket 应用,今天就到此为止了。