0%

在前面的文章中,提到过非功能性需求决定了架构。 今天我们再来考虑一下另外两个非功能性需求:性能和可用性。

前言

关于性能,其实并不是只有我们这个消息推送系统独有的问题。 对于所有的开发者而言,都多多少少会处理过性能相关的问题,比如后端为了减少数据库查询提高并发引入的缓存中间件,如 redis; 又或者如前端一次性渲染大量数据的时候,如果让用户体验更加流畅等。

本文会针对 WebSocket 应用场景下去思考一些可能出现的性能问题以及可行的解决方案。

性能

对于性能,有几个可能导致性能问题的地方:

连接数

连接数过多会导致占用的内存过多,因为对于每一个连接,我们都有两个协程,一个读协程,一个写协程; 同时我们的 Client 结构体中的 send 是一个缓冲通道,它的缓冲区大小也直接影响最终占用的内存大小。

比如,我们目前的创建 Client 实例的代码是下面这样的:

1
client := &Client{hub: hub, conn: conn, send: make(chan Log, 256), uid: uid}

我们在这里直接为 send 分配了 256 的大小,如果 Log 结构体比较大的话, 它占用的内存就会比较大了(因为最终占用内存 = 连接数 * sizeof(Log) * 256)。

在实际中,我们一般没有那么多等待发送的消息,这个其实可以设置为一个非常小的值,比如 16; 设置为一个小的值的负面影响是,当 send 塞满了 16 条 Log 的时候,发送消息的接口会阻塞:

1
2
3
4
5
6
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
// ... 其他代码
// 如果 send 满了,下面这一行会阻塞
client.send <- messageLog
hub.pending.Add(int64(1))
}

所以这个数值可能需要根据实际场景来选择一个更加合适的值。

代码本身的问题

比如,我们的代码中其实有一个很常见的性能问题,就是 string[]byte 之间直接强转:

1
2
3
4
// writePump 方法里面将 string 转 []byte
if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
return
}

至于原因,可以去看看此前的一篇文章《深入理解 go unsafe》 的最后一小节, 简单来说,就是这个转换会产生内存分配,而内存分配会导致一定的性能损耗。而通过 unsafe 就可以实现无损的转换。

除了这个,其他地方也没啥太大的问题了,因为到目前为止,我们的代码还是非常的简单的。

互斥锁

为了保证程序的并发安全,我们在 Hub 中加了一个 sync.Mutex,也就是互斥锁。 在代码中,被 sync.MutexLock 保护的代码,在同一时刻只能有一个协程可以执行。

1
2
3
4
5
6
7
8
9
// 推送消息的接口
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
// ... 其他代码
// 从 hub 中获取 client
hub.Lock()
client, ok := hub.userClients[uid]
hub.Unlock()
// ... 其他代码
}

对于上面这种只读的操作,也就是没有对 map 进行写操作,我们依然使用了 sync.MutexLock() 来锁定临界区。 这里存在的问题是,其实我们的 hub.userClients 是支持并发读的,只是不能同时读写而已。

所以我们可以考虑将 sync.Mutex 替换为 sync.RWMutex,这样就可以实现并发读了:

1
2
3
4
5
6
7
8
9
// 推送消息的接口
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
// ... 其他代码
// 从 hub 中获取 client
hub.RLock() // 读锁
client, ok := hub.userClients[uid]
hub.RUnlock() // 释放读锁
// ... 其他代码
}

这样做的好处是,当有多个并发的 send 请求的时候,这些并发的 send 请求并不会相互阻塞; 而使用 sync.Mutex 的时候,并发的 send 请求是会相互阻塞的,也就是会导致 send 变成串行的,这样性能无疑会很差。

除此之外,我们在 Hubrun 方法中也使用了 sync.Mutex

1
2
3
4
5
case client := <-h.register:
h.Lock()
h.clients[client] = true
h.userClients[client.uid] = client
h.Unlock()

也就是说,我们将 Client 注册到 Hub 的操作也是串行的。 对于这种场景,其实也有一种解决方法就是分段 map, 也就是将 clientsuserClients 这两个 map 拆分为多个 map, 然后对于每一个 map 都有一个对应的 sync.Mutex 互斥锁来保证其读写的安全。

但如果要这样做,单单分段还不够,我们的 registerunregister 还是只有一个,对于这个问题, 我们可能需要将 registerunregister 也分段,最后在 run 方法里面起多个协程来进行处理。 这个实现起来就很复杂了。

其他

由于我们的 Hub 中还有 MessageLogger、错误处理、认证等功能, 在实际中,如果我们有将其替换为自己的实现,可能还得考虑自己的实现中可能存在的性能问题:

1
2
3
4
5
type Hub struct {
messageLogger MessageLogger
errorHandler Handler
authenticator Authenticator
}

可用性

这里主要讨论的是集群部署的情况下,应用存在的一些的问题以及可行的解决方案。关于具体部署上的细节不讨论。

要实现高可用的话,我们就得加机器了,毕竟如果只有一台服务器的话,一旦它宕机了,服务就完全挂了。

由于我们的 WebSocket 应用维持着跟客户端的连接,在单机的时候,客户端连接、推送消息都是在一台机器上的。 这种情况下并没有什么问题,因为推送消息的时候,都可以根据 uid 来找到对应的 WebSocket 连接,从而给客户端推送消息。

而在多台机器的情况下,我们的客户端可能跟不同的服务器产生连接,这个时候一个比较关键的问题是: 如何根据 uid 找到对应的 WebSocket 连接所在的机器? 如果我们推送消息的请求到达的机器上并没有消息关联的 WebSocket 连接,那么我们的消息就无法推送给客户端了。

对于这个问题,一个可行的解决方案是,将 uid 和服务器建立起关联,比如,在用户登录的时候, 就给用户返回一个 WebSocket 服务器的地址,客户端拿到这个地址之后,跟这个服务器建立起 WebSocket 连接, 然后其他应用推送消息的时候,也根据同样的算法将推送消息的请求发送到这个 WebSocket 服务器即可。

总结

最后,再简单回顾一下本文的内容:

  • 具体来说,我们的系统中会有下面几个可能的地方会导致产生性能问题:
    • 连接数:一个连接会有两个协程,另外每一个 Client 结构体也会需要一定的缓冲区来缓冲发送给客户端的消息
    • 代码上的性能问题:如 string[]byte 之间转换带来的性能损耗
    • 互斥锁:某些地方可以使用读写锁来提高读的并发量,另外一个办法就是使用分段 map 配合互斥锁
    • 系统本身预留的扩展点中,用户自行实现的代码中可能会存在性能问题
  • 要实现高可用就得将系统部署到多台机器上,这个时候需要在 uid 和服务器之间建立起某种关联,以便推送消息的时候可以成功推送给客户端。

我在上一篇文章中,提到了目前的认证方式存在一些问题,需要替换为一种更简单的认证方式。 但是最后发现,认证这个实在是没有办法简单化,认证本身又是另外一个不小的话题了,因此关于这一点先留个坑。

本文先讨论一下另外一个也比较重要的功能:监控。

为认证预留扩展点

虽然我们暂时不去实现更加完善的认证流程,但是我们依然可以先为其预留一个扩展点, 这样在未来我们要实现认证的时候,就不需要改动太多的代码了。

同样的,我们也可以基于 DIP 原则来实现,我们可以定义一个 Authenticator 接口:

1
2
3
4
type Authenticator interface {
// Authenticate 验证请求是否合法,第一个返回值为用户 id,第二个返回值为错误
Authenticate(r *http.Request) (string, error)
}

然后我们可以在 Hub 结构体中添加一个 authenticator 字段:

1
2
3
4
type Hub struct {
// 验证器
authenticator Authenticator
}

而对于我们目前的这种基于 jwt token 的认证方式,我们可以实现一个 JwtAuthenticator

1
2
3
4
5
6
7
8
9
var _ Authenticator = &JWTAuthenticator{}

type JWTAuthenticator struct {
}

func (J *JWTAuthenticator) Authenticate(r *http.Request) (string, error) {
jwt := NewJwt(r.FormValue("token"))
return jwt.Parse()
}

接着,我们在 newHub 中初始化这个 authenticator

1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他代码 ...
authenticator: &JWTAuthenticator{},
}
}

这样,我们就可以在 serveWs 中使用这个 authenticator 了:

1
2
3
4
5
6
7
8
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid, err := hub.authenticator.Authenticate(r)
if err != nil {
log.Println(fmt.Errorf("jwt parse error: %w", err))
return
}
// ... 其他代码
}

在后面我们实现了更加完善的认证流程之后,我们只需要实现一个新的 Authenticator 即可。

2023 了,应用监控怎么做

发展到今天,我们已经有了很多很好用的监控相关的东西,比如 PrometheusGrafana, 以及一些分布式链路追踪的组件,如 skywalkingjaeger 等。

但是他们各自的应用场景都不太一样,并不存在一个万能的监控工具,因此我们需要根据自己的需求来选择:

  • Prometheus:Prometheus 是一个开源的系统监控和报警工具。主要用于收集、存储和查询系统的监控数据,以便进行性能分析、故障排除和告警。
  • Grafana:Grafana 是一个开源的数据可视化和监控平台,用于创建、查询、分析和可视化时间序列数据。目前比较常见的组合就是 Prometheus + Grafana,通过 Prometheus 收集数据,然后通过 Grafana 展示数据。
  • 分布式链路追踪:常用语分布式系统的调用链路追踪,可以用于分析系统的性能瓶颈,以及分析系统的调用链路。常见的实现有 skywalkingjaeger 等。

在我们这个实例中,我们只需要实现一个简单的监控即可,因此我们可以使用 Prometheus + Grafana 的组合。

Prometheus 基本原理

但在此之前我们最好先了解一下 Prometheus 的工作原理,下面是来自 Prometheus 官网的架构图:

architecture

我们可以从两个角度来看这张图:组件、流程。

  1. 组件
  • Prometheus ServerPrometheus 服务端,主要负责数据的收集、存储、查询等。(上图中间部分)
  • AlertmanagerPrometheus 的告警组件,主要负责告警的发送。(上图右上角)
  • Prometheus web UI:可以在这个界面执行 PromQL,另外 Grafana 可以让我们以一种更直观的方式来查看指标数据(也是使用 PromQL)。(上图右下角)
  • exportersexportersPrometheus 的数据采集组件,主要负责从各个组件中采集数据,然后发送给 Prometheus Server。非常常见的如 node_exporter,也就是服务器基础指标的采集组件。除了 exporters,还有一种常见的数据采集方式是 Pushgateway,也就是将数据推送到 Pushgateway,然后由 Prometheus ServerPushgateway 中拉取数据。(也就是上图左边部分)
  1. 流程
  • 采集数据:也就是从 Pushgateway 或者 exporter 拉取一些指标数据。
  • 存储数据:Prometheus Server 会将采集到的数据存储到本地的 TSDB 中。
  • 查询数据:我们可以通过 web UI 或者 Grafana 来查看数据。

最后,我们可以在 Grafana 中看到如下图表:

grafana

通过这个图,我们就可以很直观的看到我们的系统的一些指标数据了,并且能看到这些指标随着时间的变化趋势。

Grafana 里面的图表都是一个个的 PromQL 查询出来的结果,对于常见的一些监控指标,Grafana 上可以找到很多现有的模板,直接使用即可。

Prometheus 采集的是什么数据

举一个简单的例子:对于一个运行中的系统而言,每一刻它的状态都是不太一样的,比如,可能上一秒 CPU 使用率是 10%,下一秒就变成了 100% 了, 但可能过 1 秒又降低到了 10%。当我们的系统出性能问题的时候,我们就需要去分析这些指标数据,找到问题所在。 比如排查一下出现性能问题的那个时间点,CPU 使用率是不是很高,如果是的话,那么就有可能是 CPU 导致的性能问题。

Prometheus 的作用就是帮助我们采集这些指标数据,然后存储起来,等待某天我们需要分析的时候,再去查询这些数据。 又或者监控到指标有异常的时候,可以通过 Alertmanager 来发送告警。

Prometheus 采集数据频率

Prometheus 采集数据的频率是可以配置的,我们一般配置为 1 分钟采集一次。 也就是说,每隔 1 分钟,Prometheus 才会从 exporter 拉取一次数据,然后存储起来。

应用指标数据采集

对于我们的应用而言,往往也有一些指标可以帮助我们看到应用内部的状态,比如:应用内的线程数、应用占用的内存、应用的 QPS 等等。 但是对于应用指标的监控,并没有一个统一的标准,我们需要根据自己应用的实际情况来决定采集哪些指标。

我们的消息推送系统如何做监控

应用指标

对于我们的消息推送系统而言,目前采集以下这两个重要指即可:

  1. 连接数:可以了解服务器当前负载

连接数我们可以直接通过 len(hub.clients) 来获取,非常简单。

  1. 等待推送的消息数:可以了解服务器能否及时处理消息

我们可以在 Hub 中添加一个 pending atomic.Int64 字段来记录当前等待推送的消息数,然后在 send 方法中进行更新:

1
2
3
4
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
// ... 其他代码 ...
hub.pending.Add(1)
}

同时在处理完成之后,我们也需要将其减 1,所以 writePump 也需要进行修改:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
func (c *Client) writePump() {
for {
select {
case messageLog, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// ...
c.hub.pending.Add(int64(-1 * len(c.send)))
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
// ...
c.hub.pending.Add(int64(-1 * len(c.send)))
return
}
}
c.hub.pending.Add(int64(-1))
}
}

我们在 writePump 中有三个地方需要对 pending 字段做减法:连接关闭、发送出错、发送成功。

exporter 以及 Grafana 配置

现在我们知道了我们有两个比较关键的指标需要采集,那到底是如何采集的呢?

具体来说,会有以下两步:

  1. 在消息推送系统中添加一个 /metrics 接口

这个接口的作用就是将我们的指标数据暴露出来,以便 Prometheus 采集。 它返回的就是请求时的连接数和等待推送的消息数,返回的格式也有一定要求,但也不复杂,具体来说就是:

  • 一行一个指标
  • 可以返回多个指标,多行即可
  • 每个指标前一行指定其类型(TYPE
  • 每行的格式为:<指标名称>{<标签名称>=<标签值>, ...} <指标值>

下面是一个简单的例子:

1
2
3
4
# HELP http_requests_total The total number of HTTP requests.
# TYPE http_requests_total counter
http_requests_total{method="GET", endpoint="/api"} 100
http_requests_total{method="POST", endpoint="/api"} 50

在这个示例中:

  • http_requests_total 是指标名称
  • {method="GET", endpoint="/api"} 是标签集合,用于唯一标识两个不同的时间序列。
  • 10050 是样本值,表示在特定时间点上的 HTTP 请求总数。

最终,我们得到了一个如下的 /metrics 接口:

1
2
3
4
5
6
func metrics(hub *Hub, w http.ResponseWriter, r *http.Request) {
var pending = hub.pending.Load()
var connections = len(hub.clients)
w.Write([]byte(fmt.Sprintf("# HELP connections 连接数\n# TYPE connections gauge\nconnections %d\n", connections)))
w.Write([]byte(fmt.Sprintf("# HELP pending 等待发送的消息数量\n# TYPE pending gauge\npending %d\n", pending)))
}

不要忘记了在 main 中加上一个入口:

1
2
3
http.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) {
metrics(hub, w, r)
})

最终,这个接口会返回如下的数据:

1
2
3
4
5
6
# HELP connections 连接数
# TYPE connections gauge
connections 0
# HELP pending 等待发送的消息数量
# TYPE pending gauge
pending 0
  1. Prometheus 中配置 exporter

我们需要在 Prometheus 配置文件中加上以下配置:

1
2
3
4
5
scrape_configs:
# 拉取我们的应用指标
- job_name: 'websocket'
static_configs:
- targets: ['192.168.2.107:8181']

注意:这里不需要在后面加上 /metrics,因为 Prometheus 默认就是去拉取 /metrics 接口的。

web UI

然后我们就可以在 Prometheusweb UI 中看到我们的指标数据了。

  1. Grafana 中配置图表

最后,我们可以在 Grafana 中配置一个图表,来展示我们的指标数据:

Grafana

这样,我们就可以看到一个等待发送的消息数量以及连接数的变化了。

总结

最后,再来简单回顾一下本文所讲内容,主要包括以下几个方面:

  • 认证方式是另外一个比较复杂的话题,但是我们依然可以为其预留出一个扩展点,先实现其他功能后再来完善。
  • 目前市面上有很多监控相关的组件,本文使用了 Prometheus 作为例子来演示如何在项目中采集应用的指标数据,以及如何通过 Grafana 来展示这些指标的变化。
  • Prometheus 中包含了 `Prometheus Serverexporters 等组件,其中 Server 是实际存储数据的地方,而 exporters 是用来采集指标数据的程序。
  • Prometheus 采集到的数据,我们可以通过 Grafana 来进行可视化展示,更加的直观。
  • 应用中,也可以暴露一个 /metrics 端口来返回应用当前的一些状态,只要遵循 Prometheus 的规范即可。

从上一篇开始,好像我们已经脱离了 WebSocket 的技术范畴了,但是我们可能也意识到了,WebSocket 技术本身并不复杂, 我们也很容易地使用它实现了一个消息推送的雏形。复杂的是,早我们使用它来实现一些功能的时候,需要考虑的非技术性的问题, 或者说非功能性的需求。

蔡超的《十年架构感悟》里面提到过一点:非功能性需求决定架构(在极客时间上可以搜索到)。

非功能性需求包括性能、伸缩性、可扩展性、可维护性等。功能性需求就是我们实际要实现的功能。

大概意思是:一个好的架构其实是由非功能性需求决定的,而不是由功能性需求决定的。 架构设计完之后,少一个功能性需求,我们很容易就能看出来,未来也可以加上去,它对你的架构不会有本质上的影响。 但如果我们忽略的是某一种非功能性需求,那么未来这可以说是一种灾难性的麻烦,很有可能你就需要重写了。 比如你架构中的数据一致性问题无法解决,或者在设计的时候没有充分考虑性能问题,这样,所有的功能性的实现其实都没有意义。

接下来做什么

其实我们在上一篇就可以结束本系列文章了,因为从某种程度上,我们已经实现了一个消息推送中心了。 但是,这种粗制滥造的方式,在真正投入使用的时候会存在很多问题的,比如:

  1. 对于消息投递,我们没有任何的记录:无法知道消息是否投递成功,也不知道消息投递失败的原因
  2. 接入麻烦:上一节我们通过 jwt 来实现认证,但是这个 jwt token 的生成和验证都是在消息推送系统中实现的;经验告诉我们,但凡你的东西复杂一点,用户都没有使用的欲望了,人性毕竟都是懒惰的
  3. 并未考虑到用户 token 失效的问题:比如用户登出系统之后,我们的消息推送系统也得断开是吧,要不然我都登出了你还给我推送消息
  4. 系统内部指标数据完全没有:比如连接数、等待连接数、等待推送的消息数等,这样如果有性能问题就不好排查了
  5. 其他:性能、伸缩性、可扩展性都存在问题

本系列文章的最终目的是要实现一个生产可用的消息推送中心,因此会继续实现这些非功能性需求。

添加消息推送日志

需求

我们的消息推送系统,需要记录每一条消息的投递情况,包括投递成功、投递失败的原因等。 一方面是为了方便排查问题,另一方面也是为了了解系统是否正常运作。 当然这些日志不会长时间保留,具体保留多长时间,我们可以加个配置留给用户决定即可。

依赖倒置原则

虽然暂时还没有实现让整个系统具有较高的扩展性,但是我们可以在代码上先让代码具有扩展性, 这样在未来我们要扩展的时候,就不需要改动太多的代码了。

我们可以先思考一下,我们下面的推送消息代码,应该如何修改来实现上述需求(假设我们的消息要存入数据库):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid := r.FormValue("uid")
// 参数错误
if uid == "" {
w.WriteHeader(http.StatusBadRequest)
return
}

// 从 hub 中获取 client
hub.Lock()
client, ok := hub.userClients[uid]
hub.Unlock()
// 尚未建立连接
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}

// 发送消息
message := r.FormValue("message")
client.send <- []byte(message)
}

func (c *Client) writePump() {
defer func() {
_ = c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
// 设置写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// c.send 这个通道已经被关闭了
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
}
}
}

我们可以暂时不考虑上面代码的实现,只是思考一下,如果我们要实现上述需求,应该如何修改代码呢?

非常容易想到的一种方法就是,在 init 函数中初始化一个全局的数据库连接, 然后在 send 方法中使用这个连接将消息存入数据库(假设我们使用的是 gorm):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
var db *gorm.DB

type Log struct {
gorm.Model
Uid string
Message string
Status int
CreatedAt time.Time
}

func init() {
var err error
db, err = gorm.Open(sqlite.Open("log.db"), &gorm.Config{})
if err != nil {
panic(err)
}
}

然后发送消息前写入数据库:

1
2
3
4
5
6
7
8
9
10
11
12
13
// 自动迁移:表不存在的时候会自动创建
db.AutoMigrate(&Log{})
// 写入日志
db.Create(&Log{
Uid: uid,
Message: r.FormValue("message"),
Status: 0,
CreatedAt: time.Now(),
})

// 发送消息
message := r.FormValue("message")
client.send <- []byte(message)

这样实现起来确实简单,但是这样的代码耦合度太高了, 高层模块依赖了底层模块,依赖于具体的实现,这样的代码是不具有扩展性的

一种更好的方式是:针对写日志这个功能,我们先建立起一个抽象模型,然后高层代码只使用这个模型,不用去考虑底层的实现。

这一点就是 SOLID 里面的 D,依赖倒置原则(Dependency Inversion Principle)。 依赖倒置原则是这样陈述的:高层模块不应依赖于低层模块,二者应依赖于抽象。抽象不应依赖于细节,细节依赖于抽象。

基于依赖倒置原则的具体实现

  1. 先建立起一个抽象模型

首先我们得有一个实体来表示消息本身(MessageLog),然后就是记录消息的抽象模型(MessageLogger):

1
2
3
4
5
6
7
8
type MessageLog struct {
Uid string
Message string
}

type MessageLogger interface {
Log(log MessageLog) error
}
  1. 实现这个抽象模型

我们依然是使用 gorm 来实现这个抽象模型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
package main

import (
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"time"
)

var db *gorm.DB

type Log struct {
gorm.Model
Uid string
Message string
Status int
CreatedAt time.Time
}

func init() {
var err error
db, err = gorm.Open(sqlite.Open("log.db"), &gorm.Config{})
if err != nil {
panic(err)
}
}

var _ MessageLogger = &MySQLMessageLogger{}

type MySQLMessageLogger struct {
}

func (m *MySQLMessageLogger) Log(log MessageLog) error {
db.AutoMigrate(&Log{})
db.Create(&Log{
Uid: log.Uid,
Message: log.Message,
Status: 0,
CreatedAt: time.Now(),
})
return nil
}

虽然我们代码跟之前依然是一样,但是我们的代码已经具有了扩展性。

  1. 高层代码使用这个抽象模型

依赖倒置原则中说了,高层模块不应该依赖于低层模块。因此我们在 send 方法中记录消息的时候, 不应该直接使用 gorm 来写入数据库,而是使用 MessageLogger 这个抽象模型:

  1. hub 中添加 MessageLogger 字段:
1
2
3
4
type Hub struct {
// 消息日志记录器
messageLogger MessageLogger
}
  1. newHub 函数中初始化 MessageLogger
1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他字段
messageLogger: &MySQLMessageLogger{},
}
}

虽然高层模块不能直接依赖底层实现,但是总会有一个地方是将高层和底层连接起来的,这个地方一般就是创建对象的地方, 在很多现代的框架中,它有另外一个名字:依赖注入容器。

而在本系列文章中,并没有用到什么框架、依赖注入容器,但是我们还是有一个专门的创建对象的地方,那就是 newHub 函数。 因此我们在这里将 MessageLogger 依赖注入到 Hub 中。

  1. send 方法中使用 MessageLogger

最后将原本 send 方法中的数据库操作代码替换为对抽象模型的调用即可:

1
2
messageLog := MessageLog{Uid: uid, Message: r.FormValue("message")}
_ = hub.messageLogger.Log(messageLog)

这样,我们就完成了对消息推送日志的记录。

那如何替换为另一种日志记录方式

我们现在知道了,依赖倒置原则可以指导我们设计出具有扩展性的代码,那在我们这个实例中,如何替换为另一种日志记录方式呢?

其实非常简单,比如我们现在要直接输出到控制台中,那么我们只需要实现一个 StdoutMessageLogger 即可:

1
2
3
4
5
6
7
8
9
10
var _ MessageLogger = &StdoutMessageLogger{}

type StdoutMessageLogger struct {
}

func (s *StdoutMessageLogger) Log(log MessageLog) error {
res, _ := json.Marshal(log)
fmt.Println("send message: " + string(res))
return nil
}

然后在 newHub 中将 messageLogger 替换为 &StdoutMessageLogger{} 即可:

1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他字段
messageLogger: &StdoutMessageLogger{},
}
}

这样,我们在发送消息的时候就可以直接在控制台中看到消息了。 在实际开发中,使用 StdoutMessageLogger 更加方便我们调试代码。

我们可以发现,我们这种设计方式完美地实现了开闭原则,我们添加新的日志记录方式的时候, 不需要修改太多代码,只需要添加新的实现,然后修改 newHub 方法中的一行代码即可, 这样的代码显然更具扩展性,也更好维护。

错误处理

对于消息推送,如果推送失败,我们一般也需要知道推送失败的原因。

同样的,我们的框架本身也不应该依赖于具体的错误处理程序,而是应该使用抽象模型来实现。 从这个原则出发,我们就可以先建立一个抽象模型,然后再实现这个抽象模型:

  1. 先建立起一个抽象模型
1
2
3
4
5
6
7
// Handler 错误处理类型
type Handler func(log message.Log, err error)

type Hub struct {
// 错误处理器
errorHandler Handler
}

因为错误处理本身没有太复杂的功能,因此我们直接使用 type 关键字将其定义为一个函数类型即可。 然后在 Hub 中加上错误处理器的字段 errorHandler

  1. 实现这个抽象模型

其实也谈不上实现,因为没有定义什么 interface,我们只需要定义一个函数即可:

1
2
3
4
func defaultErrorHandler(log message.Log, err error) {
res, _ := json.Marshal(log)
fmt.Printf("send message: %s, error: %s\n", string(res), err.Error())
}

在本文的例子中,我们先定义一个输出错误信息到控制台的错误处理器。 然后,我们需要在 newHub 中初始化这个错误处理器:

1
2
3
4
5
6
func newHub() *Hub {
return &Hub{
// ... 其他字段
errorHandler: defaultErrorHandler,
}
}
  1. 高层代码使用这个抽象模型

为了方便后续处理,我们将 send 方法中的代码稍微修改了一下,将 messageLog 作为参数传入到 send 通道中了,同时将 clientsend 通道改为 chan message.Log

1
2
3
4
type Client struct {
// 接受消息的通道
send chan message2.Log
}

发送消息修改:

1
2
3
4
5
messageLog := message.Log{Uid: uid, Message: r.FormValue("message")}
_ = hub.messageLogger.Log(messageLog)

// 发送消息
client.send <- messageLog

writePump 修改:

1
2
3
if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
return
}

最终 writePump 会演化为下面这样,错误处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
for {
select {
case messageLog, ok := <-c.send:
// 设置写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// c.send 这个通道已经被关闭了
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
c.hub.errorHandler(messageLog, fmt.Errorf("send channel closed"))
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, []byte(messageLog.Message)); err != nil {
c.hub.errorHandler(messageLog, err)
return
}
}
}

跟之前不一样的地方是,这里会使用 c.hub.errorHandler 进行错误处理。

最终的效果是,对于后续维护而言,核心的处理流程基本上不会变动,而可能需要我们修改的地方都已经被抽象出来了: 错误处理我们可以通过修改 errorHandler 来实现,日志记录我们可以通过修改 messageLogger 来实现。

当然在实际场景中,我们可能还会有类似 onOpenonClose 之类的需求,但本文就先到此为止了,这些都是可以通过类似的方式来实现的。

总结

本人文章可能文字会比较多,但是其中都是个人在此过程中的一些思考,相比直接告诉大家怎么做,有可能知道为什么这么做更重要。

最后,简单回顾一下本文的内容:

  • 消息推送这个功能,技术上其实我们已经实现了,但是我们还得考虑很多非功能性的需求,这些非功能性的需求决定了我们的架构。
  • 依赖倒置原则可以指导我们设计出具有扩展性的代码:本文中的日志记录抽象出了一个 MessageLogger,需要的时候我们可以自行实现然后替换掉框架提供的实现。
  • 错误处理:为了方便后续维护,处理处理我们也是抽象出了一个 func 类型,实现了关注点的分离,也在一定程度上给后续的扩展提供了可能。

在上一篇文章中,我们已经搭建起了基本可用的一个 WebSocket 推送中心,但是有一个比较大的问题是, 我们并没有对进行连接的客户端进行认证,这样就会有一定的风险,如果被恶意攻击, 可能会影响我们的 WebSocket 服务器的正常运作。

本文我们就来把认证这个很关键的功能给补一下,在本文中,我们将会使用 jwt 来对我们的客户端进行认证。

什么是 jwt?

JWTJSON Web Token 的缩写,是一种用于在网络中安全传递信息的开放标准。它是一种紧凑且自包含的方式,用于在各方之间传递信息,通常用于身份验证和授权机制。

JWT 主要由三个部分组成:

  1. 头部(Header): 包含关于令牌的元数据,例如令牌的类型(typ)和签名算法(alg)。头部是一个 JSON 对象,通常会经过 Base64 编码。

  2. 载荷(Payload): 包含要传递的信息,通常包括用户身份信息以及其他声明。载荷也是一个 JSON 对象,同样经过 Base64 编码。

  3. 签名(Signature): 使用头部和载荷以及密钥生成的签名,用于验证令牌的真实性和完整性。签名是对头部和载荷的哈希值进行签名后的结果。

这三个部分通过点号(.)连接起来形成一个字符串,即 JWT。最终的 JWT 结构如下:

1
header.payload.signature

一个简单的 jwt 例子

jwt 的使用会分为两部分:生成 token,使用 token。本文中将会使用 golang-jwt/jwt 来做 jwt 的验证。

生成 token

生成 token 的操作有以下两步:

  1. 创建一个 token 对象:使用的是 jwt.NewWithClaims 方法,它第一个参数指定了签名算法,这里使用的是 HMAC,第二个参数接受一个 jwt.MapClaims,也就是上面提到的 payload
1
2
3
4
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"foo": "bar",
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})
  1. 利用上一步的 token 对象生成成一个 jwt 的签名字符串:使用的是 tokenSignedString 方法,它接受一个 key 作为参数,这个 key 也会用于解析这一步生成的 token 字符串。
1
2
3
// 注意:这里的 secret 在实际使用的时候需要替换为自己的 key
//(一般为一个随机字符串)
tokenString, err := token.SignedString([]byte("secret"))

我们生成 token 的操作一般发生在用户登录成功之后,这个 token 会作为用户后续发起请求的凭证。

使用 token

在用户登录成功拿到 token 之后,会使用这个 token 去从服务器获取其他资源,服务器会解析这个 token 并校验。

使用 token 的步骤如下:

  1. 解析 token:使用的是 jwt.Parse 方法,它第一个参数接受 token 字符串,第二个参数是一个函数(函数的参数就是解析出来的 token 对象,函数返回解密的 key
1
2
3
4
token, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// 返回签名密钥,需要是 []byte 类型的
return []byte("secret"), nil
})
  1. 使用 token 中的 payloadpayload 也就是我们业务实际使用的数据,其他的东西只是使用 jwt token 这门技术附带的一些东西。

这里说的 payload 实际上就是 token 对象的 Claims 属性,它是一个 map,保存了我们一些业务的数据还有其他一些 jwt 本身的字段。

1
2
// claims: map[foo:bar nbf:1.4444784e+09]
fmt.Println("claims:", token.Claims)

在上面这行代码中,foo:bar 是我们的业务数据,而 nbf:1.4444784e+09jwt 本身使用的字段。nbf 表示的是在这个时间之前,这个 token 不应该被处理。

例子

我们把上面的生成 token 的代码和解析 token 的代码放到一起看看效果:

注意:实际使用中这两个操作是分开的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
func TestHmac(t *testing.T) {
// part 1:
// 创建一个新的 token 对象,指定签名方法和你想要包含的 claims
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"foo": "bar",
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})

// 生成签名字符串,secret 为签名密钥
tokenString, err := token.SignedString([]byte("secret"))

// eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJuYmYiOjE3Mjg1NjE2MDB9.9AxzBmYOuOnWfKUul57ATzjQ-sMzbggaoIdDjVzjm2Y, nil
fmt.Println(tokenString, err)

// part 2:
// 解析 token
token, err = jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

// 返回签名密钥
return []byte("secret"), nil
})
if err != nil {
panic(err)
}

// header: map[alg:HS256 typ:JWT]
fmt.Println("header:", token.Header)
// claims: map[foo:bar nbf:1.4444784e+09]
fmt.Println("claims:", token.Claims)
// signature: Nv24hvNy238QMrpHvYw-BxyCp00jbsTqjVgzk81PiYA
fmt.Println("signature:", base64.RawURLEncoding.EncodeToString(token.Signature))

if claims, ok := token.Claims.(jwt.MapClaims); ok {
// bar 1.4444784e+09
fmt.Println(claims["foo"], claims["nbf"])
} else {
panic(err)
}
}

本文不是讲解 JWT 的文章,关于 JWT 的更多细节可以参考 rfc7519

在消息推送中心 demo 加上 jwt 认证

我们使用 jwt 的目的是为了杜绝一些未知来源的连接,而在我们上一篇文章的实现中, 是先建立起连接,然后再进行 "登录" 操作的,这样就会导致即使是未授权的客户端也可以先进行连接, 这样就会在认证之前就启动了两个协程。

而如果这些连接并不是正常的连接,它们只连接但是不登录,那样就会有很多僵尸连接,这显然不是我们想要的结果。

我们可以在客户端打开连接的时候就去验证客户端的 token,如果 token 校验失败,则直接断开连接, 就可以解决上述问题了,从而避免不必要的开销。

如何在建立连接的时候就认证

要实现这个很简单,我们只需要在连接的 url 后加一个 queryString 即可,如下:

1
ws = new WebSocket('ws://127.0.0.1:8181/ws?token=123')

然后在 serveWs 中通过 r.FormValue("token") 来获取客户端传递过来的 token, 再对其进行认证,认证不通过则拒绝连接。

具体来说,我们的 serveWs 会添加以下几行代码:

1
2
3
4
5
6
jwt := NewJwt(r.FormValue("token"))
err := jwt.Parse()
if err != nil {
log.Println(fmt.Errorf("jwt parse error: %w", err))
return
}

在函数入口的地方就进行验证,验证不通过则不进行连接。这里的 jwt 定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
type Jwt struct {
Secret string
Token string
}

func NewJwt(token string) *Jwt {
return &Jwt{
Token: token,
Secret: os.Getenv("JWT_SECRET"),
}
}

func (j *Jwt) Parse() error {
_, err := jwt.Parse(j.Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

return []byte(j.Secret), nil
})

return err
}

NewJwt 中,我们从 env 中获取 JWT_SECRET,这让我们的配置可以更加灵活。

token 中加上 uid

我们知道了,在 JWT 中的 payload 是可以加入我们的自定义数据的,所以我们的 uid 其实是可以加入到 jwttoken 中的,我们只需要在用户第一次获取 token 的时候加上即可:

1
2
3
4
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"uid": "123",
"nbf": time.Date(2015, 10, 10, 12, 0, 0, 0, time.UTC).Unix(),
})

同样的,在解析 jwt token 的时候,可以从中取出这个 uid

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 第一个返回值是 uid,第二个是 error
func (j *Jwt) Parse() (string, error) {
token, err := jwt.Parse(j.Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}

return []byte(j.Secret), nil
})

if err != nil {
return "", err
}

// 获取 uid
if claims, ok := token.Claims.(jwt.MapClaims); ok {
return claims["uid"].(string), nil
} else {
return "", fmt.Errorf("jwt parse error")
}
}

最终,我们的 serveWs 演化成了如下这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 解析 jwt token,从中取得 uid
jwt := NewJwt(r.FormValue("token"))
uid, err := jwt.Parse()
if err != nil {
log.Println(fmt.Errorf("jwt parse error: %w", err))
return
}

conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(fmt.Errorf("upgrade error: %w", err))
return
}

client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256), uid: uid}
client.hub.register <- client

go client.writePump()
go client.readPump()
}

这样一来,我们的 readPump 里面就不再需要处理登录消息了,那就暂时先把 readPump 中的逻辑去掉先:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
_ = c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Time{}) // 永不超时
for {
// 从客户端接收消息
_, _, err := c.conn.ReadMessage()
if err != nil {
log.Println("readPump error: ", err)
break
}
}
}

register 里的 uid 关联

我们之前是在 readPump 方法中将 uidWebSocket 建立起关联的,但由于我们已经去掉了 readPump 中 的登录消息处理逻辑。

因此我们需要在 Hubregister 中将 uidWebSocket 建立起关联:

1
2
3
4
5
6
case client := <-h.register:
h.Lock()
h.clients[client] = true
// 建立起 uid 跟 WebSocket 的关联
h.userClients[client.uid] = client
h.Unlock()

jti

在此前使用过的一些 jwt 封装中,会有些使用 jwt 规范中的 jti 字段来传输 token 的唯一 ID, 在本文的实现中,uid 也是同样的功能。如果我们之后看到了 jti 这个字段,不要惊讶,其实这个才是规范。

测试

我们将生成 token 的代码中的 foo: bar 这个键值对修改为 uid: 123 之后,得到了如下 token

1
eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYmYiOjE0NDQ0Nzg0MDAsInVpZCI6IjEyMyJ9.9yJ-ABQGJkdnDqHo-wV-vojQFEQGt-I0dyva1w6EQ7E

我们在 ws 链接后加上这个 token 即可连接成功:

1
ws = new WebSocket('ws://127.0.0.1:8181/ws?token=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJuYmYiOjE0NDQ0Nzg0MDAsInVpZCI6IjEyMyJ9.9yJ-ABQGJkdnDqHo-wV-vojQFEQGt-I0dyva1w6EQ7E')
test

我们可以看到,ws 连接返回的状态码是 101,这就表明我们的连接成功了。

推送消息

同之前一样的,在控制台执行一个 curl 命令即可推送消息了:

1
curl "http://localhost:8181/send?uid=123&message=Hello%20World"

总结

本文我们使用 jwt 对我们的 WebSocket 进行了认证。 jwt 认证我们可以已经用过很多次了,但如果还没有用过或者还没有在 go 中用过 jwt 认证的话, 那么本文就是一个很好的入门文章。

最后,再来回顾一下本文的内容:

  • jwt 包含了三部分:头部(Header)、载荷(Payload)、签名(Signature)。
  • 在 go 中可以使用 golang-jwt/jwt 来做 jwt 的验证。
  • 创建 token 分两步:使用 jwt.NewWithClaims 创建 token 对象、使用 SignedString 生成签名字符串。
  • 使用 token 分两步:使用 jwt.Parse 解析 token、使用 token.Claims 获取 payload
  • 在建立连接的时候就进行认证,可以避免非法连接导致的开销。
  • jwt 中可以加入我们的自定义数据,比如 uid,在 jwt.NewWithClaims 中加上即可。

有了前两篇的铺垫,相信大家已经对 Golang 中 WebSocket 的使用有一定的了解了, 今天我们以一个更加真实的例子来学习如何在 Golang 中使用 WebSocket

需求背景

在实际的项目中,往往有一些任务耗时比较长,然后我们会把这些任务做异步的处理,但是又要及时给客户端反馈任务的处理进度。

对于这种场景,我们可以使用 WebSocket 来实现。其他可以使用 WebSocket 进行通知的场景还有像管理后台一些通知(比如新订单通知)等。

在本篇文章中,就是要实现一个这样的消息推送系统,具体来说,它会有以下功能:

  1. 可以给特定的用户推送:建立连接的时候,就建立起 WebSocket 连接与用户 ID 之间的关联
  2. 断开连接的时候,移除 WebSocket 连接与用户的关联,并且关闭这个 WebSocket 连接
  3. 业务系统可以通过 HTTP 接口来给特定的用户推送 WebSocket 消息:只要传递用户 ID 以及需要推送的消息即可

基础框架

下面是一个最简单版本的框架图:

arch

它包含如下几个角色:

  1. Client 客户端,也就是实际中接收消息通知的浏览器
  2. Server 服务端,在我们的例子中,服务端实际不处理业务逻辑,只处理跟客户端的消息交互:维持 WebSocket 连接,推送消息到特定的 WebSocket 连接
  3. 业务逻辑:这个实际上不属于 demo 的一部分,但是 Server 推送的数据是来自业务逻辑处理的结果

设计成这样的目的是为了将技术跟业务进行分离,业务逻辑上的变化不影响到底层技术,同样的,WebSocket 推送中心的技术上的变动也不会影响到实际的业务。

开始开发

一些结构体变动

  1. Client 结构体的变化
1
2
3
4
5
6
7
type Client struct {
hub *Hub
conn *websocket.Conn
send chan []byte
// 新增字段
uid int
}

因为我们需要建立起 WebSocket 连接与用户之间的关联,因此我们需要一个额外的字段来记录用户 ID,也就是上面的 uid 字段。

这个字段会在客户端建立连接后写入。

  1. Hub 结构体的变化
1
2
3
4
5
6
7
8
9
10
11
type Hub struct {
clients map[*Client]bool
register chan *Client
unregister chan *Client

// 记录 uid 跟 client 的对应关系
userClients map[int]*Client

// 读写锁,保护 userClients 以及 clients 的读写
sync.RWMutex
}
  1. 因为我们不再需要做广播,所以会移除 Hub 中的 broadcast 字段。

取而代之的是,我们会直接在消息推送接口中写入到 uid 对应的 Clientsend 通道。 当然我们也可以在 Hub 中另外加一个字段来记录要推送给不同 uid 的消息,但是我们的 Hubrun 方法是一个协程处理的,当需要推送的数据较多或者其中有 网络延迟的时候,会直接影响到推送给其他用户的消息。当然我们也可以改造一下 run 方法,启动多个协程来处理,不过这样比较复杂,本文会在 writePump 中处理。 (也就是建立 WebSocket 连接时的那个写操作协程)

  1. 同时为了更加快速地通过 uid 来获取对应的 WebSocket 连接,新增了一个 userClients 字段。

这是一个 map 类型的字段,keyuid,值是对应的 Client 指针。

  1. 最后新增了一个 Mutex 互斥锁

因为,在用户实际进行登录的时候需要写入 userClients 字段,而这是一个 map 类型字段,并不支持并发读写。 如果我们在接受并发连接的时候同时修改 userClients 的时候会导致 panic,因此我们使用了一个互斥锁来保证 userClients 的读写安全。

同时,clients 也是一个 map,但上一篇文章中没有使用 sync.Mutex 来保护它的读写,在并发操作的时候也是会有问题的, 所以 Mutex 同时也需要保护 clients 的读写。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
func (h *Hub) run() {
for {
select {
case client := <-h.register:
h.Lock()
h.clients[client] = true
h.Unlock()
case client := <-h.unregister:
if _, ok := h.clients[client]; ok {
h.Lock()
delete(h.userClients, client.uid)
delete(h.clients, client)
h.Unlock()
close(client.send)
}
}
}
}

最后,我们会在 Hubrun 方法中写 userClients 或者 clients 字段的时候,先获取锁,写成功的时候释放锁。

建立连接

在本篇中,将会继续沿用上一篇的代码,只是其中一些细节会有所改动。建立连接这步操作,跟上一篇的一样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
// 将 HTTP 转换为 WebSocket 连接的 Upgrader
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// 处理 WebSocket 连接请求
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 升级为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
}
// 新建一个 Client
client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)}
// 注册到 Hub
client.hub.register <- client

// 推送消息的协程
go client.writePump()
// 结束消息的协程
go client.readPump()
}

接收消息

由于我们要做的只是一个推送消息的系统,所以我们只处理用户发来的登录请求,其他的消息会全部丢弃:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
_ = c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Time{}) // 永不超时
for {
// 从客户端接收消息
_, message, err := c.conn.ReadMessage()
if err != nil {
log.Println("readPump error: ", err)
break
}

// 只处理登录消息
var data = make(map[string]string)
err = json.Unmarshal(message, &data)
if err != nil {
break
}

// 写入 uid 以及 Hub 的 userClients
if uid, ok := data["uid"]; ok {
c.uid = uid
c.hub.Lock()
c.hub.userClients[uid] = c
c.hub.Unlock()
}
}
}

在本文中,假设客户端的登录消息格式为 {"uid": "123456"} 这种 json 格式。

在这里也操作了 userClients 字段,同样需要使用互斥锁来保证操作的安全性。

发送消息

  1. 在我们的系统中,可以提供一个 HTTP 接口来跟业务系统进行交互:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 发送消息的接口
// 参数:
// 1. uid:接收消息的用户 ID
// 2. message:需要发送给这个用户的消息
http.HandleFunc("/send", func(w http.ResponseWriter, r *http.Request) {
send(hub, w, r)
})

// 发送消息的方法
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid := r.FormValue("uid")
// 参数错误
if uid == "" {
w.WriteHeader(http.StatusBadRequest)
return
}

// 从 hub 中获取 client
hub.Lock()
client, ok := hub.userClients[uid]
hub.Unlock()
// 尚未建立连接
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}

// 发送消息
message := r.FormValue("message")
client.send <- []byte(message)
}
  1. 实际发送消息的操作

writePump 方法中,我们会将从 /send 接收到的数据发送给对应的用户:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// 发送消息的协程
func (c *Client) writePump() {
defer func() {
_ = c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
// 设置写超时时间
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// 连接已经被关闭了
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}

// 获取一个发送消息的 Writer
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
// 写入消息到 Writer
w.Write(message)

// 关闭 Writer
if err := w.Close(); err != nil {
return
}
}
}
}

在这个方法中,我们会从 c.send 这个 chan 中获取需要发送给客户端的消息,然后进行发送操作。

测试

  1. 启动 main 程序
1
go run main.go
  1. 打开一个浏览器的控制台,执行以下代码
1
2
ws = new WebSocket('ws://127.0.0.1:8181/ws')
ws.send('{"uid": "123"}')

这两行代码的作用是与 WebSocket 服务器建立连接,然后发送一个登录信息。

然后我们打开控制台的 Network -> WS -> Message 就可以看到浏览器发给服务端的消息:

login
  1. 使用 HTTP 客户端发送消息给 uid 为 123 的用户

假设我们的 WebSocket 服务器绑定的端口为 8181

打开终端,执行以下命令:

1
curl "http://localhost:8181/send?uid=123&message=Hello%20World"

然后我们可以在 Network -> WS -> Message 看到接收到了消息 Hello World

hello world

结束了

到此为止,我们已经实现了一个初步可工作的 WebSocket 应用,当然还有很多可以优化的地方, 比如:

  1. 错误处理
  2. Hub 状态目前对外部来说是一个黑盒子,我们可以加个接口返回一下 Hub 的当前状态,比如当前连接数
  3. 日志:出错的时候,日志可以帮助我们快速定位问题

这些功能会在后续继续完善,今天就到此为止了。