build your own redis

开始

02. Introduction to Socket

Redis 是服务器/客户端系统的一个例子。多个客户端连接到单个服务器,服务器接收来自 TCP 连接的请求并发回响应。在开始套接字编程之前,我们需要学习几个 Linux 系统调用。

bind() 和 listen() 系统调用:bind() 将地址与套接字 fd 相关联,而

listen() 使我们能够接受与该地址的连接。

accept() 采用侦听 fd,当客户端与侦听地址建立连接时,accept() 返回一个表示连接套接字的 fd。以下是解释服务器典型工作流程的伪代码:

fd = socket()
bind(fd, address)
listen(fd)
while True:
    conn_fd = accept(fd)
    do_something_with(conn_fd)
    close(conn_fd)

read() 系统调用从 TCP 连接接收数据。write() 系统调用发送数据。close() 系统调用销毁 fd 引用的资源并回收 fd 编号。

我们引入了服务器端网络编程所需的系统调用。对于客户端,connect() 系统调用采用套接字 fd 和地址,并与该地址建立 TCP 连接。下面是客户端的伪代码:

fd = socket()
connect(fd, address)
do_something_with(fd)
close(fd)

03. Hello Server/Client

本章继续介绍套接字编程。我们将编写 2 个简单(不完整和损坏)的程序来演示上一章中的系统调用。第一个程序是服务器,它接受来自客户端的连接,读取一条消息,并写入一条回复。第二个程序是客户端,它连接到服务器,写入一条消息,并读取一条回复。让我们先从服务器开始。

// client.cpp
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>

static void die(const char *msg)
{
    int err = errno;
    fprintf(stderr, "[%d] %s\n", err, msg);
    abort();
}

int main()
{
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0)
    {
        die("socket()");
    }

    struct sockaddr_in addr = {};
    addr.sin_family = AF_INET;
    addr.sin_port = ntohs(1234);
    addr.sin_addr.s_addr = ntohl(INADDR_LOOPBACK); // 127.0.0.1
    int rv = connect(fd, (const struct sockaddr *)&addr, sizeof(addr));
    if (rv)
    {
        die("connect");
    }

    char msg[] = "hello";
    write(fd, msg, strlen(msg));

    char rbuf[64] = {};
    ssize_t n = read(fd, rbuf, sizeof(rbuf) - 1);
    if (n < 0)
    {
        die("read");
    }
    printf("server says: %s\n", rbuf);
    close(fd);
    return 0;
}
// server.cpp
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>

static void msg(const char *msg)
{
    fprintf(stderr, "%s\n", msg);
}

static void die(const char *msg)
{
    int err = errno;
    fprintf(stderr, "[%d] %s\n", err, msg);
    abort();
}

static void do_something(int connfd)
{
    char rbuf[64] = {};
    ssize_t n = read(connfd, rbuf, sizeof(rbuf) - 1);
    if (n < 0)
    {
        msg("read() error");
        return;
    }
    printf("client says: %s\n", rbuf);

    char wbuf[] = "world";
    write(connfd, wbuf, strlen(wbuf));
}

int main()
{
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0)
    {
        die("socket()");
    }

    // this is needed for most server applications
    int val = 1;
    setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));

    // bind
    struct sockaddr_in addr = {};
    addr.sin_family = AF_INET;
    addr.sin_port = ntohs(1234);
    addr.sin_addr.s_addr = ntohl(0); // wildcard address 0.0.0.0
    int rv = bind(fd, (const sockaddr *)&addr, sizeof(addr));
    if (rv)
    {
        die("bind()");
    }

    // listen
    rv = listen(fd, SOMAXCONN);
    if (rv)
    {
        die("listen()");
    }

    while (true)
    {
        // accept
        struct sockaddr_in client_addr = {};
        socklen_t socklen = sizeof(client_addr);
        int connfd = accept(fd, (struct sockaddr *)&client_addr, &socklen);
        if (connfd < 0)
        {
            continue; // error
        }

        do_something(connfd);
        close(connfd);
    }

    return 0;
}
g++ -Wall -Wextra -O2 -g server.cpp -o server
g++ -Wall -Wextra -O2 -g client.cpp -o client

04. Protocol Parsing

我们的服务器将能够处理来自客户端的多个请求,为此我们需要实现某种“协议”,至少将请求与 TCP 字节流分开。拆分请求的最简单方法是在请求开始时声明请求的长度。让我们使用以下方案。

while (true) {
    // accept
    struct sockaddr_in client_addr = {};
    socklen_t socklen = sizeof(client_addr);
    int connfd = accept(fd, (struct sockaddr *)&client_addr, &socklen);
    if (connfd < 0) {
        continue; // error
    }
    // only serves one client connection at once
    while (true) {
        int32_t err = one_request(connfd);
            if (err) {

            }
        }
    close(connfd);
}

one_request 函数仅分析一个请求并回复,直到发生不良情况或客户端连接消失。我们的服务器一次只能处理一个连接,直到我们在后面的章节中介绍事件循环。

  1. read() 系统调用只返回内核中可用的任何数据,或者如果没有。它是负责处理不足数据的应用程序。read_full() 函数从内核读取,直到它正好得到 n 个字节
  2. 同样,如果内核缓冲区已满,write() 系统调用可以成功返回部分写入的数据,当 write() 返回的字节数少于我们需要的字节时,我们需要继续尝试。

为方便起见,我们添加了对最大请求大小的限制,并使用足够大的缓冲区来保存请求。字节序曾经是解析协议时的一个考虑因素,但今天它不太相关,所以我们只是 memcpy-ing 整数。

协议解析代码每个请求至少需要 2 个 read() 系统调用。可以使用“缓冲 IO”来减少系统调用的数量。也就是说:一次尽可能多地读取缓冲区,然后尝试解析来自该缓冲区的多个请求。鼓励读者尝试将其作为一种练习,因为它可能有助于理解后面的章节。

协议说明:本章中使用的协议是最简单的实用协议。大多数现实世界的协议都比这更复杂。有些使用文本而不是二进制数据。虽然文本协议具有人类可读的优点,但文本协议确实比二进制协议需要更多的解析,后者更易于编码和出错。使协议解析复杂化的另一件事是,某些协议没有直接的方法来拆分消息,这些协议可能使用分隔符,或者需要进一步解析来拆分消息。在协议中使用分隔符可能会增加另一个复杂性,当协议携带任意数据,因为数据中的分隔符需要“转义”。我们将在后面的章节中坚持使用简单的二进制协议。

// server
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>


static void msg(const char *msg) {
    fprintf(stderr, "%s\n", msg);
}

static void die(const char *msg) {
    int err = errno;
    fprintf(stderr, "[%d] %s\n", err, msg);
    abort();
}

const size_t k_max_msg = 4096;

static int32_t read_full(int fd, char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = read(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error, or unexpected EOF
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t write_all(int fd, const char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = write(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t one_request(int connfd) {
    // 4 bytes header
    char rbuf[4 + k_max_msg + 1];
    errno = 0;
    int32_t err = read_full(connfd, rbuf, 4);
    if (err) {
        if (errno == 0) {
            msg("EOF");
        } else {
            msg("read() error");
        }
        return err;
    }

    uint32_t len = 0;
    memcpy(&len, rbuf, 4);  // assume little endian
    if (len > k_max_msg) {
        msg("too long");
        return -1;
    }

    // request body
    err = read_full(connfd, &rbuf[4], len);
    if (err) {
        msg("read() error");
        return err;
    }

    // do something
    rbuf[4 + len] = '\0';
    printf("client says: %s\n", &rbuf[4]);

    // reply using the same protocol
    const char reply[] = "world";
    char wbuf[4 + sizeof(reply)];
    len = (uint32_t)strlen(reply);
    memcpy(wbuf, &len, 4);
    memcpy(&wbuf[4], reply, len);
    return write_all(connfd, wbuf, 4 + len);
}

int main() {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
        die("socket()");
    }

    // this is needed for most server applications
    int val = 1;
    setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));

    // bind
    struct sockaddr_in addr = {};
    addr.sin_family = AF_INET;
    addr.sin_port = ntohs(1234);
    addr.sin_addr.s_addr = ntohl(0);    // wildcard address 0.0.0.0
    int rv = bind(fd, (const sockaddr *)&addr, sizeof(addr));
    if (rv) {
        die("bind()");
    }

    // listen
    rv = listen(fd, SOMAXCONN);
    if (rv) {
        die("listen()");
    }

    while (true) {
        // accept
        struct sockaddr_in client_addr = {};
        socklen_t socklen = sizeof(client_addr);
        int connfd = accept(fd, (struct sockaddr *)&client_addr, &socklen);
        if (connfd < 0) {
            continue;   // error
        }

        while (true) {
            // here the server only serves one client connection at once
            int32_t err = one_request(connfd);
            if (err) {
                break;
            }
        }
        close(connfd);
    }

    return 0;
}
// client
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>


static void msg(const char *msg) {
    fprintf(stderr, "%s\n", msg);
}

static void die(const char *msg) {
    int err = errno;
    fprintf(stderr, "[%d] %s\n", err, msg);
    abort();
}

static int32_t read_full(int fd, char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = read(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error, or unexpected EOF
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

static int32_t write_all(int fd, const char *buf, size_t n) {
    while (n > 0) {
        ssize_t rv = write(fd, buf, n);
        if (rv <= 0) {
            return -1;  // error
        }
        assert((size_t)rv <= n);
        n -= (size_t)rv;
        buf += rv;
    }
    return 0;
}

const size_t k_max_msg = 4096;

static int32_t query(int fd, const char *text) {
    uint32_t len = (uint32_t)strlen(text);
    if (len > k_max_msg) {
        return -1;
    }

    char wbuf[4 + k_max_msg];
    memcpy(wbuf, &len, 4);  // assume little endian
    memcpy(&wbuf[4], text, len);
    if (int32_t err = write_all(fd, wbuf, 4 + len)) {
        return err;
    }

    // 4 bytes header
    char rbuf[4 + k_max_msg + 1];
    errno = 0;
    int32_t err = read_full(fd, rbuf, 4);
    if (err) {
        if (errno == 0) {
            msg("EOF");
        } else {
            msg("read() error");
        }
        return err;
    }

    memcpy(&len, rbuf, 4);  // assume little endian
    if (len > k_max_msg) {
        msg("too long");
        return -1;
    }

    // reply body
    err = read_full(fd, &rbuf[4], len);
    if (err) {
        msg("read() error");
        return err;
    }

    // do something
    rbuf[4 + len] = '\0';
    printf("server says: %s\n", &rbuf[4]);
    return 0;
}

int main() {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
        die("socket()");
    }

    struct sockaddr_in addr = {};
    addr.sin_family = AF_INET;
    addr.sin_port = ntohs(1234);
    addr.sin_addr.s_addr = ntohl(INADDR_LOOPBACK);  // 127.0.0.1
    int rv = connect(fd, (const struct sockaddr *)&addr, sizeof(addr));
    if (rv) {
        die("connect");
    }

    // multiple requests
    int32_t err = query(fd, "hello1");
    if (err) {
        goto L_DONE;
    }
    err = query(fd, "hello2");
    if (err) {
        goto L_DONE;
    }
    err = query(fd, "hello3");
    if (err) {
        goto L_DONE;
    }

L_DONE:
    close(fd);
    return 0;
}

05. The Event Loop and Nonblocking IO

在服务器端网络编程中,有 3 种方法可以处理并发连接。它们是:分叉、多线程和事件循环。分叉为每个客户端连接创建新进程以实现并发。多线程使用线程而不是进程。事件循环使用轮询和非阻塞 IO,通常在单个线程上运行。由于进程和线程的开销,大多数现代生产级软件都使用事件循环进行联网。

我们不只是使用 fds 做一些事情(读取、写入或接受),而是使用轮询操作来告诉我们哪个 fd 可以立即操作而不会阻塞。当我们在 fd 上执行 IO 操作时,该操作应在非阻塞模式下执行。

在阻塞模式下,当内核中没有数据时,读取会阻塞调用方,在写缓冲区已满时写入块,当内核队列中没有新连接时接受块。在非阻塞模式下,这些操作要么成功而不阻塞,要么失败并显示 errno EAGAIN,这意味着“未就绪”。在轮询通知就绪情况后,必须重试因 EAGAIN 而失败的非阻塞操作。

轮询是事件循环中唯一的阻塞操作,其他所有内容都必须是非阻塞的;因此,单个线程可以处理多个并发连接。所有阻塞网络 IO API(如读取、写入和接受)都具有非阻塞模式。没有非阻塞模式的 API(例如 gethostbyname 和磁盘 IO)应在线程池中执行,这将在后面的章节中介绍。此外,计时器必须在事件循环中实现,因为我们不能在事件循环中等待睡眠。

将 fd 设置为非阻塞模式的系统调用是 fcntl:

在 Linux 上,除了轮询系统调用之外,还有 select 和 epoll。古代 select syscall 与轮询基本相同,只是最大 fd 数限制为少量,这使得它在现代应用程序中已经过时。epoll API 由 3 个系统调用组成:epoll_create、epoll_wait 和 epoll_ctl。epoll API 是有状态的,而不是提供一组 fds 作为系统调用参数,epoll_ctl 用于操作由 epoll_create 创建的 fd 集,epoll_wait 正在操作该 fd 集。

06. The Event Loop Implementation

enum {
    STATE_REQ = 0,
    STATE_RES = 1,
    STATE_END = 2, // mark the connection for deletion
};
struct Conn {
    int fd = -1;
    uint32_t state = 0; // either STATE_REQ or STATE_RES
    // buffer for reading
    size_t rbuf_size = 0;
    uint8_t rbuf[4 + k_max_msg];
    // buffer for writing
    size_t wbuf_size = 0;
    size_t wbuf_sent = 0;
    uint8_t wbuf[4 + k_max_msg];
};

我们需要缓冲区进行读取/写入,因为在非阻塞模式下,IO 操作通常会延迟。
state 用于决定如何处理连接。正在进行的连接有 2 种状态。STATE_REQ 用于读取请求,STATE_RES 用于发送响应。

事件循环的代码

int main() {
    int fd = socket(AF_INET, SOCK_STREAM, 0);
    if (fd < 0) {
        die("socket()");
    }
    // bind, listen and etc
    // code omitted...
    // a map of all client connections, keyed by fd
    std::vector<Conn *> fd2conn;
    // set the listen fd to nonblocking mode
    fd_set_nb(fd);
    // the event loop
    std::vector<struct pollfd> poll_args;
    while (true) {
        // prepare the arguments of the poll()
        poll_args.clear();
        // for convenience, the listening fd is put in the first position
        struct pollfd pfd = {fd, POLLIN, 0};
        poll_args.push_back(pfd);
        // connection fds
        for (Conn *conn : fd2conn) {
            if (!conn) {
                continue;
            }
            struct pollfd pfd = {};
            pfd.fd = conn->fd;
            pfd.events = (conn->state == STATE_REQ) ? POLLIN : POLLOUT;
            pfd.events = pfd.events | POLLERR;
            poll_args.push_back(pfd);
        }
        // poll for active fds
        // the timeout argument doesn't matter here
        int rv = poll(poll_args.data(), (nfds_t)poll_args.size(), 1000);
        if (rv < 0) {
            die("poll");
        }
        // process active connections
        for (size_t i = 1; i < poll_args.size(); ++i) {
        if (poll_args[i].revents) {
            Conn *conn = fd2conn[poll_args[i].fd];
            connection_io(conn);
            if (conn->state == STATE_END) {
                    // client closed normally, or something bad happened.
                    // destroy this connection
                    fd2conn[conn->fd] = NULL;
                    (void)close(conn->fd);
                    free(conn);
                }
            }
        }
        // try to accept a new connection if the listening fd is active
        if (poll_args[0].revents) {
            (void)accept_new_conn(fd2conn, fd);
            }
        }
        return 0;
}

事件循环中的第一件事是设置 poll。侦听 fd 使用 POLLIN 标志进行轮询。对于连接 fd,结构 Conn 的状态决定了轮询标志。在这种特殊情况下,轮询标志要么是读取(POLLIN),要么是写入(POLLOUT),而不是两者兼而有之。如果使用 epoll,事件循环中的第一件事通常是使用 epoll_ctl 更新 fd 集。

轮询还采用可用于实现计时器的超时参数,在我们的例子中,这个参数无关紧要,只需将其设置为一个大数字即可。在 poll 返回后,我们会收到通知哪个 fd 准备好阅读/写入并采取相应行动。

static void conn_put(std::vector<Conn *> &fd2conn, struct Conn *conn) {
    if (fd2conn.size() <= (size_t)conn->fd) {
        fd2conn.resize(conn->fd + 1);
    }
    fd2conn[conn->fd] = conn;
}
static int32_t accept_new_conn(std::vector<Conn *> &fd2conn, int fd) {
    // accept
    struct sockaddr_in client_addr = {};
    socklen_t socklen = sizeof(client_addr);
    int connfd = accept(fd, (struct sockaddr *)&client_addr, &socklen);
    if (connfd < 0) {
        msg("accept() error");
        return -1; // error
    }
    // set the new connection fd to nonblocking mode
    fd_set_nb(connfd);
    // creating the struct Conn
    struct Conn *conn = (struct Conn *)malloc(sizeof(struct Conn));
    if (!conn) {
        close(connfd);
        return -1;
    }
    conn->fd = connfd;
    conn->state = STATE_REQ;
    conn->rbuf_size = 0;
    conn->wbuf_size = 0;
    conn->wbuf_sent = 0;
    conn_put(fd2conn, conn);
    return 0;
}

connection_io 是客户端连接的状态机

static void connection_io(Conn *conn) {
    if (conn->state == STATE_REQ) {
        state_req(conn);
    } else if (conn->state == STATE_RES) {
        state_res(conn);
    } else {
        assert(0); // not expected
    }
}

STATE_REQ 状态用于读取

static void state_req(Conn *conn) {
    while (try_fill_buffer(conn)) {}
}
static bool try_fill_buffer(Conn *conn) {
    // try to fill the buffer
    assert(conn->rbuf_size < sizeof(conn->rbuf));
    ssize_t rv = 0;
    do {
        size_t cap = sizeof(conn->rbuf) - conn->rbuf_size;
        rv = read(conn->fd, &conn->rbuf[conn->rbuf_size], cap);
    } while (rv < 0 && errno == EINTR);
    if (rv < 0 && errno == EAGAIN) {
        // got EAGAIN, stop.
        return false;
    }
    if (rv < 0) {
        msg("read() error");
        conn->state = STATE_END;
        return false;
    }
    if (rv == 0) {
        if (conn->rbuf_size > 0) {
            msg("unexpected EOF");
        } else {
            msg("EOF");
        }
        conn->state = STATE_END;
        return false;
    }
    conn->rbuf_size += (size_t)rv;
    assert(conn->rbuf_size <= sizeof(conn->rbuf) - conn->rbuf_size);
    // Try to process requests one by one.
    // Why is there a loop? Please read the explanation of "pipelining".
    while (try_one_request(conn)) {}
    return (conn->state == STATE_REQ);
}

  转载请注明: malred-blog build your own redis

  目录