Golang 搭建 WebSocket 应用(八) - 完整代码

本文应该是本系列文章最后一篇了,前面留下的一些坑可能后面会再补充一下,但不在本系列文章中了。

整体架构

再来回顾一下我们的整体架构:

arch

在我们的 demo 中,包含了以下几种角色:

  1. 客户端:一般是浏览器,用于接收消息;
  2. Hub:消息中心,用于管理所有的客户端连接,以及将消息推送给客户端;
  3. 调用 /send 发送消息的应用:用于将消息发送给 Hub,然后由 Hub 将消息推送给客户端。

然后,每一个 WebSocket 连接都有一个关联的读协程和写协程, 用于读取客户端发送的消息,以及将消息推送给客户端。

目录结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
├── LICENSE  // 协议
├── Makefile // 一些常用的命令
├── README.md
├── authenticator.go // 认证器
├── authenticator_test.go // 认证器测试
├── bytes.go // 字符串和 []byte 之间转换的辅助方法
├── client.go // WebSocket 客户端
├── go.mod // 项目依赖
├── go.sum // 项目依赖
├── hub.go // 消息中心
├── main.go // 程序入口
├── message // 消息记录器
│   ├── db_logger.go
│   ├── db_logger_test.go
│   ├── log.go
│   └── stdout_logger.go
├── server.go // HTTP 服务
└── server_test.go // HTTP 接口的测试

运行

注:需要 Go 1.20 或以上版本

  1. 下载依赖:

可以使用七牛云的代理加速下载。

1
go mod tidy
  1. 启动 WebSocket 服务端:
1
go run main.go

Hub 代码

最终,我们的 Hub 代码演进成了下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// bufferSize 通道缓冲区、map 初始化大小
const bufferSize = 128

// Handler 错误处理函数
type Handler func(log message.Log, err error)

// Hub 维护了所有的客户端连接
type Hub struct {
// 注册请求
register chan *Client
// 取消注册请求
unregister chan *Client
// 记录 uid 跟 client 的对应关系
userClients map[string]*Client
// 互斥锁,保护 userClients 以及 clients 的读写
sync.RWMutex
// 消息记录器
messageLogger message.Logger
// 错误处理器
errorHandler Handler
// 验证器
authenticator Authenticator
// 等待发送的消息数量
pending atomic.Int64
}

// 默认的错误处理器
func defaultErrorHandler(log message.Log, err error) {
res, _ := json.Marshal(log)
fmt.Printf("send message: %s, error: %s\n", string(res), err.Error())
}

func newHub() *Hub {
return &Hub{
register: make(chan *Client),
unregister: make(chan *Client),
userClients: make(map[string]*Client, bufferSize),
RWMutex: sync.RWMutex{},
messageLogger: &message.StdoutMessageLogger{},
errorHandler: defaultErrorHandler,
authenticator: &JWTAuthenticator{},
}
}

// 注册、取消注册请求处理
func (h *Hub) run() {
for {
select {
case client := <-h.register:
h.Lock()
h.userClients[client.uid] = client
h.Unlock()
case client := <-h.unregister:
h.Lock()
close(client.send)
delete(h.userClients, client.uid)
h.Unlock()
}
}
}

// 返回 Hub 的当前的关键指标
func metrics(hub *Hub, w http.ResponseWriter) {
pending := hub.pending.Load()
connections := len(hub.userClients)
_, _ = w.Write([]byte(fmt.Sprintf("# HELP connections 连接数\n# TYPE connections gauge\nconnections %d\n", connections)))
_, _ = w.Write([]byte(fmt.Sprintf("# HELP pending 等待发送的消息数量\n# TYPE pending gauge\npending %d\n", pending)))
}

其中:

  • Hub 中的 registerunregister 通道用于处理客户端的注册和取消注册请求;
  • Hub 中的 userClients 用于记录 uidClient 的对应关系;
  • Hub 中的 messageLogger 用于记录消息;
  • Hub 中的 errorHandler 用于处理错误;
  • Hub 中的 authenticator 用于验证客户端的身份;
  • Hub 中的 pending 用于记录等待发送的消息数量。

目前实现存在的问题:

  • registerunregister 通道被消费的时候需要加锁,这样会导致 registerunregister 变成串行的,性能不好;
  • userClients 也是需要加锁的,这样会导致 userClients 的读写也是串行的,性能不好;

对于这两个问题,前面我们讨论过,一种可行的办法分段 map,然后对每一个 map 都有一个对应的 sync.Mutex 互斥锁来保证其读写的安全。

Client 代码

Client 比较关键的方法是:

  • writePump:负责将消息推送给客户端。
  • serveWs:处理 WebSocket 连接请求。
  • send:处理消息发送请求。

writePump

这个方法会从 send 通道中获取消息,然后推送给客户端。 推送失败会调用 errorHandler 处理错误。 推送成功会将 pending 减一。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// writePump 负责推送消息给 WebSocket 客户端
//
// 该方法在一个独立的协程中运行,我们保证了每个连接只有一个 writer。
// Client 会从 send 请求中获取消息,然后在这个方法中推送给客户端。
func (c *Client) writePump() {
defer func() {
_ = c.conn.Close()
}()

// 从 send 通道中获取消息,然后推送给客户端
for {
messageLog, ok := <-c.send

// 设置写超时时间
_ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
// c.send 这个通道已经被关闭了
if !ok {
c.hub.pending.Add(int64(-1 * len(c.send)))
return
}

if err := c.conn.WriteMessage(websocket.TextMessage, StringToBytes(messageLog.Message)); err != nil {
c.hub.errorHandler(messageLog, err)
c.hub.pending.Add(int64(-1 * len(c.send)))
return
}

c.hub.pending.Add(int64(-1))
}
}

serveWs

serveWs 方法会处理 WebSocket 连接请求,然后将其注册到 Hub 中。 在连接的时候会对客户端进行认证,认证失败会断开连接。 最后会启动读写协程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// serveWs 处理 WebSocket 连接请求
func serveWs(hub *Hub, w http.ResponseWriter, r *http.Request) {
// 升级为 WebSocket 连接
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(fmt.Sprintf("upgrade error: %s", err.Error())))
return
}

// 认证失败的时候,返回错误信息,并断开连接
uid, err := hub.authenticator.Authenticate(r)
if err != nil {
_ = conn.SetWriteDeadline(time.Now().Add(time.Second))
_ = conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("authenticate error: %s", err.Error())))
_ = conn.Close()
return
}

// 注册 Client
client := &Client{hub: hub, conn: conn, send: make(chan message.Log, bufferSize), uid: uid}
client.conn.SetCloseHandler(closeHandler)
// register 无缓冲,下面这一行会阻塞,直到 hub.run 中的 <-h.register 语句执行
// 这样可以保证 register 成功之后才会启动读写协程
client.hub.register <- client

// 启动读写协程
go client.writePump()
go client.readPump()
}

send

send 是一个 http 接口,用于处理消息发送请求。 它会从 Hub 中获取 uid 对应的 Client,然后将消息发送给客户端。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// send 处理消息发送请求
func send(hub *Hub, w http.ResponseWriter, r *http.Request) {
uid := r.FormValue("uid")
if uid == "" {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("uid is required"))
return
}

// 从 hub 中获取 uid 关联的 client
hub.RLock()
client, ok := hub.userClients[uid]
hub.RUnlock()
if !ok {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(fmt.Sprintf("client not found: %s", uid)))
return
}

// 记录消息
messageLog := message.Log{Uid: uid, Message: r.FormValue("message")}
_ = hub.messageLogger.Log(messageLog)

// 发送消息
client.send <- messageLog

// 增加等待发送的消息数量
hub.pending.Add(int64(1))
}

github

完整代码可以在 github 上进行查看:https://github.com/eleven26/go-pusher