动手写RPC框架


本文学习自geektutu , 大部分内容摘自 7天用Go从零实现RPC框架GeeRPC | 极客兔兔 (geektutu.com),并在此基础上稍加个人的学习经历和理解

作者仓库地址:geektutu/7days-golang: 7 days golang programs from scratch (web framework Gee, distributed cache GeeCache, object relational mapping ORM framework GeeORM, rpc framework GeeRPC etc) 7天用Go动手写/从零实现系列 (github.com)

day0. 浅谈RPC框架

前几天在学 6.824 时,发现有太多内容是我完全没接触过的,然后其中涉及到 RPC 的内容又比较多,忽然想起 geektutu 出过 "七天实现 RPC 框架" 的文章,马上转坑来学习。

1. 谈谈RPC框架

RPC (Remote Procedure Call,远程过程调用) 是一种计算机通信协议,允许调用不同进程空间的程序。RPC 的客户端和服务器可以在一台机器上,也可以在不同的机器上。程序员使用时,就像调用本地程序一样,无需关注内部实现的细节。

不同应用程序间的通信方式有很多,例如浏览器和服务器间广泛用基于 HTTP协议的 Restful API。与 RPC相比,Restful API 有相对统一的标准,因而更通用,兼容性更好,支持不同的语言。HTTP 协议是基于文本的,一般具备更好的可读性。但是缺点也很明显:

  • Restful 接口要额外的定义,无论是客户端还是服务端,都需要额外的代码来处理,而 RPC 调用则更接近于直接调用。
  • 基于 HTTP 协议的 Restful 报文冗余,承载了过多无效信息,而RPC 通常使用自定义的协议格式,减少冗余报文。
  • RPC 可以采用更高效的序列化协议,将文本转为二进制传输,获得更高的性能。
  • 因为 RPC 的灵活性,所以更容易扩展和集成诸如注册中心,负载均衡等功能。

2. RPC框架需要解决什么问题

RPC 需要解决什么问题?或者换个说法,为什么要RPC 框架?

我们可以想象下两台机器上,两个程序之间要通信,那么首先,需要确定采用的传输协议是什么?如果这两个程序位于不同的机器,那么一般会选择 TCP 协议活 HTTP 协议;那如果两个程序位于相同的机器,也可以选择 Unix Socket 协议。传输协议确定后,还需要确定报文的编码格式,比如采用最常用的json 或xml,那如果报文比较大,还可能会选择 protobuf 等其他的编码方式,甚至编码之后,再进行压缩。接收端获取报文则需要相反的过程,先解压再解码。

解决了传输协议和保温编码的问题,接下来还需要解决一系列的可用性问题,例如,连接超时了怎么办?是否支持异步请求和并发?

如果服务端的实例很多,客户端并不关心这些实例的地址和部署位置,只关心自己能否获取到期待的结果,那就引出了注册中心 (registry) 和负载均衡 (load balance) 的问题。简单地说,即客户端和服务端相互不感知对方的存在,服务端启动时将自己注册到注册中心,客户端调用时,从注册中心获取到所有可用的实例,选择一个来调用。这样服务端和客户端只需要感知注册中心的存在就够了。注册中心还需要实现服务动态添加,删除,使用 "心跳机制" 确保服务处于可用状态等功能。

再进一步,假设服务端是不同的团队提供的,如果没有统一的RPC 框架,各个团队的服务提供方就需要各自实现一套消息编解码,连接池,收发线程,超时处理等 "业务之外" 的重复技术劳动,造成整体的低效。因此,"业务之外" 的这部分公共的能力,即是RPC 框架所需要具备的能力。

day1. 服务端与消息编码

  • 使用encoding/gob实现消息的编解码 (序列化与反序列化)。
  • 实现一个简易的服务端,仅接受消息,不处理,代码约200行。

消息的序列化与反序列化

一个典型的RPC 调用如下

err = client.call("Arith.Multiply", args, &reply)

客户端发送的请求包括服务名Arith,方法名Multiply,参数args三个,服务端的响应包括错误error,返回值reply 2个。我们将请求和响应中的参数和返回值抽象为 body,剩余的信息放在 header 中,那么就可以抽象出数据结构 Header:

day1/codec/codec.go

package codec

import "io"

type Header struct {
    ServiceMethod string // format "Service.Method"
    Seq           string // sequence number chosen by client
    Error         string
}
  • ServiceMethod 是服务名和方法名,通常与 Golang 中的结构体和方法相映射。
  • Seq 是请求的序号,也可以认为是某个请求的 ID,用来区分不同的请求。
  • Error 是错误信息,客户端设置为空,

我们将和消息编解码相关的代码都放到 codec 子目录中,在此之前,还需要在geerpc项目根目录下使用 go mod init geerpc 初始化项目,方便后续子 package 之间的引用。

进一步,抽象出对消息体进行编解码的接口 Codec,抽象出接口是为了实现不同的 Codec 实例:

type Codec interface {
    io.Closer
    ReadHeader(*Header) error
    ReadBody(interface{}) error
    Write(*Header, interface{}) error
}

紧接着,抽象出 Codec 的构造函数,客户端和服务端可以通过 Codec 的Type得到构造函数,从而创建 Codec 实例。这部分代码和工厂模式类似,与工厂模式不同的是,返回的是构造函数,而非实例。

type NewCodecFunc func(io.ReadWriteCloser) Codec 
type Type string

const (
	GobType  Type = "application/gob"
    JsonType Type = "application/json"
)

var NewCodecFuncMap map[Type]NewCodecFunc

func init() {
    NewCodecFuncMap = make(map[Type]NewCodecFunc)
    NewCodecFuncMap[GobType] = NewGobCodec // 初始化map,实例化一个GobCodec对象
}

我们定义了两种 Codec,GobJson,但是实际代码只实现了Gob一种,事实上,2者的实现非常接近,甚至只需把gob换成json即可。

首先定义GobCodec结构体,这个结构体由四部分构成,conn是由构建函数传入,通常是通过 TCP 或者 Unix 建立 socket 时得到的链接实例,dec 和 enc 对应 gob的 Decoder 和 Encoder,buf 是为了防止阻塞而创建的带缓冲的Writer,一般这么做都能提升性能。

day1/codec/gob.go

package codec

import (
    "bufio"
    "encoding/gob"
    "io"
    "log"
)

type GobCodec struct {
    conn io.ReadWriteCloser
    buf  *bufio.Writer
    dec  *gob.Decoder
    enc  *gob.Encoder
}

var _ Codec = (*GobCodec)(nil)
// 这里的写法的含义是,用来检测GobCodec是否实现了Codec接口,如果没有实现该接口则编译报错

func NewGobCodec(conn io.ReadWriteCloser) Codec {
    buf := bufio.NewWriter(conn)
    return &GobCodec {
        conn: conn,
        buf:  buf,
        dec:  gob.NewDecoder(conn),
        enc:  gob.NewEncoder(buf),
    }
}

接着实现ReadHeaderReadBodyWriteClose方法。

func (c *GobCodec) ReadHeader(h *Header) error {
    return c.dec.Decode(h)
}

func (c *GobCodec) ReadBody(body interface{}) error {
    return c.dec.Decode(body)
}

func (c *GobCodec) Write(h *Header, body interface{}) (err error) {
    defer func() {
        _ = c.buf.Flush() // 将缓存区内容写入文件,返回类型为error 
        if err != nil {
            _ = c.Close()
        }
    }()
    if err != c.enc.Encode(h); err != nil {
        log.Println("rpc codec: gob error encoding header:", err)
        return err
    }
    if err := c.enc.Encode(body); err != nil {
        log.Println("rpc codec: gob error encoding body:", err)
        return err
    }
    return nil
}

func (c *GobCodec) Close() error {
    return c.conn.Close() // 返回一个err,具体的Close()在io.go中有重写
}

通信过程

客户端与服务端的通信需要协商一些内容,例如 HTTP 报文,分为 header 和 body 两部分,body 的格式和长度通过 header 中的Content-TypeContent-Length指定,服务端通过解析 header 就能够知道如何从 body 中读取需要的信息。对于RPC 协议来说,这部分协商是需要自主设计的。为了提升性能,一般在报文的最开始会规划固定的字节,来协商相关的信息。比如第1个字节用来表示序列化方式,第2个字节表示压缩方式,第3-6字节表示 header 的长度,7-10字节表示body 长度。

对于 GeeRPC 来说,目前需要协商的唯一一项内容时消息的编解码方式。我们将这部分信息,放到结构体Option中承载。目前,已经进入到服务端的实现阶段了。

day1/server.go

package geerpc

const MagicNumber = 0x23bef5c

type Option struct {
    MagicNumber int        // MagicNumber marks this's a geerpc request
    CodecType   codec.Type // client may choose different Codec to encode body
}

var DefaultOption = &Option {
    MagicNumber: MagicNumber,
    CodecType:   codec.GobType,
}

一般来说,设计协商协议的这部分信息,需要设计固定的字节来传输。但是为了实现上更简单, GeeRPC 客户端固定采用 JSON 编码 Option,后续的 header 和 body 的编码方式由 Option 中的 CodeType指定,服务端首先使用 JSON 解码 Option,然后通过 Option 的 CodeType 解码剩余内容。即报文将以这样的形式发送:

| Option{MagicNumber: xxx, CodecType: xxx} | Header{ServiceMethod ...} | Body interface{} |
| <-------    固定 JSON 编码       -------> | <--------  编码方式由 CodeType决定   -------> |

在一次连接中,Option 固定在报文的最开始,Header 和 Body 可以有很多个,即报文可能是这样的。

| Option | Header1 | Body1 | Header2 | Body2 | ...

服务端的实现

通信过程已经定义清楚了,那么服务端的实现就比较直接了。

day1/server.go

// Server represents an RPC Server.
type Server struct{}

// NewServer returns a new Server.
func NewServer() *Server {
    return &Server{}
}

// DefaultServer is the default instance of *Server
var DefaultServer = NewServer()

// Acccept accepts connections on the listener and serves requests
// for each incoming connection
func (server *Server) Accept(lis net.Listener) {
    // for循环等待socket连接建立
    for {
        conn, err := lis.Accept()
        if err != nil {
            log.Println("rpc server: accept error:", err)
            return 
        }
        go server.ServeConn(conn)
    }
}

// Accept accepts connections on the listener and serves requests
// for each incoming connection
func Accept(lis net.Listener) {
    DefaultServer.Accept(lis)
}
  • 首先定义了结构体Server,没有任何的成员字段。
  • 实现了Accept方式,net.Listener作为参数,for 循环等待 socket 连接建立,并开启子协程处理,处理过程交给了ServerConn方法。
  • DefaultServer 是一个默认的Server实例,主要为了用户使用方便。

如果想启动服务,过程是很简单的,传入 listener 即可,tcp 协议和 unix 协议都支持。

lis, _ := net.Listen("tcp", ":9999")
geerpc.Accept(lis)

ServeConn的实现就和之前讨论的通信过程紧密相关了,首先使用json.NewDecoder反序列化得到 Option 实例,检查 MagicNumber 和 CodeType的值是否正确。然后根据 CodeType 得到对应的消息编解码器,接下来的处理就交给serverCodec

// ServeConn runs the serer on a single connection
// ServeConn blocks, serving the connection until the client hangs up
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
    defer func() {
        _ = conn.Close()
    }()
    if err := json.NewDecoder(conn).Decode(&opt); err != nil {
        log.Println("rpc server: options error:", err)
        return
    }
    // 检查Option的参数是否正确
    if opt.MagicNumber != MagicNumber {
        log.Printf("rpc server: invalid magic number %x", opt.MagicNumber)
        return
    }
    f := codec.NewCodecFuncMap[opt.CodecType]
    if f == nil {
        log.Printf("rpc server: invalid codec type %s", opt.CdoecType)
        return
    }
    server.serveCodec(f(conn))
}

// invalidRequest is a placeholder for response argv when error occurs
var invalidRequest = struct{}{}

// 注意这里要改serveCodec的入参
func (server *Server) serveCodec(cc codec.Codec, opt *Option) {
    sending := new(sync.Mutex) // make sure to send a complete response
    // 加入一个互斥锁避免多个回复报文交织在一起
    wg := new(sync.WaitGroup) // wait until all request are handled
    for {
        req, err := server.readRequest(cc) // 读取请求
        if err != nil {
            if req == nil {
                break // it's not possible to recover, so close the connection
            }
            req.h.Error = err.Error()
            server.sendResponse(cc, req.h, invalidRequest, sending)
            // 回复请求
            continue
        }
        wg.Add(1)
        go server.handleRequest(cc, req, sending, wg, opt.HandleTimeout)
        // 加入一个处理请求协程
        // 这里注意要新增一个超时时间
    }
    wg.Wait()
    _ = cc.Close()
}

serveCodec的过程很简单,主要包含三阶段:

  • 读取请求 readRequest
  • 处理请求 handleRequest
  • 回复请求 sendRequest

之前提到过,再一次连接中,允许收到多个请求,即多个 request header 和 request body,因此这里使用了 for 无限制地等待请求的到来,直到发生错误 (例如连接被关闭,接收到的报文有问题等),这里需要注意的点有三个:

  • handleRequest 使用了协程并发执行请求。
  • 处理请求是并发的,但是回复请求的报文必须是逐个发送的,并发容易导致多个回复报文交织在一起,客户端无法解析。在这里使用锁 (sending) 保证。
  • 尽力而为,只有在 header 解析失败时,才终止循环。
// request stores all infomation of a call
type request struct {
    h            *codec.Header // header of request
    argv, replyv reflect.Value // argv and replyv of request
    // Value also is a struct
}

func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
    var h codec.Header
    if err := cc.ReadHeader(&h); err != nil {
        if err != io.EOF && err != io.ErrUnexpectedEOF {
            log.Println("rpc server: read header error:", err)
        }
        return nil, err
    }
    return &h, nil
}

func (server *Server) readRequest(cc codec.Codec) (*request, error) {
    h, err := server.readRequestHeader(cc)
    if err != nil {
        return nil, err
    }
    req := &reqeust{h: h}
    // TODO: now we don't know the type of request argv
    // day1, just suppose it's string
    req.argv = reflect.New(reflect.TypeOf(""))
    if err = cc.ReadBody(req.argv.Interface()); err != nil {
        log.Println("rpc server: read argv err:", err)
    }
    return req, nil
}

func (server *Server) sendResponse(cc codec.Cdoec, h *codec.Header, body interface{}, sneding *sync.Mutex) {
    sending.Lock()
    defer sending.Unlock()
    if err := cc.Write(h, body); err != nil {
        log.Println("rpc server: write response error:", err)
    }
}

func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
    // TODO, should call registered rpc methods to get the right replyv
    // day1, just print argv and send a hello message
    defer wg.Done()
    log.Println(req.h, req.argv.Elem())
    req.replyv = reflect.ValueOf(fmt.Sprintf("geerpc resp %d", req.h.Seq))
    server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
}

目前还不能判断 body 的类型,因此在 readRequest 和 handleRequest 中,day1 将在 body作为字符串处理。接收到请求,打印 header,并回复geerpc resp ${req.h.Seq}。这一部分后续再实现。

main 函数 (一个简易的客户端)

day1 的内容就到此为止了,在这里我们已经实现了一个消息的编解码器GobCodec,并且客户端与服务端实现了简单的协议交换 (protocol exchange),即允许客户端使用不同的编码方式。实现了服务端的雏形,建立连接,读取、处理并回复客户端的请求。

接下来,我们在 main 函数中看看如何使用刚实现的 GeeRPC。

day1/main/main.go

package main

import (
    "encoding/json"
    "geerpc"
    "geerpc/codec"
    "log"
    "net"
    "time"
)

func startServer(addr chan string) {
    // pick a free port
    l, err := net.Listen("tcp", ":0")
    if err != nil {
        log.Fatal("network error: ", err)
    }
    log.Println("start rpc server on", l.Addr())
    addr <- l.Addr().String()
    geerpc.Accept(l) // 注意这里是不是数字1,是字母l
}

func main() {
    addr := make(chan string)
    go startServer(addr)
    
    // in fact, following code is like a simple geerpc client
    conn, _ := net.Dial("tcp", <- addr)
    defer func() {
        _ = conn.Close()
    }()
    
    time.Sleep(time.Second)
    // send options
    _ = json.NewEncoder(conn).Encode(geerpc.DefaultOption)
    cc := codec.NewGobCodec(conn)
    // send request & receive response
    for i := 0; i < 5; i++ {
        h := $codec.Header {
            ServiceMethod: "Foo.Sum",
            Seq:           uint64(i),
        }
        _ = cc.Write(h, fmt.Sprintf("geerpc req %d", h.Seq))
        _ = cc.ReadHeader(h)
        var reply string
        _ = cc.ReadBody(&reply)
        log.Println("reply:", reply)
    }
}
  • startServer中使用了信道addr,确保服务端端口监听成功,客户端再发起请求。
  • 客户端首先发送Option进行协议交换,接下来发送消息头h := &codec.Header{},和消息体geerpc req ${h.Seq}
  • 最后解析服务端的相应reply,并打印出来。

执行结果如下:

start rpc server on [::]63662
&{Foo.Sum 0 } geerpc req 0
reply: geerpc resp 0
&{Foo.Sum 1 } geerpc req 1
reply: geerpc resp 1
&{Foo.Sum 2 } geerpc req 2
reply: geerpc resp 2
&{Foo.Sum 3 } geerpc req 3
reply: geerpc resp 3
&{Foo.Sum 4 } geerpc req 4
reply: geerpc resp 4

day2. 支持并发和异步的客户端

Call 的设计

net/rpc而言,一个函数需要能够被远程调用,需要满足如下五个条件:

  • the method's type is exported
  • the method is exported
  • the method has two arguments, both exported (or builtin) types
  • the method's second arguments is a pointer
  • the method has return type error

更直观一点:

func (t *T) MethodName(argType T1, replyType *T2) error

根据上述需求,首先我们封装了结构体 Call 来承载一次 RPC 调用所需要的信息。

day2/client.go

// Call represents an active RPC 
type Call struct {
    Seq           uint64
    ServiceMethod string      // format "<service>.<method>"
    Args          interface{} // arguments to the function
    Reply         interface{} // reply from the fucntion
    Error         error       // if error occurs, it will be set
    Done          chan *Call  // Strobes when call is complete
} 

func (call *Call) done() {
    call.Done <- call
}

为了支持异步调用,Call 结构体中添加了一个字段 Done,Done 的类型是chan *Call,当调用结束时,会调用call.done()通知调用方。

实现 Client

接下来,我们将实现 GeeRPC 客户端最核心的部分 Client。

// Client represents an RPC Client
// There may be multipie outstanding Calls associated
// with a single Client, and a Client may be used by
// multipie goroutines simultaneously
type Client struct {
    cc       codec.Codec
    opt      *Option
    sending  sync.Mutex // protect following
    header   codec.Header
    mu       sync.Mutex // protect following
    seq      uint64
    pending  map[uint64]*Call
    closing  bool // user has called Close
    shutdown bool // server has told us to stop
}

var _ io.Closer = (*Client)(nil)

var ErrShutdown = errors.New("connection is shut down")

// Close the connection 
func (client *Client) Close() error {
    client.mu.Lock()
    defer client.mu.Unlock()
    if client.closing {
        return ErrShutdown
    }
    client.closing = true
    return client.cc.Close()
}

// IsAvaliable return true if the client does work
func (client *Client) IsAvaliable() bool {
    client.mu.Lock()
    defer client.mu.Unlock()
    return !client.shutdown && !client.closing
}

client 的字段解析如下:

  • cc 是消息的编解码器,和服务端类似,用来序列化将要发送出去的请求,以及反序列化接收到的响应。
  • sending 是一个互斥锁,和服务端类似,为了保证请求的有序发送,即防止出现多个请求报文混淆。
  • header 是每个请求的消息头,header 只有在请求发送时才需要,而请求发送是互斥的,因此每个客户端只需要一个,声明在 Client 结构体中可以复用。
  • seq 用于给发送的请求编号,每个请求有唯一编号。
  • pending 存储未处理完的请求,键是编号,值是 Call 实例。
  • closing 和 shutdown 任意一个值置为 true,则表示 Client 处于不可用的状态,但有些许的差别,closing 是用户主动关闭的,即调用Close方法,而 shutdown 置为 true 一般是有错误发生。

紧接着,实现和 Call 相关的方法。

func (client *Client) registerCall(call *Call) (uint64, error) {
    client.mu.Lock()
    defer client.mu.Unlock()
    if client.closing || client.shutdown {
        return 0, ErrShutdown
    }
    call.Seq = client.seq
    client.pending[call.Seq] = call
    client.seq++
    return call.Seq, nil
}

func (client *Client) removeCall(seq uint64) *Call {
    client.mu.Lock()
    defer client.mu.Unlock()
    call := client.pending[seq]
    delete(client.pending, seq)
    return all
}

func (client *Client) terminateCalls(err error) {
    client.sending.Lock()
    defer client.sending.Unlock()
    client.mu.Lock()
    defer client.mu.Unlock()
    client.shutdown = true
    for _, call := range client.pending {
        call.Error = err
        call.done()
    }
}
  • registerCall :将参数 call 添加到 client.pending 中,并更新 client.seq。
  • removeCall:根据seq,从 client.pending 中移除对应的 call,并返回。
  • terminateCalls:服务端或客户端发生错误时调用,将 shutdown 设置为 true,且将错误信息通知所有 pending 状态的 call。

对一个客户端来说,接收响应、发送请求是最重要的2个功能。那么首先实现接收功能,接收到的响应有三种情况:

  • call 不存在,可能是请求没有发送完整,或者因为其他原因被取消,但是服务端仍旧处理了。
  • call 存在,但服务端处理出错,即 h.Error不为空。
  • call 存在,服务端处理正常,那么需要从 body 中读取 Reply 的值。
func (client *Client) receive() {
    var err error
    for err == nil {
        var h codec.Header
        if err = client.cc.ReadHeader(&h); err != nil {
            break
        }
        call := client.removeCall(h.Seq)
        switch {
        case call == nil:
            // it usually means that Write partially failed
            // and call was already removed
            arr := client.cc.ReadBody(nil)
        case h.Error != "":
            call.Error = fmt.Errorf(h.Error)
            err = client.cc.ReadBody(nil)
            call.done()
        default:
            err = client.cc.ReadBody(call.Reply)
            if err != nil {
                call.Error = errors.New("reading body " + err.Error())
            }
            call.done()
        }
    }
    // error occurs, so terminateCalls pending calls
    client.terminateCalls(err)
}

创建 Client 实例时,首先需要完成一开始的协议交换,即发送Option信息给服务端。协商好消息的编解码方式之后,再创建一个子协程receive()接收响应。

func NewClient(conn net.conn, opt *Option) (*Client, error) {
    f := codec.NewCodecFuncMap[opt.CodecType]
    if f == nil {
        err := fmt.Errorf("invalid codec type %s", opt.CodecType)
        log.Println("rpc client: options error: ", err)
        return nil, err
    }
    // send options with server
    if err := json.NewEncoder(conn).Encode(opt); err != nil {
        log.Println("rpc client: options error: ", err)
        _ = conn.Close()
        return nil, err
    }
    return newClientCodec(f(conn), opt), nil
}

func newClientCodec(cc codec.Codec, opt *Option) *Client {
    client := &Client {
        seq:     1, // seq starts with 1, 0 means invalid call
        cc:      cc,
        opt:     opt,
        pending: make(map[uint64]*Call)
    }
    go client.receive()
    return client
}

还需要实现Dial函数,便于用户传入服务端地址,创建 Client 实例。为了简化用户调用,通过...*Option将 Option 实现为可选参数。

func parseOptions(opts ...*Option) (*Option, error) {
    // if opts is nil or pass nil as parameter
    if len(opts) == 0 || opts[0] == nil {
        return DefaultOption, nil
    }
    if len(opts) != 1 {
        return nil, errors.New("number of options is more than 1")
    }
    opt := opts[0]
    opt.MagicNumber = DefaultOption.MagicNumber
    if opt.CodecType == "" {
        opt.CodecType = DefaultOption.CodecType
    }
    return opt, nil
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (client *Client, err error) {
    opt, err := parseOptions(opts...)
    if err != nil {
        return nil, err
    }
    conn, err := net.Dial(network, address)
    if err != nil {
        return nil, err
    }
    // close the connection if client is nil
    defer func() {
        if client == nil {
            _ = conn.Close()
        }
    }()
    return NewClient(conn, opt)
}

此时,GeeRPC 客户端已经具备了完整的创建连接和接受响应的能力了,最后还需要实现发送请求的能力。

func (client *Client) send(call *Call) {
    // make sure that the client will send a complete request
    client.sending.Lock()
    defer client.sending.Unlock()
    
    // register this call
    seq, err := client.registerCall(call)
    if err != nil {
        call.Error() = err
        call.done()
        return
    }
    
    // prepare request header
    client.header.ServiceMethod = call.ServiceMethod
    client.header.Seq = seq
    client.header.Error = ""
    
    // encode and send the request
    if err := client.cc.Write(&client.header, call.Args); err != nil {
        call := client.removeCall(seq)
        // call may be nil, it usually means that Write partially failed,
        // client has receive the response and handled
        if call != nil {
            call.Error = err
            call.done()
        }
    }
}

// Go invokes the function asynchronously
// It returns the Call structure representing the invocation
func (client *Client) Go(serviceMethod string, args, reply interface{}, done chan *Call) *Call {
    if done == nil {
        done = make(chan *Call, 10)
    } else if cap(done) == 0 {
        log.Panic("rpc client: done channel is unbuffered")
    }
    call := &Call {
        ServiceMethod: serviceMethod,
        Args:          args,
        Reply:         reply,
        Done:          done,
    }
    client.send(call)
    return call
}

// Call invokes the named function, waits for it to complete,
// and returns its error status
func (client *Client) Call(serviceMethod string, args, reply interface{}) error {
    call := <- client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
    return call.Error
}
  • GoCall是客户端暴露给用户的两个 RPC 服务调用接口,Go是一个异步接口,返回 call 实例。
  • Call是对Go的封装,阻塞 call.Done,等待响应返回,是一个同步接口。

至此,一个支持异步和并发的 GeeRPC 客户端已经完成。

补充

defer的运行机制为,在return之后,在函数退出之前执行。

func test() (ans int) {
    defer func() {
        fmt.Println(ans)
    }()
    return 10
}

func main() {
    test()
}

运行结果为:10。

Demo

第一天 GeeRPC 只实现了服务端,因此我们在 main 函数中手动模拟了整个通信过程,第二天中我们将 main 函数中的通信部分替换为客户端。

day2/main/main.go

startServer 没有发生变化。

func startServer(addr chan string) {
    // pick a free port
    l, err := net.Listen("tcp", ":0")
    if err != nil {
        log.Fatal("network error: ", err)
    }
    log.Println("start rpc server on", l.Addr())
    addr <- l.Addr().String()
    geerpc.Accept(l)
}

在 main 函数中使用了client.Call并发了5个 RPC 同步调用,参数和返回值类型均为 string。

func main() {
    log.SetFalgs(0)
    addr := make(chan string)
    go startServer(addr)
    client, _ = geerpc.Dial("tcp", <-addr)
    defer func() {
        _ = client.Close()
    }()
    
    time.Sleep(time.Second)
    // send request & receive response
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1) // 每一个任务开始时,将等待组增加1
        // 开启一个并发
        go func(i int) {
            defer wg.Done()
            args := fmt.Sprintf("geerpc req %d", i)
            var reply string
            if err := client.Call("Foo.Sum", args, &reply); err != nil {
                log.Fatal("call Foo.Sum error: ", err)
            }
            log.Println("reply", reply)
        }(i)
    }
    wg.Wait() // 等待所有任务完成
}

运行结果如下 (不唯一):

start rpc server on [::]:36013
&{Foo.Sum 5} geerpc req 3
&{Foo.Sum 1} geerpc req 4
&{Foo.Sum 2} geerpc req 1
&{Foo.Sum 3} geerpc req 0
&{Foo.Sum 4} geerpc req 2
reply: geerpc resp 4
reply: geerpc resp 5
reply: geerpc resp 1
reply: geerpc resp 2
reply: geerpc resp 3

当然也有这种情况

&{Foo.Sum 1 } geerpc req 4
&{Foo.Sum 3 } geerpc req 0
&{Foo.Sum 2 } geerpc req 1
reply: geerpc resp 3
reply: geerpc resp 1
reply: geerpc resp 2
&{Foo.Sum 5 } geerpc req 3
&{Foo.Sum 4 } geerpc req 2
reply: geerpc resp 5
reply: geerpc resp 4

对于以上执行结果,加以个人的理解,添加了若干个协程,并同步调用,其中会出现延迟开启并发的现象。

day3. 服务注册

  • 通过反射实现服务注册功能。

结构体映射为服务

RPC 框架的一个基本能力是:像调用本地程序一样调用远程服务。关于如何将程序映射为服务,对于 Go 来说,这个问题就变成了如何将结构体的方法映射为服务。

net/rpc而言,一个函数需要能够被远程调用,需要满足以下五个条件:

  • the method's type is exported. - 方法所属的类型是导出的。
  • the method is exported. - 方式是导出的。
  • the method has two arguments, both expoerted (or builtin) types. - 两个入参,均为导出 or 内置类型。
  • the method's second argument is a pointer. - 第二个入参必须是一个指针。
  • the method has return type error. - 返回值为 error 类型。

更直观一些:

func (t *T)  MethodName(argType T1, replyType *T2) error

假如客户端发来一个请求,包含 ServiceMethod 和 Argv。

{
    "ServiceMethod": "T.MethodName"
 	"Argv": "001010010100..." // 序列化之后的字节流
}

通过 T.MethodName可以确定调用的是类型 T 的MethodName,如果硬编码实现这个功能,很可能是这样:

switch req.ServiceMethod {
    case "T.MethodName":
        t := new(t)
        reply := new(T2)
        var argv T1
        gob.NewDecoder(conn).Decode(&argv)
        err := t.MethodName(argv, reply)
        server.sendMessage(reply, err)
    case "Foo.Sum":
        f := new(Foo)
    	...
}

也就是说,如果使用硬编码的方式来实现结构体与服务的映射,那么每暴露一个方法,就需要编写等量的代码。那么有没有什么方法,能够将这个映射过程自动化呢?可以借助反射。

通过反射,我们能够很容易获取某个结构体的所有方法,并且能通过所有方法,获取到该方法的所有参数类型与返回值。例如:

func main() {
    var wg sync.WaitGroup
    typ := reflect.TypeOf(&wg)
    for i := 0; i < typ.NumMethod(); i++ {
        method := typ.Method(i)
        argv := make([]string, 0, method.Type.NumIn())
        returns := make([]string, 0, method.Type.NumOut())
        // j从1开始,第0个入参是wg自己
        for j := 1; j < method.Type.In(j); j++ {
            argv = append(argv, method.Type.In(j).Name())
        }
        for j := 0; j < method.Type.NumOut(); j++ {
            returns = append(returns, method.Type.Out(j).Name())
        }
        log.Printf("func (w *%s) %s(%s) %s",
           typ.Elem().Name(),
           method.Name,
           strings.Join(argv, ","),
           strings.Join(returns, ","))
    }
}

运行结果为:

func (w *WaitGroup) Add(int)
func (w *WaitGroup) Done()
func (w *WaitGroup) Wait()

通过反射实现 service

前两天我们完成了客户端和服务端,客户端相对来说功能是比较完整的,但是服务端的功能并不完整,仅仅将请求的 header 打印了出来,并没有真正地处理。那今天的主要目的是补全这部分功能。首先通过反射实现结构体与服务的映射关系,代码独立放置在service.go中。

day3/service.go

第一步,定义结构体 methodType:

type methodType struct {
    method    reflect.Method
    ArgType   reflect.Type
    ReplyType reflect.Type
    numCalls  uint64
}

func (m *methodType) NumCalls() uint64 {
    return atomic.LoadUint64(&m.numCalls)
}

func (m *methodType) newArgv() reflect.Value {
    var argv reflect.Value
    // arg may be a pointer type, or a value type
    if m.ArgType.Kind() == reflect.Ptr {
        argv = reflect.New(m.ArgType.Elem())
    } else {
        argv = reflect.New(m.ArgType).Elem()
    }
    return argv
}

func (m *methodType) newReplyv() reflect.Value {
    // reply must be a pointer type
    replyv := reflect.New(m.ReplyType.Elem())
    switch m.ReplyType.Elem().Kind() {
    case reflect.Map:
        replyv.Elem().Set(reflect.MakeMap(m.ReplyType.Elem()))
    case refelct.Slice:
        replyv.Elem().Set(reflect.MakeSlice(m.ReplyType.Elem(), 0, 0))
    }
    return replyv
}

每一个 methodType 实例包含了一个方法的完整信息。包括:

  • method:方法本身
  • ArgType:第一个参数的类型
  • ReplyType:第二个参数的类型
  • numCalls:后续统计方法调用次数时会用到

另外,我们还实现了2个方法newArgvnewReplyv,用于创建对应类型的实例。newArgv方法有一个小细节,指针类型和值类型创建实例的方法有细微区别。

第二部,定义结构体 service:

type service struct {
    name   string
    typ    reflect.Type
    rcvr   reflect.Value
    method map[string]*methodType
}

service 的定义也是非常简洁的,name 即映射的结构体的名称,比如T,比如WaitGroup;typ是结构体的类型;rcvr 即结构体的实例本身,保留 rcvr 是因为在调用时需要 rcvr 作为第0个参数;method 是 map 类型,储存映射的结构体的所有符合条件的方法。

接下来,完成构造函数newService,入参是任意需要映射为服务的结构体实例。

func newService(rcvr interface{}) *service {
    s := new(service)
    s.rcvr = reflect.ValueOf(rcvr)
    s.name = reflect.Indirect(s.rcvr).Type().Name()
    s.typ = reflect.TypeOf(rcvr)
    if !ast.IsExported(s.name) {
        log.Fatalf("rpc server: %s is not a valid service name", s.name)
    }
    s.registerMethods()
    return s
}

func (s *service) registerMethods() {
    s.method = make(map[string]*methodType)
    for i := 0; i < s.typ.NumMethod(); i++ {
        method := s.typ.Method(i)
        mType := method.Type
        if mType.NumIn() != 3 || mType.NumOut() != 1 {
            continue
        }
        if mType.Out(0) != reflect.TypeOf((*error)(nil)).Elem() {
            continue
        }
        argType, replyType := mType.In(1), mType.In(2)
        if !isExportedOrBuiltinType(argType) || !isExportedOrBuiltinType(replyType) {
            continue
        }
        s.method[method.Name] = &methodType {
            method:    method,
            ArgType:   argType,
            ReplyType: replyType,
        }
        log.Printf("rpc server: register %s.%s\n", s.name, method.Name)
    }
}

func isExportOrBuiltinType(t reflect.Type) bool {
    return ast.IsExported(t.Name()) || t.PkgPath() == ""
}

registerMethods过滤出了符合条件的方法:

  • 两个导出或内置类型的入参 (反射时为3个,第0个是自身,类似于 python 的 self,Java 中的this )
  • 返回值有且只有一个,类型为 error

最后,我们还需要实现call方法,即能够通过反射值调用方法。

func (s *service) call(m *methodType, argv, replyv reflect.Value) error {
    atomic.AddUint64(&m.numCalls, 1)
    f := m.method.Func
    returnValues := f.Call([]reflect.Value{s.rcvr, argv, replyv})
    if errInter := returnValues[0].Interface(); errInter != nil {
        return errInter.(error)
    }
    return nil
}

service 的测试用例

为了保证 service 实现的正确性,我们为 service.go 写了几个测试用例。

day3/service_test.go

定义结构体 Foo,实现2个方法,导出方法 Sum 和非导出方法 sum。

type Foo int

type Args struct { Num1, Num2 int}

func (f Foo) Sum(args Args, reply *int) error {
    *reply = args.Num1 + args.Num2
    return nil
}

// it's not a exported Method
func (f Foo) sum(args Args, reply *int) error {
    *reply = args.Num1 + args.Num2
    return nil
}
// 这里要注意,是两个不一样的函数,后面的测试中要注意写的函数名,会影响测试结果

func _assert(condition bool, msg string, v ...interface{}) {
    if !condition {
        panic(fmt.Sprintf("assertion failed: " + msg, v...))
    }
}

测试 newService 和 call 方法。

func TestNewService(t *testing.T) {
    var foo Foo
    s := newService(&foo)
    _assert(len(s.method) == 1, "wrong service Method, expect 1, but got %d", len(s.method))
    mType := s.method["Sum"]
    _assert(mType != nil, "wrong Method, Sum should't nil")
}

func TestMethodType_Call(t *testing.T) {
    var foo Foo
    s := newService(&foo)
    mType := s.method("Sum")
    
    argv := mType.newArgv()
    replyv := mType.newReplyv()
    argv.Set(reflect.ValueOf(Args{Num1: 1, Num2: 3}))
    err := s.call(mType, argv, replyv)
    _assert(err == nil && *replyv.Interface().(*int) == 4 && mType.NumCalls() == 1, "failed to call Foo.Sum")
}

这里的测试,卡了我大约2天了,开始一直没搞明白为什么注册的方法一直是 "Sum",而不是 "sum",然而,我一直在service.go里找,各种print打印相关信息,也还是找不出个所以然,其实我犯了个很低级的错误,service.go这一类是高度抽象的,一般不会有很具体的内容,问题只能出在service_test.go中,在无头绪找bug的第三天,我尝试改Sum函数,发现输出的内容变了,后面注意到导出和非导出函数,好吧,原来问题出在这,Sumsum都是 Foo 有的函数,在golang中,小写字段不可从包外访问,所以注册的是大写的Sum

集成到服务端

通过反射结构体已经映射为服务,但请求的处理还没有完成。从接收到请求到回复还差以下几个步骤:

  • 根据入参类型,将请求的 body 反序列化。
  • 调用service.call,完成方法调用。
  • 将 reply 序列化为字节流,构造响应报文,返回。

回到代码本身,补全之前在server.go中遗留的2个 TODO 任务readRequesthandleRequest即可。

在这之前,我们还需要为 Server 实现一个方法Register

day3/server.go

// Server represents an RPC Server
type Server struct {
    service sync.Map
}

// Register publishes in the server the set of methods 
func (server *Server) Register(rcvr interface{}) error {
    s := newService(rcvr)
    if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
        return errors.New("rpc: service already defined: ", + s.name)
    }
    return nil
}

// Register publishes the receiver's methods in the DefaultServer
func Register(rcvr interface{}) error {
    return DefaultServer.Register(rcvr)
}

配套实现findService方法,即通过ServiceMethod从 serviceMap 中找到对应的 service。

func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
    dot := strings.LastIndex(serviceMethod, ".")
    if dot < 0 {
        err := errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
        return
    }
    serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
    svci, ok := server.serviceMap.Load(serviceName)
    if !ok {
        err := errors.New("rpc server: can't find service " + serviceName)
        return
    }
    svc = svci.(*service)
    mtype = svc.method[methodName]
    if mtype == nil {
        err = errors.New("rpc server: can't find method " + methodName)
    }
    return
}

findService的实现看似比较繁琐,但是逻辑还是非常清晰的。因为ServiceMethod的构成是 "Service.Method",因此先将其分割成2部分,第一部分是 Service 的名称,第二部分即方法名。现在 serviceMap 中找到对应的 service 实例,再从 service 实例的 method 中,找到对应的 methodType。

准备工具已经就绪,我们首先补全 readRequest 方法:

// request stores all information of a call
type request struct {
    h            *codec.Header // header of request
    argv, replyv reflect.Value // argv and replyv of request
    mtype        *methodType
    svc          *service
}

func (server *Server) readRequest(cc codec.Codec) (*reqeust, error) {
    h, err := server.readRequest(cc)
    if err != nil {
        return nil, err
    }
    req := &request{h: h}
    req.svc, req.mtype, err = server.findService(h.ServiceMethod)
    if err != nil {
        return req, err
    }
    req.argv = req.mtype.newArgv()
    req.replyv = req.mtype.newReplyv()
    
    // make sure that argvi is a pointer, ReadBody need a pointer as parameter
    argvi := req.argv.Interface()
    if req.argv.Type().Kind() != reflect.Ptr {
        argvi = req.argv.Addr().Interface()
    }
    if err = cc.ReadBody(argvi); err != nil {
        log.Println("rpc server: read body err: ", err)
        return req, err
    }
    return req, nil
}

readRequest 方法中最重要的部分,即通过newArgv()newReplyv()两个方法创建出两个入参实例,然后通过cc.ReadBody()将请求报文反序列化为第一个入参 argv,在这里同样要注意 argv 可能是值类型,也可能是指针类型,所以处理方式有点差异。

接下来补全 handleRequest 方法:

func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup) {
    defer wg.Done()
    err := req.svc.call(req.mtype, req.argv, req.replyv)
    if err != nil {
        req.h.Error = err.Error()
        server.sendResponse(cc, req.h, invalidRequest, sending)
        return
    }
    server.sendResponse(cc, req.h, replyv.Interface(), sending)
}

相对于 readRequest,handleRequest 的实现非常简单,通过req.svc.call完成方法调用,将 replyv 传递给 sendResponse 完成序列化即可。

到这里,今天所有内容已实现完成,成功在服务端实现了服务注册与调用。

Demo

最后,修改下 main 验证成果。

day3/main/main.go

第一步,定义结构体 Foo 和方法 Sum。

package main

import (
    "geerpc"
    "log"
    "net"
    "sync"
    "time"
)

type Foo int

type Args struct{ Num1, Num2 int }

func (f Foo) Sum(args Args, reply *int) error {
    *reply = args.Num1 + args.Num2
    return nil
}

第二步,注册 Foo 到 Server 中,并启动 RPC 服务。

func startServer(addr chan string) {
    var foo Foo
    if err := geerpc.Register(&foo); err != nil {
        log.Fatal("register error: ", err)
    }
    //pick a free port
    l, err := net.Listen("tcp", ":0")
    if err != nil {
        log.Fatal("network error: ", err)
    }
    log.Println("start rpc server on", l.Addr())
    addr <- l.Addr().String()
    geerpc.Accept(l)
}

第三步,构造参数,发送 RPC 请求,并打印结果。

func main() {
    log.SetFlags(0)
    addr := make(chan string)
    go startServer(addr)
    client, _ := geerpc.Dial("tcp", <-addr)
    defer func() {
        _ = client.Close()
    }()
    
    time.Sleep(time.Second)
    // send request & receive response
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            args := &Args{Num1: i, Num2: i * i}
            var reply int
            if err := client.Call("Foo.Sum", args, &reply); err != nil {
                log.Fatal("call Foo.Sum error: ", err)
            }
            log.Printf("%d + %d = %d", args.Num1, args.Num2, reply)
        }(i)
    }
    wg.Wait()
}

运行结果如下:

rpc server: register Foo.Sum
start rpc server on [::]:57509
0 + 0 = 0
2 + 4 = 6
4 + 16 = 20
3 + 9 = 12
1 + 1 = 2

day4. 超时处理

为什么要超时处理机制

超时处理是 RPC 框架一个比较基本的能力,如果缺少超时处理机制,无论是服务端还是客户端都容易因为网络或其他错误导致挂死,资源耗尽,这些问题的出现大大降低了服务的可用性。因此,我们需要在 RPC 框架中加入超时处理的能力。

纵观整个远程调用的过程,需要客户端处理超时的地方有:

  • 与服务端建立连接,导致的超时。
  • 发送请求到服务端,写报文导致的超时。
  • 等待服务端处理时,等待处理导致的潮实 (比如服务端已挂死,迟迟不响应)
  • 从服务端接收响应时,读报文导致的超时。

需要服务端处理超时的地方有:

  • 读取客户端请求报文时,读报文导致的超时。
  • 发送响应报文时,写报文导致的超时。
  • 调用映射服务的方法时,处理报文导致的超时。

GeeRPC 在3个地方添加了超时处理机制。分别是:

  • 客户端创建连接时。
  • 客户端Client.Call()整个过程导致的超时 (包含发送报文,等待处理,接收报文所有阶段)。
  • 服务端处理报文,即Server.handleRequest超时。

创建连接超时

为了实现上的简单,将超时设定放在了 Option 中。ConnectTimeout的默认值为 10s,HandleTimeout默认值为0,即不设限。

day4/server.go

type Option struct {
    MagicNumber    int // MagicNumber marks this's a geerpc request
    CodecType      codec.Type // client may choose different Codec to encode body
    ConnectTimeout time.Duration // 0 means no limit
    HandleTimeout  time.Duration
}

var DefaultOption = &Option {
    MagicNumber:    MagicNumber,
    CodecType:      codec.GobType,
    ConnectTimeout: time.Second * 10,
}

客户端连接超时,只需要为 Dial 添加一层超时处理的外壳即可。

day4/client.go

type clientResult struct {
    client *Client
    err    error
}

type newClientFunc func(conn net.Conn, opt *Option) (client *Client, err error) {
    opt, err := parseOptions(opts...)
    if err != nil {
        return nil, err
    }
    conn, err := net.DialTimeout(network, address, opt.ConnectTimeout)
    if err != nil {
        return nil, err
    }
    // close the connection if client is nil
    defer func() {
        if err != nil {
            _ = conn.Close()
        }
    }()
    ch := make(chan clientResult)
    go func() {
        client, err := f(conn, opt)
        ch <- clientResult{client: client, err: err}
    }()
    select {
        case <-time.After(opt.ConnectTimeout):
            return nil, fmt.Errorf("rpc client: connect timeout: expect within %s", opt.ConnectTimeout)
        case result := <-ch:
            return result.client, result.err
    }
}

// Dial connects to an RPC server at the specified network address
func Dial(network, address string, opts ...*Option) (*Client, error) {
    return dialTimeout(NewClient, network, address, opts...)
}

在这里实现了一个超时处理的外壳dialTimeout,这个壳将NewClient作为入参,在2个地方添加了超时处理的机制。

  1. net.Dial替换为net.DialTimeout,如果连接创建超时,将返回错误。
  2. 使用子协程执行NewClient,执行完成后则通过信道 ch 发送结果,如果time.After()信道先接收到消息,则说明NewClient执行超时,返回错误。

Client.Call 超时

Client.Call的超时处理机制,使用 context 包实现,控制权交给用户,控制更为灵活。

// Call invokes the named function, waits for it to complete,
// and returns its error status.
func (client *Client) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    call := client.Go(serviceMethod, args, reply, make(chan *Call, 1))
    select {
    case <-ctx.Done():
        client.removeCall(call.Seq)
        return errors.New("rpc client: call failed: " + ctx.Err().Error())
    case call := <-call.Done:
        return call.Error
    }
}

用户可以使用context.WithTimeout创建具备超时检测能力的 context 对象来控制,例如:

ctx, _ := context.WithTimeout(context.Background(), time.Second)
var reply int
err := client.Call(ctx, "Foo.Sum", &Args{1, 2}, &reply)
...

服务端处理超时

这一部分的实现与客户端很接近,使用time.After()结合select + chan完成。

day4/server.go

func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
    defer wg.Done()
    called := make(chan struct{})
    sent := make(chan struct{})
    go func() {
        err := req.svc.call(req.mtype, req.argv, req.replyv)
        called <- struct{}{}
        if err != nil {
            req.h.Error = err.Error()
            server.sendResponse(cc, req.h, invalidRequest, sending)
            sent <- struct{}{}
            return 
        }
    }()
    
    if timeout == 0 {
        <-called
        <-sent
        // 从信道获取值,忽略结果(类似于pop())
        return
    }
    select {
    case <-time.After(timeout):
        req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
        server.sendResponse(cc, req.h, invalidRequest, sending)
    case <-called:
        <-sent
    }
}

这里需要确保sendResponse仅调用一次,因此将整个过程拆分为calledsent两个阶段,在这段代码中只会发生如下两种情况:

  • called 信道接收到消息,代表处理没有超时,继续执行 sendResponse
  • time.After()先于 called 接收到消息,说明处理已经超时,called 和 sent 都将被阻塞。在case <-time.After(timeout)处调用sendResponse

测试用例

day4/client_test.go

func TestClient_dialTimeout(t *testing.T) {
    t.Parallel()
    l, _ := net.Listen("tcp", ":0")
    
    f := func(conn net.Conn, opt *Option) (client *Client, err error) {
        _ = conn.Close()
        time.Sleep(time.Second * 2)
        return nil, nil
    }
    t.Run("timeout", func(t *testing.T) {
        _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: time.Second})
        _assert(err != nil && strings.Contains(err.Error(), "connect timeout"), "expect a timeout error")
    })
    t.Run("0", func(t *testing.T) {
        _, err := dialTimeout(f, "tcp", l.Addr().String(), &Option{ConnectTimeout: 0})
        _assert(err == nil, "0 means no limit")
    }) 
}

第二个测试用例,用于测试处理超时。Bar.Timeout耗时2s,场景一:客户端设置超时时间为1s,服务端无限制;场景二,服务端设置超时时间为1s,客户端无限制。

type Bar int

func (b Bar) Timeout(argv int, reply *int) error {
    time.Sleep(time.Second * 2)
    return nil
}

func startServer(addr chan string) {
    var b Bar
    _ = Register(&b)
    // pick a free port
    l, _ := net.Listen("tcp", ":0")
    addr <- l.Addr().String()
    Accept(l)
}

func TestClient_Call(t *testing.T) {
    t.Parallel()
    addrChh := make(chan string)
    go startServer(addrCh)
    addr := <-addrCh
    time.Sleep(time.Second)
    t.Run("client timeout", func(t *testing.T) {
        client, _ := Dial("tcp", addr)
        ctx, _ := context.WithTimeout(context.Background(), time.Second)
        var reply int
        err := client.Call(ctx, "Bar.Timeout", 1, &reply)
        _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error")
    })
    t.Run("server handle timeout", func(t *testing.T) {
        client, _ := Dial("tcp", addr, &Option{
            HandleTimeout: time.Second,
        })
        var reply int
        err := client.Call(context.Background(), "Bar.Timeout", 1, &reply)
        _assert(err != nil && strings.Contains(err.Error(), ctx.Err().Error()), "expect a timeout error")
    })    
}

day5. 支持HTTP协议

  • 支持 HTTP 协议
  • 基于 HTTP 实现一个简单的 Debug 页面,代码约 150 行。

支持 HTTP 协议需要做什么?

Web 开发中,我们经常使用 HTTP 协议中的 HEAD、GET、POST 等方式发送请求,等待响应。但 RPC 的消息格式与标准的 HTTP 协议并不兼容,在这种情况下,就需要一个协议的转换过程。HTTP 协议的 CONNECT 方法恰好提供了这个能力,CONNECT 一般用于代理服务。

假设浏览器与服务器之间的 HTTPS 通信都是加密的,浏览器通过代理服务器发起 HTTPS 请求时,由于请求的站点地址和端口号都是加密保存在 HTTPS 请求报文头中的,代理服务器如何直到往哪里发送请求呢?为了解决这个问题,浏览器通过 HTTP 明文形式向代理服务器发送一个 CONNECT 请求告诉代理服务器目标地址和端口,代理服务器接收到这个请求后,会在对应端口和目标站点建立一个 TCP 连接,连接建立成功后返回 HTTP 200 状态码告诉浏览器与该站点的加密通道已经完成。接下来代理服务器仅需透传浏览器和服务器之间的加密数据包即可,代理服务器无需解析 HTTPS 报文。

举一个简单的例子:

  1. 浏览器向代理服务器发送 CONNECT 请求。
CONNECT jaydenchang.top:443 HTTP/1.0 
  1. 代理服务器返回 HTTP 200 状态码表示连接已经建立。
HTTP/1.0 200 Connection Established
  1. 之后浏览器和服务器开始 HTTPS 握手并交换加密数据,代理服务器只负责传输彼此的数据包,并不能读取具体数据内容 (代理服务器也可以选择安装可信根证书解密 HTTPS 报文)。

事实上,这个过程其实是通过代理服务器将 HTTP 协议转换为 HTTPS 协议的过程。对 RPC 服务端来说,需要做的事是将 HTTP 协议转换为 RPC 协议,对客户端来说,需要新增通过 HTTP CONNECT 请求创建连接的逻辑。

服务端支持 HTTP 协议

那通信过程应该是这样的:

  1. 客户端向 RPC 服务器发送 CONNECT 请求
CONNECT 10.0.0.1:9999/geerpc HTTP/1.0 
  1. RPC 服务器返回 HTTP 200 状态码表示连接建立。
HTTP/1.0 200 Connected to Gee RPC
  1. 客户端使用创建好的连接发送 RPC 报文,先发送 Option,再发送 N 个请求报文,服务端处理 RPC 请求并响应。

server.go中新增如下的方法:

day5/server.go

const (
    connected        = "200 Connected to Gee RPC"
    defaultRPCPath   = "/geerpc"
    defaultDebugPath = "/debug/geerpc"
)

// ServerHTTP implements an http.Handler that answer RPC requests
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    if req.Method != "CONNECT" {
        w.Header().Set("Content-Type", "text/plain; charset=utf-8")
        w.WriteHeader(http.StatusMethodNotAllowed)
        _, _ = io.WriteString(w, "405 must CONNECT\n")
        return
    }
    conn, _, err := w.(http.Hijacker).Hijack()
    if err != nil {
        log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
        return
    }
    _, _ = io.WriteString(conn, "HTTP/1.0 " + connected + "\n\n")
    server.ServeConn(conn)
}

// HandleHTTP registers an HTTP handler for RPC messages on rpcPath
// It is still necessary to invoke http.Serve(), typically in a go statement
func (server *Server) HandleHTTP() {
    http.Handle(defaultRPCPath, server)
}

// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
    DefaultServer.HandleHTTP()
}

defaultDebugPath是后续 DEBUG 页面预留的地址。

在 GO 中处理 HTTP 请求是非常简单的一件事,Go 标准库中http.Handle的实现如下:

package http
// Handle registers the handler for the given pattern.
// in the DefaultServeMux.
// The documentation for ServeMux explains how patterns are matched.
func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }

第一个参数是支持统配的字符串 pattern,在这里,我们固定传入/geerpc,第二个参数是 Handler 类型,Handler 是一个接口类型,定义如下:

type Handler interface {
    ServeHTTP(w ResponseWriter, r *Request)
}

也就是说,只需要实现接口 Handler 即可作为一个 HTTP Handler 处理 HTTP 请求。接口 Handler 只定义了一个方法ServeHTTP,实现该方法即可。

客户端支持 HTTP 协议

服务端已经能够接受 CONNECT 请求,并返回了 200 状态码HTTP/1.0 200 Connected to Gee RPC,客户端要做的,发起 CONNECT 请求,检查返回状态码即可成功建立连接。

day5/client.go

// NewHTTPClient new a Client instance via HTTP as transport protocol
func NewHTTPClient(conn net.Conn, opt *Option) (*Client, err) {
    _, _ = io.WriteString(conn, fmt.Sprintf("CONNECT %s HTTP/1.0\n\n", defaultRPCPath))
    
    // Require successful HTTP reesponse
    // before switching to RPC protocol
    resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
    if err == nil && resp.Status == connected {
        return NewClient(conn, opt)
    }
    if err == nil {
        err = errors.New("unexpected HTTP response: " + resp.Status)
    }
    return nil, err
}

// DialHTTP connectd to an HTTP RPC server at the specified network address 
// listening on the default HTTP RPC path.
func DialHTTP(network, address string, opts ...*Option) (*Client, error) {
    return dialTimeout(NewHTTPClient, network, address, opts...)
}

通过 HTTP CONNECT 请求建立连接后,后续的通信过程就交给 NewClient 了。

为了简化调用,提供了一个统一入口XDial

// XDial calls different functions to connect to a RPC server
// according the first parameter rpcAddr.
// rpcAddr is a general format (protocol@addr) to represent a rpc server
// eg, [email protected]:7890, [email protected]:9999, unix@/tmp/geerpc.sock
func XDial(rpcAddr string, opts ...*Option) (*Client, error) {
    parts := strings.Split(rpcAddr, "@")
    if len(parts) != 2 {
        return nil, fmt.Errorf("rpc client err: wrong format '%s', expect protocol@addr", rpcAddr)
    }
    protocol, addr := parts[0]. parts[1];
    switch protocol {
    case "http":
        return DialHTTP("tcp", addr, opts...)
    default:
        // tcp, unix or other transport protocol
        return Dial(protocol, addr, opts...)
    }
}

添加一个测试用例试一试,这个测试用例使用了 unix 协议创建 socket 连接,适用于本机内部的通信,使用上和 TCP 协议无区别。

day5/client_test.go

func TestXDial(t *testing.T) {
    if runtime.GOOS == "linux" {
        ch := make(chan struct{})
        addr := "/tmp/geerpc.sock"
        go func() {
            _ = os.Remove(addr)
            l, err := net.Listen("unix", addr)
            if err != nil {
                t.Fatal("failed to listen unix socket")
            }
            ch <- struct{}{}
            Accept(l)
        }()
        <-ch
        _, err := XDial("unix@" + addr)
        _assert(err == nil, "failed to connect unix socket")
    }
}

实现简单的 DEBUG 页面

支持 HTTP 协议的好处在于,RPC 服务仅仅使用了监听端口的/geerpc路径,在其他路径上我们可以提供诸如日志,统计等更为丰富的功能。接下来我们在/debug/geerpc上展示服务的调用统计视图。

day5/debug.go

package geerpc

import (
    "fmt"
    "html/template"
    "net/http"
)

const debugText = `<html>
	<body>
	<title>GeeRPC Services</title>
	{{range .}}
	<hr>
	Service {{.Name}}
	<hr>
		<table>
		<th align=center>Method</th><th align=center>Calls</th>
		{{range $name, $mtype := .Method}}
			<tr>
			<td align=left font=fixed>{{$name}}({{$mtype.ArgType}}, {{$mtype.ReplyType}}) error</td>
			<td align=center>{{$mtype.NumCalls}}</td>
			</tr>
		{{end}}
		</table>
	{{end}}
	</body>
	</html>`

var debug = template.Must(template.New("RPC debug").Parse(debugText))

type debugHTTP struct {
    *Server
}

type debugService struct {
    Name   string
    Method map[string]*methodType
}

// Runs at /debug/geerpc
func (server debugHTTP) ServerHTTP(w http.ResponseWriter, req *http.Request) {
    // build a sorted version of the data
    var services []debugService
    server.serviceMap.Range(func(namei, svci interface{}) bool {
        svc := svci.(*service)
        services = append(services, debugService{
            Name:   namei.(string),
            Method: svc.method,
        })
        return true
    })
    err := debug.Execute(w, services)
    if err != nil {
        _, _ = fmt.Fprintln(w, "rpc: error executing template:", err.Error())
    }
}

在这里,我们将返回一个 HTML 报文,这个报文将展示注册所有的 service 的每一个方法的调用情况。

将 debugHTTP 实例绑定的地址/debug/geerpc

func (server *Server) HandleHTTP() {
    http.Handle(defaultRPCPath, server)
    http.Handle(defaultDebugPath, debugHTTP{server})
    log.Println("rpc server debug path:", defaultDebugPath)
}

Demo

到此,我们已经迫不及待地想看看最终的效果了。

day5/main/main.go

和之前的例子相比较,将 startServer 中的geerpc.Accept()替换为了geerpc.HandleHTTP(),端口固定为 9999。

type Foo int

type Args struct { Num1, Num2 int }

func (f Foo) Sum(args Args, reply *int) error {
    *reply = args.Num1 + args.Num2
    return nil
}

func startServer(addrCh chan string) {
    var foo Foo
    l, _ := net.Listen("tcp", "9999")
    _ = geerpc.Register(&foo)
    geerpc.HandleHTTP()
    addrCh <- l.Addr().String()
    _ = http.Serve(l, nil)
}

客户端将Dial替换为DialHTTP,其余地方没有发生改变。

func call(addrCh chan string) {
    client, _ := geerpc.DialHTTP("tcp", <-addrCh)
    defer func() { _ = client.Close() }()
    
    time.Sleep(time.Second)
    // send a request & receive response
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            args := &Args{Num1: i, Num2: i * i}
            var reply int
            if err := client.Call(context.Background(), "Foo.Sum", args, &reply);err != nil {
                log.Fatal("call Foo.Sum error:", err)
            }
            log.Fatal("%d + %d = %d", args.Num1, args.Num2, reply)
        }(i)
    }
    wg.Wait()
}

func main() {
    log.SetFlags(0)
    ch := make(chan string)
    go call(ch)
    startServer(ch)
}

main 函数中,我们在最后调用startServer,服务启动后将一直等待。

运行结果如下:

main$ go run.
rpc server: register Foo.Sum
rpc server debug path: /debug/geerpc
4 + 16 = 20
3 + 9 = 12
0 + 0 = 0
2 + 4 = 6
1 + 1 = 2

服务已经启动,此时我们如果在浏览器中访问 localhost:9999/debug/geerpc,将会看到:

day6. 负载均衡

  • 通过随机选择和 Round Robin 轮询调度算法实现服务端负载均衡,约 250 行代码。

负载均衡策略

假设有多个服务实例,每个实例提供相同的功能,为了提高整个系统的吞吐量,每个实例部署在不同的机器上。客户端可以选择任意一个实例进行调用,获取想要的结果。那如何选择呢?取决了负载均衡的策略。对于 RPC 框架来说,我们可以很容易地想到这么几种策略:

  • 随机选择策略 - 从服务列表中随机选择一个。
  • 轮询算法 (Round Robin) - 依次调度不同的服务器,每次调度执行 i = (i + 1) mode n。
  • 加权轮询 (Weight Round Robin) - 在轮询算法的基础上,为每个服务实例设置一个权重,高性能的机器赋予更高的权重,也可以根据服务实例的当前的负载情况做动态的调整,例如考虑最近 5 分钟部署服务器的 CPU 、内存消耗情况。
  • 哈希 / 一致性哈希策略 - 依据请求的某些特征,计算一个 hash 值,根据 hash 值将请求发送到对应的机器,一致性 hash 还可以解决服务实例动态添加情况下,调度抖动的问题。一致性哈希的一个典型应用场景是分布式缓存服务。

服务发现

负载均衡的前提是有多个服务实例,那我们首先实现一个最基础的服务发现模块 Discovery。为了与通信部分解耦,这部分的代码统一放置在 xclient 子目录下。

定义 2 个类型:

  • SelectMode 代表不同的负载均衡策略,简单起见,GeeRPC 仅实现 Random 和 RoundRobin 两种策略。
  • Discovery 是一个接口类型,包含了服务发现所需要的最基本的接口。
    • Refresh()从注册中心更新服务列表。
    • Update(servers []string)手动更新服务列表。
    • Get(mode SelectMode)根据负载均衡策略,选择一个服务实例。
    • GetAll()返回所有的服务实例。

day6/xclient/discovery.go

package xclient

import (
    "errors"
    "math"
    "math/rand"
    "sync"
    "time"
)

type SelectMode int

const (
    RandomSelect SelectMode = iota // select randomly
    RoundRobinSelect               // select using Robbin algorithm
)

type Discovery interface {
    Refresh() error // refresh from remote registry
    Update(servers []string) error
    Get(mode SelectMode) (string, error)
    GetAll() ([]string, error)
}

紧接着,我们实现一个不需要注册中心,服务列表由手工维护的服务发现的结构体:MultiServersDiscovery

type MultiServersDiscovery struct {
    r       *rand.Rand   // generate random number
    mu      sync.RWMutex // protect following
    servers []string 
    index   int          // record the selected position for robin algorithm
}

// NewMultiServerDiscovery creates a MultiServersDiscovery instance
func NewMultiServerDiscovery(servers []string) *MultiServersDiscovery {
    d := &MultiServersDiscovery {
        servers: servers,
        r:       rand.New(rand.NewSource(time.Now().UnixNano())),
    }
    d.index = d.r.Intn(math.MaxInt32 - 1)
    return d
}
  • r 是一个产生随机数的实例,初始化时使用时间戳设定随机数种子,避免每次产生相同的随机数序列。
  • index 记录 Round Robin 算法已经轮询到的位置,为了避免每次从 0 开始,初始化时随机设定一个值。

然后,实现 Discovery 接口

var _ Discovery = (*MultiServersDiscovery)(nil)

// Refresh doesn't make sense for MultiServersDiscovery, so ignore it
func (d *MultiServersDiscovery) Refresh() error {
    return nil
}

// Update the servers of discovery dynamically if needed
func (d *MultiServersDiscovery) Update(servers []string) error {
    d.mu.Lock()
    defer d.mu.Unlock()
    d.servers = servers
    return nil
}

// Get a server according to mode
func (d *MultiServersDiscovery) Get(mode SelectMode) (string, error) {
    d.mu.Lock()
    defer d.mu.Unlock()
    n := len(d.servers)
    if n == 0 {
        return "", errors.New("rpc discovery: no available servers")
    }
    switch mode {
    case RandomSelect:
        return d.servers[d.r.Intn(n)], nil
    case RoundRobinSelect:
        s := d.servers[d.index % n] // servers could be updated, so mode n to ensure safety
        d.index = (d.index + 1) % n
        return s, nil
    default:
        return "", errors.New("rpc discovery: not supported select mode")
    }
}

// returns all servers in discovery
func (d *MultiServersDiscovery) GetAll() ([]string, error) {
    d.mu.RLock()
    defer d.mu.RUnlock()
    // return a copy of d.servers
    servers := make([]string, len(d.servers), len(d.servers))
    copy(servers, d.servers)
    return servers, nil
}

支持负载均衡的客户端

接下来,我们向用户暴露一个支持负载均衡的客户端的 XClient。

day6/xclient/xclient.go

package xclient

import (
    "context"
    . "geerpc"
    "io"
    "reflect"
    "sync"
)

type XClient struct {
    d       Discovery
    mode    SelectMode
    opt     *Option
    mu      sync.Mutex // protect following
    clients map[string]*Client
}

var _ io.Closer = (*XClient)(nil)

func NewXClient(d Discovery, mode SelectMode, opt *Option) *XClient {
    return &XClient{d: d, mode: mode, opt: opt, clients make(map[string]*Client)}
}

func (xc *XClient) Close() error {
    xc.mu.Lock()
    defer xc.mu.Unlock()
    for key, client := range xc.clients {
        // I hava no idea how to deal with error, just ignore it.
        _ = client.Close()
        delete(xc.clients, key)
    }
    return nil
}

XClient 的构造函数需要传入三个参数,服务发现实例 Discovery、负载均衡模式 SelectMode 以及协议选项 Option。为了尽量地复用已经创建好的 Socket 连接,使用 clients 保存创建成功的 Client 实例,并提供 Close 方法在结束后,关闭已经创建的连接。

接下来,实现客户端最基本的功能Call

func (xc *Client) dial(rpcAddr string) (*Client, error) {
    xc.mu.Lock()
    defer xc.mu.Unlock()
    client, ok := xc.clients[rpcAddr]
    if ok && !client.IsAvailable() {
        _ = client.Close()
        delete(xc.clients, rpcAddr)
        client = nil
    }
    if client == nil {
        var err error
        client, err = XDial(rpcAddr, xc.opt)
        if err != nil {
            return nil, err
        }
        xc.clients[rpcAddr] = client
    }
    return client, nil
}

func (xc *XClient) call(rpcAddr string, ctx context.Context, serviceMethod string, args, reply interface{}) error {
    client, err := xc.dial(rpcAddr)
    if err != nil {
        return err
    }
    return client.Call(ctx, serviceMethod, args, reply)
}

// Call invokes the named function, waits for it to complete,
// and returns its error status.
// xc will choose a proper server.
func (xc *XClient) Call(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    rpcAddr, err := xc.d.Get(xc.mode)
    if err != nil {
        return err
    }
    return xc.call(rpcAddr, ctx, serviceMethod, args, reply)
}

我们将复用 Client 的能力封装在方法dial中,dial的处理逻辑如下:

  1. 检查xc.clients是否有缓存的 Client,如果有,检查是否时可用状态,如果是,则返回缓存的 Client,如果不可用,则从缓存中删除。
  2. 如果步骤 1 没有返回缓存的 Client,则说明需要创建新的 Client,缓存并返回。

另外,我们为 XClient 添加一个常用功能:Broadcast

// Broadcast invokes the named function for every server registered in discovery
func (xc *XClient) Broadcast(ctx context.Context, serviceMethod string, args, reply interface{}) error {
    servers, err := xc.d.GetAll()
    if err != nil {
        return err
    }
    var wg sync.WaitGroup
    var mu sync.Mutex // protect e and replyDone
    var e error
    replyDone := reply == nil // if reply is nil, don't need to set value
    ctx, cancel := context.WithCancel(ctx)
    for _, rpcAddr := range servers {
        wg.Add(1)
        go func(rpcAddr string) {
            defer wg.Done()
            var clonedReply interface{}
            if reply != nil {
                clonedReply = reflect.New(reflect.ValueOf(reply).Elem().Type()).Interface()                
            }
            err := xc.call(rpcAddr, ctx, serviceMethod, args, clonedReply)
            mu.Lock()
            if err != nil && e == nil {
                e = err
                cancel() // if any call failed, cancel unfinished calls
            }
            if err == nil && !replyDone {
                reflect.ValueOf(reply).Elem().Set(reflect.ValueOf(clonedReply).Elem())
                replyDone = true
            }
            mu.Unlock()
        }(rpcAddr)
    }
    wg.Wait()
    return e
}

Broadcast 将请求广播到所有的服务实例,如果任意一个实例发生错误,则返回其中一个错误;如果调用成功,则返回其中一个的结果。有以下几点需要注意:

  1. 为了提升性能,请求是并发的。
  2. 并发情况下需要使用互斥锁保证 error 和 reply 能被正确赋值。
  3. 借助context.WithCancel确保有错误发生时,快速失败。

Demo

首先,启动 RPC 服务的代码还是类似的,Sum 时正常的方法,Sleep 用于验证 XClient 的超时机制能否正常运作。

day6/main/main.go

package main

import (
    "context"
    "geerpc"
    "geerpc/xclient"
    "log"
    "net"
    "sync"
    "time"
)

type Foo int

type Args struct{ Num1, Num2 int }

func (f Foo) Sum(args Args, reply *int) error {
    time.Sleep(time.Second * time.Duration(args.Num1))
    *reply = args.Num1 + args.Num2
    return nil
}

func (f Foo) Sleep(args Args, reply *int) error {
    var foo Foo
    l, _ := net.Listen("tcp", ":0")
    server := geerpc.NewServer()
    // send request & receive response
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i})
        }(i)
    }
    wg.Wait()
}

func broadcast(addr1, addr2 string) {
    d := xclient.NewMultiServerDiscovery([]string{"tcp" + addr1, "tcp@" + addr2})
    xc := xclient.NewXClient(d, xclient.RandomSelect, nil)
    defer func() { _ = xc.Close() }()
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i})
            // expect 2 - 5 timeout
            ctx, _ := context.WithTimeout(context.Background(), time.Second * 2)
            foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i})
        }(i)
    }
    wg.Wait()
}

func main() {
    log.SetFlags(0)
    ch1 := make(chan string)
    ch2 := make(chan string)
    // start two servers
    go startServer(ch1)
    go startServer(ch2)
    
    addr1 := <-ch1
    addr@ := <-ch2
    
    time.Sleep(time.Second)
    call(addr1, addr2)
    broadcast(addr1, addr2)
}

运行结果如下

*main.Foo    Sleep
rpc server: register Foo.Sleep
*main.Foo    Sum
rpc server: register Foo.Sum
*main.Foo    Sleep
rpc server: register Foo.Sleep
*main.Foo    Sum
rpc server: register Foo.Sum
call Foo.Sum success: 3 + 9 = 12
call Foo.Sum success: 4 + 16 = 20
call Foo.Sum success: 2 + 4 = 6
call Foo.Sum success: 0 + 0 = 0
call Foo.Sum success: 1 + 1 = 2
broadcast Foo.Sum success: 4 + 16 = 20
broadcast Foo.Sum success: 2 + 4 = 6
broadcast Foo.Sum success: 1 + 1 = 2
broadcast Foo.Sum success: 0 + 0 = 0
broadcast Foo.Sum success: 3 + 9 = 12
broadcast Foo.Sleep success: 0 + 0 = 0
broadcast Foo.Sleep success: 1 + 1 = 2
broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded
broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded
broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded

day7. 服务发现与注册中心

// header 一定要一致,不然一直错

  • 实现一个简单的注册中心,支持服务注册,接收心跳等功能。
  • 客户端实现基于注册中心的服务发现机制。

注册中心的位置

注册中心的位置如上图所示。注册中心的好处在于,客户端和服务端都只需要感知注册中心的存在,而无需感知对方的存在。更具体一点:

  1. 服务端启动后,向注册中心发送注册信息,注册中心得知该服务已经启动,处于可用状态。一般来说,服务端还需要定期向注册中心发送心跳,证明自己还活着。
  2. 客户端向注册中心询问,当前哪天服务是可用的,注册中心将可用的服务列表返回客户端。
  3. 客户端根据注册中心得到的服务列表,选择其中一个发起调用。

如果没有注册中心,就像 GeeRPC 第六天实现的一样,客户端需要硬编码服务端的地址,而且没有机制保证服务端是否处于可用状态。当然注册中心的功能还有很多,比如配置的动态同步,通知机制等。比较常用的注册中心有 etcd、zookeeper、consul,一般比较出名的微服务或者 RPC 框架,这些主流的注册中心都是支持的。

Gee Registry

主流的注册中心 etcd、zookeeper 等功能强大,与这类注册中心的对接代码量是比较大的,需要实现的接口很多。GeeRPC 选择自己实现一个简单的支持心跳保活的注册中心。

GeeRegistry 的代码独立放置在子目录 registry 中。

首先定义 GeeRegistry 结构体,默认超时时间设置为 5 min,也就是说,任何注册的服务超过 5 min,即视为不可用状态。

day7/registry/registry.go

type GeeRegistry struct {
    timeout time.Duration
    mu      sync.Mutex // protect following
    servers map[string]*ServerItem
}

type ServerItem struct {
    Addr string
    start time.Time
}

const (
	defaultPath    = "/geerpc/registry"
    defaultTimeout = time.Minute * 5
)

// New create a registry instance with timeout setting
func New(timeout time.Duration) *GeeRegistry {
    return &GeeRegistry {
        servers: make(map[string]*ServerItem),
        timeout: timeout,
    }
}

var DefaultGeeRegister = New(defaultTimeout)

为 GeeRegistry 实现添加服务实例和返回服务列表的方法。

  • putServer:添加服务实例,如果服务已存在,则更新 start。
  • aliveServers:返回可用的服务列表,如果存在超时的服务,则删除。
func (r *GeeRegistry) putServer(addr string) {
    r.mu.Lock()
    defer r.mu.Unlock()
    s := r.servers[addr]
    if s == nil {
        r.servers[addr] = &ServerItem{Addr: addr, start: time.Now()}
    } else {
        s.start = time.Now() // if exists, update start time to keep alive
    }
}

func (r *GeeRegistry) aliveServers() []string {
    r.mu.Lock()
    defer r.mu.Unlock()
    var alive []string
    for addr, s := range r.servers {
        if r.timeout == 0 || s.start.Add(r.timeout).After(time.Now()) {
            alive = append(alive, addr)
        } else {
            delete(r.servers, addr)
        }
        sort.Strings(alive)
        return alive
    }
}

为了实现上的简单,GeeRegistry 采用 HTTP 协议提供服务,且所有的有用信息都承载在 HTTP Header 中。

  • Get:返回所有可用的服务列表,通过自定义字段 X-Geerpc-Servers 承载。
  • Post:添加服务实例或发送心跳,通过自定义字段 X-Geerpc-Server 承载。

这里要注意,Get 和 Post 各自使用的 header 一定要一样,不然就会出现rpc discovery: no available servers的错误

func (r *GeeRegistry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    switch req.Method {
    case "GET":
        // keep it simple, server is in req.Header
        w.Header().Set("X-Geerpc-Servers", strings.Join(r.aliveServers(), ","))
    case "POST":
        // keep it simple, server is in req.Header
        addr := req.Header.Get("X-Geerpc-Server")
        if addr == "" {
            w.WriteHeader(http.StatusInternalServerError)
            return
        }
        r.putServer(addr)
    default:
        w.WriteHeader(http.StatusInternalServerError)
    }
}

// HandleHTTP registers an HTTP handler for GeeRegistry messages on registryPath
func (r *GeeRegistry) HandleHTTP(registryPath string) {
    http.Handle(registryPath, r)
    log.Println("rpc registry path:", registryPath)
}

func HandleHTTP() {
    DefaultGeeRegister.HandleHTTP(defaultPath)
}

另外,提供 Heartbeat 方法,便于服务启动时定时向注册中心发送心跳,默认周期比注册中心设置的过期时间少 1 min。

// Heartbeat send a heartbeat message every once in a while
// it's a helper function for a server to register or send heartbeat
func Heartbeat(registry, addr string, duration time.Duration) {
    if duration == 0 {
        // make sure there is enough time to send heart beat
        // before it's removed from registry
        duration = defaultTimeout - time.Duration(1)*time.Minute
    }
    var err error
    err = sendHeartbeat(registry, addr)
    go func() {
        t := time.NewTicker(duration)
        for err == nil {
            <-t.C
            err = sendHeartbeat(registry, addr)
        }
    }()
}

func sendHeartbeat(registry, addr string) error {
    log.Println(addr, "send heart beat to registry", registry)
    httpClient := &http.Client{}
    req, _ := http.NewRequest("POST", registry, nil)
    req.Header.Set("X-Geerpc-Server", addr)
    if _, err := httpClient.Do(req); err != nil {
        log.Println("rpc server: heart beat err:", err)
        return err
    }
    return nil
}

GeeRegistryDiscovery

在 xclient 中对应实现 Discovery。

day7/xclient/discovery_gee.go

package xclient

type GeeRegistryDiscovery struct {
    *MultiServersDiscovery 
    registry   string
    timeout    time.Duration
    lastUpdate time.Time
}

const defaultUpdateTimeout = time.Second * 10

func NewGeeRegistryDiscovery(registerAddr string, timeout time.Duration) *GeeRegistryDiscovery {
    if timeout == 0 {
        timeout = defaultUpdateTimeout
    }
    d := &GeeRegistryDiscovery {
        MultiServerDiscovery: NewMultiServerDiscovery(make([]string, 0))
        registry:             registerAddr,
        timeout:              timeout,
    }
    return d
}
  • GeeRegistryDiscovery 嵌套了 MultiServersDiscovery,很多能力可以复用。
  • registry 即注册中心的地址。
  • timeout 服务列表的过期时间。
  • lastUpdate 是代表最后从注册中心更新服务列表的时间,默认 10s 过期,即 10s 之后,需要从注册中心更新新的列表。

实现 Update 和 Refresh 方法,超时重新获取的逻辑在 Refresh 中实现:

func (d *GeeRegistryDiscovery) Update(servers []string) error {
    d.mu.Lock()
    defer d.mu.Unlock()
    d.servers = servers
    d.lastUpdate = time.Now()
    return nil
}

func (d *GeeRegistryDiscovery) Refresh() error {
    d.mu.Lock()
    defer d.mu.Unlock()
    if d.lastUpdate.Add(d.timeout).After(time.Now()) {
        return nil
    }
    log.Println("rpc registry: refresh servers form registry", d.registry)
    resp, err := http.Get(d.registry)
    if err != nil {
        log.Println("rpc registry refresh err:", err)
        return err
    }
    servers := strings.Split(resp.Header.Get("X-Geerpc-Server", ","))
    d.server = make([]string, 0, len(servers))
    for _, server := range servers {
        if strings.TrimSpace(server) != "" {
            d.servers = append(d.servers, strings.TrimSpace(server))
        }
    }
    d.lastUpdate = time.Now()
    return nil
}

GetGetAllMultiServersDiscovery相似,唯一不同的在于,GeeRegistryDiscovery需要先调用 Refresh 确保服务列表没有过期。

func (d *GeeRegistryDiscovery) Get(mode SelectMode) (string, error) {
    if err := d.Refresh(); err != nil {
        return "", err
    }
    return d.MultiServersDiscovery.Get(mode)
}

func (d *GeeRegistryDiscovery) GetAll() ([]string, error) {
    if err := d.Refresh(); err != nil {
        return nil, err
    }
    return d.MultiServersDiscovery.GetAll()
}

Demo

最后,依旧通过简单的 Demo 验证今天的成果。

添加函数 startRegistry,稍微修改 startServer,添加调用注册中心的Heartbeat方法的逻辑,定期向注册中心发送心跳。

day7/main/main.go

func startRegistry(wg *sync.WaitGroup) {
    l, _ := net.Listen("tcp", ":9999")
    registry.HandleHTTP()
    wg.Done()
    _ = http.Serve(l, nil)
}

func startServer(registryAddr string, wg *sync.WaitGroup) {
    var foo Foo
    l, _ := net.Listen("tcp", ":0")
    server := geerpc.NewServer()
    _ = server.Register(&foo)
    registry.Heartbeat(registryAddr)
    wg.Done()
    server.Accept(l)
}

接下来,将 call 和 broadcast 的 MultiServersDiscovery 替换为 GeeRegistryDiscovery,不再需要硬编码服务列表。

func call(registry string) {
    d := xclient.NewGeeRegistryDiscovery(registry, 0)
    xc := xclient.NewXClient(d, xclient.RandomSelect, nil)
    defer func() { _ = xc.Close() }()
    // send request & receive response
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            foo(xc, context.Background(), "call", "Foo.Sum", &Args{Num1: i, Num2: i * i})
        }(i)
    }
    wg.Wait()
}

func broadcast(registry string) {
    d := xclient.NewGeeRegistryDiscovery(registry, 0)
    xc := xclient.NewXClient(d, xclient.RandomSelect, nil)
    defer func() { _ = xc.Close() }()
    var wg sync.WaitGroup
    for i := 0; i < 5; i++ {
        wg.Add(1)
        go func(i int) {
            defer wg.Done()
            foo(xc, context.Background(), "broadcast", "Foo.Sum", &Args{Num1: i, Num2: i * i})
            // expect 2- 5 timeout
            ctx, _ := context.WithTimeout(context.Background(), time.Second*2)
            foo(xc, ctx, "broadcast", "Foo.Sleep", &Args{Num1: i, Num2: i * i})
        }(i)
    }
    wg.Wait()
}

最后在 main 函数中,将所有的逻辑串联起来,确保注册中心启动后,再启动 RPC 服务端,最后客户端远程调用。

func main() {
    log.SetFlags(0)
    registryAddr := "http://localhost:9999/geerpc/registry"
    var wg sync.WaitGroup
    wg.Add(1)
    go startRegistry(&wg)
    wg.Wait()
    
    time.Sleep(time.Second)
    wg.Add(2)
    go startServer(registryAddr, &wg)
    go startServer(registryAddr, &wg)
    wg.Wait()
    
    time.Sleep(time.Second)
    call(registryAddr)
    broadcast(registryAddr)    
}

运行结果如下:

rpc registry path: /geerpc/registry
*main.Foo    Sleep
rpc server: register Foo.Sleep
*main.Foo    Sum
rpc server: register Foo.Sum
tcp@[::]:46043 send heart beat to registry http://localhost:9999/geerpc/registry
*main.Foo    Sleep
rpc server: register Foo.Sleep
*main.Foo    Sum
rpc server: register Foo.Sum
tcp@[::]:45079 send heart beat to registry http://localhost:9999/geerpc/registry
rpc registry: refresh servers from registry http://localhost:9999/geerpc/registry
call Foo.Sum success: 2 + 4 = 6
call Foo.Sum success: 4 + 16 = 20
call Foo.Sum success: 1 + 1 = 2
call Foo.Sum success: 0 + 0 = 0
call Foo.Sum success: 3 + 9 = 12
rpc registry: refresh servers from registry http://localhost:9999/geerpc/registry
broadcast Foo.Sum success: 3 + 9 = 12
broadcast Foo.Sum success: 2 + 4 = 6
broadcast Foo.Sum success: 4 + 16 = 20
broadcast Foo.Sum success: 1 + 1 = 2
broadcast Foo.Sum success: 0 + 0 = 0
broadcast Foo.Sleep success: 0 + 0 = 0
broadcast Foo.Sleep success: 1 + 1 = 2
broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded
broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded
broadcast Foo.Sleep error: rpc client: call failed: context deadline exceeded

七天时间,参照 golang 标准库 net/rpc,实现了服务端以及支持并发的客户端,并且支持选择不同的序列化与反序列化方式;为了防止服务挂死,在其中一些关键部分添加蓝超时处理机制;支持 TCP、Unix、HTTP 等多种传输协议;支持多种负载均衡模式,最后还实现了一个简易的服务注册和发现中心。

一些想法

其实我学习这个的目的是为了尝试完成 mit 6.824,它的 lab1 里,要求使用利用 rpc 来完成客户端和服务器之间的通信,但无奈,学习 golang 的时间并不是很长,恰巧留意到 geektutu 有发过 rpc 的七天项目,那么,正好,来敲一敲练练手。并且推进这个小项目的时间也不止七天。

那么来提炼一些个人觉得比较重要的知识。

典型的 RPC 调用过程

感觉,这个项目中的 rpc 的一些定义我不是很能理解,那么根据自己在别的网站上学到的,以及在做 6.824 lab1时的一些经验,重新总结下相关 rpc 的基本结构。

server

这里的定义,结合了 6.824 lab1 的实验要求进行总结,当然下面举的例子也只是简要说明下 rpc 的组成。

type Task struct {
    FileName string
    TaskType int
}

type Coordinator struct {
    task Task
}

client

type client struct {
    clientID int
    mapf     func(string, string) []KeyValue
    reducef  func(string, []string) string
}

rpc

type TaskRequest struct {
    X int
}
// 用于获取任务的请求结构体,在 lab1 中不携带信息

type TaskReply struct {
    X int
}
// 回复任务的结构体,在 lab1 中不需要携带信息

server 中定义了一个函数 func()

func (s *Server)Func() {
    // code
}

然后在 client 里有一个 CallGetFunc(), 用于远程调用。

func CallGetFunc() {
    args := TaskRequest{}
    reply := TaskReply{}
    call("Server.Func", &args, &reply)
}

然后按照下列方式调用远程调用函数:

call("StructName.FunctionName", &args, &reply),

这样子,就完成了一个基本的 rpc 调用过程。