这次第三个红包笔者的考虑是让大家写一个 Modbus TCP Client 来枚举红包密码。大概计算了一下,在支付宝红包的数据规模 (10^9) 下,这个 Server 的实现要具备下列几个特点。

  1. 不同的连接需要有自己的上下文
  2. 同时支持大量连接
  3. 支持高吞吐

所以显然,现有的 Modbus Server 不能满足这个红包的需要。于是只能按照这个需求从头实现一个高性能 Modbus Server 了。

为了体验下新的语言,本次 Server 笔者用 Go 来实现。由于笔者并非专业的 Go 开发者,如果下文内容有疏漏,还请大家批评指正。

0x00:网络库的选型

Go 语言本身其实提供了网络库 net。但是作为一个 C/C++ 开发者,笔者更相信基于 epoll 的 Server 实现相比基于 goroutinenet 更适合多连接的场景。所以笔者预先进行了一番网络库选型。

通过 epoll 这个关键字,笔者找到了字节跳动基础架构团队开发的 netpoll 库。根据其官网给出的性能测试,确实在一些场景的表现比 net 更好。(其实也只是常数优化,哈哈)

netpoll 官方的多路复用echo QPS测试:

多路复用echo QPS测试

netpoll 官方的连接池性能测试,可以看到,并不是所有场景netpoll性能都最优:

连接池性能测试

以及看到它有一个“零拷贝”Buffer操作接口,更是打动了我这个 C/C++ 开发者的心。(当然,这里“零拷贝”并不能解决内核拷用户态的拷贝)

没时间对比到底在笔者这个场景到底是更适合 netpoll 还是 net,反正最后还是凭直觉用 netpoll 来实现本次的 Modbus Server 了。

netpoll 官方文档写了一个server框架:

func main() {
    network, address := "tcp", ":502"
    listener, _ := netpoll.CreateListener(network, address)

    eventLoop, _ := netpoll.NewEventLoop(
        handle,
        netpoll.WithOnPrepare(prepare),
        netpoll.WithReadTimeout(time.Second),
    )

    // start listen loop ...
    eventLoop.Serve(listener)
}

var _ netpoll.OnPrepare = prepare
var _ netpoll.OnRequest = handle
var close_cb netpoll.CloseCallback = close_conn

func prepare(connection netpoll.Connection) context.Context {
    ctx := context.Background()
    connection.AddCloseCallback(close_cb)
    // prepare ctx
    fmt.Printf("Incoming connection from %s\n", connection.RemoteAddr().String())
    return ctx
}

func close_conn(conn netpoll.Connection) error {
    // destroy ctx
    return nil
}

func handle(ctx context.Context, connection netpoll.Connection) error {
    reader := connection.Reader()
    // decode modbus packet
    defer reader.Release()
    return nil
}

0x01:实现 per connection 上下文

考虑到每个红包的玩家之间不能互相影响,以及key的取值范围需要玩家多线程求解,所以这里 Server 需要实现为每个连接持有自己独立的上下文,来隔离不同连接之间的影响。简单起见,笔者使用一个全局的 sync.Map 来存储每个连接的上下文数据。

将需要保存的上下文定义为:

var svc_ctx sync.Map

type redpack struct {
    pack uint32
    valid int16
}

type modbus_data struct {
    header_cnt uint32 
    redpack1 redpack
    redpack3 redpack
}

prepare 函数中为每个新的连接初始化上下文:

func prepare(connection netpoll.Connection) context.Context {
    ctx := context.Background()
    connection.AddCloseCallback(close_cb)
    data := new(modbus_data)
    data.redpack1.pack = valid_pack1
    data.redpack1.valid = 1
    data.redpack3.pack = 0
    data.redpack3.valid = 0
    svc_ctx.Store(connection, data)
    return ctx
}

handle 函数中,我们可以通过连接检索和使用此数据:

data_raw, _ := svc_ctx.Load(connection)
var data *modbus_data = data_raw.(*modbus_data)

最后,当连接关闭时,在 close_conn 函数中清除存储的上下文数据,避免内存泄漏:

func close_conn(conn netpoll.Connection) error {
    svc_ctx.Delete(conn)
    return nil
}

0x02:解析 Modbus TCP 报文

2022年家里装了一套新风,为了将这套新风接入智能家居,我只能硬着头皮适配了一遍这个新风唯一对外的 RS485 Modbus RTU 协议到智能家居平台。这是本次红包用 Modbus 协议的“创意”来源。

考虑到 Modbus 协议族中只有 Modbus TCP 协议才有一个特征性的协议端口(502)可以作为一些提示,所以这次笔者在 Modbus RTU over TCP 和 Modbus TCP 两个协议当中选择了后者。

虽然 Modbus 协议里面没有 RPC 这个概念,但实际上,Modbus 协议的每个 Transaction 都可以看作一个 RPC 请求。因此,实现 Modbus 的 Server 可以使用很多 RPC Server 的理念。

那么,要实现一个 RPC Server,个人认为有这样两个环节:

  1. RPC 报文封装处理:将 RPC 请求的具体字段从协议封装中反序列化出来,并将响应结果打包回去。例如一些 RPC 框架会使用 Protobuf/thrift 包装参数,在 Server 侧就要对应进行反序列化与序列化的操作;
  2. RPC 任务管理:把不同的RPC任务调度到适当的线程、进程中完成。

分开来讲讲本次 Modbus Server 在这两个环节上实现的考量:

RPC 报文封装处理

Modbus这种字段固定的报文,如果是在C/C++里,配上零拷贝buffer,可以很低开销地完成报文解析。直接把零拷贝buffer强转成协议struct,再用大端字节序读取/写入具体字段,这样可以最小化CPU开销。但是Go语言当中似乎不推荐直接进行强转指针的操作,所以笔者用binary.Read来完成报文转换的工作。

首先是解析MBAP头,并进行必要的头校验。

modbus_hdr_raw, _ := reader.ReadBinary(7)
modbus_header := modbus_hdr{}
buf := bytes.NewBuffer(modbus_hdr_raw[:])
_ = binary.Read(buf, binary.BigEndian, &modbus_header)

随后是解析 opcode。后续可以使用 switch 语句来根据 opcode 调用相应的处理函数进行进一步的解析。后续的解析同样用reader.ReadBinarybinary.Read配合进行。

opcode_raw, _ := reader.ReadBinary(1)
opcode := opcode_raw[0]
switch opcode {
case 0x3:
    modbus_read(data, reader, connection, &modbus_header)
case 0x6:
    modbus_single_write(data, reader, connection, &modbus_header)
case 0x10:
    modbus_multiple_write(data, reader, connection, &modbus_header)
default:
    modbus_error(connection, &modbus_header, int8(opcode), 0x1)
}

进行消息响应的方式也类似,以返错为例:

func modbus_respheader(writer netpoll.Writer, header *modbus_hdr, len uint16) error {
    header.Length = len + 1
    buf := new(bytes.Buffer)
    _ = binary.Write(buf, binary.BigEndian, header)
    writer.WriteBinary(buf.Bytes())
    return nil;
}

func modbus_error(conn netpoll.Connection, header *modbus_hdr, opcode int8, error_code int8) error {
    modbus_respheader(conn.Writer(), header, 2)
    conn.Writer().WriteByte(uint8(opcode) + 0x80)
    conn.Writer().WriteByte(uint8(error_code))
    conn.Writer().Flush()
    return nil;
}

RPC 任务管理

由于这次的 Server 实现都是内存操作,所以这里的任务都相当轻量。加之netpoll框架本身提供了异步机制,因此笔者在这里没有引入更复杂的线程/协程机制来执行RPC,而是简单地边解析协议边处理 Modbus 请求。以 0x3 命令 read 为例:

func modbus_read(data *modbus_data, reader netpoll.Reader, conn netpoll.Connection, header *modbus_hdr) error {
    if header.DeviceId != 1 && header.DeviceId != 3 {
        modbus_error(conn, header, 0x3, 0x2)
        return nil;
    }
    if header.Length != 6 {
        modbus_error(conn, header, 0x3, 0x3)
        return nil;
    }
    start_addr_raw, _ := reader.ReadBinary(2)
    var start_addr uint16 = 0
    buf := bytes.NewBuffer(start_addr_raw[:])
        _ = binary.Read(buf, binary.BigEndian, &start_addr)
    reg_num_raw, _ := reader.ReadBinary(2)
    var reg_num uint16 = 0
    buf = bytes.NewBuffer(reg_num_raw[:])
        _ = binary.Read(buf, binary.BigEndian, &reg_num)
    if start_addr > 0x2 || start_addr < 0 || reg_num < 0 || start_addr + reg_num > 0x3 {
        modbus_error(conn, header, 0x3, 0x2)
        return nil;
    }
    pack := &data.redpack1
    if header.DeviceId == 3 {
        pack = &data.redpack3
    }
    writer := conn.Writer()
    modbus_respheader(writer, header, 2 + reg_num * 2)
    writer.WriteByte(0x3)
    writer.WriteByte(byte(reg_num * 2))
    buf = new(bytes.Buffer)
    _ = binary.Write(buf, binary.BigEndian, pack)
    writer.WriteBinary(buf.Bytes()[start_addr * 2:(start_addr + reg_num) * 2])
    writer.Flush()
    return nil;
}

0x03 性能测试

按说应该先完成性能测试再上线给大家玩,但由于时间实在有限,所以就直接上线了。于是只能通过线上的监控数据对性能做一个评估了。

下面是@NickCao的程序执行期间的监控:

监控数据

可以看到,当服务器 5M 带宽打满后,CPU 占用率仅为 2.4%。最终服务器带宽的QoS成了整个系统的瓶颈。这个CPU开销甚至没有必要买独享型实例,使用突发实例规格甚至都可以支持整个业务。

完整代码

最后,我把完整的代码贴给大家,希望能对大家有些帮助。

package main

import (
    "context"
    "time"
    "fmt"
    "sync"
    "bytes"
    "encoding/binary"

    "github.com/cloudwego/netpoll"
)

const valid_pack1 = 25734994
const valid_pack3 = 56730894

type redpack struct {
    pack uint32
    valid int16
}

type modbus_data struct {
    header_cnt uint32 
    redpack1 redpack
    redpack3 redpack
}

type modbus_hdr struct {
    PktId uint16
    ProtocolMagic uint16
    Length uint16
    DeviceId uint8
}

var svc_ctx sync.Map

func main() {
    network, address := "tcp", ":502"
    listener, _ := netpoll.CreateListener(network, address)

    eventLoop, _ := netpoll.NewEventLoop(
        handle,
        netpoll.WithOnPrepare(prepare),
        netpoll.WithReadTimeout(time.Second),
    )

    // start listen loop ...
    eventLoop.Serve(listener)
}

var _ netpoll.OnPrepare = prepare
var _ netpoll.OnRequest = handle
var close_cb netpoll.CloseCallback = close_conn

func prepare(connection netpoll.Connection) context.Context {
    ctx := context.Background()
    connection.AddCloseCallback(close_cb)
    data := new(modbus_data)
    data.redpack1.pack = valid_pack1
    data.redpack1.valid = 1
    data.redpack3.pack = 0
    data.redpack3.valid = 0
    fmt.Printf("Incoming connection from %s\n", connection.RemoteAddr().String())
    svc_ctx.Store(connection, data)
    return ctx
}

func close_conn(conn netpoll.Connection) error {
    data, _ := svc_ctx.Load(conn)
    svc_ctx.Delete(conn)
    fmt.Printf("Connection from %s closed, with latest pkt_id = %d\n", conn.RemoteAddr().String(), data.(*modbus_data).header_cnt - 1)
    return nil
}

func modbus_respheader(writer netpoll.Writer, header *modbus_hdr, len uint16) error {
    header.Length = len + 1
    buf := new(bytes.Buffer)
    _ = binary.Write(buf, binary.BigEndian, header)
    writer.WriteBinary(buf.Bytes())
    return nil;
}

func modbus_read(data *modbus_data, reader netpoll.Reader, conn netpoll.Connection, header *modbus_hdr) error {
    if header.DeviceId != 1 && header.DeviceId != 3 {
        modbus_error(conn, header, 0x3, 0x2)
        return nil;
    }
    if header.Length != 6 {
        modbus_error(conn, header, 0x3, 0x3)
        return nil;
    }
    start_addr_raw, _ := reader.ReadBinary(2)
    var start_addr uint16 = 0
    buf := bytes.NewBuffer(start_addr_raw[:])
        _ = binary.Read(buf, binary.BigEndian, &start_addr)
    reg_num_raw, _ := reader.ReadBinary(2)
    var reg_num uint16 = 0
    buf = bytes.NewBuffer(reg_num_raw[:])
        _ = binary.Read(buf, binary.BigEndian, &reg_num)
    if start_addr > 0x2 || start_addr < 0 || reg_num < 0 || start_addr + reg_num > 0x3 {
        modbus_error(conn, header, 0x3, 0x2)
        return nil;
    }
    pack := &data.redpack1
    if header.DeviceId == 3 {
        pack = &data.redpack3
    }
    writer := conn.Writer()
    modbus_respheader(writer, header, 2 + reg_num * 2)
    writer.WriteByte(0x3)
    writer.WriteByte(byte(reg_num * 2))
    buf = new(bytes.Buffer)
    _ = binary.Write(buf, binary.BigEndian, pack)
    writer.WriteBinary(buf.Bytes()[start_addr * 2:(start_addr + reg_num) * 2])
    writer.Flush()
    return nil;
}

func modbus_single_write(data *modbus_data, reader netpoll.Reader, conn netpoll.Connection, header *modbus_hdr) error {
    if header.DeviceId != 1 && header.DeviceId != 3 {
        modbus_error(conn, header, 0x6, 0x2)
        return nil
    }
    if header.Length != 6 {
        modbus_error(conn, header, 0x6, 0x3)
        return nil
    }
    addr_raw, _ := reader.ReadBinary(2)
    var addr uint16 = 0
    buf := bytes.NewBuffer(addr_raw[:])
    _ = binary.Read(buf, binary.BigEndian, &addr)
    if addr > 0x1 || addr < 0 {
        modbus_error(conn, header, 0x6, 0x2)
        return nil
    }
    value_raw, _ := reader.ReadBinary(2)
    var value uint16 = 0
    buf = bytes.NewBuffer(value_raw[:])
    _ = binary.Read(buf, binary.BigEndian, &value)
    pack := &data.redpack1
    if header.DeviceId == 3 {
        pack = &data.redpack3
    }
    if addr == 0 {
        pack.pack &= 0xffff
        pack.pack |= uint32(value) << 16  
    } else {
        pack.pack &= 0xffff0000
        pack.pack |= uint32(value)
    }
    pack.valid = 0
    if header.DeviceId == 1 {
        if pack.pack == valid_pack1 {
            pack.valid = 1
        }
    } else {
        if pack.pack == valid_pack3 {
            pack.valid = 1
        }
    }
    modbus_respheader(conn.Writer(), header, 5)
    conn.Writer().WriteByte(0x6)
    conn.Writer().WriteBinary(addr_raw)
    conn.Writer().WriteBinary(value_raw)
    conn.Writer().Flush()
    return nil;
}

func modbus_multiple_write(data *modbus_data, reader netpoll.Reader, conn netpoll.Connection, header *modbus_hdr) error {
    if header.DeviceId != 1 && header.DeviceId != 3 {
        modbus_error(conn, header, 0x10, 0x2)
        return nil
    }
    if header.Length < 7 {
        modbus_error(conn, header, 0x10, 0x3)
        return nil
    }
    start_addr_raw, _ := reader.ReadBinary(2)
    var start_addr uint16 = 0
    buf := bytes.NewBuffer(start_addr_raw[:])
    _ = binary.Read(buf, binary.BigEndian, &start_addr)
    reg_num_raw, _ := reader.ReadBinary(2)
    var reg_num uint16 = 0
    buf = bytes.NewBuffer(reg_num_raw[:])
    _ = binary.Read(buf, binary.BigEndian, &reg_num)
    if start_addr > 0x1 || start_addr < 0 || reg_num < 0 || start_addr + reg_num > 0x2 {
        modbus_error(conn, header, 0x10, 0x2)
        return nil
    }
    if header.Length != 7 + reg_num * 2 {
        modbus_error(conn, header, 0x10, 0x3)
        return nil
    }
    len, _ := reader.ReadBinary(1)
    if uint16(len[0]) != reg_num * 2 {
        modbus_error(conn, header, 0x10, 0x3)
        return nil
    }
    pack := &data.redpack1
    if header.DeviceId == 3 {
        pack = &data.redpack3
    }
    for i := start_addr; i < reg_num; i++ {
        value_raw, _ := reader.ReadBinary(2)
        var value uint16 = 0
        buf = bytes.NewBuffer(value_raw[:])
        _ = binary.Read(buf, binary.BigEndian, &value)
        if i == 0 {
            pack.pack &= 0xffff
            pack.pack |= uint32(value) << 16
        } else {
            pack.pack &= 0xffff0000
            pack.pack |= uint32(value)
        }
    }
    pack.valid = 0
    if header.DeviceId == 1 {
        if pack.pack == valid_pack1 {
            pack.valid = 1
        }
    } else {
        if pack.pack == valid_pack3 {
            pack.valid = 1
        }
    }
    writer := conn.Writer()
    modbus_respheader(writer, header, 5)
    writer.WriteByte(0x10)
    writer.WriteBinary(start_addr_raw)
    writer.WriteBinary(reg_num_raw)
    writer.Flush()
    return nil;
}

func modbus_error(conn netpoll.Connection, header *modbus_hdr, opcode int8, error_code int8) error {
    modbus_respheader(conn.Writer(), header, 2)
    conn.Writer().WriteByte(uint8(opcode) + 0x80)
    conn.Writer().WriteByte(uint8(error_code))
    conn.Writer().Flush()
    return nil;
}

func handle(ctx context.Context, connection netpoll.Connection) error {
    reader := connection.Reader()
    data_raw, _ := svc_ctx.Load(connection)
    var data *modbus_data = data_raw.(*modbus_data)
    modbus_hdr_raw, _ := reader.ReadBinary(7)
    modbus_header := modbus_hdr{}
    buf := bytes.NewBuffer(modbus_hdr_raw[:])
    _ = binary.Read(buf, binary.BigEndian, &modbus_header)
    if modbus_header.ProtocolMagic != 0 {
        modbus_error(connection, &modbus_header, 0, 0x3)
        defer reader.Release()
        return nil
    }
    data.header_cnt = data.header_cnt + 1
    opcode_raw, _ := reader.ReadBinary(1)
    opcode := opcode_raw[0]
    switch opcode {
    case 0x3:
        modbus_read(data, reader, connection, &modbus_header)
    case 0x6:
        modbus_single_write(data, reader, connection, &modbus_header)
    case 0x10:
        modbus_multiple_write(data, reader, connection, &modbus_header)
    default:
        modbus_error(connection, &modbus_header, int8(opcode), 0x1)
    }
    defer reader.Release()
    return nil
}