Go 实战训练营

07 第一周:Web 框架之 Server 与路由树

1. Web 框架概览:学习路线

2. Web 框架概览:Beego 框架分析


package beego

import "github.com/beego/beego/v2/server/web"

type UserController struct {
    web.Controller
}

func (c *UserController) GetUser() {
    c.Ctx.WriteString("你好,我是malred")
}

func (c *UserController) CreateUser() {
    u := &User{}
    err := c.Ctx.BindJSON(u)
    if err != nil {
        c.Ctx.WriteString(err.Error())
        return
    }

    _ = c.Ctx.JSONResp(u)
}

type User struct {
    Name string
}
package beego

import (
    "github.com/beego/beego/v2/server/web"
    "testing"
)

func TestUserController(t *testing.T) {
    // 这个配置是beego独有的
    web.BConfig.CopyRequestBody = true
    c := &UserController{}
    web.Router("/user", c, "get:GetUser")
    web.Run(":8081")
}



作者认为可以将 response 放入 output,request 放入 input


不同 server 之间是隔离的

3. Web 框架概览:Gin 框架分析

package gin

import (
    "github.com/gin-gonic/gin"
)

type UserController struct{}

func (c *UserController) GetUser(ctx *gin.Context) {
    ctx.String(200, "hello world")
}
package gin

import (
    "net/http"
    "testing"

    "github.com/gin-gonic/gin"
)

func TestUserController_GetUser(t *testing.T) {
    g := gin.Default()
    ctrl := &UserController{}
    g.GET("/user", ctrl.GetUser)
    g.POST("/user", func(ctx *gin.Context) {
        ctx.String(http.StatusOK, "hello %s", "world")
    })

    g.GET("/static", func(ctx *gin.Context) {
        // 读文件
        // 写响应
    })
    _ = g.Run(":8082")
}

handle 接口是核心,作者认为没必要将 static 放入核心接口里,而是可以提供一个实现








4. Web 框架概览:Iris 框架分析

// v12版本需要go1.21
package iris

import (
    "github.com/kataras/iris/v12"
    "testing"
)

func TestHelloWorld(t *testing.T) {
    app := iris.New()
    app.Get("/", func(ctx iris.Context) {
        _, _ = ctx.HTML("hello <strong>%s</strong>!", "world")
    })
    _ = app.Listen(":8083")
}






5. Web 框架概览:Echo 框架分析与对比总结

package echo

import (
    "github.com/labstack/echo/v4"
    "github.com/labstack/echo/v4/middleware"
    "net/http"
    "testing"
)

func TestHelloWorld(t *testing.T) {
    // echo instance
    e := echo.New()

    // Middleware
    e.Use(middleware.Logger())
    e.Use(middleware.Recover())

    // Routers
    e.GET("/", hello)

    // Start server
    e.Logger.Fatal(e.Start(":8084"))
}

// Handle
func hello(c echo.Context) error {
    return c.String(http.StatusOK, "hello world")
}

作者认为要做隔离就再弄一个 echo 对象,而不是在一个 echo 里放 namespace 和普通 route




6. Server 详解与面试要点







既然有路由管理功能,就得有路由注册的功能







// context.go
package web

import "net/http"

type Context struct {
    Req  *http.Request
    Resp http.ResponseWriter
}
// server_test.go
package web

import (
    "fmt"
    "net/http"
    _ "net/http"
    "testing"
)

func TestServer(t *testing.T) {
    h := &HTTPServer{} // NewServer

    h.AddRoute(http.MethodGet, "/user", func(ctx Context) {
        fmt.Println("user")
    })

    handler1 := func(ctx Context) {
        fmt.Println("1")
    }
    handler2 := func(ctx Context) {
        fmt.Println("2")
    }

    // 使用注册一个的也能实现
    h.AddRoute(http.MethodGet, "/user", func(ctx Context) {
        handler1(ctx)
        handler2(ctx)
    })

    // h:=&HTTPServer{} 才有该方法
    h.Get("/user", func(ctx Context) {})

    // h.AddRoutes(
    //     http.MethodGet, "/user", handler1, handler2,
    // )

    // 用法1 完全委托给http包
    // 第二个参数 handler 是我们框架和http包的结合点
    // http.ListenAndServe(":8081", h)
    // http.ListenAndServeTLS(":443", "", "", h)

    // 用法2 自己手动管理
    h.Start(":8081")
}
// server.go
package web

import (
    "net"
    "net/http"
)

type HandleFunc func(ctx Context)

// 确保HTTPServer结构体一定实现了Server
var _ Server = &HTTPServer{}

type Server interface {
    http.Handler
    Start(addr string) error
    // StartHttp() error

    // AddRoute 需要增加路由注册的功能
    // method 请求方法
    // path 路由
    // handleFunc 业务逻辑
    AddRoute(method string, path string, handleFunc HandleFunc)
    // 注册多个(没必要提供,用户可以自己根据AddRoute实现)
    // AddRoutes(method string, path string, handleFunc ...HandleFunc)
}

// type HTTPSServer struct {
//     HTTPServer // 装饰HTTPServer
// }

type HTTPServer struct {
    // addr string // 可以创建时传递,而不是通过start传,也是可以的
}

// http.Handler的接口
// 核心方法 处理请求的入口
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    // 框架代码
    ctx := &Context{
        Req:  req,
        Resp: w,
    }
    // 接下来就是查找路由并执行命中的业务逻辑
    h.serve(ctx)
}

func (h *HTTPServer) serve(ctx *Context) {

}

func (h *HTTPServer) Start(addr string) error {
    // 也可以自己创建server
    // http.Server{}
    l, err := net.Listen("tcp", addr)
    if err != nil {
        return err
    }

    // 在这里可以让用户注册 after、start 等回调(生命周期)
    // 比如往 admin 注册这个实例
    // 或者执行一些业务所需的前置条件

    return http.Serve(l, h)
}

func (h *HTTPServer) AddRoute(method string, path string, handleFunc HandleFunc) {
    // 注册到路由树
    panic("implement me")
}

// 接口定义的东西尽量简洁
// 提供多种实现
func (h *HTTPServer) Get(path string, handleFunc HandleFunc) {
    h.AddRoute(http.MethodGet, path, handleFunc)
}

func (h *HTTPServer) Post(path string, handleFunc HandleFunc) {
    h.AddRoute(http.MethodPost, path, handleFunc)
}

func (h *HTTPServer) Put(path string, handleFunc HandleFunc) {
    h.AddRoute(http.MethodPut, path, handleFunc)
}

func (h *HTTPServer) Patch(path string, handleFunc HandleFunc) {
    h.AddRoute(http.MethodPatch, path, handleFunc)
}

func (h *HTTPServer) Delete(path string, handleFunc HandleFunc) {
    h.AddRoute(http.MethodDelete, path, handleFunc)
}

// func (h *HTTPServer) AddRoutes(method string, path string, handleFunc ...HandleFunc) {
//     panic("implement me")
// }

func (h *HTTPServer) StartHttp(addr string) error {
    return http.ListenAndServe(addr, h)
}

7. 路由树:Beego、Gin、Echo 实现与设计总结







gin 路由树的复杂在于要找最长公共前缀,然后重构树

8. 路由树:全静态匹配


// router.go
package web

// 用来支持对路由树的操作
// 代表路由树(森林)
type router struct {
    // Beego Gin HTTP method 对应一棵树
    // GET 一颗,POST 一颗 ……

    // http method => 路由树根节点
    trees map[string]*node
}

// type tree struct {
//     root *node
// }

func newRouter() *router {
    return &router{
        trees: map[string]*node{},
    }
}

func (r *router) AddRoute(method string, path string, handleFunc HandleFunc) {
    // 注册到路由树
    panic("implement me")
}

type node struct {
    path string
    // path => 子节点
    children map[string]*node
    // 需要一个代表用户注册的业务逻辑
    handler HandleFunc
}
// server.go
package web

import (
    "net"
    "net/http"
)

type HandleFunc func(ctx Context)

// 确保HTTPServer结构体一定实现了Server
var _ Server = &HTTPServer{}

type Server interface {
    http.Handler
    Start(addr string) error
    // StartHttp() error

    // AddRoute 需要增加路由注册的功能
    // method 请求方法
    // path 路由
    // handleFunc 业务逻辑
    AddRoute(method string, path string, handleFunc HandleFunc)
    // 注册多个(没必要提供,用户可以自己根据AddRoute实现)
    // AddRoutes(method string, path string, handleFunc ...HandleFunc)
}

// type HTTPSServer struct {
//     HTTPServer // 装饰HTTPServer
// }

type HTTPServer struct {
    // addr string // 可以创建时传递,而不是通过start传,也是可以的
    // router
    *router // 因为router实现了AddRoute,所以也可以通过编译
}

// 初始化server
func NewHTTPServer() *HTTPServer {
    return &HTTPServer{
        router: newRouter(),
    }
}

// func (h *HTTPServer) AddRoute(method string, path string, handleFunc HandleFunc) {
//     // 注册到路由树
//     panic("implement me")
// }

// ...

9. 路由树:TDD 起步


// router.go
package web

import "strings"

// ...

func (r *router) AddRoute(method string, path string, handleFunc HandleFunc) {
    // 注册到路由树
    root, ok := r.trees[method]
    if !ok {
        // 说明还没有根节点
        root = &node{
            path: "/",
        }
        // 处理方法
        r.trees[method] = root
    }

    // /user/home -> 会被切成3段
    path = path[1:] // /user -> user
    // 切割这个 path
    segs := strings.Split(path, "/")
    for _, seg := range segs {
        // 递归下去, 中途有节点不存在就创建
        children := root.childOrCreate(seg)
        // root向下(深)走
        root = children
    }

    // 前面root已经走到/user/home, 所以这里添加处理方法
    // 是对/user/home的处理
    root.handler = handleFunc
}

func (n *node) childOrCreate(seg string) *node {
    if n.children == nil {
        n.children = map[string]*node{}
    }
    res, ok := n.children[seg]
    if !ok {
        // 新建
        res = &node{
            path: seg,
        }
        n.children[seg] = res
    }
    return res
}

type node struct {
    // ...
}
// router_test.go
package web

import (
    "fmt"
    "net/http"
    "reflect"
    "testing"

    "github.com/stretchr/testify/assert"
)

func TestRouter_AddRoute(t *testing.T) {
    // 1. 构造路由树
    testRoutes := []struct {
        method string
        path   string
    }{
        {
            method: http.MethodGet,
            path:   "/user/home",
        },
    }

    var mockHandler HandleFunc = func(ctx Context) {}
    r := newRouter()
    for _, route := range testRoutes {
        r.AddRoute(route.method, route.path, mockHandler)
    }

    // 2. 验证路由树
    wantRouter := &router{
        trees: map[string]*node{
            http.MethodGet: &node{
                path: "/",
                children: map[string]*node{
                    "user": &node{
                        path: "user",
                        children: map[string]*node{
                            "home": &node{
                                path:    "home",
                                handler: mockHandler,
                            },
                        },
                    },
                },
            },
        },
    }
    msg, ok := wantRouter.equal(r)
    assert.True(t, ok, msg)
}

func (r *router) equal(y *router) (string, bool) {
    for k, v := range r.trees {
        dst, ok := y.trees[k]
        if !ok {
            return fmt.Sprintf("找不到对应的 http method"), false
        }
        msg, equal := v.equal(dst)
        if !equal {
            return msg, false
        }
    }
    return "", true
}

func (n *node) equal(y *node) (string, bool) {
    if n.path != y.path {
        return fmt.Sprintf("节点路径不匹配"), false
    }
    if len(n.children) != len(y.children) {
        return fmt.Sprintf("子节点数量不相等"), false
    }

    // 比较handler
    nHandler := reflect.ValueOf(n.handler)
    yHandler := reflect.ValueOf(y.handler)
    if nHandler != yHandler {
        return fmt.Sprintf("handler 不相等"), false
    }

    for path, c := range n.children {
        dst, ok := y.children[path]
        if !ok {
            return fmt.Sprintf("子节点 %s 不存在", path), false
        }
        msg, ok := c.equal(dst)
        if !ok {
            return msg, false
        }
    }
    return "", true
}

10. 路由树:静态匹配测试用例

// router_test.go
package web

import (
    "fmt"
    "net/http"
    "reflect"
    "testing"

    "github.com/stretchr/testify/assert"
)

func TestRouter_addRoute(t *testing.T) {
    // 1. 构造路由树
    testRoutes := []struct {
        method string
        path   string
    }{
        {
            method: http.MethodGet,
            path:   "/",
        },
        {
            method: http.MethodGet,
            path:   "/user",
        },
        {
            method: http.MethodGet,
            path:   "/user/home",
        },
        {
            method: http.MethodGet,
            path:   "/order/detail",
        },
        {
            method: http.MethodPost,
            path:   "/order/create",
        },
        {
            method: http.MethodPost,
            path:   "/login",
        },
        // {
        //     method: http.MethodPost,
        //     path:   "login",
        // },
        // {
        //     method: http.MethodPost,
        //     path:   "login////",
        // },
        // {
        //     method: http.MethodPost,
        //     path:   "//login//a//b",
        // },
    }

    var mockHandler HandleFunc = func(ctx Context) {}
    r := newRouter()
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    // 2. 验证路由树
    wantRouter := &router{
        trees: map[string]*node{
            http.MethodGet: &node{
                path:    "/",
                handler: mockHandler,
                children: map[string]*node{
                    "user": &node{
                        path:    "user",
                        handler: mockHandler,
                        children: map[string]*node{
                            "home": &node{
                                path:    "home",
                                handler: mockHandler,
                            },
                        },
                    },
                    "order": &node{
                        path: "order",
                        children: map[string]*node{
                            "detail": &node{
                                path:    "detail",
                                handler: mockHandler,
                            },
                        },
                    },
                },
            },
            http.MethodPost: &node{
                path: "/",
                children: map[string]*node{
                    "order": &node{
                        path: "order",
                        children: map[string]*node{
                            "create": &node{
                                path:    "create",
                                handler: mockHandler,
                            },
                        },
                    },
                    "login": &node{
                        path:    "login",
                        handler: mockHandler,
                    },
                },
            },
        },
    }
    msg, ok := wantRouter.equal(r)
    assert.True(t, ok, msg)

    r = newRouter()
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "", mockHandler)
    }, "web: 路径不能为空字符串")
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "/a/b/c/", mockHandler)
    }, "web: 路径必须以 / 结尾")
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "a/b", mockHandler)
    }, "web: 路径必须以 / 开头")
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "a//b", mockHandler)
    }, "web: 不能有连续的 /")

    r = newRouter()
    r.addRoute(http.MethodGet, "/", mockHandler)
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "/", mockHandler)
    }, "web: 路由冲突, 重复注册: [/]")

    r = newRouter()
    r.addRoute(http.MethodGet, "/user/home", mockHandler)
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "/user/home", mockHandler)
    }, "web: 路由冲突, 重复注册: [/user/home]")
}

func (r *router) equal(y *router) (string, bool) {
    // ...
}

func (n *node) equal(y *node) (string, bool) {
    // ...
}
package web

import (
    "fmt"
    "strings"
)

// 用来支持对路由树的操作
// 代表路由树(森林)
type router struct {
    // Beego Gin HTTP method 对应一棵树
    // GET 一颗,POST 一颗 ……

    // http method => 路由树根节点
    trees map[string]*node
}

// type tree struct {
//     root *node
// }

func newRouter() *router {
    // ...
}

// 加一些限制:
// path 必须以 / 开头, 不能以 / 结尾, 中间不能有连续的 //
func (r *router) addRoute(method string, path string, handleFunc HandleFunc) {
    if path == "" {
        panic("web: 路径不能为空字符串")
    }
    // 注册到路由树
    root, ok := r.trees[method]
    if !ok {
        // 说明还没有根节点
        root = &node{
            path: "/",
        }
        // 处理方法
        r.trees[method] = root
    }

    // 根节点特殊处理
    if path == "/" {
        if root.handler != nil {
            // 根节点重复注册
            panic("web: 路由冲突, 重复注册: [/]")
        }
        root.handler = handleFunc
        return
    }

    // 开头不能没有 /
    if path[0] != '/' {
        panic("web: 路径必须以 / 开头")
    }

    // 结尾
    if path != "/" && path[len(path)-1] == '/' {
        panic("web: 路径不能以 / 结尾")
    }

    // 中间连续 //, 可以用string.contains("//")

    // /user/home -> 会被切成3段
    path = path[1:] // /user -> user
    // 切割这个 path
    segs := strings.Split(path, "/")
    for _, seg := range segs {
        if seg == "" {
            panic("web: 不能有连续的 /")
        }

        // 递归下去, 中途有节点不存在就创建
        children := root.childOrCreate(seg)
        // root向下(深)走
        root = children
    }

    // 前面root已经走到/user/home, 所以这里添加处理方法
    // 是对/user/home的处理
    if root.handler != nil {
        // 普通路由重复注册
        panic(fmt.Sprintf("web: 路由冲突, 重复注册: [/%s]", path))
    }
    root.handler = handleFunc
}

func (n *node) childOrCreate(seg string) *node {
    // ...
}

type node struct {
    // ...
}

将 AddRoute 首字母小写,则变成私有方法,只能用我们暴露的 Get、Post 等方法注册,就不用担心 method 传乱七八糟的字符串

而 handler 如果为 nil,则表示根本没注册,一般用户不会这样

11. 路由树:静态匹配之路由查找



// router_test.go
package web

import (
    "fmt"
    "net/http"
    "reflect"
    "testing"

    "github.com/stretchr/testify/assert"
)

func TestRouter_findRoute(t *testing.T) {
    testRoutes := []struct {
        method string
        path   string
    }{
        {
            method: http.MethodGet,
            path:   "/",
        },
        {
            method: http.MethodDelete,
            path:   "/",
        },
        {
            method: http.MethodGet,
            path:   "/user",
        },
        {
            method: http.MethodGet,
            path:   "/user/home",
        },
        {
            method: http.MethodGet,
            path:   "/order/detail",
        },
        {
            method: http.MethodPost,
            path:   "/order/create",
        },
        {
            method: http.MethodPost,
            path:   "/login",
        },
    }

    r := newRouter()
    var mockHandler HandleFunc = func(ctx Context) {}
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    testCases := []struct {
        name string

        method string
        path   string

        wantFound bool
        wantNode  *node
    }{
        {
            // method不存在
            name:   "method not found",
            method: http.MethodOptions,
            path:   "/order/detail",
        },
        {
            // 完全命中
            name:      "order detail",
            method:    http.MethodGet,
            path:      "/order/detail",
            wantFound: true,
            wantNode: &node{
                handler: mockHandler,
                path:    "detail",
            },
        },
        {
            // 命中了,但是没有handler
            name:      "order",
            method:    http.MethodGet,
            path:      "/order",
            wantFound: true,
            wantNode: &node{
                path: "order",
                // order/detail注册的child -> detail
                children: map[string]*node{
                    "detail": &node{
                        handler: mockHandler,
                        path:    "detail",
                    },
                },
            },
        },
        {
            // 根节点
            name:      "root",
            method:    http.MethodDelete,
            path:      "/",
            wantFound: true,
            wantNode: &node{
                path:    "/",
                handler: mockHandler,
            },
        },
        {
            name:   "path not found",
            method: http.MethodGet,
            path:   "/404",
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            n, found := r.findRoute(tc.method, tc.path)
            assert.Equal(t, tc.wantFound, found)
            if !found {
                return
            }
            msg, ok := tc.wantNode.equal(n)
            assert.True(t, ok, msg)
        })
    }
}

// ...
// router.go
func (r *router) findRoute(method string, path string) (*node, bool) {
    // 沿着树深度查找下去
    root, ok := r.trees[method]
    if !ok {
        return nil, false
    }

    // 根节点
    if path == "/" {
        return root, true
    }

    // 把path前后的 / 去掉
    path = strings.Trim(path, "/")
    // 按照 / 切割
    segs := strings.Split(path, "/")
    for _, seg := range segs {
        child, found := root.childOf(seg)
        if !found {
            return nil, false
        }
        root = child
    }
    // 表示有该节点,但是不一定有handler
    return root, true
    // return root, root.handler != nil
}

// 找到对应child
func (n *node) childOf(path string) (*node, bool) {
    if n.children == nil {
        return nil, false
    }
    child, ok := n.children[path]
    return child, ok
}

12. 路由树:静态匹配之集成 Server【更多 IT 资料加微信 djy136928775638】




// server.go
func (h *HTTPServer) serve(ctx *Context) {
    n, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
    if !ok || n.handler == nil {
        // 路由没命中,404
        ctx.Resp.WriteHeader(404)
        ctx.Resp.Write([]byte("NOT FOUND"))
        return
    }
    n.handler(ctx)
}

13. 路由树:通配符匹配之路由注册【更多 IT 资料加微信 djy136928775638】

// router_test.go
package web

import (
    "fmt"
    "net/http"
    "reflect"
    "testing"

    "github.com/stretchr/testify/assert"
)

func TestRouter_addRoute(t *testing.T) {
    // 1. 构造路由树
    testRoutes := []struct {
        method string
        path   string
    }{
        // {
        //     method: http.MethodGet,
        //     path:   "/*",
        // },
        // {
        //     method: http.MethodGet,
        //     path:   "/*/*",
        // },
        // {
        //     method: http.MethodGet,
        //     path:   "/*/abc/*",
        // },
        // {
        //     method: http.MethodGet,
        //     path:   "/*/abc",
        // },
        // ...
        {
            method: http.MethodGet,
            path:   "/order/*",
        },
    }

    var mockHandler HandleFunc = func(ctx *Context) {}
    r := newRouter()
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    // 2. 验证路由树
    wantRouter := &router{
        trees: map[string]*node{
            http.MethodGet: &node{
                path:    "/",
                handler: mockHandler,
                children: map[string]*node{
                    "user": &node{
                        path:    "user",
                        handler: mockHandler,
                        children: map[string]*node{
                            "home": &node{
                                path:    "home",
                                handler: mockHandler,
                            },
                        },
                    },
                    "order": &node{
                        path: "order",
                        children: map[string]*node{
                            "detail": &node{
                                path:    "detail",
                                handler: mockHandler,
                            },
                        },
                        starChild: &node{
                            path:    "*",
                            handler: mockHandler,
                        },
                    },
                },
            },
            http.MethodPost: &node{
                // ...
            },
        },
    }
    // ...
}

func (n *node) equal(y *node) (string, bool) {
    if n.path != y.path {
        return fmt.Sprintf("节点路径不匹配"), false
    }
    if len(n.children) != len(y.children) {
        return fmt.Sprintf("子节点数量不相等"), false
    }
    if n.starChild != nil {
        msg, ok := n.starChild.equal(y.starChild)
        if !ok {
            return msg, ok
        }
    }

    // ...
    return "", true
}
// router.go
func (n *node) childOrCreate(seg string) *node {
    // 通配符
    if seg == "*" {
        n.starChild = &node{
            path: seg,
        }
        return n.starChild
    }
    // ...
}

type node struct {
    path string
    // 静态匹配的节点
    // path => 子节点
    children map[string]*node
    // 通配符(/*)匹配的节点
    starChild *node
    // 需要一个代表用户注册的业务逻辑
    handler HandleFunc
}

14. 路由树:通配符匹配之路由查找与测试【更多 IT 资料加微信 djy136928775638】

// router.go
func (n *node) childOf(path string) (*node, bool) {
    // 优先考虑静态匹配,匹配不上再考虑通配符
    if n.children == nil {
        // return nil, false
        return n.starChild, n.starChild != nil
    }
    child, ok := n.children[path]
    if !ok {
        return n.starChild, n.starChild != nil
    }
    return child, ok
}
// router_test.go
func TestRouter_findRoute(t *testing.T) {
    testRoutes := []struct {
        method string
        path   string
    }{
        // ...
        {
            method: http.MethodGet,
            path:   "/order/detail",
        },
        {
            method: http.MethodGet,
            path:   "/order/*",
        },
    }

    r := newRouter()
    var mockHandler HandleFunc = func(ctx *Context) {}
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    testCases := []struct {
        name string

        method string
        path   string

        wantFound bool
        wantNode  *node
    }{
        {
            // 完全命中
            name:      "order detail",
            method:    http.MethodGet,
            path:      "/order/detail",
            wantFound: true,
            wantNode: &node{
                handler: mockHandler,
                path:    "detail",
            },
        },
        {
            // 完全命中
            name:      "order star",
            method:    http.MethodGet,
            path:      "/order/abc",
            wantFound: true,
            wantNode: &node{
                handler: mockHandler,
                path:    "*",
            },
        },
        // ...
    }

    // ...
}
// server_test.go
package web

import (
    "fmt"
    _ "net/http"
    "testing"
)

func TestServer(t *testing.T) {
    // h := &HTTPServer{} // NewServer
    h := NewHTTPServer()

    h.Get("/order/detail", func(ctx *Context) {
        ctx.Resp.Write([]byte("hello order detail"))
    })

    h.Get("/order/*", func(ctx *Context) {
        ctx.Resp.Write([]byte(fmt.Sprintf("hello, %s", ctx.Req.URL.Path)))
    })

    h.Start(":8081")
}

15. 路由树:参数路径之基本注册和查找【更多 IT 资料加微信 djy136928775638】

// router_test.go
package web

import (
    "fmt"
    "net/http"
    "reflect"
    "testing"

    "github.com/stretchr/testify/assert"
)

func TestRouter_addRoute(t *testing.T) {
    // 1. 构造路由树
    testRoutes := []struct {
        method string
        path   string
    }{
        // ...
        // 路径参数
        {
            method: http.MethodGet,
            path:   "/order/detail/:id",
        },
    }

    var mockHandler HandleFunc = func(ctx *Context) {}
    r := newRouter()
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    // 2. 验证路由树
    wantRouter := &router{
        trees: map[string]*node{
            http.MethodGet: &node{
                path:    "/",
                handler: mockHandler,
                children: map[string]*node{
                    // ...
                    "order": &node{
                        path: "order",
                        children: map[string]*node{
                            "detail": &node{
                                path:    "detail",
                                handler: mockHandler,
                                paramChild: &node{
                                    path:    ":id",
                                    handler: mockHandler,
                                },
                            },
                        },
                        starChild: &node{
                            path:    "*",
                            handler: mockHandler,
                        },
                    },
                },
            },
            http.MethodPost: &node{
                // ...
            },
        },
    }
    // ...
}

func TestRouter_findRoute(t *testing.T) {
    testRoutes := []struct {
        method string
        path   string
    }{
        // ...
        {
            method: http.MethodPost,
            path:   "/login/:username",
        },
    }

    r := newRouter()
    var mockHandler HandleFunc = func(ctx *Context) {}
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    testCases := []struct {
        name string

        method string
        path   string

        wantFound bool
        wantNode  *node
    }{
        // ...
        {
            // 路径参数匹配
            name:      "login username",
            method:    http.MethodPost,
            path:      "/login/wuif",
            wantFound: true,
            wantNode:  &node{path: ":username", handler: mockHandler},
        },
    }

    // ...
}

// ...
// router.go

type node struct {
    path string
    // 静态匹配的节点
    // path => 子节点
    children map[string]*node
    // 通配符(/*)匹配的节点
    starChild *node
    // 路径参数
    paramChild *node
    // 需要一个代表用户注册的业务逻辑
    handler HandleFunc
}

func (n *node) childOf(path string) (*node, bool) {
    // 优先考虑静态匹配,匹配不上再考虑通配符
    if n.children == nil {
        if n.paramChild != nil {
            return n.paramChild, true
        }
        // return nil, false
        return n.starChild, n.starChild != nil
    }
    child, ok := n.children[path]
    if !ok {
        if n.paramChild != nil {
            return n.paramChild, true
        }
        return n.starChild, n.starChild != nil
    }
    return child, ok
}

func (n *node) childOrCreate(seg string) *node {
    if seg[0] == ':' {
        n.paramChild = &node{
            path: seg,
        }
        return n.paramChild
    }
    // ...
}

16. 路由树:参数路径之校验【更多 IT 资料加微信 djy136928775638】

// router_test.go
    r = newRouter()
    r.addRoute(http.MethodGet, "/a/*", mockHandler)
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "/a/:b", mockHandler)
    }, "web: 不允许同时注册路径参数和通配符路径, 已经注册通配符路径")

    r = newRouter()
    r.addRoute(http.MethodGet, "/a/:b", mockHandler)
    assert.Panics(t, func() {
        r.addRoute(http.MethodGet, "/a/*", mockHandler)
    }, "web: 不允许同时注册路径参数和通配符路径, 已经注册路径参数")
// router.go
func (n *node) childOrCreate(seg string) *node {
    if seg[0] == ':' {
        if n.starChild != nil {
            panic("web: 不允许同时注册路径参数和通配符路径, 已经注册通配符路径")
        }
        n.paramChild = &node{
            path: seg,
        }
        return n.paramChild
    }
    // 通配符
    if seg == "*" {
        if n.paramChild != nil {
            panic("web: 不允许同时注册路径参数和通配符路径, 已经注册路径参数")
        }
        n.starChild = &node{
            path: seg,
        }
        return n.starChild
    }
    // ...
    return res
}

17. 路由树:参数路径之参数值【更多 IT 资料加微信 djy136928775638】

用户希望直接获取路径参数,我们将 node 和路径参数封装到一个对象并返回

// router.go
type matchInfo struct {
    n          *node
    pathParams map[string]string
}

func (r *router) findRoute(method string, path string) (*matchInfo, bool) {
    // ...
    // 按照 / 切割
    segs := strings.Split(path, "/")
    var pathParams map[string]string
    for _, seg := range segs {
        child, paramChild, found := root.childOf(seg)
        if !found {
            return nil, false
        }
        //命中了路径参数
        if paramChild {
            if pathParams == nil {
                pathParams = make(map[string]string)
            }
            // path 是 :id 这种形式
            pathParams[child.path[1:]] = seg
        }
        root = child
    }
    // 表示有该节点,但是不一定有handler
    return &matchInfo{n: root, pathParams: pathParams}, true
    // return root, root.handler != nil
}

// 找到对应child
// 参数1:找到的子节点
// 参数2:找到的child是不是路径参数
// 参数3:是否找到
func (n *node) childOf(path string) (*node, bool, bool) {
    // 优先考虑静态匹配,匹配不上再考虑通配符
    if n.children == nil {
        if n.paramChild != nil {
            return n.paramChild, true, true
        }
        // return nil, false
        return n.starChild, false, n.starChild != nil
    }
    child, ok := n.children[path]
    if !ok {
        if n.paramChild != nil {
            return n.paramChild, true, true
        }
        return n.starChild, false, n.starChild != nil
    }
    return child, false, ok
}
// router_test.go
func TestRouter_findRoute(t *testing.T) {
    testRoutes := []struct {
        method string
        path   string
    }{
        // ...
    }

    r := newRouter()
    var mockHandler HandleFunc = func(ctx *Context) {}
    for _, route := range testRoutes {
        r.addRoute(route.method, route.path, mockHandler)
    }

    testCases := []struct {
        name string

        method string
        path   string

        wantFound bool
        info      *matchInfo
    }{
        {
            // method不存在
            name:   "method not found",
            method: http.MethodOptions,
            path:   "/order/detail",
        },
        {
            // 完全命中
            name:      "order detail",
            method:    http.MethodGet,
            path:      "/order/detail",
            wantFound: true,
            info: &matchInfo{
                n: &node{
                    handler: mockHandler,
                    path:    "detail",
                },
            },
        },
        {
            // 完全命中
            name:      "order star",
            method:    http.MethodGet,
            path:      "/order/abc",
            wantFound: true,
            info: &matchInfo{
                n: &node{
                    handler: mockHandler,
                    path:    "*",
                },
            },
        },
        {
            // 命中了,但是没有handler
            name:      "order",
            method:    http.MethodGet,
            path:      "/order",
            wantFound: true,
            info: &matchInfo{
                n: &node{
                    path: "order",
                    // order/detail注册的child -> detail
                    children: map[string]*node{
                        "detail": &node{
                            handler: mockHandler,
                            path:    "detail",
                        },
                    },
                },
            },
        },
        {
            // 根节点
            name:      "root",
            method:    http.MethodDelete,
            path:      "/",
            wantFound: true,
            info: &matchInfo{
                n: &node{
                    path:    "/",
                    handler: mockHandler,
                },
            },
        },
        {
            name:   "path not found",
            method: http.MethodGet,
            path:   "/404",
        },
        {
            // 路径参数匹配
            name:      "login username",
            method:    http.MethodPost,
            path:      "/login/wuif",
            wantFound: true,
            info: &matchInfo{
                n:          &node{path: ":username", handler: mockHandler},
                pathParams: map[string]string{"username": "wuif"},
            },
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            info, found := r.findRoute(tc.method, tc.path)
            assert.Equal(t, tc.wantFound, found)
            if !found {
                return
            }
            assert.Equal(t, tc.info.pathParams, info.pathParams)
            msg, ok := tc.info.n.equal(info.n)
            assert.True(t, ok, msg)
        })
    }
}
// context.go
package web

import "net/http"

type Context struct {
    Req        *http.Request
    Resp       http.ResponseWriter
    PathParams map[string]string
}
// server.go
func (h *HTTPServer) serve(ctx *Context) {
    info, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
    if !ok || info.n.handler == nil {
        // 路由没命中,404
        ctx.Resp.WriteHeader(404)
        ctx.Resp.Write([]byte("NOT FOUND"))
        return
    }
    ctx.PathParams = info.pathParams
    info.n.handler(ctx)
}

18. 路由树总结与面试要点【更多 IT 资料加微信 djy136928775638】





如果注册了非常多的路由,最长公共前缀(gin)的实现性能高一点,但是实现很麻烦


20.第一周作业:实现一棵路由树[选学]【更多 IT 资料加微信 djy136928775638】












21.第一周路由树作业讲解[选学]【更多 IT 资料加微信 djy136928775638】

它设置了 nodeType,如果查找路由时,所有的都找不到,还加一个判断是否有通配符







name(正则表达式)

08 第二周:Web 框架之 Context 与 AOP 方案

context

1. Context 简介【更多 IT 资料加微信 djy136928775638】







parseForm 可以多次调用,只会解析一次

2. Context:Beego Context 设计分析【更多 IT 资料加微信 djy136928775638】

3. Context:Gin Context 设计分析

4. Context:Echo 和 Iris 的 Context 设计分析

5. Context:处理输入输出总结

6. Context:处理输入之 Body 输入

如果引入小众需求会影响框架的复杂度,或者导致性能等受挫,就不要考虑

7. Context:处理输入之表单输入

parseForm只解析表单数据

// context.go
func (c *Context) FormValue(key string) (string, error) {
    err := c.Req.ParseForm()
    if err != nil {
        return "", err
    }
    //vals, ok := c.Req.Form[key]
    //if !ok {
    //    return "", errors.New("web: key 不存在")
    //}
    //return vals[0], nil
    return c.Req.FormValue(key), nil
}

8. Context:处理输入之查询参数、路径参数和 StringValue


func (c *Context) QueryValue(key string) (string, error) {
    if c.queryValues == nil {
        // Query每次会调用parseQuery,而form只调用一次,所以query需要缓存
        // 缓存
        c.queryValues = c.Req.URL.Query()
    }
    vals, ok := c.queryValues[key]
    if !ok {
        return "", errors.New("web: key 不存在")
    }
    return vals[0], nil
}

func (c *Context) QueryValueV1(key string) StringValue {
    if c.queryValues == nil {
        // Query每次会调用parseQuery,而form只调用一次,所以query需要缓存
        // 缓存
        c.queryValues = c.Req.URL.Query()
    }
    vals, ok := c.queryValues[key]
    if !ok {
        return StringValue{
            err: errors.New("web: key 不存在"),
        }
    }
    return StringValue{
        val: vals[0],
    }
}

func (c *Context) PathValue(key string) (string, error) {
    val, ok := c.PathParams[key]
    if !ok {
        return "", errors.New("web: key 不存在")
    }
    return val, nil
}
func (c *Context) PathValueV1(key string) StringValue {
    val, ok := c.PathParams[key]
    if !ok {
        return StringValue{
            err: errors.New("web: key 不存在"),
        }
    }
    return StringValue{val: val}
}

type StringValue struct {
    val string
    err error
}

func (s StringValue) AsInt64() (int64, error) {
    if s.err != nil {
        return 0, s.err
    }
    return strconv.ParseInt(s.val, 10, 64)
}
// server_test.go
    h.Post("/values/v1/:id", func(ctx *Context) {
        // 这种就比较方便
        id, err := ctx.PathValueV1("id").AsInt64()
        if err != nil {
            ctx.Resp.WriteHeader(400)
            ctx.Resp.Write([]byte("id 输入错误"))
            return
        }
        ctx.Resp.Write([]byte(fmt.Sprintf("hello, %d", id)))
    })

    h.Post("/values/:id", func(ctx *Context) {
        idStr, err := ctx.PathValue("id")
        if err != nil {
            ctx.Resp.WriteHeader(400)
            ctx.Resp.Write([]byte("id 输入错误"))
            return
        }
        id, err := strconv.ParseInt(idStr, 10, 64)
        if err != nil {
            ctx.Resp.WriteHeader(400)
            ctx.Resp.Write([]byte("id 输入错误"))
            return
        }
        ctx.Resp.Write([]byte(fmt.Sprintf("hello, %d", id)))
    })

9. Context:处理输出

10. Context 总结与面试要点

用户自己通过装饰器模式来实现线程安全

go的泛型功能比较有限

aop

11. AOP 简介与不同框架设计概览


12. AOP 设计方案:Middleware

// middleware.go
package web

// Middleware 函数式的责任链模式(洋葱模式)
type Middleware func(next HandleFunc) HandleFunc

// web框架AOP方案在不同语言有不同叫法
// Middleware Handler Chain Filter Filter-Chain
// Interceptor Wrapper

//type MiddlewareV1 interface {
//    Invoke(next HandleFunc) HandleFunc
//}
//
//type Interceptor interface {
//    Before(ctx *Context)
//    After(ctx *Context)
//    Surround(ctx *Context)
//}

//type Chain []HandleFunc
//
//type HandleFuncV1 func(ctx *Context) (next bool)
//
//type ChainV1 struct {
//    handlers []HandleFuncV1
//}
//
//func (c ChainV1) Run(ctx *Context) {
//    for _, h := range c.handlers {
//        next := h(ctx)
//        // 这种是中断执行
//        if !next {
//            return
//        }
//    }
//}
// server.go
type HTTPServer struct {
    // addr string // 可以创建时传递,而不是通过start传,也是可以的
    // router
    router // 因为router实现了addRoute,所以也可以通过编译

    mdls []Middleware
} 

// http.Handler的接口
// 核心方法 处理请求的入口
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    // 框架代码
    ctx := &Context{
        Req:  req,
        Resp: w,
    }
    // 最后一个handler是这个
    root := h.serve
    // 利用最后一个,从后往前回溯,组装链条
    // 后一个作为前一个的next 构造链条
    for i := len(h.mdls) - 1; i >= 0; i-- {
        root = h.mdls[i](root)
    }
    // 这里执行就是从前往后的
    root(ctx)
    // 接下来就是查找路由并执行命中的业务逻辑
    //h.serve(ctx)
}

无侵入式比较体现代码功底,但是有可能牺牲一部分性能

别的方式, 链式扩散成网状

13. Middleware:AccessLog

![go:build](1ea4c724.png

// web/accesslog/middleware.go
package accesslog

import (
    "github.com/goccy/go-json"
    "web_framework/web"
)

type MiddlewareBuilder struct {
    logFunc func(log string)
}

func (m *MiddlewareBuilder) LogFunc(fn func(log string)) *MiddlewareBuilder {
    m.logFunc = fn
    return m
}

func (m MiddlewareBuilder) Build() web.Middleware {
    return func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            // 在这里记录请求
            defer func() {
                l := accessLog{
                    Host:       ctx.Req.Host,
                    Route:      ctx.MatchedRoute,
                    HTTPMethod: ctx.Req.Method,
                    Path:       ctx.Req.URL.Path,
                }
                data, _     := json.Marshal(l)
                m.logFunc(string(data))
            }()
            next(ctx)
        }
    }
}

type accessLog struct {
    Host string `json:"host,omitempty"`
    // 命中的路由
    Route      string `json:"route,omitempty"`
    HTTPMethod string `json:"http_method,omitempty"`
    Path       string `json:"path,omitempty"`
}
// router.go
func (r *router) addRoute(method string, path string, handleFunc HandleFunc) {
    // ...

    // 根节点特殊处理
    if path == "/" {
        // ...
        root.handler = handleFunc
        root.route = "/"
        return
    }

    // ... 
    root.handler = handleFunc
    root.route = path
}

type node struct {
    // /a/b/* -> /a/b/*
    route string
    // 当前路由 /a/* -> *
    path string
    // ...
} 
// middleware_e2e_test.go
//go:build e2e

package accesslog

import (
    "fmt"
    "testing"
    "web_framework/web"
)

func TestMiddlewareBuilderE2E(t *testing.T) {
    builder := MiddlewareBuilder{}
    mdls := builder.LogFunc(func(log string) {
        fmt.Println(log)
    }).Build()
    server := web.NewHTTPServer(web.ServerWithMiddleware(mdls))
    server.Get("/a/b/*", func(ctx *web.Context) {
        ctx.Resp.Write([]byte("hello, it's me"))
    })
    server.Start(":8081")
}
// middleware_test.go
package accesslog

import (
    "fmt"
    "net/http"
    "testing"
    "web_framework/web"
)

func TestMiddlewareBuilder(t *testing.T) {
    builder := MiddlewareBuilder{}
    mdls := builder.LogFunc(func(log string) {
        fmt.Println(log)
    }).Build()
    server := web.NewHTTPServer(web.ServerWithMiddleware(mdls))
    server.Post("/a/b/*", func(ctx *web.Context) {
        fmt.Println("hello, it's me")
    })
    req, err := http.NewRequest(http.MethodPost, "/a/b/c", nil)
    if err != nil {
        t.Fatal(err)
    }
    server.ServeHTTP(nil, req)
}
// server.go
func (h *HTTPServer) serve(ctx *Context) {
    info, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
    if !ok || info.n.handler == nil {
        // ...
    }
    ctx.PathParams = info.pathParams
    ctx.MatchedRoute = info.n.route
    info.n.handler(ctx)
}
// context.go
type Context struct {
    // ...

    MatchedRoute string

    //cookieSameSite http.SameSite
}

14. Middleware:Trace 简介和 OpenTelemetry

一种是用户传数据进来,怎么来的无所谓(编程类接口),一种是用户传配置文件的路径,我们去读取和解析,这种是强耦合

// web/middlewares/opentelemetry/middleware.go
package opentelemetry

import (
    "go.opentelemetry.io/otel"
    "go.opentelemetry.io/otel/attribute"
    "go.opentelemetry.io/otel/propagation"
    "go.opentelemetry.io/otel/trace"
    "web_framework/web"
)

const instrumentationName = "web_framework/middlewares/opentelemetry"

type MiddlewareBuilder struct {
    Tracer trace.Tracer
}

//func NewMiddlewareBuilder(tracer trace.Tracer) *MiddlewareBuilder {
//    return &MiddlewareBuilder{tracer}
//}

func (m MiddlewareBuilder) Build() web.Middleware {
    if m.Tracer == nil {
        m.Tracer = otel.GetTracerProvider().Tracer(instrumentationName)
    }
    return func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            reqCtx := ctx.Req.Context()
            // 尝试和客户端的trace结合
            reqCtx = otel.GetTextMapPropagator().Extract(reqCtx, propagation.HeaderCarrier(ctx.Req.Header))

            // 创建span,这里的spanName(参数2)应该是命中的路由名称,但是现在还不知道
            _, span := m.Tracer.Start(reqCtx, "unknown")
            defer span.End()

            // 根据业务自己加
            span.SetAttributes(attribute.String("http.method", ctx.Req.Method))
            span.SetAttributes(attribute.String("http.url", ctx.Req.URL.String()))
            // schema -> http/https
            span.SetAttributes(attribute.String("http.schema", ctx.Req.URL.Scheme))
            span.SetAttributes(attribute.String("http.host", ctx.Req.Host))

            // 下一步
            next(ctx)
            // 执行完next,才有值
            span.SetName(ctx.MatchedRoute)
        }
    }
}

15. Middleware:OpenTelemetry 测试


// middleware_e2e_test.go
//go:build e2e

package opentelemetry

import (
    "go.opentelemetry.io/otel"
    "go.opentelemetry.io/otel/exporters/zipkin"
    "go.opentelemetry.io/otel/sdk/resource"
    sdktrace "go.opentelemetry.io/otel/sdk/trace"
    semcove "go.opentelemetry.io/otel/semconv/v1.10.0"
    "log"
    "os"
    "testing"
    "time"
    "web_framework/web"
)

func TestMiddlewareBuilder_Build(t *testing.T) {
    tracer := otel.GetTracerProvider().Tracer(instrumentationName)
    builder := MiddlewareBuilder{
        Tracer: tracer,
    }
    server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))

    server.Get("/user", func(ctx *web.Context) {
        c, span := tracer.Start(ctx.Req.Context(), "first_layer")
        defer span.End()

        //defer func() {
        //    // 执行完next,才有值
        //    span.SetName(ctx.MatchedRoute)
        //    // 响应码
        //    span.SetAttributes(attribute.Int("http.status", ctx.RespStatusCode))
        //    span.End()
        //}()

        secondC, second := tracer.Start(c, "second_layer")
        time.Sleep(time.Second)
        _, third1 := tracer.Start(secondC, "third_layer_1")
        time.Sleep(100 * time.Millisecond)
        third1.End()
        _, third2 := tracer.Start(secondC, "third_layer_2")
        time.Sleep(300 * time.Millisecond)
        third2.End()
        second.End()

        _, first := tracer.Start(ctx.Req.Context(), "first_layer_1")
        defer first.End()
        //ctx.Resp.Write([]byte("hello,world"))
        time.Sleep(100 * time.Millisecond)
        ctx.RespJson(202, User{
            Name: "Tom",
        })
    })

    initZipkin(t)

    // 一定要加冒号
    server.Start(":8081")
}

type User struct {
    Name string `json:"name"`
}

func initZipkin(t *testing.T) {
    exporter, err := zipkin.New(
        "http://tyy:9411/api/v2/spans",
        zipkin.WithLogger(log.New(os.Stderr, "opentelemetry-demo", log.Lmicroseconds|log.Ldate)),
    )
    if err != nil {
        t.Fatal(err)
    }
    batcher := sdktrace.NewBatchSpanProcessor(exporter)
    tp := sdktrace.NewTracerProvider(
        sdktrace.WithSpanProcessor(batcher),
        sdktrace.WithResource(resource.NewWithAttributes(
            semcove.SchemaURL,
            semcove.ServiceNameKey.String("opentelemetry-demo"),
        )),
    )
    otel.SetTracerProvider(tp)
}
// server.go
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    // ...
    // 这里执行就是从前往后的

    // 最后把RespData和RespStatusCode刷新到响应
    var m Middleware = func(next HandleFunc) HandleFunc {
        return func(ctx *Context) {
            // 这里就设置了RespData和RespStatusCode
            next(ctx)
            h.flushResp(ctx)
        }
    }
    
    root = m(root)

    root(ctx)
    // 接下来就是查找路由并执行命中的业务逻辑
    //h.serve(ctx)
}

func (h *HTTPServer) flushResp(ctx *Context) {
    if ctx.RespStatusCode != 0 {
        ctx.Resp.WriteHeader(ctx.RespStatusCode)
    }
    n, err := ctx.Resp.Write(ctx.RespData)
    if err != nil || n != len(ctx.RespData) {
        //log.Fatalln("写入响应失败")
        h.log("写入响应失败 %v\n", err)
    }
}

func (h *HTTPServer) serve(ctx *Context) {
    info, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
    if !ok || info.n.handler == nil {
        // 路由没命中,404
        //ctx.Resp.WriteHeader(404)
        //ctx.Resp.Write([]byte("NOT FOUND"))
        ctx.RespData = []byte("NOT FOUND")
        ctx.RespStatusCode = 404
        return
    }
    ctx.PathParams = info.pathParams
    ctx.MatchedRoute = info.n.route
    info.n.handler(ctx)
}

func NewHTTPServer(opts ...HTTPServerOption) *HTTPServer {
    res := &HTTPServer{
        router: newRouter(),
        log: func(msg string, args ...any) {
            fmt.Printf(msg, args...)
        },
    }
    for _, opt := range opts {
        opt(res)
    }
    return res
}

type HTTPServer struct {
    // addr string // 可以创建时传递,而不是通过start传,也是可以的
    // router
    router // 因为router实现了addRoute,所以也可以通过编译

    mdls []Middleware

    log func(msg string, args ...any)
}
func (c *Context) RespJson(code int, val any) error {
    data, err := json.Marshal(val)
    if err != nil {
        return err
    }
    //c.Resp.WriteHeader(code) // 状态码
    //// 参数1,写了多少数据
    //n, err := c.Resp.Write(data)
    //if n != len(data) {
    //    return errors.New("web: 未写入全部数据")
    //}
    c.RespData = data
    c.RespStatusCode = code
    return err
}

type Context struct {
    Req *http.Request
    // 如果用了Resp,就绕开了RespData和RespStatusCode,
    // 部分middleware可能无法运作
    Resp       http.ResponseWriter
    PathParams map[string]string

    // 主要是为了给middleware用
    RespData       []byte
    RespStatusCode int

    queryValues url.Values

    MatchedRoute string

    //cookieSameSite http.SameSite
}
// web/middlewares/opentelemetry/middleware.go
package opentelemetry

import (
    "go.opentelemetry.io/otel"
    "go.opentelemetry.io/otel/attribute"
    "go.opentelemetry.io/otel/propagation"
    "go.opentelemetry.io/otel/trace"
    "web_framework/web"
)

const instrumentationName = "web_framework/middlewares/opentelemetry"

type MiddlewareBuilder struct {
    Tracer trace.Tracer
}

//func NewMiddlewareBuilder(tracer trace.Tracer) *MiddlewareBuilder {
//    return &MiddlewareBuilder{tracer}
//}

func (m MiddlewareBuilder) Build() web.Middleware {
    if m.Tracer == nil {
        m.Tracer = otel.GetTracerProvider().Tracer(instrumentationName)
    }
    return func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            reqCtx := ctx.Req.Context()
            // 尝试和客户端的trace结合
            reqCtx = otel.GetTextMapPropagator().Extract(reqCtx, propagation.HeaderCarrier(ctx.Req.Header))

            // 创建span,这里的spanName(参数2)应该是命中的路由名称,但是现在还不知道
            reqCtx, span := m.Tracer.Start(reqCtx, "unknown")
            defer span.End()

            // 根据业务自己加
            span.SetAttributes(attribute.String("http.method", ctx.Req.Method))
            span.SetAttributes(attribute.String("http.url", ctx.Req.URL.String()))
            // schema -> http/https
            span.SetAttributes(attribute.String("http.schema", ctx.Req.URL.Scheme))
            span.SetAttributes(attribute.String("http.host", ctx.Req.Host))

            // reqCtx创建完trace,复制给ctx.Req的ctx
            // 复制的操作耗性能,但是为了使用该ctx,没办法
            ctx.Req = ctx.Req.WithContext(reqCtx)

            // 下一步
            next(ctx)
            // 执行完next,才有值
            span.SetName(ctx.MatchedRoute)
            // 响应码
            span.SetAttributes(attribute.Int("http.status", ctx.RespStatusCode))
        }
    }
}

16. Middleware:OpenTelemetry 总结


17. Prometheus 详解

"这可不是我们公司哒,要是我们公司的,今天我贴出来,明天我就被炒掉啦"
namespace那些字符串不能用-,要换成_

18. Middleware:Prometheus

// web/middlewares/prometheus/middleware_test.go
//go:build e2e

package prometheus

import (
    "github.com/prometheus/client_golang/prometheus/promhttp"
    "math/rand"
    "net/http"
    "testing"
    "time"
    "web_framework/web"
)

func TestMiddlewareBuilder_Build(t *testing.T) {
    builder := MiddlewareBuilder{
        Namespace: "malred",
        Subsystem: "web",
        Name:      "http_response",
    }
    server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))

    server.Get("/user", func(ctx *web.Context) {
        val := rand.Intn(1000) + 1
        time.Sleep(time.Duration(val) * time.Millisecond)
        ctx.RespJson(202, User{Name: "Tom"})
    })

    // 启动端口让prometheus暴露数据
    go func() {
        http.Handle("/metrics", promhttp.Handler())
        http.ListenAndServe(":8082", nil)
    }() 

    server.Start(":8081")
}

type User struct {
    Name string `json:"name"`
}
// web/middlewares/prometheus/middleware.go
package prometheus

import (
    "github.com/prometheus/client_golang/prometheus"
    "strconv"
    "time"
    "web_framework/web"
)

// builder方便扩展
type MiddlewareBuilder struct {
    Namespace string
    Subsystem string
    Name      string
    Help      string
}

func (m MiddlewareBuilder) Build() web.Middleware {
    vector := prometheus.NewSummaryVec(prometheus.SummaryOpts{
        Namespace: m.Namespace,
        Subsystem: m.Subsystem,
        Name:      m.Name,
        Help:      m.Help,
        Objectives: map[float64]float64{
            0.5:   0.01,
            0.75:  0.01,
            0.90:  0.01,
            0.99:  0.001,
            0.999: 0.0001,
        },
    }, []string{"pattern", "method", "status"})
    // 注册观察者
    prometheus.MustRegister(vector) // 两次调用builder会panic
    return func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            startTime := time.Now()
            defer func() {
                duration := time.Now().Sub(startTime).Milliseconds()
                pattern := ctx.MatchedRoute
                if pattern == "" {
                    pattern = "unknown"
                }
                // Observe 表示响应时间
                vector.WithLabelValues(pattern, ctx.Req.Method,
                    strconv.Itoa(ctx.RespStatusCode)).
                    Observe(float64(duration))
            }()
            next(ctx)
        }
    }
}

19. Middleware 例子:错误页面

// middleware_test.go
package errhdl

import (
    "net/http"
    "testing"
    "web_framework/web"
)

func TestNewMiddlewareBuilder(t *testing.T) {
    builder := NewMiddlewareBuilder()
    builder.AddCode(http.StatusNotFound, []byte(`
        <html>
            <body>
                <h1>NOT FOUND</h1>
            </body>
        </html>
    `)).
        AddCode(http.StatusBadRequest, []byte(`
        <html>
            <body>
                <h1>BAD REQUEST</h1>
            </body>
        </html>
    `))
    server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))
    server.Start(":8081")
}
// web/middlewares/errhdl/middleware.go
package errhdl

import "web_framework/web"

type MiddlewareBuilder struct {
    // 这种设计只能返回固定值
    // 不能动态渲染
    resp map[int][]byte
}


func NewMiddlewareBuilder() *MiddlewareBuilder {
    return &MiddlewareBuilder{
        resp: map[int][]byte{},
    }
}

func (m *MiddlewareBuilder) AddCode(status int, data []byte) *MiddlewareBuilder {
    m.resp[status] = data
    return m
}

func (m MiddlewareBuilder) Build() web.Middleware {
    return func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            next(ctx)
            resp, ok := m.resp[ctx.RespStatusCode]
            if ok {
                // 篡改结果
                ctx.RespData = resp
            }
        }
    }
}

20. Middleware 例子:从 panic 中恢复

// middleware_test.go
package recover

import (
    "fmt"
    "testing"
    "web_framework/web"
)

func TestMiddlewareBuilder_Build(t *testing.T) {
    builder := MiddlewareBuilder{
        Data:       []byte("你panic了"),
        StatusCode: 500,
        Log: func(ctx *web.Context) {
            fmt.Printf("panic 路径: %s", ctx.Req.URL.String())
        },
    }
    server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))
    server.Get("/user", func(ctx *web.Context) {
        panic("发生panic")
    })
    server.Start(":8081")
}
// web/middlewares/recover/middleware.go
package recover

import "web_framework/web"

type MiddlewareBuilder struct {
    StatusCode int
    Data       []byte
    //Log func(err any)
    Log func(ctx *web.Context)
    //Log func(stack string)
}

func (m MiddlewareBuilder) Build() web.Middleware {
    return func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            defer func() {
                if err := recover(); err != nil {
                    ctx.RespData = m.Data
                    ctx.RespStatusCode = m.StatusCode
                    m.Log(ctx)
                }
            }()
            next(ctx)
        }
    }
}

21. Middleware 总结和面试

第二周作业

09 第三周:Web 框架之页面渲染、文件处理与 Session

页面渲染

1. 页面渲染:模板引擎接口定义


2. 页面渲染:Template 语法

text/template在中间件中常用,比如grpc代码生成、谷歌依赖注入的wire、go不支持动态生成类型,所以动态代理一般用代码生成实现


package template_demo

import (
    "bytes"
    "fmt"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "html/template"
    "testing"
)

func TestHelloWorld(t *testing.T) {
    type User struct {
        Name string
    }
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`Hello, {{.Name}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, User{Name: "Tom"})
    require.NoError(t, err)
    assert.Equal(t, `Hello, Tom`, buffer.String())
}

func TestMapData(t *testing.T) {
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`Hello, {{.Name}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, map[string]string{"Name": "Tom"})
    require.NoError(t, err)
    assert.Equal(t, `Hello, Tom`, buffer.String())
}

func TestSlice(t *testing.T) {
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`Hello, {{index . 0}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, []string{"Tom"})
    require.NoError(t, err)
    assert.Equal(t, `Hello, Tom`, buffer.String())
}

func TestBasic(t *testing.T) {
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`Hello, {{.}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, 123)
    require.NoError(t, err)
    assert.Equal(t, `Hello, 123`, buffer.String())
}

func TestInvoke(t *testing.T) {
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`打印数字: {{printf "%.2f" 1.2345}}\n切片长度: {{len .Slice}}\n{{.Hello "Tom" "Jerry"}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, FnCall{Slice: []string{"a", "b"}})
    require.NoError(t, err)
    assert.Equal(t, `打印数字: 1.23\n切片长度: 2\nHello, Tom-Jerry`, buffer.String())
}

type FnCall struct {
    Slice []string
}

func (f FnCall) Hello(first string, last string) string {
    return fmt.Sprintf("Hello, %s-%s", first, last)
}

func TestFor(t *testing.T) {
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`
{{- range $idx, $ele := .Slice}}
{{- .}}
{{$idx}}-{{$ele}}
{{end}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, FnCall{Slice: []string{"a", "b"}})
    require.NoError(t, err)
    assert.Equal(t, `a
0-a
b
1-b
`, buffer.String())
}

func TestForLoop(t *testing.T) {
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`
{{- range $idx, $ele := .}} 
{{- $idx}}
{{- end}}`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, make([]int, 10))
    require.NoError(t, err)
    assert.Equal(t, `0123456789`, buffer.String())
}

func TestIfElse(t *testing.T) {
    type User struct {
        Age int
    }
    tpl := template.New("hello-world")
    tpl, err := tpl.Parse(`
{{- if and (gt .Age 0) (le .Age 6) -}}
儿童: (0, 6]
{{- else if and (gt .Age 6) (le .Age 18) -}}
少年: (6. 18]
{{- else -}}
成人: >18
{{- end -}}
`)
    require.NoError(t, err)
    buffer := &bytes.Buffer{}
    err = tpl.Execute(buffer, User{Age: 14})
    require.NoError(t, err)
    assert.Equal(t, `少年: (6. 18]`, buffer.String())
}

3. 页面渲染:GoTemplateEngin 实现、面试要点总结

模板引擎涉及编译原理(解释执行)

// context.go
type Context struct {
    // ...

    tplEngine TemplateEngine
    //cookieSameSite http.SameSite
}

func (c *Context) Render(tplName string, data any) error {
    var err error
    c.RespData, err = c.tplEngine.Render(c.Req.Context(), tplName, data)
    if err != nil {
        c.RespStatusCode = http.StatusInternalServerError
        return err
    }
    c.RespStatusCode = http.StatusOK
    return nil
}
// server.go
type HTTPServer struct {
    // addr string // 可以创建时传递,而不是通过start传,也是可以的
    // router
    router // 因为router实现了addRoute,所以也可以通过编译

    mdls []Middleware

    log func(msg string, args ...any)

    tplEngine TemplateEngine
}
func ServerWithTemplateEngine(tplEngin TemplateEngine) HTTPServerOption {
    return func(server *HTTPServer) {
        server.tplEngine = tplEngin
    }
}

// http.Handler的接口
// 核心方法 处理请求的入口
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    // 框架代码
    ctx := &Context{
        Req:       req,
        Resp:      w,
        tplEngine: h.tplEngine,
    }
    // ...
}
// template_e2e_test.go
//go:build e2e

package web

import (
    "github.com/stretchr/testify/require"
    "html/template"
    "log"
    "testing"
)

func TestLoginPage(t *testing.T) {
    tpl, err := template.ParseGlob("testdata/tpls/*.gohtml")
    require.NoError(t, err)
    engine := &GoTemplateEngine{
        T: tpl,
    }
    h := NewHTTPServer(ServerWithTemplateEngine(engine))
    h.Get("/login", func(ctx *Context) {
        err := ctx.Render("login.gohtml", nil)
        if err != nil {
            log.Println(err)
        }
    })
    h.Start(":8081")
}
// template.go
package web

import (
    "bytes"
    "context"
    "html/template"
)

type TemplateEngine interface {
    // Render 渲染页面
    // tplName 模板的名字,按名索引
    // data 渲染页面用的数据
    Render(ctx context.Context, tplName string, data any) ([]byte, error)

    // 渲染页面,数据写入writer
    //Render(ctx context.Context, tplName string, data any,writer io.Writer) (  error)

    // 用这个也行,但是会耦合
    //Render(ctx Context)
}

type GoTemplateEngine struct {
    T *template.Template
}

func (g *GoTemplateEngine) Render(ctx context.Context, tplName string, data any) ([]byte, error) {
    bs := &bytes.Buffer{}
    err := g.T.ExecuteTemplate(bs, tplName, data)
    return bs.Bytes(), err
}

文件处理

4. 文件处理:文件基本操作

package file_demo

import (
    "fmt"
    "github.com/stretchr/testify/require"
    "os"
    "testing"
)

func TestFile(t *testing.T) {
    // working dir
    // D:\go\projs\web_framework\web\file_demo <nil>
    fmt.Println(os.Getwd())

    f, err := os.Open("testdata/my_file.txt")
    require.NoError(t, err)
    data := make([]byte, 64)
    n, err := f.Read(data)
    require.NoError(t, err)
    // 读取了多少
    fmt.Println(n)

    n, err = f.WriteString("hello")
    // bad file descriptor 不可写
    fmt.Println(err)
    // 不可写入
    require.Error(t, err) // Access is denied.
    // 写入了多少
    fmt.Println(n)
    f.Close()

    f, err = os.OpenFile("testdata/my_file.txt", os.O_APPEND|os.O_WRONLY, os.ModeAppend)
    require.NoError(t, err)
    n, err = f.WriteString("hello")
    require.NoError(t, err)
    fmt.Println(n)
    f.Close()

    f, err = os.Create("testdata/my_file_copy.txt")
    require.NoError(t, err)
    n, err = f.WriteString("hello, world")
    require.NoError(t, err)
    fmt.Println(n)
    f.Close()
}

5. 文件处理:文件上传

// file.go
package web

import (
    "io"
    "mime/multipart"
    "net/http"
    "os"
    "strings"
)

type FileUploader struct {
    FileField string
    // 返回文件保存路径
    // 如果框架来设名称要考虑重名问题(可以使用UUID解决)
    DataPathFunc func(file *multipart.FileHeader) string
}

func (u FileUploader) Handle() HandleFunc {
    if u.FileField==""{
        u.FileField="file"
    }
    if u.DataPathFunc==nil{
        // 设置默认值
        u.DataPathFunc= func(file *multipart.FileHeader) string {
            return filepath.Join("testdata", "upload", file.Filename)
        }
    }
    return func(ctx *Context) {
        // 上传文件的逻辑
        // 1. 读取文件内容
        /*
            // A FileHeader describes a file part of a multipart request.
            type FileHeader struct {
                Filename string
                Header   textproto.MIMEHeader
                Size     int64

                content   []byte
                tmpfile   string
                tmpoff    int64
                tmpshared bool
            }
        */
        file, fileHeader, err := ctx.Req.FormFile(u.FileField)
        if err != nil {
            ctx.RespStatusCode = 500
            ctx.RespData = []byte("上传失败 " + err.Error())
            return
        }
        defer file.Close()
        // 2. 计算目标路径
        // 将目标路径计算的逻辑交给用户
        dst := u.DataPathFunc(fileHeader)
        // 可以尝试把dst里不存在的目录都创建,防止报错
        index := strings.LastIndex(dst, "\\")
        err = os.MkdirAll(dst[:index], os.ModePerm)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("上传失败 " + err.Error())
            return
        }

        // os.O_WRONLY 可写
        // os.O_TRUNC 存在就清空
        // os.O_CREATE 不存在就创建
        dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0o666)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("上传失败 " + err.Error())
            return
        }
        defer dstFile.Close()
        // 3. 保存文件
        // buf(参数3)指定每次传输多大一段, 会影响性能
        // 要考虑复用, 如果传nil, 有默认提供
        _, err = io.CopyBuffer(dstFile, file, nil)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("上传失败 " + err.Error())
            return
        }
        // 4. 返回响应
        ctx.RespStatusCode = http.StatusOK
        ctx.RespData = []byte("上传成功")
    }
}
// file_e2e_test.go
//go:build e2e

package web

import (
    "github.com/stretchr/testify/require"
    "html/template"
    "log"
    "mime/multipart"
    "path/filepath"
    "testing"
)

func TestFileUpload(t *testing.T) {
    tpl, err := template.ParseGlob("testdata/tpls/*.gohtml")
    require.NoError(t, err)
    engine := &GoTemplateEngine{
        T: tpl,
    }
    h := NewHTTPServer(ServerWithTemplateEngine(engine))
    h.Get("/upload", func(ctx *Context) {
        err := ctx.Render("upload.gohtml", nil)
        if err != nil {
            log.Println(err)
        }
    })

    fu := FileUploader{
        // 表单form的name属性
        FileField: "myfile",
        DataPathFunc: func(file *multipart.FileHeader) string {
            return filepath.Join("testdata", "upload", file.Filename)
        },
    }
    h.Post("/upload", fu.Handle())
    h.Start(":8081")
}
// testdata/upload.gohtml
<html>
<body>
<form action="/upload" method="post" enctype="multipart/form-data">
    <input type="file" name="myfile">
    <button type="submit">上传</button>
</form>
</body>
</html>

6. 文件处理:文件下载

// file_e2e_test.go
func TestFileDownload(t *testing.T) {
    h := NewHTTPServer()

    fd := FileDownloader{
        Dir: filepath.Join("testdata", "download"),
    }
    h.Get("/download", fd.Handle())
    h.Start(":8081")
}
// file.go
type FileDownloader struct {
    Dir string
}

func (d FileDownloader) Handle() HandleFunc {
    return func(ctx *Context) {
        // xxx?file=xxx
        filename, err := ctx.QueryValue("file")
        if err != nil {
            ctx.RespStatusCode = http.StatusBadRequest
            ctx.RespData = []byte("找不到目标文件")
            return
        }
        // 安全性提升(一点点)
        filename = filepath.Clean(filename)
        dst := filepath.Join(d.Dir, filename)
        // 转绝对路径(相对路径可能被拿到私密文件)
        //dst,err=filepath.Abs(dst)
        //if strings.Contains(d.Dir,filename){
        //
        //}
        fn := filepath.Base(dst)

        header := ctx.Resp.Header()
        header.Set("Content-Disposition", "attachment;filename="+fn)
        header.Set("Content-Description", "File Transfer")
        // 同样二进制文件
        header.Set("Content-Type", "application/octet-stream")
        header.Set("Content-Transfer-Encoding", "binary")
        // 不会缓存
        header.Set("Expires", "0")
        header.Set("Cache-Control", "must-revalidate")
        header.Set("Pragma", "public")

        http.ServeFile(ctx.Resp, ctx.Req, dst)
    }
}

现在的文件下载没有缓存, 而且没有对大文件的处理(分片[需要前后端配合])

7. 文件处理:静态资源处理、面试要点总结


type StaticResourceHandlerOption func(handler *StaticResourceHandler)

type StaticResourceHandler struct {
    dir string
    // 设置type防止有些浏览器无法解析文件
    // 根据文件后缀名匹配type
    extensionContextTypeMap map[string]string
    cache                   *lru.Cache // 缓存
    maxSize                 int        // 大文件不缓存
}

func NewStaticResourceHandler(dir string, opts ...StaticResourceHandlerOption) (*StaticResourceHandler, error) {
    // New(总共缓存多少个k-y)
    cache, err := lru.New(1000)
    if err != nil {
        return nil, err
    }
    res := &StaticResourceHandler{
        dir:     dir,
        cache:   cache,
        maxSize: 1024 * 1024 * 10, // 10MB
        extensionContextTypeMap: map[string]string{
            "jpeg": "image/jpeg",
            "jpg":  "image/jpg",
            "jpe":  "image/jpeg",
            "png":  "image/png",
            "pdf":  "image/pdf",
        },
    }
    for _, opt := range opts {
        opt(res)
    }
    return res, nil
}

func StaticWithMaxFileSize(maxSize int) StaticResourceHandlerOption {
    return func(handler *StaticResourceHandler) {
        handler.maxSize = maxSize
    }
}

func StaticWithCache(c *lru.Cache) StaticResourceHandlerOption {
    return func(handler *StaticResourceHandler) {
        handler.cache = c
    }
}

func StaticWithMoreExtension(extMap map[string]string) StaticResourceHandlerOption {
    return func(h *StaticResourceHandler) {
        for ext, contenType := range extMap {
            h.extensionContextTypeMap[ext] = contenType
        }
    }
}

func (s *StaticResourceHandler) Handle(ctx *Context) {
    // 1. 拿到目标文件名
    file, err := ctx.PathValue("file")
    if err != nil {
        ctx.RespStatusCode = http.StatusBadRequest
        ctx.RespData = []byte("请求路径错误")
        return
    }

    dst := filepath.Join(s.dir, file)
    // jpg/txt/video/audio ...
    ext := filepath.Ext(dst)[1:]
    header := ctx.Resp.Header()

    if data, ok := s.cache.Get(dst); ok {
        header.Set("Content-Type", s.extensionContextTypeMap[ext])
        header.Set("Content-Length", strconv.Itoa(len(data.([]byte))))
        header.Set("Web-Msg", "from cache")
        ctx.RespData = data.([]byte)
        ctx.RespStatusCode = 200
        return
    }

    // 2. 定位到目标文件,读取
    //dst := filepath.Join(s.dir, file)
    data, err := ioutil.ReadFile(dst)
    if err != nil {
        ctx.RespStatusCode = http.StatusInternalServerError
        ctx.RespData = []byte("服务器系统错误")
        return
    }

    // 缓存
    if len(data) <= s.maxSize {
        s.cache.Add(dst, data)
    }

    // 3. 返回
    header.Set("Content-Type", s.extensionContextTypeMap[ext])
    header.Set("Content-Length", strconv.Itoa(len(data)))
    ctx.RespData = data
    ctx.RespStatusCode = 200
}
// test
func TestStaticResourceHandler(t *testing.T) {
    h := NewHTTPServer()
    s, err := NewStaticResourceHandler(filepath.Join("testdata", "static"))
    require.NoError(t, err)
    // localhost:8081/static/xxx.jpg
    h.Get("/static/:file", s.Handle)
    h.Start(":8081")
}

Session

8. Session:概念与不同框架的 Session 设计分析


flash很少用,讲师举例: 跳转页面有时携带临时的敏感数据可以考虑
讲师推荐: 直接用gorilla
gorilla缓存session,是为了防止太多session每次都写如磁盘会影响性能

核心接口(单一职责)不需要很多方法

9. Session:接口设计


// session/types.go
package session

import (
    "context"
    "net/http"
)

// 管理session本身
type Store interface {
    // Session对应的id谁来指定 ? 要不要让session内部生成id
    // 要不要在接口维度上设置超时时间
Generate(ctx context.Context,id string)(Session,error)
Refresh(ctx context.Context,id string)error
Remove(ctx context.Context,id string)error
Get(ctx context.Context,id string)(Session,error)
}

type Session interface {
    Get(ctx context.Context,key string)(any,error)
    Set(ctx context.Context,key string,val string)error
    ID()string
}

type Propagator interface {
    Inject(id string,writer http.ResponseWriter)error
    Extract(req *http.Request) (string,error)
    Remove(writer http.ResponseWriter)
}

10. Session:用户使用示例和 Manager 设计

// types_test.go
package session

import (
    "net/http"
    "testing"
    "web_framework/web"
)

func TestSession(t *testing.T) {
    // 简单的登录检验
    var m Manager
    server := web.NewHTTPServer(web.ServerWithMiddleware(func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            if ctx.Req.URL.Path == "/login" {
                // 放行,让用户登录
                next(ctx)
                return
            }
            _, err := m.GetSession(ctx)
            if err != nil {
                ctx.RespStatusCode = http.StatusUnauthorized
                ctx.RespData = []byte("请重新登录")
                return
            }
            // 刷新session过期时间
            _ = m.RefreshSession(ctx)
            next(ctx)
        }
    }))

    server.Post("/login", func(ctx *web.Context) {
        // 在session之前校验用户名密码

        sess, err := m.InitSession(ctx)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("登录失败")
            return
        }
        err = sess.Set(ctx.Req.Context(), "nickname", "xiaoming")
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("登录失败")
            return
        }
        ctx.RespStatusCode = http.StatusOK
        ctx.RespData = []byte("登录成功")

        return
    })

    server.Post("/logout", func(ctx *web.Context) {
        // 清理数据
        err := m.RemoveSession(ctx)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("退出失败")
            return
        }
        ctx.RespStatusCode = http.StatusOK
        ctx.RespData = []byte("退出成功")
    })
    
    server.Get("/user", func(ctx *web.Context) {
        sess, err := m.GetSession(ctx)
        if err != nil {
            ctx.RespStatusCode = http.StatusUnauthorized
            ctx.RespData = []byte("请重新登录")
            return
        }
        sess.Get(ctx.Req.Context(), "nickname")
    })
    server.Start(":8081")
}
// types.go
package session

import (
    "context"
    "net/http"
)

// 管理session本身
type Store interface {
    // Session对应的id谁来指定 ? 要不要让session内部生成id
    // 要不要在接口维度上设置超时时间
    Generate(ctx context.Context, id string) (Session, error)
    Refresh(ctx context.Context, id string) error
    Remove(ctx context.Context, id string) error
    Get(ctx context.Context, id string) (Session, error)
}

type Session interface {
    Get(ctx context.Context, key string) (any, error)
    Set(ctx context.Context, key string, val string) error
    ID() string
}

type Propagator interface {
    Inject(id string, writer http.ResponseWriter) error
    Extract(req *http.Request) (string, error)
    Remove(writer http.ResponseWriter) error
}
// manager.go
package session

import (
    "github.com/google/uuid"
    "web_framework/web"
)

type Manager struct {
    Propagator
    Store
}

func (m *Manager) GetSession(ctx *web.Context) (Session, error) {
    sessId, err := m.Extract(ctx.Req)
    if err != nil {
        return nil, err
    }

    return m.Get(ctx.Req.Context(), sessId)
}

func (m *Manager) InitSession(ctx *web.Context) (Session, error) {
    id := uuid.New().String()
    sess, err := m.Generate(ctx.Req.Context(), id)
    if err != nil {
        return nil, err
    }
    // 注入进去响应
    err = m.Inject(id, ctx.Resp)
    return sess, err
}

func (m *Manager) RemoveSession(ctx *web.Context) error {
    sess, err := m.GetSession(ctx)
    if err != nil {
        return err
    }
    err = m.Store.Remove(ctx.Req.Context(), sess.ID())
    if err != nil {
        return err
    }
    err = m.Propagator.Remove(ctx.Resp)
    if err != nil {
        return err
    }
    return err
}

func (m *Manager) RefreshSession(ctx *web.Context) error {
    sess, err := m.GetSession(ctx)
    if err != nil {
        return err
    }
    // 这里假设session的id不会变
    return m.Refresh(ctx.Req.Context(), sess.ID())
}

11. Session:web.Context 缓存 Session

// context.go
type Context struct {
    // ...

    UserValues map[string]any // 缓存
}
// session/manager.go
type Manager struct {
    Propagator
    Store
    CtxSessKey string // 在ctx缓存中的key
}

func (m *Manager) GetSession(ctx *web.Context) (Session, error) {
    if ctx.UserValues == nil {
        ctx.UserValues = make(map[string]any, 1)
    }
    // 从缓存中拿
    val, ok := ctx.UserValues[m.CtxSessKey]
    if ok {
        return val.(Session), nil
    }
    sessId, err := m.Extract(ctx.Req)
    if err != nil {
        return nil, err
    }

    sess, err := m.Get(ctx.Req.Context(), sessId)
    if err != nil {
        return nil, err
    }
    // 设置缓存
    ctx.UserValues[m.CtxSessKey] = sess
    return sess, err
}

12. Session:基于内存的实现

// momery/session.go
package momery

import (
    "context"
    "errors"
    "fmt"
    cache "github.com/patrickmn/go-cache"
    "sync"
    "time"
    "web_framework/web/session"
)

var (
    errKeyNotFound     = errors.New("session: 找不到key")
    errSessionNotFound = errors.New("session: 找不到session")
)

type Store struct {
    mutex      sync.RWMutex
    sessions   *cache.Cache
    expiration time.Duration
}

func NewStore(expiration time.Duration) *Store {
    return &Store{
        // 缓存过期时间,检查间隔
        sessions:   cache.New(expiration, time.Second),
        expiration: expiration,
    }
}

func (s Store) Generate(ctx context.Context, id string) (session.Session, error) {
    // 防止在generate的时候有refresh
    s.mutex.Lock()
    defer s.mutex.Unlock()
    sess := &Session{
        id: id,
    }
    // 存起来
    s.sessions.Set(id, sess, s.expiration)
    return sess, nil
}

func (s Store) Refresh(ctx context.Context, id string) error {
    s.mutex.Lock()
    defer s.mutex.Unlock()
    val, ok := s.sessions.Get(id)
    if !ok {
        return fmt.Errorf("session: 该 id 对应的 session 不存在 %s", id)
    }
    s.sessions.Set(id, val, s.expiration)
    return nil
}

func (s Store) Remove(ctx context.Context, id string) error {
    s.mutex.Lock()
    defer s.mutex.Unlock()
    s.sessions.Delete(id)
    return nil
}

func (s Store) Get(ctx context.Context, id string) (session.Session, error) {
    s.mutex.RLock()
    defer s.mutex.RUnlock()
    sess, ok := s.sessions.Get(id)
    if !ok {
        return nil, errSessionNotFound
    }
    return sess.(*Session), nil
}

type Session struct {
    id         string
    expiration time.Duration

    // 控制性强
    //mutex sync.RWMutex
    //values map[string]any

    values sync.Map
}

// ctrl+I 快捷实现接口
func (s Session) Get(ctx context.Context, key string) (any, error) {
    val, ok := s.values.Load(key)
    if !ok {
        //return nil, fmt.Errorf("%w, key %s", errKeyNotFound,key)
        return nil, errKeyNotFound
    }
    return val, nil
}

// 和第三方交互可以考虑加context
func (s Session) Set(ctx context.Context, key string, val string) error {
    s.values.Store(key, val)
    return nil
}

func (s Session) ID() string {
    // 只读不用太考虑线程安全的问题
    return s.id
}

13. Session:基于 Redis 的实现

// session_e2e_test.go
//go:build e2e
package redis

import (
    "context"
    "github.com/redis/go-redis/v9"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "testing"
)

func TestStore_Generate(t *testing.T) {
    s := newStore()
    ctx := context.Background()
    id := "sess_test_id"
    sess, err := s.Generate(ctx, id)
    require.NoError(t, err)
    defer s.Remove(ctx, id)
    err = sess.Set(ctx, "key1", "123")
    require.NoError(t, err)
    val, err := sess.Get(ctx, "key1")
    require.NoError(t, err)
    assert.Equal(t, "123", val)
}

func newStore() *Store {
    rc := redis.NewClient(&redis.Options{
        Addr:     "localhost:6379",
        // Password: "redis",
    })
    return NewStore(rc)
}
// redis/session.go
package redis

import (
    "context"
    "errors"
    "fmt"
    "github.com/redis/go-redis/v9"
    "time"
    "web_framework/web/session"
)

var (
    errSessionNotFound = errors.New("session: id 对应的 session 不存在")
)

type StoreOption func(store *Store)

// hset
//
//    sessionid: key: value
//
// map[string]map[string]string
type Store struct {
    prefix     string
    client     redis.Cmdable
    expiration time.Duration
}

func NewStore(client redis.Cmdable, opts ...StoreOption) *Store {
    res := &Store{
        expiration: time.Minute * 15,
        client:     client,
        prefix:     "sessid",
    }

    for _, opt := range opts {
        opt(res)
    }

    return res
}

// StoreWithPrefix 用户可以设置自己的prefix
func StoreWithPrefix(prefix string) StoreOption {
    return func(store *Store) {
        store.prefix = prefix
    }
}

func (s Store) Generate(ctx context.Context, id string) (session.Session, error) {
    key := redisKey(s.prefix, id)
    // 如果是用lua脚本,就线程安全
    _, err := s.client.HSet(ctx, key, id, id).Result()
    if err != nil {
        return nil, err
    }
    _, err = s.client.Expire(ctx, key, s.expiration).Result()
    if err != nil {
        return nil, err
    }
    return &Session{
        id:     id,
        key:    key,
        client: s.client,
    }, nil
}

func (s Store) Refresh(ctx context.Context, id string) error {
    key := redisKey(s.prefix, id)
    ok, err := s.client.Expire(ctx, key, s.expiration).Result()
    if err != nil {
        return err
    }
    if !ok {
        return errSessionNotFound
    }
    return nil
}

func (s Store) Remove(ctx context.Context, id string) error {
    key := redisKey(s.prefix, id)
    _, err := s.client.Del(ctx, key).Result()
    return err
    //if err != nil {
    //    return err
    //}
    // 代表id对应的session不存在
    //if cnt == 0 {
    //}
}

func (s Store) Get(ctx context.Context, id string) (session.Session, error) {
    // 自由决策要不要提取把session存储的用户数据全部提取
    // 1. 都不拿 // 2. 只拿高频数据(热点数据) // 3. 都拿
    key := redisKey(s.prefix, id)
    cnt, err := s.client.Exists(ctx, key).Result()
    if err != nil {
        return nil, err
    }
    if cnt != 1 {
        return nil, errSessionNotFound
    }
    return &Session{
        id:     id,
        key:    key,
        client: s.client,
    }, nil
}

type Session struct {
    prefix string
    key    string
    id     string
    client redis.Cmdable
}

func (s Session) Get(ctx context.Context, key string) (any, error) {
    //    const lua=`
    //if redis.call("exists", KEYS[1])
    //then
    //    return redis.call("hset", KEYS[1], ARGV[1], ARGV[2])
    //else
    //    return -1
    //end
    //`
    val, err := s.client.HGet(ctx, s.key, key).Result()
    return val, err
}

func (s Session) Set(ctx context.Context, key string, val string) error {
    // KEYS[1] => s.id
    const lua = `
if redis.call("exists", KEYS[1])
then
    return redis.call("hset", KEYS[1], ARGV[1], ARGV[2])
else 
    return -1
end
`
    res, err := s.client.Eval(ctx, lua, []string{s.key}, key, val).Int()
    if err != nil {
        return err
    }
    if res < 0 {
        return errSessionNotFound
    }
    return nil
}

func (s Session) ID() string {
    return s.id
}

func redisKey(prefix, id string) string {
    return fmt.Sprintf("%s-%s", prefix, id)
}
// cookie/propagator
package cookie

import "net/http"

type Option func(p *Propagator)

type Propagator struct {
    cookieName   string
    cookieOption func(c *http.Cookie)
}

func NewPropagator() *Propagator {
    return &Propagator{
        cookieName: "sessid",
        cookieOption: func(c *http.Cookie) {

        },
    }
}

// 推荐设置cookie的属性: Domain/Secure/SameSite/HttpOnly
func WithCookieName(name string) Option {
    return func(p *Propagator) {
        p.cookieName = name
    }
}

func (p *Propagator) Inject(id string, writer http.ResponseWriter) error {
    c := &http.Cookie{
        Name:  p.cookieName,
        Value: id,
    }
    // 用户可以通过这个对cookie进行操作
    p.cookieOption(c)
    http.SetCookie(writer, c)
    return nil
}

func (p *Propagator) Extract(req *http.Request) (string, error) {
    cookie, err := req.Cookie(p.cookieName)
    if err != nil {
        return "", err
    }
    return cookie.Value, nil
}

func (p *Propagator) Remove(writer http.ResponseWriter) error {
    c := &http.Cookie{
        Name:   p.cookieName,
        MaxAge: -1, // 设置为过期状态
    }
    http.SetCookie(writer, c)
    return nil
}

15. Session:测试与面试要点总结

// test/types_test.go
package test

import (
    "net/http"
    "testing"
    "time"
    "web_framework/web"
    "web_framework/web/session"
    "web_framework/web/session/cookie"
    "web_framework/web/session/memory"
)

func TestSession(t *testing.T) {
    // 简单的登录检验
    var m *session.Manager = &session.Manager{
        Propagator: cookie.NewPropagator(),
        Store:      memory.NewStore(time.Minute * 15),
        CtxSessKey: "sessKey",
    }
    server := web.NewHTTPServer(web.ServerWithMiddleware(func(next web.HandleFunc) web.HandleFunc {
        return func(ctx *web.Context) {
            if ctx.Req.URL.Path == "/login" {
                // 放行,让用户登录
                next(ctx)
                return
            }
            _, err := m.GetSession(ctx)
            if err != nil {
                ctx.RespStatusCode = http.StatusUnauthorized
                ctx.RespData = []byte("请重新登录")
                return
            }
            // 刷新session过期时间
            _ = m.RefreshSession(ctx)
            next(ctx)
        }
    }))

    server.Post("/login", func(ctx *web.Context) {
        // 在session之前校验用户名密码

        sess, err := m.InitSession(ctx)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("登录失败")
            return
        }
        err = sess.Set(ctx.Req.Context(), "nickname", "xiaoming")
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("登录失败")
            return
        }
        ctx.RespStatusCode = http.StatusOK
        ctx.RespData = []byte("登录成功")

        return
    })

    server.Post("/logout", func(ctx *web.Context) {
        // 清理数据
        err := m.RemoveSession(ctx)
        if err != nil {
            ctx.RespStatusCode = http.StatusInternalServerError
            ctx.RespData = []byte("退出失败")
            return
        }
        ctx.RespStatusCode = http.StatusOK
        ctx.RespData = []byte("退出成功")
    })

    server.Get("/user", func(ctx *web.Context) {
        sess, err := m.GetSession(ctx)
        if err != nil {
            ctx.RespStatusCode = http.StatusUnauthorized
            ctx.RespData = []byte("请重新登录")
            return
        }
        val, _ := sess.Get(ctx.Req.Context(), "nickname")
        ctx.RespData = []byte(val.(string))
    })
    server.Start(":8081")
}

出bug了

10 第四周:ORM 框架之 SELECT 与元数据

ORM概览

1. ORM 学习路线图

2. ORM 框架概览:Beego ORM 分析

3. ORM 框架概览:GORM 和 Ent 分析


go的语言特性限制了动态代理,一般go中都用代码生成来实现

4. ORM 框架总结和面试要点

select

5. SELECT:Beego、GORM、Ent 的 SQL构造分析

可读性差,可扩展性差

6. SELECT:核心接口定义

这种的query是一次性的,可以用泛型,而前一种的Orm要多个不同表复用,就不好用泛型

// org/types.go
package orm

// context 包是Go 1.7 引入的标准库,
// 主要用于在goroutine 之间传递取消信号、超时时间、截止时间以及一些共享的值等。
import (
    "context"
    "database/sql"
)

// Querier 用于select
type Querier[T any] interface {
    // 用不用指针(*)都可以
    // 反射,aop之类的用指针方便点
    // 指针可能存在内存逃逸问题
    Get(ctx context.Context) (*T, error)
    GetMulti(ctx context.Context) ([]*T, error)
}

// Executor 用于 INSERT DELETE UPDATE
type Executor interface {
    Exec(ctx context.Context) (sql.Result, error)
}

type QueryBuilder interface {
    Build() (*Query, error)
}

type Query struct {
    SQL  string
    Args []any
}

7. SELECT:SELECT 语句规范、Selector 定义、FROM 语句实现

sql规范,可以有人遵守有人不遵守,所以不同数据库可能sql是不同的

// select_test.go
package orm

import (
    "database/sql"
    "github.com/stretchr/testify/assert"
    "testing"
)

func TestSelector_Build(t *testing.T) {
    testCases := []struct {
        name string

        builder QueryBuilder

        wantQuery *Query
        wantErr   error
    }{
        {
            name:    "not from",
            builder: &Selector[TestModel]{},
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel`;",
                Args: nil,
            },
        },
        {
            name:    "from",
            builder: (&Selector[TestModel]{}).Table("`test_model`"),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "empty from",
            builder: (&Selector[TestModel]{}),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel`;",
                Args: nil,
            },
        },
        {
            name:    "empty table",
            builder: (&Selector[TestModel]{}).Table(""),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel`;",
                Args: nil,
            },
        },
        {
            name:    "db table",
            builder: (&Selector[TestModel]{}).Table("`mybatis`.`test`"),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `mybatis`.`test`;",
                Args: nil,
            },
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.builder.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}
// select.go
package orm

import (
    "context"
    "reflect"
    "strings"
)

type Selector[T any] struct {
    table string
}

func (s *Selector[T]) Build() (*Query, error) {
    var sb strings.Builder
    sb.WriteString("SELECT * FROM ")
    if s.table == "" {
        // 反射获取表名
        var t T
        typ := reflect.TypeOf(t)
        sb.WriteByte('`')
        sb.WriteString(typ.Name())
        sb.WriteByte('`')
    } else {
        //segs := strings.Split(s.table, ".")
        //sb.WriteByte('`')
        sb.WriteString(s.table)
        //sb.WriteByte('`')
    }
    sb.WriteByte(';')
    return &Query{
        SQL: sb.String(),
    }, nil
}

//    func (s *Selector[T]) From(table string) *Selector[T] {
//        s.table = table
//        return s
//    }
func (s *Selector[T]) Table(table string) *Selector[T] {
    s.table = table
    return s
}
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    //TODO implement me
    panic("implement me")
}

func (s *Selector[T]) GetMulti(ctx context.Context) ([]*T, error) {
    //TODO implement me
    panic("implement me")
}

8. SELECT:WHRER 语句、Expression 抽象和面试要点

不定参数传递要解包,这个容易踩坑,可以通过golint-ci等代码检测工具来检查


// select.go
package orm

import (
    "context"
    "fmt"
    "reflect"
    "strings"
)

type Selector[T any] struct {
    table string
    where []Predicate
    sb    *strings.Builder
    args  []any
}

func (s *Selector[T]) Build() (*Query, error) {
    //var sb strings.Builder
    s.sb = &strings.Builder{} // 指针类型一定要初始化
    sb := s.sb
    sb.WriteString("SELECT * FROM ")
    if s.table == "" {
        // 反射获取表名
        var t T
        typ := reflect.TypeOf(t)
        sb.WriteByte('`')
        sb.WriteString(typ.Name())
        sb.WriteByte('`')
    } else {
        //segs := strings.Split(s.table, ".")
        //sb.WriteByte('`')
        sb.WriteString(s.table)
        //sb.WriteByte('`')
    }

    if len(s.where) > 0 {
        sb.WriteString(" WHERE ")
        p := s.where[0]
        for i := 1; i < len(s.where); i++ {
            p = p.And(s.where[i])
        }
        if err := s.buildExpression(p); err != nil {
            return nil, err
        }
    }

    sb.WriteByte(';')
    return &Query{
        SQL:  sb.String(),
        Args: s.args,
    }, nil
}

func (s *Selector[T]) buildExpression(expr Expression) error {
    switch exp := expr.(type) {
    case nil:
    case Predicate:
        // 处理p
        _, ok := exp.left.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.left); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }

        s.sb.WriteByte(' ')
        s.sb.WriteString(exp.op.String())
        s.sb.WriteByte(' ')

        _, ok = exp.right.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.right); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }
    case Column:
        s.sb.WriteByte('`')
        s.sb.WriteString(exp.name)
        s.sb.WriteByte('`')
    //
    case value:
        s.sb.WriteByte('?')
        s.addArg(exp.val)
        //
    default:
        return fmt.Errorf("orm: 不支持的表达式类型 %v", expr)
    }
    return nil
}

func (s *Selector[T]) addArg(val any) *Selector[T] {
    if s.args == nil {
        s.args = make([]any, 0, 8)
    }
    s.args = append(s.args, val)
    return s
}

func (s *Selector[T]) Where(ps ...Predicate) *Selector[T] {
    s.where = ps
    return s
} 
// predicate.go
package orm

// 衍生类型
type op string

// 别名
//type op =string

const (
    opEq  op = "="
    opNot op = "NOT"
    opAnd op = "AND"
    opOr  op = "OR"
)

func (o op) String() string {
    return string(o)
}

type Predicate struct {
    left  Expression
    op    op
    right Expression
}

// Eq("id",12)
// Eq(sub,"id",12)
// Eq(sub.id,12)
// Eq("sub.id",12)
//func Eq(column string, arg any) Predicate {
//    return Predicate{
//        Column: column,
//        Op:     "=",
//        Arg:    arg,
//    }
//}

type Column struct {
    name string
}

func C(name string) Column {
    return Column{name: name}
}

// C("id").Eq(12)
// sub.C("id").Eq(12)
func (c Column) Eq(arg any) Predicate {
    return Predicate{
        left:  c,
        op:    opEq,
        right: value{val: arg},
    }
}

func (Column) expr() {}

// Not Not(C("name").Eq("Tom"))
func Not(p Predicate) Predicate {
    return Predicate{
        op:    opNot,
        right: p,
    }
}

// And C("id").Eq(12).And(C("name").Eq("Tom"))
func (left Predicate) And(right Predicate) Predicate {
    return Predicate{
        left:  left,
        op:    opAnd,
        right: right,
    }
}

// And C("id").Eq(12).Or(C("name").Eq("Tom"))
func (left Predicate) Or(right Predicate) Predicate {
    return Predicate{
        left:  left,
        op:    opOr,
        right: right,
    }
}

func (Predicate) expr() {

}

// Expression 是一个标记接口,代表表达式
type Expression interface {
    expr()
}

type value struct {
    val any
}

func (value) expr() {}
// select_test.go
package orm

import (
    "database/sql"
    "github.com/stretchr/testify/assert"
    "testing"
)

func TestSelector_Build(t *testing.T) {
    testCases := []struct {
        name string

        builder QueryBuilder

        wantQuery *Query
        wantErr   error
    }{
        // ...
        {
            name:    "where",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18)),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel` WHERE `Age` = ?;",
                Args: []any{18},
            },
        },
        {
            name:    "empty where",
            builder: (&Selector[TestModel]{}).Where(),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel`;",
                Args: nil,
            },
        },
        {
            name:    "not",
            builder: (&Selector[TestModel]{}).Where(Not(C("Age").Eq(18))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel` WHERE  NOT (`Age` = ?);",
                Args: []any{18},
            },
        },
        {
            name:    "and",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel` WHERE (`Age` = ?) AND (`FirstName` = ?);",
                Args: []any{18, "Tom"},
            },
        },
        {
            name:    "and",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `TestModel` WHERE (`Age` = ?) OR (`FirstName` = ?);",
                Args: []any{18, "Tom"},
            },
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.builder.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}

元数据

9. 元数据简介



10. 元数据:反射-读字段

构建中间件必备: 反射 unsafe技术 抽象语法树ast 模板(代码生成等)


取指针对应的elem

// field_test.go
package reflect

import (
    "errors"
    "github.com/stretchr/testify/assert"
    "testing"
)

func TestIterateFields(t *testing.T) {
    type User struct {
        Name string
        age  int
    }

    testCases := []struct {
        name string

        entity any

        wantErr error
        wantRes map[string]any
    }{
        {
            name: "struct",
            entity: User{
                Name: "Tom",
                age:  18,
            },
            wantRes: map[string]any{
                "Name": "Tom",
                // 私有的无法反射获取,用0值填充
                "age": 0,
            },
        },
        {
            name: "pointer",
            entity: &User{
                Name: "Tom",
                age:  18,
            },
            wantRes: map[string]any{
                "Name": "Tom",
                // 私有的无法反射获取,用0值填充
                "age": 0,
            },
        },
        {
            name:    "pointer",
            entity:  18,
            wantErr: errors.New("不支持类型"),
        },
        {
            name: "multiple pointer",
            entity: func() **User {
                res := &User{
                    Name: "Tom",
                    age:  18,
                }
                return &res
            }(),
            wantRes: map[string]any{
                "Name": "Tom",
                // 私有的无法反射获取,用0值填充
                "age": 0,
            },
        },
        {
            name:    "multiple pointer",
            entity:  nil,
            wantErr: errors.New("不支持 nil"),
        },
        {
            name: "multiple pointer",
            // 给nil转为User类型,通过了nil和非struct检查
            entity:  (*User)(nil),
            wantErr: errors.New("不支持零值"),
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            res, err := IterateFields(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantRes, res)
        })
    }
}
// fields.go
package reflect

import (
    "errors"
    "reflect"
)

// IterateFields 遍历字段
func IterateFields(entity any) (map[string]any, error) { // 作者习惯: 公共方法都有一个error返回值
    if entity == nil {
        return nil, errors.New("不支持 nil")
    }
    typ := reflect.TypeOf(entity)
    val := reflect.ValueOf(entity)
    if val.IsZero(){
        // 零值取Field会报错
        return nil,errors.New("不支持零值")
    }

    // 如果不是结构体
    //if typ.Kind() != reflect.Struct {
    // 指针
    // for是为了防止多级指针
    for typ.Kind() == reflect.Pointer {
        typ = typ.Elem()
        //  val.Field(i) 的val必须是struct
        val = val.Elem()
    }

    if typ.Kind() != reflect.Struct {
        return nil, errors.New("不支持类型")
    }
    // struct有多少字段
    numFields := typ.NumField()
    res := make(map[string]any, numFields)
    for i := 0; i < numFields; i++ {
        // 字段类型
        fieldType := typ.Field(i)
        // 字段值
        fieldVal := val.Field(i)
        // public的字段才能被反射获取
        if fieldType.IsExported() {
            // 反射获取的值被封装为Value,.Interface()就是获取Value里真正意义上的值
            res[fieldType.Name] = fieldVal.Interface()
        } else {
            // 获取不到的值就给零值
            res[fieldType.Name] = reflect.Zero(fieldType.Type).Interface()
        }
    }

    return res, nil
}

11. 元数据:反射-写字段

// fields_test.go
func TestSetField(t *testing.T) {
    testCases := []struct {
        name string

        entity   any
        field    string
        newValue any

        wantErr error
        // 修改后的entity
        wantEntity any
    }{
        {
            name:       "struct",
            entity:     User{Name: "Tom"},
            field:      "Name",
            newValue:   "Jerry",
            wantErr:    errors.New("不可修改字段"),
            wantEntity: User{Name: "Tom"},
        },
        {
            name:     "pointer",
            entity:   &User{Name: "Tom"},
            field:    "Name",
            newValue: "Jerry",
            //wantErr:    errors.New("不可修改字段"),
            wantEntity: &User{Name: "Jerry"},
        },
        {
            name:       "pointer exported",
            entity:     &User{age: 18},
            field:      "age",
            newValue:   16,
            wantErr:    errors.New("不可修改字段"),
            wantEntity: &User{age: 18},
        },
    }
    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            err := SetField(tc.entity, tc.field, tc.newValue)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantEntity, tc.entity)
        })
    }
}

func TestSet(t *testing.T) {
    var i = 0
    // panic: unaddressable value
    //reflect.ValueOf(i).Set(reflect.ValueOf(12))
    ptr := &i
    reflect.ValueOf(ptr).Elem().Set(reflect.ValueOf(12))
    assert.Equal(t, 12, i)
}
// fields.go
func SetField(entity any, field string, newValue any) error {
    val := reflect.ValueOf(entity)
    // 拿到指针指向的结构体
    for val.Type().Kind() == reflect.Pointer {
        val = val.Elem()
    }

    fieldVal := val.FieldByName(field)
    // 能否被修改
    if !fieldVal.CanSet() {
        return errors.New("不可修改字段")
    }
    fieldVal.Set(reflect.ValueOf(newValue))

    return nil
}

12. 元数据:反射-方法

// func_call_test.go
package reflect

import (
    "github.com/stretchr/testify/assert"
    "my_framework/orm/reflect/types"
    "reflect"
    "testing"
)

func TestIterateFunc(t *testing.T) {
    testCases := []struct {
        name string

        entity  any
        wantRes map[string]FuncInfo
        wantErr error
    }{
        {
            name:   "struct",
            entity: types.NewUser("Tom", 18),
            wantRes: map[string]FuncInfo{
                // 结构体只能访问定义在struct上的
                "GetAge": {
                    Name:        "GetAge",
                    InputTypes:  []reflect.Type{reflect.TypeOf(types.User{})},
                    OutputTypes: []reflect.Type{reflect.TypeOf(0)},
                    Result:      []any{18},
                },
                // 定义在指针上,无法获取到
                //"ChangeName": {
                //    Name: "ChangeName",
                //    InputTypes: []reflect.Type{reflect.TypeOf("")},
                //},
            },
        },
        {
            name:   "struct",
            entity: types.NewUserPtr("Tom", 18),
            wantRes: map[string]FuncInfo{
                "GetAge": {
                    Name:        "GetAge",
                    InputTypes:  []reflect.Type{reflect.TypeOf(&types.User{})},
                    OutputTypes: []reflect.Type{reflect.TypeOf(0)},
                    Result:      []any{18},
                },
                // 定义在指针上,无法获取到
                "ChangeName": {
                    Name:        "ChangeName",
                    InputTypes:  []reflect.Type{reflect.TypeOf(&types.User{}), reflect.TypeOf("")},
                    OutputTypes: []reflect.Type{},
                    Result:      []any{},
                },
            },
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            res, err := IterateFunc(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantRes, res)
        })
    }
}
// func_call.go
package reflect

import "reflect"

func IterateFunc(entity any) (map[string]FuncInfo, error) {
    typ := reflect.TypeOf(entity)
    numMethod := typ.NumMethod()
    res := make(map[string]FuncInfo, numMethod)
    for i := 0; i < numMethod; i++ {
        method := typ.Method(i)
        fn := method.Func

        numIn := fn.Type().NumIn()
        input := make([]reflect.Type, 0, numIn)
        inputVals := make([]reflect.Value, 0, numIn)
        // 方法的第一个参数是该方法所属结构体
        input = append(input, reflect.TypeOf(entity))
        inputVals = append(inputVals, reflect.ValueOf(entity))
        for i := 1; i < numIn; i++ {
            fnInType := fn.Type().In(i)
            input = append(input, fnInType)
            inputVals = append(inputVals, reflect.Zero(fnInType))
        }

        numOut := fn.Type().NumOut()
        output := make([]reflect.Type, 0, numOut)
        for i := 0; i < numOut; i++ {
            output = append(output, fn.Type().Out(i))
        }

        resVals := fn.Call(inputVals)
        result := make([]any, 0, len(resVals))
        for _, v := range resVals {
            result = append(result, v.Interface())
        }

        res[method.Name] = FuncInfo{
            Name:        method.Name,
            InputTypes:  input,
            OutputTypes: output,
            Result:      result,
        }
    }
    return res, nil
}

type FuncInfo struct {
    Name        string
    InputTypes  []reflect.Type
    OutputTypes []reflect.Type

    Result []any
}
// orm/types/user.go
package types

import "fmt"

type User struct {
    Name string
    age  int
}

func NewUser(name string, age int) User {
    return User{
        Name: name,
        age:  age,
    }
}

func NewUserPtr(name string, age int) *User {
    return &User{
        Name: name,
        age:  age,
    }
}

func (u User) GetAge() int {
    return u.age
}

func (u *User) ChangeName(newName string) {
    u.Name = newName
}

func (u User) private() {
    fmt.Println("private")
}

13. 元数据:反射-遍历

// iterate_test.go
package reflect

import (
    "github.com/stretchr/testify/assert"
    "testing"
)

func TestIterateArray(t *testing.T) {
    testCases := []struct {
        name   string
        entity any

        wantVals []any
        wantErr  error
    }{
        {
            name:     "arr",
            entity:   [3]int{1, 2, 3},
            wantVals: []any{1, 2, 3},
        },
        {
            name:     "slice",
            entity:   []int{1, 2, 3},
            wantVals: []any{1, 2, 3},
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            vals, err := IterateArrayOrSlice(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantVals, vals)
        })
    }
}

func TestIterateMap(t *testing.T) {
    tests := []struct {
        name   string
        entity any

        wantKeys   []any
        wantValues []any
        wantErr    error
    }{
        {
            name: "map",
            entity: map[string]string{
                "A": "a",
                "B": "b",
            },
            wantKeys:   []any{"A", "B"},
            wantValues: []any{"a", "b"},
        },
    }

    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            // go map 的遍历, key是无序的
            keys, values, err := IterateMap(tt.entity)
            assert.Equal(t, tt.wantErr, err)
            if err != nil {
                return
            }
            assert.EqualValues(t, tt.wantKeys, keys)
            assert.EqualValues(t, tt.wantValues, values)
        })
    }
}
// iterate.go
package reflect

import "reflect"

func IterateArrayOrSlice(entity any) ([]any, error) {
    val := reflect.ValueOf(entity)
    res := make([]any, 0, val.Len())
    for i := 0; i < val.Len(); i++ {
        ele := val.Index(i)
        res = append(res, ele.Interface())
    }
    return res, nil
}

// keys,vals,err
func IterateMap(entity any) ([]any, []any, error) {
    val := reflect.ValueOf(entity)
    resKeys := make([]any, 0, val.Len())
    resValues := make([]any, 0, val.Len())

    // 方式1
    itr := val.MapRange()
    for itr.Next() {
        resKeys = append(resKeys, itr.Key().Interface())
        resValues = append(resValues, itr.Value().Interface())
    }

    // 方式2
    //keys := val.MapKeys()
    //for _, key := range keys {
    //    v := val.MapIndex(key)
    //    resKeys = append(resKeys, key.Interface())
    //    resValues = append(resValues, v.Interface())
    //}

    return resKeys, resValues, nil
}

14. 元数据:反射的开源实例、面试要点总结


15. 元数据:反射解析模型


// orm/internal/errs/error.go
// internal 包里的文件不会被用户访问到
package errs

import (
    "errors"
    "fmt"
)

var (
    ErrPointerOnly = errors.New("orm: 只支持指向结构体的一级指针")
)

func NewErrUnsupportedExpression(expr any) error {
    return fmt.Errorf("orm: 不支持的表达式类型 %v", expr)
}
// model_test.go
package orm

import (
    "github.com/stretchr/testify/assert"
    "my_framework/orm/internal/errs"
    "testing"
)

func Test_parseModel(t *testing.T) {
    tests := []struct {
        name string

        entity    any
        wantModel *model
        wantErr   error
    }{
        {
            name:   "struct",
            entity: TestModel{},
            //wantModel: &model{
            //    tableName: "test_model",
            //    fields: map[string]*field{
            //        "Id": {
            //            colName: "id",
            //        },
            //        "FirstName": {
            //            colName: "first_name",
            //        },
            //        "LastName": {
            //            colName: "last_name",
            //        },
            //        "Age": {
            //            colName: "age",
            //        },
            //    },
            //},
            wantErr: errs.ErrPointerOnly,
        },
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &model{
                tableName: "test_model",
                fields: map[string]*field{
                    "Id": {
                        colName: "id",
                    },
                    "FirstName": {
                        colName: "first_name",
                    },
                    "LastName": {
                        colName: "last_name",
                    },
                    "Age": {
                        colName: "age",
                    },
                },
            },
        },
    }
    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            m, err := parseModel(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantModel, m)
        })
    }
}
// model.go
package orm

import (
    "my_framework/orm/internal/errs"
    "reflect"
    "unicode"
)

type model struct {
    tableName string
    fields    map[string]*field
}

type field struct {
    // 列名
    colName string
}

func parseModel(entity any) (*model, error) {
    typ := reflect.TypeOf(entity)
    if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
        return nil, errs.ErrPointerOnly
    }
    // 指针
    // 可以处理多级指针
    //for typ.Kind() == reflect.Pointer {
    // 只能处理一级指针
    if typ.Kind() == reflect.Pointer {
        typ = typ.Elem()
    }

    numField := typ.NumField()

    fieldMap := make(map[string]*field, numField)
    for i := 0; i < numField; i++ {
        fd := typ.Field(i)
        fieldMap[fd.Name] = &field{
            colName: underscoreName(fd.Name),
        }
    }
    return &model{
        tableName: underscoreName(typ.Name()),
        fields:    fieldMap,
    }, nil
}

// underscoreName 驼峰命名转下划线
func underscoreName(tableName string) string {
    var buf []byte
    for i, v := range tableName {
        if unicode.IsUpper(v) {
            if i != 0 {
                buf = append(buf, '_')
            }
            buf = append(buf, byte(unicode.ToLower(v)))
        } else {
            buf = append(buf, byte(v))
        }
    }
    return string(buf)
}

16. 元数据:利用元数据改造 Selector、元数据阶段总结

// select.go
package orm

import (
    "context"
    "my_framework/orm/internal/errs"
    "strings"
)

type Selector[T any] struct {
    table string
    model *model
    where []Predicate
    sb    *strings.Builder
    args  []any
}

func (s *Selector[T]) Build() (*Query, error) {
    //var sb strings.Builder
    s.sb = &strings.Builder{} // 指针类型一定要初始化
    var err error
    s.model, err = parseModel(new(T)) // new(T) 生成T的指针
    if err != nil {
        return nil, err
    }
    sb := s.sb
    sb.WriteString("SELECT * FROM ")
    if s.table == "" {
        sb.WriteByte('`')
        sb.WriteString(s.model.tableName)
        sb.WriteByte('`')
    } else {
        //segs := strings.Split(s.table, ".")
        //sb.WriteByte('`')
        sb.WriteString(s.table)
        //sb.WriteByte('`')
    }

    if len(s.where) > 0 {
        // ...
    }

    sb.WriteByte(';')
    return &Query{
        SQL:  sb.String(),
        Args: s.args,
    }, nil
}

func (s *Selector[T]) buildExpression(expr Expression) error {
    switch exp := expr.(type) {
    case nil:
    case Predicate:
        // ...
    case Column:
        fd, ok := s.model.fields[exp.name]
        // 传入了错误的字段
        if !ok {
            return errs.NewErrUnknownField(exp.name)
        }
        s.sb.WriteByte('`')
        s.sb.WriteString(fd.colName)
        s.sb.WriteByte('`')
    //
    case value:
        s.sb.WriteByte('?')
        s.addArg(exp.val)
        //
    default:
        return errs.NewErrUnsupportedExpression(expr)
    }
    return nil
}

// ...
// select_test.go
package orm

import (
    "database/sql"
    "github.com/stretchr/testify/assert"
    "my_framework/orm/internal/errs"
    "testing"
)

func TestSelector_Build(t *testing.T) {
    testCases := []struct {
        name string

        builder QueryBuilder

        wantQuery *Query
        wantErr   error
    }{
        {
            name:    "not from",
            builder: &Selector[TestModel]{},
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "from",
            builder: (&Selector[TestModel]{}).Table("`test_model`"),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "empty from",
            builder: (&Selector[TestModel]{}),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "empty table",
            builder: (&Selector[TestModel]{}).Table(""),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "db table",
            builder: (&Selector[TestModel]{}).Table("`mybatis`.`test`"),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `mybatis`.`test`;",
                Args: nil,
            },
        },
        {
            name:    "where",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18)),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE `age` = ?;",
                Args: []any{18},
            },
        },
        {
            name:    "empty where",
            builder: (&Selector[TestModel]{}).Where(),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "not",
            builder: (&Selector[TestModel]{}).Where(Not(C("Age").Eq(18))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE  NOT (`age` = ?);",
                Args: []any{18},
            },
        },
        {
            name:    "and",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`first_name` = ?);",
                Args: []any{18, "Tom"},
            },
        },
        {
            name:    "and",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE (`age` = ?) OR (`first_name` = ?);",
                Args: []any{18, "Tom"},
            },
        },
        {
            name:    "invalid column",
            builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).Or(C("XXXX").Eq("Tom"))),
            wantErr: errs.NewErrUnknownField("XXXX"),
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.builder.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}

作业

18.第四周作业:DELETE 语句[选学]

19.第四周 DELETE 作业讲解[选学]

// builder.go
package orm

import (
    "my_framework/orm/internal/errs"
    "strings"
)

type builder struct {
    sb   *strings.Builder
    model *model
    args []any
} 

func (s *builder ) buildExpression(expr Expression) error {
    switch exp := expr.(type) {
    case nil:
    case Predicate:
        // 处理p
        _, ok := exp.left.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.left); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }

        s.sb.WriteByte(' ')
        s.sb.WriteString(exp.op.String())
        s.sb.WriteByte(' ')

        _, ok = exp.right.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.right); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }
    case Column:
        fd, ok := s.model.fields[exp.name]
        // 传入了错误的字段
        if !ok {
            return errs.NewErrUnknownField(exp.name)
        }
        s.sb.WriteByte('`')
        s.sb.WriteString(fd.colName)
        s.sb.WriteByte('`')
    //
    case value:
        s.sb.WriteByte('?')
        s.addArg(exp.val)
        //
    default:
        return errs.NewErrUnsupportedExpression(expr)
    }
    return nil
}

func (s *builder) addArg(val any) *builder {
    if s.args == nil {
        s.args = make([]any, 0, 8)
    }
    s.args = append(s.args, val)
    return s
}

func (b *builder) buildPredicates(ps []Predicate) error {
    p := ps[0]
    for i := 1; i < len(ps); i++ {
        p = p.And(ps[i])
    }
    return b.buildExpression(p)
}
// delete.go
package orm

import (
    "strings"
)

type Deleter[T any] struct {
    builder
    where []Predicate
    table string
}

func (d *Deleter[T]) Build() (*Query, error) {
    //var sb strings.Builder
    d.sb = &strings.Builder{} // 指针类型一定要初始化
    var err error
    d.model, err = parseModel(new(T)) // new(T) 生成T的指针
    if err != nil {
        return nil, err
    }
    sb := d.sb
    sb.WriteString("DELETE * FROM ")
    if d.table == "" {
        sb.WriteByte('`')
        sb.WriteString(d.model.tableName)
        sb.WriteByte('`')
    } else {
        //segs := strings.Split(d.table, ".")
        //sb.WriteByte('`')
        sb.WriteString(d.table)
        //sb.WriteByte('`')
    }

    if len(d.where) > 0 {
        sb.WriteString(" WHERE ")
        if err := d.buildPredicates(d.where); err != nil {
            return nil, err
        }
    }

    sb.WriteByte(';')
    return &Query{
        SQL:  sb.String(),
        Args: d.args,
    }, nil
} 

func (d *Deleter[T]) From(table string) *Deleter[T] {

    return d
}

func (d *Deleter[T]) Where(predicate ...Predicate) *Deleter[T] {
    return d
}
// delete_test.go
package orm

import (
    "github.com/stretchr/testify/assert"
    "testing"
)

func TestDeleter_Build(t *testing.T) {
    testCases := []struct {
        name string

        builder QueryBuilder

        wantQuery *Query
        wantErr   error
    }{
        {
            name: "struct",
            builder: &Deleter[TestModel]{},
            wantQuery: &Query{
                SQL:  "DELETE * FROM `test_model`;",
                Args: nil,
            },
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.builder.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

11 第五周:ORM 框架之元数据、SQL 编程与结果集处理

元数据

1. 元数据:注册中心


超时配置很重要

// select.go
type Selector[T any] struct {
    builder
    table string
    where []Predicate
    db    *DB
}

func NewSelector[T any](db *DB) *Selector[T] {
    return &Selector[T]{
        db: db,
    }
}
// select_test.go
package orm

import (
    "database/sql"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "my_framework/orm/internal/errs"
    "testing"
)

func TestSelector_Build(t *testing.T) {
    db, err := NewDB()
    require.NoError(t, err)
    testCases := []struct {
        name string

        builder QueryBuilder

        wantQuery *Query
        wantErr   error
    }{
        {
            name:    "not from",
            builder: NewSelector[TestModel](db),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "from",
            builder: (NewSelector[TestModel](db)).Table("`test_model`"),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "empty from",
            builder: (NewSelector[TestModel](db)),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "empty table",
            builder: (NewSelector[TestModel](db)).Table(""),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "db table",
            builder: (NewSelector[TestModel](db)).Table("`mybatis`.`test`"),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `mybatis`.`test`;",
                Args: nil,
            },
        },
        {
            name:    "where",
            builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18)),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE `age` = ?;",
                Args: []any{18},
            },
        },
        {
            name:    "empty where",
            builder: (NewSelector[TestModel](db)).Where(),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model`;",
                Args: nil,
            },
        },
        {
            name:    "not",
            builder: (NewSelector[TestModel](db)).Where(Not(C("Age").Eq(18))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE  NOT (`age` = ?);",
                Args: []any{18},
            },
        },
        {
            name:    "and",
            builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`first_name` = ?);",
                Args: []any{18, "Tom"},
            },
        },
        {
            name:    "and",
            builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE (`age` = ?) OR (`first_name` = ?);",
                Args: []any{18, "Tom"},
            },
        },
        {
            name:    "invalid column",
            builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("XXXX").Eq("Tom"))),
            wantErr: errs.NewErrUnknownField("XXXX"),
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.builder.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}
// db.go
package orm

type DBOption func(db *DB)

type DB struct {
    r *registry
}

func NewDB(opts ...DBOption) (*DB, error) {
    res := &DB{
        r: newRegistry(),
    }
    for _, opt := range opts {
        opt(res)
    }
    return res, nil
}

func MustNewDB(opts ...DBOption) *DB {
    res, err := NewDB(opts...)
    if err != nil {
        panic(err)
    }
    return res
}
// model_test.go
package orm

import (
    "github.com/stretchr/testify/assert"
    "my_framework/orm/internal/errs"
    "reflect"
    "testing"
)

func Test_parseModel(t *testing.T) {
    tests := []struct {
        name string

        entity    any
        wantModel *model
        wantErr   error
    }{
        {
            name:   "struct",
            entity: TestModel{}, 
            wantErr: errs.ErrPointerOnly,
        },
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &model{
                tableName: "test_model",
                fields: map[string]*field{
                    "Id": {
                        colName: "id",
                    },
                    "FirstName": {
                        colName: "first_name",
                    },
                    "LastName": {
                        colName: "last_name",
                    },
                    "Age": {
                        colName: "age",
                    },
                },
            },
        },
    }

    r := &registry{}
    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            m, err := r.parseModel(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantModel, m)
        })
    }
}

func TestRegistry_Get(t *testing.T) {
    testCases := []struct {
        name string

        entity    any
        wantModel *model
        wantErr   error
        cacheSize int
    }{
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &model{
                tableName: "test_model",
                fields: map[string]*field{
                    "Id": {
                        colName: "id",
                    },
                    "FirstName": {
                        colName: "first_name",
                    },
                    "LastName": {
                        colName: "last_name",
                    },
                    "Age": {
                        colName: "age",
                    },
                },
            },
            cacheSize: 1,
        },
    }
    r := newRegistry()
    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            m, err := r.get(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantModel, m)
            // 只检测了数量
            assert.Equal(t, tc.cacheSize, len(r.models))

            typ := reflect.TypeOf(tc.entity)
            m, ok := r.models[typ]
            assert.True(t, ok)
            assert.Equal(t, tc.wantModel, m)
        })
    }
}
// model.go
// 元数据注册中心
type registry struct {
    models map[reflect.Type]*model
}

func newRegistry() *registry {
    return &registry{
        models: make(map[reflect.Type]*model, 64),
    }
}

func (r *registry) get(val any) (*model, error) {
    typ := reflect.TypeOf(val)
    m, ok := r.models[typ]
    if !ok {
        var err error
        m, err = r.parseModel(val)
        if err != nil {
            return nil, err
        }
        r.models[typ] = m
    }
    return m, nil
}

2. 元数据:注册中心并发问题

并发读可以,并发读写不行

// model.go
func (r *registry) get(val any) (*model, error) {
    typ := reflect.TypeOf(val)
    m, ok := r.models.Load(typ)
    if ok {
        return m.(*model), nil
    }
    m, err := r.parseModel(val)
    if err != nil {
        return nil, err
    }
    r.models.Store(typ, m)
    return m.(*model), nil
}

3. 元数据:标签自定义列名


// model_test.go
package orm

import (
    "github.com/stretchr/testify/assert"
    "my_framework/orm/internal/errs"
    "reflect"
    "testing"
)

func Test_parseModel(t *testing.T) {
    tests := []struct {
        name string

        entity    any
        wantModel *model
        wantErr   error
    }{
        {
            name:   "struct",
            entity: TestModel{}, 
            wantErr: errs.ErrPointerOnly,
        },
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &model{
                tableName: "test_model",
                fields: map[string]*field{
                    "Id": {
                        colName: "id",
                    },
                    "FirstName": {
                        colName: "first_name",
                    },
                    "LastName": {
                        colName: "last_name",
                    },
                    "Age": {
                        colName: "age",
                    },
                },
            },
        },
    }

    r := &registry{}
    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            m, err := r.parseModel(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantModel, m)
        })
    }
}

func TestRegistry_Get(t *testing.T) {
    testCases := []struct {
        name string

        entity    any
        wantModel *model
        wantErr   error
    }{
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &model{
                tableName: "test_model",
                fields: map[string]*field{
                    "Id": {
                        colName: "id",
                    },
                    "FirstName": {
                        colName: "first_name",
                    },
                    "LastName": {
                        colName: "last_name",
                    },
                    "Age": {
                        colName: "age",
                    },
                },
            },
        },
        {
            name: "tag",
            entity: func() any {
                type TagTable struct {
                    FirstName string `orm:"column=first_name_t"`
                }
                return &TagTable{}
            }(),
            wantModel: &model{
                tableName: "tag_table",
                fields: map[string]*field{
                    "FirstName": {
                        colName: "first_name_t",
                    },
                },
            },
        },
        {
            name: "empty tag",
            entity: func() any {
                type TagTable struct {
                    FirstName string `orm:"column="`
                }
                return &TagTable{}
            }(),
            wantModel: &model{
                tableName: "tag_table",
                fields: map[string]*field{
                    "FirstName": {
                        colName: "first_name",
                    },
                },
            },
        },
        {
            name: "column only",
            entity: func() any {
                type TagTable struct {
                    FirstName string `orm:"column"`
                }
                return &TagTable{}
            }(),
            wantErr: errs.NewErrInvalidTagContent("column"),
        },
        {
            name: "invalid tag",
            entity: func() any {
                type TagTable struct {
                    FirstName string `orm:"abc=abc"`
                }
                return &TagTable{}
            }(),
            wantModel: &model {
                tableName: "tag_table",
                fields: map[string]*field{
                    "FirstName": {
                        colName: "first_name",
                    },
                },
            },
        },
    }

    r := newRegistry()
    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            m, err := r.get(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantModel, m)

            typ := reflect.TypeOf(tc.entity)
            cache, ok := r.models.Load(typ)
            assert.True(t, ok)
            assert.Equal(t, tc.wantModel, cache)
        })
    }
}
// model.go
// 解析创建model元数据
func (r *registry) parseModel(entity any) (*model, error) {
    typ := reflect.TypeOf(entity)
    if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
        return nil, errs.ErrPointerOnly
    }
    // 指针
    // 可以处理多级指针
    //for typ.Kind() == reflect.Pointer {
    // 只能处理一级指针
    if typ.Kind() == reflect.Pointer {
        typ = typ.Elem()
    }

    numField := typ.NumField()

    fieldMap := make(map[string]*field, numField)
    for i := 0; i < numField; i++ {
        fd := typ.Field(i)
        // 取tag `xxx:xxx=xxx`
        pair, err := r.parseTag(fd.Tag)
        if err != nil {
            return nil, err
        }
        // 取column: `orm:"column"=xxx`
        colName := pair[tagKeyColumn]
        if colName == "" {
            // 用户没有设置,用字段名
            colName = underscoreName(fd.Name)
        }
        fieldMap[fd.Name] = &field{
            //colName: underscoreName(fd.Name),
            colName: colName,
        }
    }
    return &model{
        tableName: underscoreName(typ.Name()),
        fields:    fieldMap,
    }, nil
}

type User struct {
    ID uint64 `orm:"column=id,xxx=bbb"`
}

func (r *registry) parseTag(tag reflect.StructTag) (map[string]string, error) {
    ormTag, ok := tag.Lookup("orm")
    if !ok {
        return map[string]string{}, nil
    }
    // orm:"xxx=xxx,xxx=xxx"
    pairs := strings.Split(ormTag, ",")
    res := make(map[string]string, len(pairs))
    for _, pair := range pairs {
        segs := strings.Split(pair, "=")
        if len(segs) != 2 {
            return nil, errs.NewErrInvalidTagContent(pair)
        }
        key := segs[0]
        val := segs[1]
        res[key] = val
    }
    return res, nil
}

4. 元数据:接口自定义表名

// model_test.go
func TestRegistry_Get(t *testing.T) {
    testCases := []struct {
        name string

        entity    any
        wantModel *model
        wantErr   error
    }{
        // ...
        {
            name:   "custom table name",
            entity: &CustomTableName{},
            wantModel: &model{
                tableName: "custom_table_name_t",
                fields: map[string]*field{
                    "FirstName": {
                        colName: "first_name",
                    },
                },
            },
        },
        {
            name:   "custom table name ptr",
            entity: &CustomTableNamePtr{},
            wantModel: &model{
                tableName: "custom_table_name_ptr_t",
                fields: map[string]*field{
                    "FirstName": {
                        colName: "first_name",
                    },
                },
            },
        },
        {
            name:   "custom table name empty ptr",
            entity: &CustomTableNameEmpty{},
            wantModel: &model{
                tableName: "custom_table_name_empty",
                fields: map[string]*field{
                    "FirstName": {
                        colName: "first_name",
                    },
                },
            },
        },
    }

    r := newRegistry()
    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            m, err := r.get(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantModel, m)

            typ := reflect.TypeOf(tc.entity)
            cache, ok := r.models.Load(typ)
            assert.True(t, ok)
            assert.Equal(t, tc.wantModel, cache)
        })
    }
}

type CustomTableName struct {
    FirstName string
}

func (c CustomTableName) TableName() string {
    return "custom_table_name_t"
}

type CustomTableNamePtr struct {
    FirstName string
}

func (c *CustomTableNamePtr) TableName() string {
    return "custom_table_name_ptr_t"
}

type CustomTableNameEmpty struct {
    FirstName string
}

func (c  CustomTableNameEmpty) TableName() string {
    return ""
}
// model.go
func (r *registry) parseModel(entity any) (*model, error) {
    // ...

    var tableName string
    if tbl, ok := entity.(TableName); ok {
        tableName = tbl.TableName()
    }
    if tableName == "" {
        tableName = underscoreName(typ.Name())
    }

    return &model{
        tableName: tableName,
        fields:    fieldMap,
    }, nil
}
// types.go
type TableName interface {
    TableName() string
}

5. 元数据:编程方式自定义表名和列名

// model.go
func (r *registry) Get(val any) (*Model, error) {
    typ := reflect.TypeOf(val)
    m, ok := r.models.Load(typ)
    if ok {
        return m.(*Model), nil
    }
    m, err := r.Register(val)
    if err != nil {
        return nil, err
    }
    return m.(*Model), nil
}

func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
    typ := reflect.TypeOf(entity)
    if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
        return nil, errs.ErrPointerOnly
    }
    // 指针
    // 可以处理多级指针
    //for elemType.Kind() == reflect.Pointer {
    // 只能处理一级指针
    //if typ.Kind() == reflect.Pointer {
    elemType := typ.Elem()
    //}

    numField := elemType.NumField()

    fieldMap := make(map[string]*Field, numField)
    for i := 0; i < numField; i++ {
        fd := elemType.Field(i)
        // 取tag `xxx:xxx=xxx`
        pair, err := r.parseTag(fd.Tag)
        if err != nil {
            return nil, err
        }
        // 取column: `orm:"column"=xxx`
        colName := pair[tagKeyColumn]
        if colName == "" {
            // 用户没有设置,用字段名
            colName = underscoreName(fd.Name)
        }
        fieldMap[fd.Name] = &Field{
            //colName: underscoreName(fd.Name),
            colName: colName,
        }
    }

    var tableName string
    if tbl, ok := entity.(TableName); ok {
        tableName = tbl.TableName()
    }
    if tableName == "" {
        tableName = underscoreName(elemType.Name())
    }

    res := &Model{
        tableName: tableName,
        fields:    fieldMap,
    }

    for _, opt := range opts {
        err := opt(res)
        if err != nil {
            return nil, err
        }
    }

    r.models.Store(typ, res)
    return res, nil
}
// model_test.go
func TestModelWithTableName(t *testing.T) {
    r := newRegistry()
    m, err := r.Register(&TestModel{}, ModelWithTableName("test_model_ttt"))
    require.NoError(t, err)
    assert.Equal(t, "test_model_ttt", m.tableName)
}

func TestModelWithColName(t *testing.T) {
    testCases := []struct {
        name string

        field   string
        colName string

        wantColName string
        wantErr     error
    }{
        {
            name:        "column name",
            field:       "FirstName",
            colName:     "first_name_ccc",
            wantColName: "first_name_ccc",
        },
        {
            name:    "invalid column name",
            field:   "XXX",
            colName: "first_name_ccc",
            wantErr: errs.NewErrUnknownField("XXX"),
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            r := newRegistry()
            m, err := r.Register(&TestModel{}, ModelWithColumnName(tc.field, tc.colName))
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            fd, ok := m.fields[tc.field]
            require.True(t, ok)
            assert.Equal(t, tc.wantColName, fd.colName)
        })
    }
}

6. 元数据:总结与面试要点

sql编程

7. SQL 编程:增删改查


// crud_test.go
package sql_demo

import (
    "context"
    "database/sql"
    "log"
    "time"

    //_ "github.com/mattn/go-sqlite3"
    _ "github.com/go-sql-driver/mysql"
    "github.com/stretchr/testify/require"
    "testing"
)

func TestDB(t *testing.T) {
    db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    require.NoError(t, err)
    defer db.Close()
    err = db.Ping()
    if err != nil {
        return
    }
    // 可以用db
    //sql.OpenDB()
    ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    // 除了select, 都是用的 ExecContext
    //_, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS test_model;`)

    _, err = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS test_model(
    id int AUTO_INCREMENT PRIMARY KEY, 
    first_name TEXT NOT NULL,
    age int,
    last_name TEXT NOT NULL
)
`)
    //cancel()
    require.NoError(t, err)

    // INSERT 插入
    // 使用?作为占位符,防止sql注入
    //res, err := db.ExecContext(ctx, "INSERT INTO test_model(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?)",
    res, err := db.ExecContext(ctx, "INSERT INTO test_model(`first_name`,`age`,`last_name`) VALUES (?,?,?)",
        //1, "Tom", 18, "Jerry")
        "Tom", 18, "Jerry")
    // 全都执行完再cancel
    //cancel()
    require.NoError(t, err)

    affected, err := res.RowsAffected()
    require.NoError(t, err)
    log.Println("受影响行数: ", affected)

    lastId, err := res.LastInsertId()
    require.NoError(t, err)
    // auto_increment才会返回正确的值,如果不是主键自增或者插入失败就会返回0
    log.Println("最后插入的id: ", lastId)

    // SELECT 查询
    row := db.QueryRowContext(ctx, "SELECT * FROM `test_model` WHERE `id`=?", 1)
    require.NoError(t, row.Err())
    tm := TestModel{}
    //// 如果不放心,就把select *换成字段名,然后这里传入字段对应的指针&tm.xxx
    err = row.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
    require.NoError(t, err)
    log.Println(tm)

    row = db.QueryRowContext(ctx, "SELECT * FROM `test_model` WHERE `id`=?", 100)
    // 查询不到
    //log.Println(row.Err())
    //require.Error(t, sql.ErrNoRows, row.Err())
    err = row.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
    log.Println(err)
    require.Error(t, sql.ErrNoRows, err)

    rows, err := db.QueryContext(ctx, "SELECT * FROM `test_model` WHERE `id`>?", 1)
    for rows.Next() {
        tm = TestModel{}
        err = rows.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
        require.NoError(t, err)
        log.Println(tm)
    }

    // 最后再用
    defer cancel()
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}

8. SQL 编程:Valuer 和 Scanner 接口

// json_test.go
package sql_demo

import (
    "github.com/stretchr/testify/assert"
    "testing"
)

type User struct {
    Name string
}

func TestJsonColumn_Value(t *testing.T) {
    js := JsonColumn[User]{Valid: true, Val: User{Name: "Tom"}}
    value, err := js.Value()
    assert.Nil(t, err)
    assert.Equal(t, []byte(`{"Name":"Tom"}`), value)

    js = JsonColumn[User]{}
    value, err = js.Value()
    assert.Nil(t, err)
    assert.Nil(t, value)
}

func TestJsonColumn_Scan(t *testing.T) {
    testCases := []struct {
        name    string
        src     any
        wantErr error
        wantVal User
        valid   bool
    }{
        {
            name: "nil",
        },
        {
            name:    "string",
            src:     `{"Name":"Tom"}`,
            wantVal: User{Name: "Tom"},
            valid:   true,
        },
        {
            name:    "bytes",
            src:     []byte(`{"Name":"Tom"}`),
            wantVal: User{Name: "Tom"},
            valid:   true,
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            js := &JsonColumn[User]{}
            err := js.Scan(tc.src)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantVal, js.Val)
            assert.Equal(t, tc.valid, js.Valid)
        })
    }
}
// json.go
package sql_demo

import (
    "database/sql/driver"
    "errors"
    "github.com/goccy/go-json"
)

type JsonColumn[T any] struct {
    Val T
    // not NULL
    Valid bool
}

func (j JsonColumn[T]) Value() (driver.Value, error) {
    // 如果查出来是null
    if !j.Valid {
        return nil, nil
    }
    return json.Marshal(j.Val)
}

func (j *JsonColumn[T]) Scan(src any) error {
    var bs []byte
    switch data := src.(type) {
    case string:
        // 可以考虑额外处理空字符串
        bs = []byte(data)
    case []byte:
        // 可以考虑处理[]byte{}
        bs = data
    case nil:
        // 说明数据库里没数据(null)
        return nil
    default:
        return errors.New("不支持类型")
    }
    err := json.Unmarshal(bs, &j.Val)
    if err == nil {
        j.Valid = true
    }
    return err
}

可以用这种方法给字段加解密,因为是在代码层面上,不会给数据库造成什么压力(如果是让数据库加密就会)

9. SQL 编程:事务与隔离级别


// crud_test.go
func TestTranslation(t *testing.T) {
    db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    require.NoError(t, err)
    defer db.Close()
    err = db.Ping()
    if err != nil {
        return
    }
    ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    defer cancel()

    tx, err := db.BeginTx(ctx, &sql.TxOptions{})
    require.NoError(t, err)

    res, err := tx.ExecContext(ctx, "INSERT INTO test_model(`first_name`,`age`,`last_name`) VALUES (?,?,?)",
        //1, "Tom", 18, "Jerry")
        "Tom", 18, "Jerry")
    if err != nil {
        // 回滚
        err := tx.Rollback()
        if err != nil {
            log.Println(err)
            return
        }
    }

    // 提交事务
    err = tx.Commit()

    require.NoError(t, err)

    affected, err := res.RowsAffected()
    require.NoError(t, err)
    log.Println("受影响行数: ", affected)

    lastId, err := res.LastInsertId()
    require.NoError(t, err)
    // auto_increment才会返回正确的值,如果不是主键自增或者插入失败就会返回0
    log.Println("最后插入的id: ", lastId)
}

10. SQL 编程:Prepare Statement

// crud_test.go
func TestPS(t *testing.T) {
    db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    require.NoError(t, err)
    defer db.Close()
    err = db.Ping()
    if err != nil {
        return
    }
    ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    defer cancel()
    stmt, err := db.PrepareContext(ctx, "SELECT * FROM  `test_model` WHERE `id`=?")
    require.NoError(t, err)
    rows, err := stmt.QueryContext(ctx, 1)
    require.NoError(t, err)
    for rows.Next() {
        tm := TestModel{}
        err = rows.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
        require.NoError(t, err)
        log.Println(tm)
    }

    // 如果是in,就无法复用
    // 如果定义很多个prepareStatement,到达数据库上限会报错
    //stmt, err  = db.PrepareContext(ctx, "SELECT * FROM  `test_model` WHERE `id` in (?,?)")
    //stmt, err  = db.PrepareContext(ctx, "SELECT * FROM  `test_model` WHERE `id` in (?,?,?)")
    // 整个应用关闭的时候调用
    // 什么时候可以被关?有可能正在使用 -> 1.引用计数 2.一次性ps
    stmt.Close()
}

11. SQL 编程:sqlmock 入门、SQL 编程面试要点

// sqlmock_test.go
package sql_demo

import (
    "context"
    "errors"
    "github.com/DATA-DOG/go-sqlmock"
    "github.com/stretchr/testify/require"
    "log"
    "testing"
)

func TestSQLMock(t *testing.T) {
    db, mock, err := sqlmock.New()
    defer db.Close()
    require.NoError(t, err)

    mockRows := sqlmock.NewRows([]string{"id", "first_name"})
    mockRows.AddRow(1, "Tom")
    // 准备数据
    // 正则表达式
    mock.ExpectQuery("SELECT id,first_name FROM `user`.*").WillReturnRows(mockRows)
    mock.ExpectQuery("SELECT id FROM `user`.*").WillReturnError(errors.New("mock error"))

    // mock准备的顺序和执行的顺序要一致  一一对应
    rows, err := db.QueryContext(context.Background(),
        "SELECT id,first_name FROM `user` WHERE id=?", 1)
    require.NoError(t, err)
    for rows.Next() {
        tm := TestModel{}
        err = rows.Scan(&tm.Id, &tm.FirstName)
        require.NoError(t, err)
        log.Println(tm)
    }

    _, err = db.QueryContext(context.Background(),
        "SELECT id FROM `user` WHERE id=?", 1)
    require.Error(t, err)
}

结果集

12. 结果集处理:Open 与 OpenDB

// db.go
package orm

import "database/sql"

type DBOption func(db *DB)

// DB 是一个sql.DB的装饰器
type DB struct {
    r  *registry
    db *sql.DB
}

func Open(driver string, dataSoruceName string, opts ...DBOption) (*DB, error) {
    db, err := sql.Open(driver, dataSoruceName)
    if err != nil {
        return nil, err
    }
    return OpenDB(db, opts...)
}

func OpenDB(db *sql.DB, opts ...DBOption) (*DB, error) {
    res := &DB{
        r:  newRegistry(),
        db: db,
    }
    for _, opt := range opts {
        opt(res)
    }
    return res, nil
}

func MustOpen(driver string, dataSoruceName string, opts ...DBOption) *DB {
    res, err := Open(driver, dataSoruceName, opts...)
    if err != nil {
        panic(err)
    }
    return res
}
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    q, err := s.Build()
    if err != nil {
        return nil, err
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
}

func (s *Selector[T]) GetMulti(ctx context.Context) ([]*T, error) {
    q, err := s.Build()
    if err != nil {
        return nil, err
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
    for rows.Next() {

    }
}

13. 结果集处理:发起查询异常情况

// select_test.go
func TestGet(t *testing.T) {
    mockDB, mock, err := sqlmock.New()
    require.NoError(t, err)
    db, err := OpenDB(mockDB)
    require.NoError(t, err)

    // query error
    mock.ExpectQuery("SELECT .*").WillReturnError(errors.New("query error"))
    // no rows
    rows := sqlmock.NewRows([]string{
        "id", "first_name", "age", "last_name",
    })
    mock.ExpectQuery("SELECT .*").WillReturnRows(rows)

    testCases := []struct {
        name string
        s    *Selector[TestModel]

        wantRes *TestModel
        wantErr error
    }{
        {
            name:    "invalid query",
            s:       NewSelector[TestModel](db).Where(C("xxx").Eq(1)),
            wantErr: errs.NewErrUnknownField("xxx"),
        },
        {
            name:    "query error",
            s:       NewSelector[TestModel](db).Where(C("Id").Eq(1)),
            wantErr: errors.New("query error"),
        },
        {
            name:    "no rows",
            s:       NewSelector[TestModel](db).Where(C("Id").Lt(1)),
            wantErr: errors.New("orm: 没有数据"),
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            res, err := tc.s.Get(context.Background())
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantRes, res)
        })
    }
}
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    q, err := s.Build()
    if err != nil {
        return nil, err
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
    //_, err  = db.QueryContext(ctx, q.SQL, q.Args...)
    // 查询错误
    if err != nil {
        return nil, err
    }
    if !rows.Next() {
        return nil, ErrNoRows
    }

    // select出来多少列
    cs, err := rows.Columns()
    if err != nil {
        return nil, err
    }

    tp := new(T)

    rows.Scan()

    return nil, nil
}

14. 结果集处理:反射处理结果集

// model.go
type Field struct {
    // 字段名
    goName string
    // 列名
    colName string
    // 列的类型
    typ reflect.Type
}
func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
    // ...

    fieldMap := make(map[string]*Field, numField)
    for i := 0; i < numField; i++ {
        // ...
        fieldMap[fd.Name] = &Field{
            goName: fd.Name,
            //colName: underscoreName(fd.Name),
            colName: colName,
            typ:     fd.Type,
        }
    }
    // ...
}
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    q, err := s.Build()
    if err != nil {
        return nil, err
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
    //_, err  = db.QueryContext(ctx, q.SQL, q.Args...)
    // 查询错误
    if err != nil {
        return nil, err
    }
    if !rows.Next() {
        return nil, ErrNoRows
    }

    // select出来多少列
    cs, err := rows.Columns()
    if err != nil {
        return nil, err
    }

    tp := new(T)

    vals := make([]any, 0, len(cs))
    for _, c := range cs {
        // c是列名
        for _, fd := range s.model.fields {
            if fd.colName == c {
                // 反射创建实例
                // 这里创建的实例是原本类型的指针
                val := reflect.New(fd.typ)
                vals = append(vals, val.Interface())
            }
        }
    }

    rows.Scan(vals...)

    // 把vals填入tp
    tpValue := reflect.ValueOf(tp)
    for i, c := range cs {
        // c是列名
        for _, fd := range s.model.fields {
            if fd.colName == c {
                tpValue.Elem().FieldByName(fd.goName).
                    Set(reflect.ValueOf(vals[i]).Elem())
            }
        }
    }

    return tp, nil
}
// select_test.go
func TestGet(t *testing.T) {
    mockDB, mock, err := sqlmock.New()
    require.NoError(t, err)
    db, err := OpenDB(mockDB)
    require.NoError(t, err)

    // query error
    mock.ExpectQuery("SELECT .*").WillReturnError(errors.New("query error"))
    // no rows
    rows := sqlmock.NewRows([]string{
        "id", "first_name", "age", "last_name",
    })
    mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
    // data
    rows = sqlmock.NewRows([]string{
        "id", "first_name", "age", "last_name",
    })
    rows.AddRow("1", "Tom", "18", "Jerry")
    mock.ExpectQuery("SELECT .*").WillReturnRows(rows)

    testCases := []struct {
        name string
        s    *Selector[TestModel]

        wantRes *TestModel
        wantErr error
    }{
        {
            name:    "invalid query",
            s:       NewSelector[TestModel](db).Where(C("xxx").Eq(1)),
            wantErr: errs.NewErrUnknownField("xxx"),
        },
        {
            name:    "query error",
            s:       NewSelector[TestModel](db).Where(C("Id").Eq(1)),
            wantErr: errors.New("query error"),
        },
        {
            name:    "no rows",
            s:       NewSelector[TestModel](db).Where(C("Id").Lt(1)),
            wantErr: errors.New("orm: 没有数据"),
        },
        {
            name:    "data",
            s:       NewSelector[TestModel](db).Where(C("Id").Lt(1)),
            wantRes: &TestModel{Id: 1, FirstName: "Tom", Age: 18, LastName: &sql.NullString{Valid: true, String: "Jerry"}},
        },
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            res, err := tc.s.Get(context.Background())
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantRes, res)
        })
    }
}

15. 结果集处理:代码优化与总结

// model.go
type Model struct {
    tableName string
    // 字段名-字段定义
    fieldMap map[string]*Field
    // 列名-字段定义
    columnMap map[string]*Field
}

func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
    // ...

    fieldMap := make(map[string]*Field, numField)
    columnMap := make(map[string]*Field, numField)
    for i := 0; i < numField; i++ {
        fd := elemType.Field(i)
        // 取tag `xxx:xxx=xxx`
        pair, err := r.parseTag(fd.Tag)
        if err != nil {
            return nil, err
        }
        // 取column: `orm:"column"=xxx`
        colName := pair[tagKeyColumn]
        if colName == "" {
            // 用户没有设置,用字段名
            colName = underscoreName(fd.Name)
        }
        fdData := &Field{
            goName: fd.Name,
            //colName: underscoreName(fd.Name),
            colName: colName,
            typ:     fd.Type,
        }
        fieldMap[fd.Name] = fdData
        columnMap[colName] = fdData
    }

    var tableName string
    if tbl, ok := entity.(TableName); ok {
        tableName = tbl.TableName()
    }
    if tableName == "" {
        tableName = underscoreName(elemType.Name())
    }

    res := &Model{
        tableName: tableName,
        fieldMap:  fieldMap,
        columnMap: columnMap,
    }

    for _, opt := range opts {
        err := opt(res)
        if err != nil {
            return nil, err
        }
    }

    r.models.Store(typ, res)
    return res, nil
}
// model_test.go
package orm

import (
    "database/sql"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "my_framework/orm/internal/errs"
    "reflect"
    "testing"
)

func Test_parseModel(t *testing.T) {
    tests := []struct {
        name string

        entity    any
        wantModel *Model
        fields    []*Field
        wantErr   error
    }{
        {
            name:   "struct",
            entity: TestModel{},
            //wantModel: &Model{
            //    tableName: "test_model",
            //    fieldMap: map[string]*Field{
            //        "Id": {
            //            colName: "id",
            //        },
            //        "FirstName": {
            //            colName: "first_name",
            //        },
            //        "LastName": {
            //            colName: "last_name",
            //        },
            //        "Age": {
            //            colName: "age",
            //        },
            //    },
            //},
            wantErr: errs.ErrPointerOnly,
        },
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &Model{
                tableName: "test_model",
                fieldMap:  map[string]*Field{},
            },
            fields: []*Field{
                {
                    colName: "id",
                    goName:  "Id",
                    typ:     reflect.TypeOf(int64(0)),
                },
                {
                    colName: "first_name",
                    goName:  "FirstName",
                    typ:     reflect.TypeOf(""),
                },
                {
                    colName: "last_name",
                    goName:  "LastName",
                    typ:     reflect.TypeOf(&sql.NullString{}),
                },
                {
                    colName: "age",
                    goName:  "Age",
                    typ:     reflect.TypeOf(int8(0)),
                },
            },
        },
    }

    r := &registry{}
    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            m, err := r.Register(tc.entity)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            fieldMap := make(map[string]*Field)
            columnMap := make(map[string]*Field)
            for _, f := range tc.fields {
                fieldMap[f.goName] = f
                columnMap[f.colName] = f
            }
            tc.wantModel.fieldMap = fieldMap
            tc.wantModel.columnMap = columnMap
            assert.Equal(t, tc.wantModel, m)
        })
    }
}

func TestRegistry_Get(t *testing.T) {
    testCases := []struct {
        name string

        entity    any
        wantModel *Model
        fields    []*Field
        wantErr   error
    }{
        {
            name:   "pointer",
            entity: &TestModel{},
            wantModel: &Model{
                tableName: "test_model",
            },
            fields: []*Field{
                {
                    colName: "id",
                    goName:  "Id",
                    typ:     reflect.TypeOf(int64(0)),
                },
                {
                    colName: "first_name",
                    goName:  "FirstName",
                    typ:     reflect.TypeOf(""),
                },
                {
                    colName: "last_name",
                    goName:  "LastName",
                    typ:     reflect.TypeOf(&sql.NullString{}),
                },
                {
                    colName: "age",
                    goName:  "Age",
                    typ:     reflect.TypeOf(int8(0)),
                },
            },
        },
        {
            name: "tag",
            entity: func() any {
                type TagTable struct {
                    FirstName string `orm:"column=first_name_t"`
                }
                return &TagTable{}
            }(),
            fields: []*Field{
                {
                    colName: "first_name_t",
                    goName:  "FirstName",
                    typ:     reflect.TypeOf(""),
                },
            },
            wantModel: &Model{
                tableName: "tag_table",
            },
        },
        {
            name: "empty tag",
            entity: func() any {
                type TagTable struct {
                    FirstName string `orm:"column="`
                }
                return &TagTable{}
            }(),
            fields: []*Field{{
				colName: "first_name",
				goName:  "FirstName",
				typ:     reflect.TypeOf(""),
			},
			},
			wantModel: &Model{
				tableName: "tag_table",
			},
		},
		{
			name: "column only",
			entity: func() any {
				type TagTable struct {
					FirstName string `orm:"column"`
				}
				return &TagTable{}
			}(),
			wantErr: errs.NewErrInvalidTagContent("column"),
		},
		{
			name: "invalid tag",
			entity: func() any {
				type TagTable struct {
					FirstName string `orm:"abc=abc"`
				}
				return &TagTable{}
			}(),
			fields: []*Field{
				{
					colName: "first_name",
					goName:  "FirstName",
					typ:     reflect.TypeOf(""),
				},
			},
			wantModel: &Model{
				tableName: "tag_table",
			},
		},
		{
			name:   "custom table name",
			entity: &CustomTableName{},
			fields: []*Field{
				{
					colName: "first_name",
					goName:  "FirstName",
					typ:     reflect.TypeOf(""),
				},
			},
			wantModel: &Model{
				tableName: "custom_table_name_t",
			},
		},
		{
			name:   "custom table name ptr",
			entity: &CustomTableNamePtr{},
			fields: []*Field{
				{
					colName: "first_name",
					goName:  "FirstName",
					typ:     reflect.TypeOf(""),
				},
			},
			wantModel: &Model{
				tableName: "custom_table_name_ptr_t",
			},
		},
		{
			name:   "custom table name empty ptr",
			entity: &CustomTableNameEmpty{},
			fields: []*Field{
				{
					colName: "first_name",
					goName:  "FirstName",
					typ:     reflect.TypeOf(""),
				},
			},
			wantModel: &Model{
				tableName: "custom_table_name_empty",
			},
		},
	}

	r := newRegistry()
	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			m, err := r.Get(tc.entity)
			assert.Equal(t, tc.wantErr, err)
			if err != nil {
				return
			}

			fieldMap := make(map[string]*Field)
			columnMap := make(map[string]*Field)
			for _, f := range tc.fields {
				fieldMap[f.goName] = f
				columnMap[f.colName] = f
			}
			tc.wantModel.fieldMap = fieldMap
			tc.wantModel.columnMap = columnMap

			assert.Equal(t, tc.wantModel, m)

			typ := reflect.TypeOf(tc.entity)
			cache, ok := r.models.Load(typ)
			assert.True(t, ok)
			assert.Equal(t, tc.wantModel, cache)
		})
	}
}

type CustomTableName struct {
	FirstName string
}

func (c CustomTableName) TableName() string {
	return "custom_table_name_t"
}

type CustomTableNamePtr struct {
	FirstName string
}

func (c *CustomTableNamePtr) TableName() string {
	return "custom_table_name_ptr_t"
}

type CustomTableNameEmpty struct {
	FirstName string
}

func (c CustomTableNameEmpty) TableName() string {
	return ""
}

func TestModelWithTableName(t *testing.T) {
	r := newRegistry()
	m, err := r.Register(&TestModel{}, ModelWithTableName("test_model_ttt"))
	require.NoError(t, err)
	assert.Equal(t, "test_model_ttt", m.tableName)
}

func TestModelWithColName(t *testing.T) {
	testCases := []struct {
		name string

		field   string
		colName string

		wantColName string
		wantErr     error
	}{
		{
			name:        "column name",
			field:       "FirstName",
			colName:     "first_name_ccc",
			wantColName: "first_name_ccc",
		},
		{
			name:    "invalid column name",
			field:   "XXX",
			colName: "first_name_ccc",
			wantErr: errs.NewErrUnknownField("XXX"),
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			r := newRegistry()
			m, err := r.Register(&TestModel{}, ModelWithColumnName(tc.field, tc.colName))
			assert.Equal(t, tc.wantErr, err)
			if err != nil {
				return
			}
			fd, ok := m.fieldMap[tc.field]
			require.True(t, ok)
			assert.Equal(t, tc.wantColName, fd.colName)
		})
	}
}
```

```go
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
	// ...

	tp := new(T)

	vals := make([]any, 0, len(cs))
	valElems := make([]reflect.Value, 0, len(cs))
	for _, c := range cs {
		// c是列名
		fd, ok := s.model.columnMap[c]
		if !ok {
			return nil, errs.NewErrUnknownColumn(c)
		}
		// 反射创建实例
		// 这里创建的实例是原本类型的指针
		val := reflect.New(fd.typ)
		vals = append(vals, val.Interface())
		valElems = append(valElems, val.Elem())
	}

	err = rows.Scan(vals...)
	if err != nil {
		return nil, err
	}

	// 把vals填入tp
	tpValue := reflect.ValueOf(tp)
	for i, c := range cs {
		fd, ok := s.model.columnMap[c]
		if !ok {
			return nil, errs.NewErrUnknownColumn(c)
		}
		tpValue.Elem().FieldByName(fd.goName).
			//Set(reflect.ValueOf(vals[i]).Elem())
			Set(valElems[i])
	}

	return tp, nil
}
```

```go
// select_test.go
package orm

import (
	"context"
	"database/sql"
	"errors"
	"github.com/DATA-DOG/go-sqlmock"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"my_framework/orm/internal/errs"
	"testing"
)

func TestSelector_Build(t *testing.T) {
	db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
	require.NoError(t, err)
	testCases := []struct {
		name string

		builder QueryBuilder

		wantQuery *Query
		wantErr   error
	}{
		{
			name:    "not from",
			builder: NewSelector[TestModel](db),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model`;",
				Args: nil,
			},
		},
		{
			name:    "from",
			builder: (NewSelector[TestModel](db)).Table("`test_model`"),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model`;",
				Args: nil,
			},
		},
		{
			name:    "empty from",
			builder: (NewSelector[TestModel](db)),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model`;",
				Args: nil,
			},
		},
		{
			name:    "empty table",
			builder: (NewSelector[TestModel](db)).Table(""),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model`;",
				Args: nil,
			},
		},
		{
			name:    "db table",
			builder: (NewSelector[TestModel](db)).Table("`mybatis`.`test`"),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `mybatis`.`test`;",
				Args: nil,
			},
		},
		{
			name:    "where",
			builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18)),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model` WHERE `age` = ?;",
				Args: []any{18},
			},
		},
		{
			name:    "empty where",
			builder: (NewSelector[TestModel](db)).Where(),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model`;",
				Args: nil,
			},
		},
		{
			name:    "not",
			builder: (NewSelector[TestModel](db)).Where(Not(C("Age").Eq(18))),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model` WHERE  NOT (`age` = ?);",
				Args: []any{18},
			},
		},
		{
			name:    "and",
			builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`first_name` = ?);",
				Args: []any{18, "Tom"},
			},
		},
		{
			name:    "and",
			builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
			wantQuery: &Query{
				SQL:  "SELECT * FROM `test_model` WHERE (`age` = ?) OR (`first_name` = ?);",
				Args: []any{18, "Tom"},
			},
		},
		{
			name:    "invalid column",
			builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("XXXX").Eq("Tom"))),
			wantErr: errs.NewErrUnknownField("XXXX"),
		},
	}

	for _, tc := range testCases {
		t.Run(tc.name, func(t *testing.T) {
			q, err := tc.builder.Build()
			assert.Equal(t, tc.wantErr, err)
			if err != nil {
				return
			}
			assert.Equal(t, tc.wantQuery, q)
		})
	}
}

func TestGet(t *testing.T) {
	mockDB, mock, err := sqlmock.New()
	require.NoError(t, err)
	db, err := OpenDB(mockDB)
	require.NoError(t, err)

	// query error
	mock.ExpectQuery("SELECT .*").WillReturnError(errors.New("query error"))
	// no rows
	rows := sqlmock.NewRows([]string{
		"id", "first_name", "age", "last_name",
	})
	mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
	// data
	rows = sqlmock.NewRows([]string{
		"id", "first_name", "age", "last_name",
	})
	rows.AddRow("1", "Tom", "18", "Jerry")
	mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
	// scan error
	rows = sqlmock.NewRows([]string{
		"id", "first_name", "age", "last_name",
	})
	rows.AddRow("abc", "Tom", "18", "Jerry")
	mock.ExpectQuery("SELECT .*").WillReturnRows(rows)

	testCases := []struct {
		name string
		s    *Selector[TestModel]

		wantRes *TestModel
		wantErr error
	}{
		{
			name:    "invalid query",
			s:       NewSelector[TestModel](db).Where(C("xxx").Eq(1)),
			wantErr: errs.NewErrUnknownField("xxx"),
		},
		{
			name:    "query error",
			s:       NewSelector[TestModel](db).Where(C("Id").Eq(1)),
			wantErr: errors.New("query error"),
		},
		{
			name:    "no rows",
			s:       NewSelector[TestModel](db).Where(C("Id").Lt(1)),
			wantErr: errors.New("orm: 没有数据"),
		},
		{
			name:    "data",
			s:       NewSelector[TestModel](db).Where(C("Id").Lt(1)),
			wantRes: &TestModel{Id: 1, FirstName: "Tom", Age: 18, LastName: &sql.NullString{Valid: true, String: "Jerry"}},
        },
        //{
        //    name:    "scan error",
        //    s:       NewSelector[TestModel](db).Where(C("Id").Lt(1)),
        //    wantErr: errs.ErrNoRows,
        //},
    }

    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            res, err := tc.s.Get(context.Background())
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantRes, res)
        })
    }
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}

16. 加餐:Option 设计模式

package other

import "errors"

type MyStructOption func(myStruct *MyStruct)
type MyStructOptionErr func(myStruct *MyStruct) error

type MyStruct struct {
    // 有两部分, 1.必须用户输入 2.可选
    // 必传
    id   uint64
    name string

    // 可选
    address string
    // ...

    field1 int
    field2 int
}

func WithField1And2(field1 int, field2 int) MyStructOption {
    return func(myStruct *MyStruct) {
        myStruct.field1 = field1
        myStruct.field2 = field2
    }
}

func WithAddress(address string) MyStructOption {
    return func(myStruct *MyStruct) {
        myStruct.address = address
    }
}

func WithField1And2V1(field1 int, field2 int) MyStructOptionErr {
    return func(myStruct *MyStruct) error {
        myStruct.field1 = field1
        myStruct.field2 = field2
        return nil
    }
}

func WithAddressV1(address string) MyStructOptionErr {
    return func(myStruct *MyStruct) error {
        if address == "" {
            return errors.New("地址不能为空")
        }
        myStruct.address = address
        return nil
    }
}

func WithAddressV2(address string) MyStructOption {
    return func(myStruct *MyStruct) {
        if address == "" {
            panic(errors.New("地址不能为空"))
        }
        myStruct.address = address
    }
}

// 参数包含用户所有必须传入的字段
func NewMyStruct(id uint64, name string, opts ...MyStructOption) *MyStruct {
    // 必传部分
    res := &MyStruct{
        id:   id,
        name: name,
    }
    for _, opt := range opts {
        opt(res)
    }
    return res
}

// 参数包含用户所有必须传入的字段
func NewMyStructV1(id uint64, name string, opts ...MyStructOptionErr) (*MyStruct, error) {
    // 必传部分
    res := &MyStruct{
        id:   id,
        name: name,
    }
    for _, opt := range opts {
        if err := opt(res); err != nil {
            return nil, err
        }
    }
    return res, nil
}

12 第六周:ORM 框架之结果集处理、SELECT 进阶与 INSERT

结果集处理

1. 结果集处理:unsafe 入门



// accessor.go
package unsafe

import (
    "errors"
    "reflect"
    "unsafe"
)

type UnsafeAccessor struct {
    fields  map[string]FieldMeta
    address unsafe.Pointer
}

func NewUnsafeAccessor(entity any) *UnsafeAccessor {
    typ := reflect.TypeOf(entity)
    typ = typ.Elem()
    numField := typ.NumField()
    fields := make(map[string]FieldMeta, numField)
    for i := 0; i < numField; i++ {
        fd := typ.Field(i)
        fields[fd.Name] = FieldMeta{
            Offset: fd.Offset,
            typ:    fd.Type,
        }
    }
    val := reflect.ValueOf(entity)
    return &UnsafeAccessor{
        fields:  fields,
        address: val.UnsafePointer(),
    }
}

func (a *UnsafeAccessor) Field(field string) (any, error) {
    // 起始地址
    //a.address
    fd, ok := a.fields[field]
    if !ok {
        return nil, errors.New("非法字段")
    }
    // 字段起始地址
    fdAddress := unsafe.Pointer(uintptr(a.address) + fd.Offset)
    // 知道类型的情况
    //return *(*int)(fdAddress), nil
    // 不知道类型
    // 根据地址转为对应类型的指针
    return reflect.NewAt(fd.typ, fdAddress).Elem().Interface(), nil
}

func (a *UnsafeAccessor) SetField(field string, val any) error {
    // 起始地址
    //a.address
    fd, ok := a.fields[field]
    if !ok {
        return errors.New("非法字段")
    }
    // 字段起始地址
    fdAddress := unsafe.Pointer(uintptr(a.address) + fd.Offset)
    // 知道类型
    //*(*int)(fdAddress) = val.(int)
    // 不知道类型
    reflect.NewAt(fd.typ, fdAddress).Elem().Set(reflect.ValueOf(val))
    return nil
}

type FieldMeta struct {
    Offset uintptr
    typ    reflect.Type
}
// accessor_test.go
package unsafe

import (
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "testing"
)

func TestUnsafeAccessor_Field(t *testing.T) {
    type User struct {
        Name string
        Age  int
    }

    accessor := NewUnsafeAccessor(&User{
        Name: "Tom",
        Age:  18,
    })
    field, err := accessor.Field("Age")
    require.NoError(t, err)
    assert.Equal(t, 18, field)

    err = accessor.SetField("Age", 19)
    require.NoError(t, err)
}
// iterate_fields_test.go
package unsafe

import "testing"

func TestPrintFieldOffset(t *testing.T) {
    testCases := []struct {
        name   string
        entity any
    }{
        {
            name:   "user",
            entity: User{},
        },
        {
            name:   "user v1",
            entity: UserV1{},
        },
    }
    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            PrintFieldOffset(tc.entity)
        })
    }
}

type User struct {
    // 0
    Name    string
    // 16
    Age     int32
    // 24
    Alias   []string
    // 48
    Address string
}

type UserV1 struct {
    // 0
    Name    string
    // 16
    Age     int32
    // 20
    AgeV1     int32
    // 24
    Alias   []string
    // 48
    Address string
}
// iterate_fields.go
package unsafe

import "reflect"

func PrintFieldOffset(entity any) {
    typ := reflect.TypeOf(entity)
    numField := typ.NumField()
    for i := 0; i < numField; i++ {
        field := typ.Field(i)
        println(field.Offset)
    }
}

2. 结果集处理:unsafe 实现

// select.go
func (s *Selector[T]) GetV1(ctx context.Context) (*T, error) {
    q, err := s.Build()
    if err != nil {
        return nil, err
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
    //_, err  = db.QueryContext(ctx, q.SQL, q.Args...)
    // 查询错误
    if err != nil {
        return nil, err
    }
    if !rows.Next() {
        return nil, ErrNoRows
    }

    cs, err := rows.Columns()
    if err != nil {
        return nil, err
    }

    var vals []any
    tp := new(T)   
    // 起始地址
    address := reflect.ValueOf(tp).UnsafePointer()
    for _, c := range cs {
        // c是列名
        fd, ok := s.model.columnMap[c]
        if !ok {
            return nil, errs.NewErrUnknownColumn(c)
        }
        // 需要计算偏移量
        // 起始地址+偏移量
        fdAddress := unsafe.Pointer(uintptr(address) + fd.offset)
        // 在指定地址上创建原本类型的指针
        // 得到的是指针类型
        val := reflect.NewAt(fd.typ, fdAddress)
        vals = append(vals, val.Interface())
    }

    err = rows.Scan(vals...)
    return tp, err
}
// model.go
func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
    // ...

    fieldMap := make(map[string]*Field, numField)
    columnMap := make(map[string]*Field, numField)
    for i := 0; i < numField; i++ {
        fd := elemType.Field(i)
        // 取tag `xxx:xxx=xxx`
        pair, err := r.parseTag(fd.Tag)
        if err != nil {
            return nil, err
        }
        // 取column: `orm:"column"=xxx`
        colName := pair[tagKeyColumn]
        if colName == "" {
            // 用户没有设置,用字段名
            colName = underscoreName(fd.Name)
        }
        fdData := &Field{
            goName: fd.Name,
            //colName: underscoreName(fd.Name),
            colName: colName,
            typ:     fd.Type,
            offset:  fd.Offset,
        }
        fieldMap[fd.Name] = fdData
        columnMap[colName] = fdData
    }

    // ...
}

3. 结果集处理:valuer 重构与基准测试

// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    q, err := s.Build()
    if err != nil {
        return nil, err
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
    //_, err  = db.QueryContext(ctx, q.SQL, q.Args...)
    // 查询错误
    if err != nil {
        return nil, err
    }
    if !rows.Next() {
        return nil, ErrNoRows
    }

    tp := new(T)
    val := s.db.creator(s.model, tp)
    err = val.SetColumns(rows)
    return tp, err
}
// db.go
type DB struct {
    r       model.Registry
    db      *sql.DB
    creator valuer.Creator
}

func DBUseReflect() DBOption {
    return func(db *DB) {
        db.creator = valuer.NewReflectValue
    }
}

func DBUseUnsafe() DBOption {
    return func(db *DB) {
        db.creator = valuer.NewUnsafeValue
    }
}

func Open(driver string, dataSoruceName string, opts ...DBOption) (*DB, error) {
    db, err := sql.Open(driver, dataSoruceName)
    if err != nil {
        return nil, err
    }
    return OpenDB(db, opts...)
}

func OpenDB(db *sql.DB, opts ...DBOption) (*DB, error) {
    res := &DB{
        r:       model.NewRegistry(),
        db:      db,
        creator: valuer.NewUnsafeValue,
    }
    for _, opt := range opts {
        opt(res)
    }
    return res, nil
}
// value.go
package valuer

import (
    "database/sql"
    "my_framework/orm/model"
)

type Value interface {
    SetColumns(rows *sql.Rows) error
}

type ValuerV1 interface {
    SetColumns(ehntity any, rows *sql.Rows) error
}

type Creator func(model *model.Model, entity any) Value
// unsafe.go
package valuer

import (
    "database/sql"
    "my_framework/orm/internal/errs"
    "my_framework/orm/model"
    "reflect"
    "unsafe"
)

type unsafeValue struct {
    model *model.Model
    // *T
    val any
}

// 如果后续Creator的定义修改,这里会报错,就可以知道发生了改变
var _ Creator = NewUnsafeValue

func (u unsafeValue) SetColumns(rows *sql.Rows) error {
    cs, err := rows.Columns()
    if err != nil {
        return err
    }

    var vals []any
    // 起始地址
    address := reflect.ValueOf(u.val).UnsafePointer()
    for _, c := range cs {
        // c是列名
        fd, ok := u.model.ColumnMap[c]
        if !ok {
            return errs.NewErrUnknownColumn(c)
        }
        // 需要计算偏移量
        // 起始地址+偏移量
        fdAddress := unsafe.Pointer(uintptr(address) + fd.Offset)
        // 在指定地址上创建原本类型的指针
        // 得到的是指针类型
        val := reflect.NewAt(fd.Typ, fdAddress)
        vals = append(vals, val.Interface())
    }

    err = rows.Scan(vals...)
    return err
}

func NewUnsafeValue(model *model.Model, val any) Value {
    return unsafeValue{
        model: model,
        val:   val,
    }
}
// reflect.go
package valuer

import (
    "database/sql"
    "my_framework/orm/internal/errs"
    "my_framework/orm/model"
    "reflect"
)

type reflectValue struct {
    model *model.Model
    // 对应T的指针
    val any
}

var _ Creator = NewReflectValue

func (r reflectValue) SetColumns(rows *sql.Rows) error {
    // select出来多少列
    cs, err := rows.Columns()
    if err != nil {
        return err
    }

    vals := make([]any, 0, len(cs))
    valElems := make([]reflect.Value, 0, len(cs))
    for _, c := range cs {
        // c是列名
        fd, ok := r.model.ColumnMap[c]
        if !ok {
            return errs.NewErrUnknownColumn(c)
        }
        // 反射创建实例
        // 这里创建的实例是原本类型的指针
        val := reflect.New(fd.Typ)
        vals = append(vals, val.Interface())
        valElems = append(valElems, val.Elem())
    }

    err = rows.Scan(vals...)
    if err != nil {
        return err
    }

    // 把vals填入tp
    tpValue := reflect.ValueOf(r.val)
    for i, c := range cs {
        fd, ok := r.model.ColumnMap[c]
        if !ok {
            return errs.NewErrUnknownColumn(c)
        }
        tpValue.Elem().FieldByName(fd.GoName).
            //Set(reflect.ValueOf(vals[i]).Elem())
            Set(valElems[i])
    }
    return nil
}

func NewReflectValue(model *model.Model, val any) Value {
    return reflectValue{
        model: model,
        val:   val,
    }
}
// reflect_test.go
package valuer

import (
    "database/sql"
    "github.com/DATA-DOG/go-sqlmock"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "my_framework/orm/model"
    "testing"
)

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}

func TestReflectValue_SetColumns(t *testing.T) {
    testSetColumns(t, NewReflectValue)
}

func testSetColumns(t *testing.T, create Creator) {
    testCases := []struct {
        name string
        // 指针
        entity     any
        rows       func() *sqlmock.Rows
        wantErr    error
        wantEntity any
    }{
        {
            name:   "set columns",
            entity: &TestModel{},
            rows: func() *sqlmock.Rows {
                rows := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"})
                rows.AddRow("1", "Tom", "18", "Jerry")
                return rows
            },
            wantEntity: &TestModel{
                Id:        1,
                FirstName: "Tom",
                Age:       18,
                LastName:  &sql.NullString{Valid: true, String: "Jerry"},
            },
        },
        {
            // 部分列
            name:   "partial columns",
            entity: &TestModel{},
            rows: func() *sqlmock.Rows {
                rows := sqlmock.NewRows([]string{"id", "first_name"})
                rows.AddRow("1", "Tom")
                return rows
            },
            wantEntity: &TestModel{
                Id:        1,
                FirstName: "Tom",
            },
        },
        {
            // 列的不同顺序
            name:   "order",
            entity: &TestModel{},
            rows: func() *sqlmock.Rows {
                rows := sqlmock.NewRows([]string{"id", "last_name", "first_name", "age"})
                rows.AddRow("1", "Jerry", "Tom", "18")
                return rows
            },
            wantEntity: &TestModel{
                Id:        1,
                FirstName: "Tom",
                Age:       18,
                LastName:  &sql.NullString{Valid: true, String: "Jerry"},
            },
        },
    }

    r := model.NewRegistry()

    mockDB, mock, err := sqlmock.New()
    require.NoError(t, err)
    defer mockDB.Close()
    for _, tc := range testCases {
        t.Run(tc.name, func(t *testing.T) {
            // 构造rows
            mockRows := tc.rows()
            mock.ExpectQuery("SELECT XX").WillReturnRows(mockRows)
            rows, err := mockDB.Query("SELECT XX")
            require.NoError(t, err)

            rows.Next()

            model, err := r.Get(tc.entity)
            require.NoError(t, err)
            val := create(model, tc.entity)
            val.SetColumns(rows)
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            // 比较有没有设置好数据
            assert.Equal(t, tc.wantEntity, tc.entity)
        })
    }
}
// unsafe_test.go
package valuer

import "testing"

func TestUnsafeValue_SetColumns(t *testing.T) {
    testSetColumns(t, NewUnsafeValue)
}
// value_test.go
package valuer

import (
    "database/sql/driver"
    "github.com/DATA-DOG/go-sqlmock"
    "github.com/stretchr/testify/require"
    "my_framework/orm/model"
    "testing"
)

// go test -bench=BenchmarkSetColumns -benchtime=10000x -benchmem
// 单元测试已经确认结果正确,基准测试只需要测试调用
func BenchmarkSetColumns(b *testing.B) {
    fn := func(b *testing.B, creator Creator) {
        mockDB, mock, err := sqlmock.New()
        require.NoError(b, err)
        defer mockDB.Close()

        // 构造rows,跑N次,要准备N条
        mockRows := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"})
        row := []driver.Value{"1", "Tom", "18", "Jerry"}
        for i := 0; i < b.N; i++ {
            mockRows.AddRow(row...)
        }
        mock.ExpectQuery("SELECT XX").WillReturnRows(mockRows)

        rows, err := mockDB.Query("SELECT XX")

        r := model.NewRegistry()
        m, err := r.Get(&TestModel{})
        require.NoError(b, err)
        // 重置计时器
        b.ResetTimer()
        require.NoError(b, err)
        for i := 0; i < b.N; i++ {
            rows.Next()
            val := creator(m, &TestModel{})
            _ = val.SetColumns(rows)
        }
    }
    b.Run("reflect", func(b *testing.B) {
        fn(b, NewReflectValue)
    })

    b.Run("unsafe", func(b *testing.B) {
        fn(b, NewUnsafeValue)
    })
}

4. 结果集处理:总结与面试要点




// unsafe.go
package valuer

import (
    "database/sql"
    "my_framework/orm/internal/errs"
    "my_framework/orm/model"
    "reflect"
    "unsafe"
)

type unsafeValue struct {
    model *model.Model
    // 起始地址
    address unsafe.Pointer
}

// 如果后续Creator的定义修改,这里会报错,就可以知道发生了改变
var _ Creator = NewUnsafeValue

func (u unsafeValue) SetColumns(rows *sql.Rows) error {
    cs, err := rows.Columns()
    if err != nil {
        return err
    }

    var vals []any
    for _, c := range cs {
        // c是列名
        fd, ok := u.model.ColumnMap[c]
        if !ok {
            return errs.NewErrUnknownColumn(c)
        }
        // 需要计算偏移量
        // 起始地址+偏移量
        fdAddress := unsafe.Pointer(uintptr(u.address) + fd.Offset)
        // 在指定地址上创建原本类型的指针
        // 得到的是指针类型
        val := reflect.NewAt(fd.Typ, fdAddress)
        vals = append(vals, val.Interface())
    }

    err = rows.Scan(vals...)
    return err
}

func NewUnsafeValue(model *model.Model, val any) Value {
    address := reflect.ValueOf(val).UnsafePointer()
    return unsafeValue{
        model:   model,
        address: address,
    }
}

SELECT进阶

5. SELECT 进阶:指定简单列


// select_test.go 
func TestSelector_Select(t *testing.T) {
    db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    require.NoError(t, err)
    //db := memoryDB(t)
    tests := []struct {
        name      string
        s         QueryBuilder
        wantQuery *Query
        wantErr   error
    }{
        {
            name: "multiple columns",
            //s:    NewSelector[TestModel](db).Select("first_name", "last_name"),
            s: NewSelector[TestModel](db).Select(C("FirstName"), C("LastName")),
            wantQuery: &Query{
                SQL: "SELECT `first_name`,`last_name` FROM `test_model`;",
            },
        },
        {
            name: "invalid columns",
            //s:    NewSelector[TestModel](db).Select("first_name", "last_name"),
            s:       NewSelector[TestModel](db).Select(C("Invalid")),
            wantErr: errs.NewErrUnknownField("Invalid"),
        },
    }

    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.s.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}
// select.go
// Selectable 是一个标记接口
// 代表的是查找的列或者聚合函数等
type Selectable interface {
    selectable()
}

type Selector[T any] struct {
    builder
    table   string
    where   []Predicate
    db      *DB
    columns []Selectable
}

func (s *Selector[T]) Select(cols ...Selectable) *Selector[T] {
    s.columns = cols
    return s
}

func (s *Selector[T]) Build() (*Query, error) {
    // ...
    sb := s.sb

    sb.WriteString("SELECT ")
    // 是否指定列
    if len(s.columns) > 0 {
        for i, col := range s.columns {
            println(i)
            // 在添加字段前添加 ,
            // 相对于在前一个col后加 ,
            if i > 0 {
                sb.WriteByte(',')
            }
            err = s.buildColumn(col.(Column))
            if err != nil {
                return nil, err
            }
        }
    } else {
        sb.WriteString("*")
    }
    sb.WriteString(` FROM `)
    // ...
}
// build.go 
func (s *builder) buildExpression(expr Expression) error {
    switch exp := expr.(type) {
    case nil:
    case Predicate:
        // 处理p
        _, ok := exp.left.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.left); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }

        s.sb.WriteByte(' ')
        s.sb.WriteString(exp.op.String())
        s.sb.WriteByte(' ')

        _, ok = exp.right.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.right); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }
    case Column:
        return s.buildColumn(exp)
    case value:
        s.sb.WriteByte('?')
        s.addArg(exp.val)
    default:
        return errs.NewErrUnsupportedExpression(expr)
    }
    return nil
}

func (s *builder) buildColumn(c Column) error {
    fd, ok := s.model.FieldMap[c.name]
    // 传入了错误的字段
    if !ok {
        return errs.NewErrUnknownField(c.name)
    }
    s.sb.WriteByte('`')
    s.sb.WriteString(fd.ColName)
    s.sb.WriteByte('`')
    return nil
}
// column.go
package orm

type Column struct {
    name string
}

func C(name string) Column {
    return Column{name: name}
}

// Eq
// C("id").Eq(12)
// sub.C("id").Eq(12)
func (c Column) Eq(arg any) Predicate {
    return Predicate{
        left:  c,
        op:    opEq,
        right: value{val: arg},
    }
}

func (c Column) Lt(arg any) Predicate {
    return Predicate{
        left:  c,
        op:    opLt,
        right: value{val: arg},
    }
}

func (Column) expr() {}
func (c Column) selectable() {

}

6. SELECT 进阶:指定聚合函数

// aggregate.go
package orm

// Aggregate 聚合函数
// AVG("age"), SUM("age"), COUNT("age"), MAX("age"), MIN("age")
type Aggregate struct {
    fn  string
    arg string
}

func (a Aggregate) selectable() {
}

func Avg(col string) Aggregate {
    return Aggregate{
        fn:  "AVG",
        arg: col,
    }
}

func Sum(col string) Aggregate {
    return Aggregate{
        fn:  "SUM",
        arg: col,
    }
}

func Count(col string) Aggregate {
    return Aggregate{
        fn:  "COUNT",
        arg: col,
    }
}

func Max(col string) Aggregate {
    return Aggregate{
        fn:  "MAX",
        arg: col,
    }
}

func Min(col string) Aggregate {
    return Aggregate{
        fn:  "MIN",
        arg: col,
    }
}
// select_test.go
func TestSelector_Select(t *testing.T) {
    db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    require.NoError(t, err)
    //db := memoryDB(t)
    tests := []struct {
        name      string
        s         QueryBuilder
        wantQuery *Query
        wantErr   error
    }{
        // ...
        {
            name: "avg",
            s:    NewSelector[TestModel](db).Select(Avg("Age")),
            wantQuery: &Query{
                SQL: "SELECT AVG(`age`) FROM `test_model`;",
            },
        },
        {
            name: "sum",
            s:    NewSelector[TestModel](db).Select(Sum("Age")),
            wantQuery: &Query{
                SQL: "SELECT SUM(`age`) FROM `test_model`;",
            },
        },
        {
            name: "sum invalid",
            s:    NewSelector[TestModel](db).Select(Sum("Invalid")),
            wantErr: errs.NewErrUnknownField("Invalid"),
        },
        {
            name: "sum",
            s:    NewSelector[TestModel](db).Select(Sum("Age"),Avg("Age")),
            wantQuery: &Query{
                SQL: "SELECT SUM(`age`),AVG(`age`) FROM `test_model`;",
            },
        },
    }

    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.s.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}
// select.go
func (s *Selector[T]) Build() (*Query, error) {
    // ...
    sb := s.sb

    sb.WriteString("SELECT ")

    if err = s.buildColumns(); err != nil {
        return nil, err
    }
    sb.WriteString(` FROM `)

    // ...
}

func (s *Selector[T]) buildColumns() error {
    if len(s.columns) == 0 {
        s.sb.WriteString("*")
        return nil
    }
    for i, col := range s.columns {
        // 在添加字段前添加 ,
        // 相对于在前一个col后加 ,
        if i > 0 {
            s.sb.WriteByte(',')
        }
        switch c := col.(type) {
        case Column:
            err := s.buildColumn(c.name)
            if err != nil {
                return err
            }
        case Aggregate:
            // 如果是聚合函数
            s.sb.WriteString(c.fn)
            s.sb.WriteByte('(')
            err := s.buildColumn(c.arg)
            if err != nil {
                return err
            }
            s.sb.WriteByte(')')
        }
    }
    return nil
}

7. SELECT 进阶:原生表达式


// expression.go
package orm

// Expression 是一个标记接口,代表表达式
type Expression interface {
    expr()
}

// RawExpr 代表原生表达式
type RawExpr struct {
    raw  string
    args []any
}

func (r RawExpr) selectable() {
}
func (r RawExpr) expr() {
}

func Raw(expr string, args ...any) RawExpr {
    return RawExpr{
        raw:  expr,
        args: args,
    }
}
func (r RawExpr) AsPredicate() Predicate{
     return Predicate{
         left: r,
     }
}

8. SELECT 进阶:别名

// builder.go
func (s *builder) buildColumn(c Column) error {
    fd, ok := s.model.FieldMap[c.name]
    // 传入了错误的字段
    if !ok {
        return errs.NewErrUnknownField(c.name)
    }
    s.sb.WriteByte('`')
    s.sb.WriteString(fd.ColName)
    s.sb.WriteByte('`')
    if c.alias != "" {
        s.sb.WriteString(" AS `")
        s.sb.WriteString(c.alias)
        s.sb.WriteByte('`')
    }
    return nil
}
// column.go
type Column struct {
    name  string
    alias string
}

// 不可变设计
func (c Column) As(alias string) Column {
    return Column{
        name:  c.name,
        alias: alias,
    }
}
// select_test.go
func TestSelector_Select(t *testing.T) {
    db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    require.NoError(t, err)
    //db := memoryDB(t)
    tests := []struct {
        name      string
        s         QueryBuilder
        wantQuery *Query
        wantErr   error
    }{
        // ...
        {
            name: "column alias",
            s:    NewSelector[TestModel](db).Select(C("FirstName").As("my_name")),
            wantQuery: &Query{
                SQL: "SELECT `first_name` AS `my_name` FROM `test_model`;",
            },
        },
        {
            name: "column alias agg",
            s:    NewSelector[TestModel](db).Select(Avg("FirstName").As("my_name")),
            wantQuery: &Query{
                SQL: "SELECT AVG(`first_name`) AS `my_name` FROM `test_model`;",
            },
        },
        {
            name: "column alias in where",
            s:    NewSelector[TestModel](db).Where(C("Id").As("my_id").Eq(18)),
            wantQuery: &Query{
                SQL:  "SELECT * FROM `test_model` WHERE `id` = ?;",
                Args: []any{18},
            },
        },
    }

    for _, tc := range tests {
        // ...
    }
}
// aggregate.go
type Aggregate struct {
    fn    string
    arg   string
    alias string
}

func (a Aggregate) As(alias string) Aggregate {
    return Aggregate{
        fn:    a.fn,
        arg:   a.arg,
        alias: alias,
    }
}

INSERT

9. INSERT:INSERT 语句概览

10. INSERT:最简实现

// model.go
type Model struct {
    TableName string
    // 为了保证顺序遍历
    Fields []*Field
    // 字段名-字段定义
    FieldMap map[string]*Field
    // 列名-字段定义
    ColumnMap map[string]*Field
}

func (r *registry) Registry(entity any, opts ...ModelOpt) (*Model, error) {
    // ...

    res := &Model{
        TableName: tableName,
        FieldMap:  fieldMap,
        ColumnMap: columnMap,
        Fields:    fields,
    }

    // ...
    return res, nil
}
// insert
package orm

import (
    "my_framework/orm/internal/errs"
    "reflect"
    "strings"
)

type Inserter[T any] struct {
    values []*T
    db     *DB
}

func NewInserter[T any](db *DB) *Inserter[T] {
    return &Inserter[T]{
        db: db,
    }
}

func (i *Inserter[T]) Values(vals ...*T) *Inserter[T] {
    i.values = vals
    return i
}

func (i Inserter[T]) Build() (*Query, error) {
    if len(i.values) == 0 {
        return nil, errs.ErrInsertZeroRow
    }
    var sb strings.Builder
    sb.WriteString("INSERT INTO ")
    // 拿到元数据
    m, err := i.db.r.Get(i.values[0])
    if err != nil {
        return nil, err
    }
    // 拼接表名
    sb.WriteByte('`')
    sb.WriteString(m.TableName)
    sb.WriteByte('`')
    // 显示地指定列的顺序,不然很难处理
    sb.WriteByte('(')
    // 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
    for idx, field := range m.Fields {
        if idx > 0 {
            sb.WriteByte(',')
        }
        sb.WriteByte('`')
        sb.WriteString(field.ColName)
        sb.WriteByte('`')
    }
    sb.WriteByte(')')

    sb.WriteString(" VALUES ")
    //sb.WriteByte('(')
    args := make([]any, 0, len(i.values)*len(m.Fields))
    for j, val := range i.values {
        if j > 0 {
            sb.WriteByte(',')
        }
        sb.WriteByte('(')
        for idx, field := range m.Fields {
            if idx > 0 {
                sb.WriteByte(',')
            }
            sb.WriteByte('?')
            // 读取参数
            arg := reflect.ValueOf(val).Elem().
                FieldByName(field.GoName).Interface()
            args = append(args, arg)
        }
        sb.WriteByte(')')
    }
    sb.WriteByte(';')

    return &Query{
        SQL:  sb.String(),
        Args: args,
    }, err
}
// insert_test.go
package orm

import (
    "database/sql"
    "github.com/stretchr/testify/assert"
    "my_framework/orm/internal/errs"
    "testing"
)

func TestInserter_Build(t *testing.T) {
    db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    assert.NoError(t, err)
    tests := []struct {
        name string
        i    QueryBuilder

        wantErr   error
        wantQuery *Query
    }{
        {
            // 插入一行
            name: "single insert",
            i: NewInserter[TestModel](db).Values(&TestModel{
                Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
            }),
            wantQuery: &Query{
                SQL:  "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?);",
                Args: []any{int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true}},
            },
        },
        {
            // 插入多行
            name: "multiple insert",
            i: NewInserter[TestModel](db).Values(
                &TestModel{
                    Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
                },
                &TestModel{
                    Id: 13, FirstName: "Dummy", Age: 19, LastName: &sql.NullString{String: "Shan", Valid: true},
                },
            ),
            wantQuery: &Query{
                SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?),(?,?,?,?);",
                Args: []any{
                    int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
                    int64(13), "Dummy", int8(19), &sql.NullString{String: "Shan", Valid: true},
                },
            },
        },
        {
            // 插入多行
            name:    "no rowinsert",
            i:       NewInserter[TestModel](db).Values(),
            wantErr: errs.ErrInsertZeroRow,
        },
    }

    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.i.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

11. INSERT:指定列

// insert.go
func (i *Inserter[T]) Columns(cols ...string) *Inserter[T] {
    i.columns = cols
    return i
}

func (i Inserter[T]) Build() (*Query, error) {
    // ...
    sb.WriteByte('(')

    fields := m.Fields
    // 用户指定了列
    if len(i.columns) > 0 {
        fields = make([]*model.Field, 0, len(i.columns))
        for _, fd := range i.columns {
            fdMeta, ok := m.FieldMap[fd]
            // 传入错误的列
            if !ok {
                return nil, errs.NewErrUnknownField(fd)
            }
            fields = append(fields, fdMeta)
        }
    }

    // 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
    for idx, field := range fields {
        if idx > 0 {
            sb.WriteByte(',')
        }
        sb.WriteByte('`')
        sb.WriteString(field.ColName)
        sb.WriteByte('`')
    }
    sb.WriteByte(')')

    sb.WriteString(" VALUES ")
    //sb.WriteByte('(')
    args := make([]any, 0, len(i.values)*len(fields))
    for j, val := range i.values {
        if j > 0 {
            sb.WriteByte(',')
        }
        sb.WriteByte('(')
        for idx, field := range fields {
            // ...
        }
        sb.WriteByte(')')
    }
    sb.WriteByte(';')

    return &Query{
        SQL:  sb.String(),
        Args: args,
    }, err
}
// insert_test.go
func TestInserter_Build(t *testing.T) {
    db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    assert.NoError(t, err)
    tests := []struct {
        name string
        i    QueryBuilder

        wantErr   error
        wantQuery *Query
    }{
        // ...
        {
            // 部分列
            name: "partial columns",
            i: NewInserter[TestModel](db).
                Columns("Id","FirstName").
                Values(
                    &TestModel{
                        Id: 12, FirstName: "Tom", Age: 18,
                    },
                    &TestModel{
                        Id: 13, FirstName: "Dummy", Age: 19,
                    },
                ),
            wantQuery: &Query{
                SQL: "INSERT INTO `test_model`(`id`,`first_name`) VALUES (?,?),(?,?);",
                Args: []any{
                    int64(12), "Tom",
                    int64(13), "Dummy",
                },
            },
        },
    }

    for _, tc := range tests {
        // ...
    }
}

12. INSERT:UPSERT API 定义


// insert.go
type OnDuplicateKeyBuilder[T any] struct {
    i *Inserter[T]
}

type OnDuplicateKey[T any] struct {
    assigns []Assignable
}

type Assignable interface {
    assign()
}

type Inserter[T any] struct {
    values  []*T
    db      *DB
    columns []string
    //onDuplicateKey []Assignable
    onDuplicateKey *OnDuplicateKey[T]
}

//func (i *Inserter[T]) OnDuplicateKey(assigns ...Assignable) *Inserter[T] {
//    i.onDuplicateKey = assigns
//    return i
//}

func (o *OnDuplicateKeyBuilder[T]) Update(assigns ...Assignable) *Inserter[T] {
    o.i.onDuplicateKey = &OnDuplicateKey[T]{
        assigns: assigns,
    }
    return o.i
}

func (i *Inserter[T]) OnDuplicateKey() *OnDuplicateKeyBuilder[T] {
    return &OnDuplicateKeyBuilder[T]{
        i: i,
    }
}

13. INSERT:MySQL UPSERT 基本实现

// insert.go
func (i Inserter[T]) Build() (*Query, error) {
    // ...

    sb.WriteString(" VALUES ")
    //sb.WriteByte('(')
    args := make([]any, 0, len(i.values)*len(fields))
    for j, val := range i.values {
        // ...
    }

    // on duplicate
    if i.onDuplicateKey != nil {
        sb.WriteString(" ON DUPLICATE KEY UPDATE ")
        for idx, assign := range i.onDuplicateKey.assigns {
            if idx > 0 {
                sb.WriteByte(',')
            }
            switch a := assign.(type) {
            case Assignment:
                fd, ok := m.FieldMap[a.col]
                if !ok {
                    return nil, errs.NewErrUnknownField(a.col)
                }
                sb.WriteByte('`')
                sb.WriteString(fd.ColName)
                sb.WriteByte('`')
                sb.WriteString("=?")
                args = append(args, a.val)
            case Column:
                fd, ok := m.FieldMap[a.name]
                if !ok {
                    return nil, errs.NewErrUnknownField(a.name)
                }
                sb.WriteByte('`')
                sb.WriteString(fd.ColName)
                sb.WriteByte('`')
                sb.WriteString("=VALUES(")
                sb.WriteByte('`')
                sb.WriteString(fd.ColName)
                sb.WriteByte('`')
                sb.WriteByte(')')
            default:
                return nil, errs.NewErrUnsupportedAssignable(assign)
            }
        }
    }

    sb.WriteByte(';')

    return &Query{
        SQL:  sb.String(),
        Args: args,
    }, err
}
// assignment.go
package orm

type Assignment struct {
    col string
    val any
}

func (Assignment) assign() {
}

func Assign(col string, val any) Assignment {
    return Assignment{
        col: col, val: val,
    }
}
// insert_test.go
func TestInserter_Build(t *testing.T) {
    db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
    assert.NoError(t, err)
    tests := []struct {
        name string
        i    QueryBuilder

        wantErr   error
        wantQuery *Query
    }{
        // ...
        {
            name: "upsert",
            i: NewInserter[TestModel](db).Values(&TestModel{
                Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
            }).OnDuplicateKey().Update(
                Assign("FirstName", "Deng"),
                Assign("Age", 19),
            ),
            wantQuery: &Query{
                SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?)" +
                    " ON DUPLICATE KEY UPDATE `first_name`=?,`age`=?;",
                Args: []any{
                    int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
                    "Deng", 19,
                },
            },
        },
        {
            name: "update insert column",
            i: NewInserter[TestModel](db).Values(
                &TestModel{
                    Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
                },
                &TestModel{
                    Id: 13, FirstName: "Dummy", Age: 19, LastName: &sql.NullString{String: "Shan", Valid: true},
                },
            ).OnDuplicateKey().Update(C("FirstName"), C("Age")),
            wantQuery: &Query{
                SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?),(?,?,?,?)" +
                    " ON DUPLICATE KEY UPDATE `first_name`=VALUES(`first_name`),`age`=VALUES(`age`);",
                Args: []any{
                    int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
                    int64(13), "Dummy", int8(19), &sql.NullString{String: "Shan", Valid: true},
                },
            },
        },
    }

    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.i.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

14. INSERT:方言抽象 Dialect

// dialect.go
package orm

import (
    "my_framework/orm/internal/errs"
    "strings"
)

type Dialect interface {
    // quoter 就是为了解决引号问题
    // Mysql
    quoter() byte

    buildOnDuplicateKey(sb strings.Builder, odk *OnDuplicateKey) error
}

type standardSQL struct {
}

func (s standardSQL) quoter() byte {
    //TODO implement me
    panic("implement me")
}

func (s standardSQL) buildOnDuplicateKey(sb strings.Builder, odk *OnDuplicateKey) error {
    //TODO implement me
    panic("implement me")
}

type mysqlDialect struct {
    standardSQL
}

func (m mysqlDialect) quoter() byte {
    return '`'
}

func (m mysqlDialect) buildOnDuplicateKey(sb strings.Builder, odk *OnDuplicateKey) error {
    // ...
}

type sqliteDialect struct {
    standardSQL
}

type postgreDialect struct {
    standardSQL
}

15. INSERT:builder 抽象与重构

// build.go
package orm

import (
    "my_framework/orm/internal/errs"
    "my_framework/orm/model"
    "strings"
)

type builder struct {
    sb    *strings.Builder
    model *model.Model
    args  []any

    dialect Dialect
    quoter  byte
}

func (b *builder) quote(name string) {
    b.sb.WriteByte(b.quoter)
    b.sb.WriteString(name)
    b.sb.WriteByte(b.quoter)
} 

func (s *builder) buildExpression(expr Expression) error {
    switch exp := expr.(type) {
    case nil:
    case Predicate:
        // 处理p
        _, ok := exp.left.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.left); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }

        if exp.op != "" {
            s.sb.WriteByte(' ')
            s.sb.WriteString(exp.op.String())
            s.sb.WriteByte(' ')
        }

        _, ok = exp.right.(Predicate)
        if ok {
            s.sb.WriteByte('(')
        }
        if err := s.buildExpression(exp.right); err != nil {
            return err
        }
        if ok {
            s.sb.WriteByte(')')
        }
    case Column:
        // 防止where加as
        exp.alias = ""
        return s.buildColumn(exp)
    case value:
        s.sb.WriteByte('?')
        s.addArg(exp.val)
    case RawExpr:
        s.sb.WriteByte('(')
        s.sb.WriteString(exp.raw)
        s.addArg(exp.args...)
        s.sb.WriteByte(')')
    default:
        return errs.NewErrUnsupportedExpression(expr)
    }
    return nil
} 

// ...
// insert.go
type Inserter[T any] struct {
    builder
    values  []*T
    db      *DB
    columns []string
    //onDuplicateKey []Assignable
    onDuplicateKey *OnDuplicateKey
}

func NewInserter[T any](db *DB) *Inserter[T] {
    return &Inserter[T]{
        builder: builder{
            dialect: db.dialect,
            quoter:  db.dialect.quoter(),
            sb: &strings.Builder{},
        },
        db: db,
    }
}

func (i *Inserter[T]) Build() (*Query, error) {
    if len(i.values) == 0 {
        return nil, errs.ErrInsertZeroRow
    }
    //var sb strings.Builder
    i.sb.WriteString("INSERT INTO ")
    // 拿到元数据
    var err error
    i.model, err = i.db.r.Get(i.values[0])
    if err != nil {
        return nil, err
    }
    // 拼接表名
    //sb.WriteByte('`')
    //sb.WriteString(m.TableName)
    //sb.WriteByte('`')
    i.quote(i.model.TableName)
    // 显示地指定列的顺序,不然很难处理
    i.sb.WriteByte('(')

    fields := i.model.Fields
    // 用户指定了列
    if len(i.columns) > 0 {
        fields = make([]*model.Field, 0, len(i.columns))
        for _, fd := range i.columns {
            fdMeta, ok := i.model.FieldMap[fd]
            // 传入错误的列
            if !ok {
                return nil, errs.NewErrUnknownField(fd)
            }
            fields = append(fields, fdMeta)
        }
    }

    // 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
    for idx, field := range fields {
        if idx > 0 {
            i.sb.WriteByte(',')
        }
        //i.sb.WriteByte('`')
        //i.sb.WriteString(field.ColName)
        //i.sb.WriteByte('`')
        i.quote(field.ColName)
    }
    i.sb.WriteByte(')')

    i.sb.WriteString(" VALUES ")
    //sb.WriteByte('(')
    i.args = make([]any, 0, len(i.values)*len(fields))
    for j, val := range i.values {
        if j > 0 {
            i.sb.WriteByte(',')
        }
        i.sb.WriteByte('(')
        for idx, field := range fields {
            if idx > 0 {
                i.sb.WriteByte(',')
            }
            i.sb.WriteByte('?')
            // 读取参数
            arg := reflect.ValueOf(val).Elem().
                FieldByName(field.GoName).Interface()
            //args = append(args, arg)
            i.addArg(arg)
        }
        i.sb.WriteByte(')')
    }

    // on duplicate
    if i.onDuplicateKey != nil {
        err = i.dialect.buildOnDuplicateKey(&i.builder, i.onDuplicateKey)
        if err != nil {
            return nil, err
        }
    }

    i.sb.WriteByte(';')

    return &Query{
        SQL:  i.sb.String(),
        Args: i.args,
    }, err
}
// select.go
type Selector[T any] struct {
    builder
    table   string
    where   []Predicate
    db      *DB
    columns []Selectable
}

func NewSelector[T any](db *DB) *Selector[T] {
    return &Selector[T]{
        builder: builder{
            dialect: db.dialect,
            quoter:  db.dialect.quoter(),
            sb: &strings.Builder{},
        },
        db: db,
    }
}

func (s *Selector[T]) Build() (*Query, error) {
    //var sb strings.Builder
    s.sb = &strings.Builder{} // 指针类型一定要初始化
    var err error
    s.model, err = s.db.r.Get(new(T)) // new(T) 生成T的指针
    if err != nil {
        return nil, err
    }
    sb := s.sb

    sb.WriteString("SELECT ")

    if err = s.buildColumns(); err != nil {
        return nil, err
    }
    sb.WriteString(` FROM `)

    if s.table == "" {
        sb.WriteByte('`')
        sb.WriteString(s.model.TableName)
        sb.WriteByte('`')
    } else {
        //segs := strings.Split(s.table, ".")
        //sb.WriteByte('`')
        sb.WriteString(s.table)
        //sb.WriteByte('`')
    }

    if len(s.where) > 0 {
        sb.WriteString(" WHERE ")
        if err := s.buildPredicates(s.where); err != nil {
            return nil, err
        }
    }

    sb.WriteByte(';')
    return &Query{
        SQL:  sb.String(),
        Args: s.args,
    }, nil
}
// db.go
// DB 是一个sql.DB的装饰器
type DB struct {
    r       model.Registry
    db      *sql.DB
    creator valuer.Creator
    dialect Dialect
}

func DBWithDialect(d Dialect) DBOption {
    return func(db *DB) {
        db.dialect = d
    }
}

func OpenDB(db *sql.DB, opts ...DBOption) (*DB, error) {
    res := &DB{
        r:       model.NewRegistry(),
        db:      db,
        creator: valuer.NewUnsafeValue,
        dialect: DialectMysql,
    }
    for _, opt := range opts {
        opt(res)
    }
    return res, nil
}
// dialect.go
var (
    DialectMysql   Dialect = mysqlDialect{}
    DialectSqlLite Dialect = sqliteDialect{}
    DialectPostgre Dialect = postgreDialect{}
)

func (m mysqlDialect) buildOnDuplicateKey(b *builder, odk *OnDuplicateKey) error {
    b.sb.WriteString(" ON DUPLICATE KEY UPDATE ")
    for idx, assign := range odk.assigns {
        if idx > 0 {
            b.sb.WriteByte(',')
        }
        switch a := assign.(type) {
        case Assignment:
            fd, ok := b.model.FieldMap[a.col]
            if !ok {
                return errs.NewErrUnknownField(a.col)
            }
            //b.sb.WriteByte('`')
            //b.sb.WriteString(fd.ColName)
            //b.sb.WriteByte('`')
            b.quote(fd.ColName)
            b.sb.WriteString("=?")
            b.addArg(a.val)
        case Column:
            fd, ok := b.model.FieldMap[a.name]
            if !ok {
                return errs.NewErrUnknownField(a.name)
            }
            //b.sb.WriteByte('`')
            //b.sb.WriteString(fd.ColName)
            //b.sb.WriteByte('`')
            b.quote(fd.ColName)
            b.sb.WriteString("=VALUES(")
            //b.sb.WriteByte('`')
            //b.sb.WriteString(fd.ColName)
            //b.sb.WriteByte('`')
            b.quote(fd.ColName)
            b.sb.WriteByte(')')
        default:
            return errs.NewErrUnsupportedAssignable(assign)
        }
    }
    return nil
}

16. INSERT:SQLite UPSERT 实现、方言抽象局限性

// dialect.go
func (s sqliteDialect) quoter() byte {
    return '`'
}

func (s sqliteDialect) buildOnDuplicateKey(b *builder, odk *OnDuplicateKey) error {
    b.sb.WriteString(" ON CONFLICT(")
    for i, col := range odk.conflictColumns {
        if i > 0 {
            b.sb.WriteByte(',')
        }
        err := b.buildColumn(Column{
            name: col,
        })
        if err != nil {
            return err
        }
    }
    b.sb.WriteString(") DO UPDATE SET ")
    for idx, assign := range odk.assigns {
        if idx > 0 {
            b.sb.WriteByte(',')
        }
        switch a := assign.(type) {
        case Assignment:
            fd, ok := b.model.FieldMap[a.col]
            if !ok {
                return errs.NewErrUnknownField(a.col)
            }
            //b.sb.WriteByte('`')
            //b.sb.WriteString(fd.ColName)
            //b.sb.WriteByte('`')
            b.quote(fd.ColName)
            b.sb.WriteString("=?")
            b.addArg(a.val)
        case Column:
            fd, ok := b.model.FieldMap[a.name]
            if !ok {
                return errs.NewErrUnknownField(a.name)
            }
            //b.sb.WriteByte('`')
            //b.sb.WriteString(fd.ColName)
            //b.sb.WriteByte('`')
            b.quote(fd.ColName)
            b.sb.WriteString("=excluded.")
            //b.sb.WriteByte('`')
            //b.sb.WriteString(fd.ColName)
            //b.sb.WriteByte('`')
            b.quote(fd.ColName)
        default:
            return errs.NewErrUnsupportedAssignable(assign)
        }
    }
    return nil
}
// insert_test.go
func TestSqlite_upsert_Build(t *testing.T) {
    db, err := Open("sqlite",
        "file:test.db?cache=shared&mode=memory",
        DBWithDialect(DialectSqlLite))
    assert.NoError(t, err)
    tests := []struct {
        name string
        i    QueryBuilder

        wantErr   error
        wantQuery *Query
    }{
        {
            name: "upsert",
            i: NewInserter[TestModel](db).Values(&TestModel{
                Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
            }).OnDuplicateKey().ConflictColumns("FirstName", "Age").Update(
                Assign("FirstName", "Deng"),
                Assign("Age", 19),
            ),
            wantQuery: &Query{
                SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?)" +
                    " ON CONFLICT(`first_name`,`age`) UPDATE `first_name`=excluded.`first_name`,`age`=excluded.`age`;",
                Args: []any{
                    int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
                    "Deng", 19,
                },
            },
        },
        {
            name: "update insert column",
            i: NewInserter[TestModel](db).Values(
                &TestModel{
                    Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
                },
                &TestModel{
                    Id: 13, FirstName: "Dummy", Age: 19, LastName: &sql.NullString{String: "Shan", Valid: true},
                },
            ).OnDuplicateKey().Update(C("FirstName"), C("Age")),
            wantQuery: &Query{
                SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?),(?,?,?,?)" +
                    " ON DUPLICATE KEY UPDATE `first_name`=VALUES(`first_name`),`age`=VALUES(`age`);",
                Args: []any{
                    int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
                    int64(13), "Dummy", int8(19), &sql.NullString{String: "Shan", Valid: true},
                },
            },
        },
    }

    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            q, err := tc.i.Build()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.wantQuery, q)
        })
    }
}

17. INSERT:INSERT 执行

// types.go
type Executor interface {
    //Exec(ctx context.Context) (sql.Result, error)
    Exec(ctx context.Context) Result
}
// result.go
package orm

import "database/sql"

type Result struct {
    err error
    res sql.Result
}

func (r Result) LastInsertId() (int64, error) {
    if r.err != nil {
        return 0, r.err
    }
    return r.res.LastInsertId()
}

func (r Result) RowsAffected() (int64, error) {
    if r.err != nil {
        return 0, r.err
    }
    return r.res.RowsAffected()
}

func (r Result) Err() error {
    return r.err
}
// insert_test.go
func TestInserter_Exec(t *testing.T) {
    mockDB, mock, err := sqlmock.New()
    require.NoError(t, err)
    db, err := OpenDB(mockDB)
    require.NoError(t, err)

    tests := []struct {
        name string
        i    *Inserter[TestModel]
        //wantRes Result
        wantErr  error
        affected int64
    }{
        {
            name: "query error",
            i: func() *Inserter[TestModel] {
                return NewInserter[TestModel](db).Values(&TestModel{}).
                    Columns("Invalid")
            }(),
            wantErr: errs.NewErrUnknownField("Invalid"),
        },
        {
            name: "db error",
            i: func() *Inserter[TestModel] {
                mock.ExpectExec("INSERT INTO .*").
                    WillReturnError(errors.New("db error"))
                return NewInserter[TestModel](db).Values(&TestModel{})
            }(),
            wantErr: errors.New("db error"),
        },
        {
            name: "exec",
            i: func() *Inserter[TestModel] {
                res := driver.RowsAffected(1)
                mock.ExpectExec("INSERT INTO .*").
                    WillReturnResult(res)
                return NewInserter[TestModel](db).Values(&TestModel{})
            }(),
            affected: 1,
        },
    }

    for _, tc := range tests {
        t.Run(tc.name, func(t *testing.T) {
            res := tc.i.Exec(context.Background())
            affected, err := res.RowsAffected()
            assert.Equal(t, tc.wantErr, err)
            if err != nil {
                return
            }
            assert.Equal(t, tc.affected, affected)
        })
    }
}
// insert.go
func (i *Inserter[T]) Exec(ctx context.Context) Result {
    q, err := i.Build()
    if err != nil {
        return Result{
            err: err,
        }
    }
    res, err := i.db.db.Exec(q.SQL, q.Args...)
    return Result{
        err: err, res: res,
    }
}

18. INSERT:unsafe 读取字段、总结与面试要点

// insert.go 
func (i *Inserter[T]) BuildUnsafe() (*Query, error) {
    if len(i.values) == 0 {
        return nil, errs.ErrInsertZeroRow
    }
    //var sb strings.Builder
    i.sb.WriteString("INSERT INTO ")
    // 拿到元数据
    var err error
    i.model, err = i.db.r.Get(i.values[0])
    if err != nil {
        return nil, err
    }
    // 拼接表名
    //sb.WriteByte('`')
    //sb.WriteString(m.TableName)
    //sb.WriteByte('`')
    i.quote(i.model.TableName)
    // 显示地指定列的顺序,不然很难处理
    i.sb.WriteByte('(')

    fields := i.model.Fields
    // 用户指定了列
    if len(i.columns) > 0 {
        fields = make([]*model.Field, 0, len(i.columns))
        for _, fd := range i.columns {
            fdMeta, ok := i.model.FieldMap[fd]
            // 传入错误的列
            if !ok {
                return nil, errs.NewErrUnknownField(fd)
            }
            fields = append(fields, fdMeta)
        }
    }

    // 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
    for idx, field := range fields {
        if idx > 0 {
            i.sb.WriteByte(',')
        }
        //i.sb.WriteByte('`')
        //i.sb.WriteString(field.ColName)
        //i.sb.WriteByte('`')
        i.quote(field.ColName)
    }
    i.sb.WriteByte(')')

    i.sb.WriteString(" VALUES ")
    //sb.WriteByte('(')
    i.args = make([]any, 0, len(i.values)*len(fields))
    for j, val := range i.values {
        if j > 0 {
            i.sb.WriteByte(',')
        }
        i.sb.WriteByte('(')

        v := i.db.creator(i.model, val)

        for idx, field := range fields {
            if idx > 0 {
                i.sb.WriteByte(',')
            }
            i.sb.WriteByte('?')
            // 读取参数
            arg, err := v.Filed(field.GoName)
            if err != nil {
                return nil, err
            }
            //args = append(args, arg)
            i.addArg(arg)
        }
        i.sb.WriteByte(')')
    }

    // on duplicate
    if i.onDuplicateKey != nil {
        err = i.dialect.buildOnDuplicateKey(&i.builder, i.onDuplicateKey)
        if err != nil {
            return nil, err
        }
    }

    i.sb.WriteByte(';')

    return &Query{
        SQL:  i.sb.String(),
        Args: i.args,
    }, err
}
// unsafe.go
func (u unsafeValue) Filed(name string) (any, error) {
    fd, ok := u.model.FieldMap[name]
    if !ok {
        return nil, errs.NewErrUnknownField(name)
    }
    fdAddr := unsafe.Pointer(uintptr(u.address) + fd.Offset)
    val := reflect.NewAt(fd.Typ, fdAddr)
    return val.Elem().Interface(), nil
}
// reflect.go
package valuer

import (
    "database/sql"
    "my_framework/orm/internal/errs"
    "my_framework/orm/model"
    "reflect"
)

type reflectValue struct {
    model *model.Model
    // 对应T的指针
    //val any
    val reflect.Value
}

func (r reflectValue) Filed(name string) (any, error) {
    //if _, ok := r.val.Type().FieldByName(name); !ok {
    //    return nil, errors.New("")
    //}
    //val:=r.val.FieldByName(name)
    //if val==(reflect.Value{}){
    //    // err
    //}
    return r.val.FieldByName(name).Interface(), nil
}

var _ Creator = NewReflectValue

func (r reflectValue) SetColumns(rows *sql.Rows) error {
    // ...

    // 把vals填入tp
    //tpValue := reflect.ValueOf(r.val)
    tpValue := r.val
    for i, c := range cs {
        fd, ok := r.model.ColumnMap[c]
        if !ok {
            return errs.NewErrUnknownColumn(c)
        }
        tpValue.Elem().FieldByName(fd.GoName).
            //Set(reflect.ValueOf(vals[i]).Elem())
            Set(valElems[i])
    }
    return nil
}

func NewReflectValue(model *model.Model, val any) Value {
    return reflectValue{
        model: model,
        val:   reflect.ValueOf(val).Elem(),
    }
}
// value.go
type Value interface {
    Filed(name string) (any, error)
    SetColumns(rows *sql.Rows) error
}

20.第六周作业:丰富 SELECT 语句[选学]【更多IT资料加微信djy136928775638】

21.第六周 SELECT 作业讲解[选学]【更多IT资料加微信djy136928775638】

13 第七周:ORM 框架之事务 API、AOP 方案与集成测试

事务api

1. 事务 API:不同框架设计分析、设计与实现


2. 事务 API:事务闭包 API、总结与面试要点


// orm/transaction.go
package orm

// ...

type Tx struct {
    tx *sql.Tx
    db *DB

    // 给事务扩散方案
    done bool
}

// ...

func (t *Tx) Commit() error {
    t.done = true
    return t.tx.Commit()
}

func (t *Tx) Rollback() error {
    t.done = true
    return t.tx.Rollback()
}

func (t *Tx) RollbackIfNotCommit() error {
    t.done = true
    err := t.tx.Rollback()
    // 已经提交过
    if err == sql.ErrTxDone {
        return nil
    }
    return err
}
// orm/db.go
package orm

// ...

type txKey struct{}

func (db *DB) BeginTxV2(ctx context.Context, opts *sql.TxOptions) (context.Context, *Tx, error) {
    val := ctx.Value(txKey{})
    tx, ok := val.(*Tx)
    // 存在事务并且从未提交
    if ok && !tx.done {
        return ctx, tx, nil
    }
    tx, err := db.BeginTx(ctx, opts)
    if err != nil {
        return nil, nil, err
    }
    ctx = context.WithValue(ctx, txKey{}, tx)
    return ctx, tx, nil
}

// 必须要已有事务
//func (db *DB) BeginTxV3(ctx context.Context, opts *sql.TxOptions) (  *Tx, error) {
//    val := ctx.Value(txKey{})
//    tx, ok := val.(*Tx)
//    if ok {
//        return tx, nil
//    }
//    return  nil,errors.New("没有事务")
//}

func (db *DB) DoTx(
    ctx context.Context,
    fn func(ctx context.Context, tx *Tx) error,
    opts *sql.TxOptions,
) (err error) {
    tx, err := db.BeginTx(ctx, opts)
    if err != nil {
        return err
    }
    panicked := true
    defer func() {
        if panicked || err != nil {
            e := tx.Rollback()
            err = errs.NewErrFailedToRollbackTx(err, e, panicked)
        } else {
            err = tx.Commit()
        }
    }()
    err = fn(ctx, tx)
    panicked = false
    return err
}

AOP 方案

3. AOP 方案:不同框架设计分析、方案总结



// orm/middleware.go
package orm

import "context"

type QueryContent struct {
    // 查询类型, 标记增删改查
    Type string
    // 代表的是查询本身
    Builder QueryBuilder
}

type QueryResult struct {
    // Result 在不同查询下, 类型不同
    // select *T/[]*T
    // 其他就是 Result
    Result any
    // Err 是查询本身出的问题
    Err error
}

type Handler func(ctx context.Context, qc *QueryContent) *QueryResult

type Middleware func(next Handler) Handler

4. AOP 方案:Middleware 接入与 querylog

//orm/middleware/querylog/middleware_test.go
package querylog

import (
    "context"
    "database/sql"
    _ "github.com/go-sql-driver/mysql"
    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"
    "my_framework/orm"
    "testing"
)

func TestMiddlewareBuilder_Build(t *testing.T) {
    var query string
    var args []any
    m := (&MiddlewareBuilder{}).LogFunc(func(q string, as []any) {
        query = q
        args = as
    })

    db, err := orm.Open("mysql",
        "root:123456@tcp(127.0.0.1:3307)/mybatis",
        orm.DBWithMiddlewares(m.Build()))
    require.NoError(t, err)
    _, _ = orm.NewSelector[TestModel](db).Where(orm.C("Id").Eq(10)).Get(context.Background())
    assert.Equal(t, "SELECT * FROM `test_model` WHERE `id` = ?;", query)
    assert.Equal(t, []any{10}, args)

    orm.NewInserter[TestModel](db).Values(&TestModel{Id: 19}).Exec(context.Background())
    assert.Equal(t, "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?);", query)
    assert.Equal(t, []any{int64(19),"",int8(0), (*sql.NullString)(nil)}, args)
}

type TestModel struct {
    Id        int64
    FirstName string
    Age       int8
    LastName  *sql.NullString
}
// orm/middleware/querylog/middleware.go
package querylog

import (
    "context"
    "log"
    "my_framework/orm"
)

type MiddlewareBuilder struct {
    logFunc func(query string, args []any)
    //logFunc func(query string, args ...)
}

func NewMiddlewareBuilder() *MiddlewareBuilder {
    return &MiddlewareBuilder{
        // 用户自己提供日志打印
        logFunc: func(query string, args []any) {
            log.Printf("sql: %s, args: %v\n", query, args)
        },
    }
}

func (m *MiddlewareBuilder) LogFunc(fn func(query string, args []any)) *MiddlewareBuilder {
    m.logFunc = fn
    return m
}

func (m MiddlewareBuilder) Build() orm.Middleware {
    return func(next orm.Handler) orm.Handler {
        return func(ctx context.Context, qc *orm.QueryContent) *orm.QueryResult {

            q, err := qc.Builder.Build()
            if err != nil {
                // 记录吗?
                //log.Println("构造 SQL 出错", err)
                return &orm.QueryResult{
                    Err: err,
                }
            }
            //log.Println("sql: %s, args: %v \n", q.SQL, q.Args)
            m.logFunc(q.SQL, q.Args)

            res := next(ctx, qc)
            return res
        }
    }
}
//orm/insert.go
package orm

// ...

func (i *Inserter[T]) Exec(ctx context.Context) Result {
    root := i.execHandler
    for j := len(i.db.mdls) - 1; j >= 0; j-- {
        // 遍历并使用中间件
        root = i.db.mdls[j](root)
    }
    res := root(ctx, &QueryContent{
        Type:    "INSERT",
        Builder: i,
    })
    var sqlRes sql.Result
    if res.Result != nil {
        sqlRes = res.Result.(sql.Result)
    }
    return Result{
        err: res.Err,
        res: sqlRes,
    }
}

var _ Handler = (&Inserter[int]{}).execHandler

func (i *Inserter[T]) execHandler(ctx context.Context, qc *QueryContent) *QueryResult {
    q, err := i.Build()
    if err != nil {
        return &QueryResult{
            Err: err,
            Result: Result{
                err: err,
            },
        }
    }
    res, err := i.db.db.Exec(q.SQL, q.Args...)
    return &QueryResult{
        Err: err,
        Result: Result{
            err: err,
            res: res,
        },
    }
}
//orm/select.go
package orm

// ...

func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    root := s.getHandler
    for i := len(s.db.mdls)-1; i >= 0; i-- {
        // 遍历并使用中间件
        root = s.db.mdls[i](root)
    }
    res := root(ctx, &QueryContent{
        Type:    "SELECT",
        Builder: s,
    })
    if res.Result != nil {
        return res.Result.(*T), res.Err
    }
    return nil, res.Err
}

var _ Handler = (&Selector[any]{}).getHandler

func (s *Selector[T]) getHandler(ctx context.Context, qc *QueryContent) *QueryResult {
    q, err := s.Build()
    if err != nil {
        return &QueryResult{
            Err: err,
        }
    }

    db := s.db.db
    // 发起查询,并处理结果集
    rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
    //_, err  = db.QueryContext(ctx, q.SQL, q.Args...)
    // 查询错误
    if err != nil {
        return &QueryResult{
            Err: err,
        }
    }
    if !rows.Next() {
        return &QueryResult{
            Err: ErrNoRows,
        }
    }

    tp := new(T)
    val := s.db.creator(s.model, tp)
    err = val.SetColumns(rows)
    return &QueryResult{
        Result: tp,
        Err: err,
    }
}
// orm/db.go
package orm

type DB struct {
    //r       model.Registry
    db *sql.DB
    //creator valuer.Creator
    //dialect Dialect
    core
}

func DBWithMiddlewares(mdls ...Middleware) DBOption {
    return func(db *DB) {
        db.mdls = mdls
        //db.mdls = append(db.mdls, mdls...)
    }
}

// ...
// orm/core.go
package orm

import (
    "my_framework/orm/internal/valuer"
    "my_framework/orm/model"
)

type core struct {
    model   *model.Model
    dialect Dialect
    creator valuer.Creator
    r       model.Registry
    mdls    []Middleware
}

5. AOP 方案:Middleware 各种实现、总结与面试要点


// orm/middleware/opentelemetry/middleware.go
package opentelemetry

import (
    "context"
    "fmt"
    "go.opentelemetry.io/otel"
    "go.opentelemetry.io/otel/attribute"
    "go.opentelemetry.io/otel/trace"
    "my_framework/orm"
)

const instrumentationName = "orm/middleware/opentelemetry/middleware.go"

type MiddlewareBuilder struct {
    Tracer trace.Tracer
}

func (m MiddlewareBuilder) Build() orm.Middleware {
    if m.Tracer == nil {
        m.Tracer = otel.GetTracerProvider().Tracer(instrumentationName)
    }
    return func(next orm.Handler) orm.Handler {
        return func(ctx context.Context, qc *orm.QueryContent) *orm.QueryResult {
            // span name: select-test_model
            // insert-test_model
            tbl := qc.Model.TableName
            spanCtx, span := m.Tracer.Start(ctx, fmt.Sprintf("%s-%s", qc.Type, tbl))
            defer span.End()

            q, _ := qc.Builder.Build()
            if q != nil {
                span.SetAttributes(attribute.String("sql", q.SQL))
            }

            span.SetAttributes(attribute.String("table", tbl))
            span.SetAttributes(attribute.String("component", "orm"))
            res := next(spanCtx, qc)
            if res.Err != nil {
                span.RecordError(res.Err)
            }
            return res
        }
    }
}
// orm/middleware/prometheus/middleware.go
package prometheus

import (
    "context"
    "github.com/prometheus/client_golang/prometheus"
    "my_framework/orm"
    "time"
)

type MiddlewareBuilder struct {
    Namespace string
    Subsystem string
    Name      string
    Help      string
}

func (m MiddlewareBuilder) Build() orm.Middleware {
    vector := prometheus.NewSummaryVec(prometheus.SummaryOpts{
        Namespace: m.Namespace,
        Subsystem: m.Subsystem,
        Name:      m.Name,
        Help:      m.Help,
        Objectives: map[float64]float64{
            0.5:   0.01,
            0.75:  0.01,
            0.90:  0.01,
            0.99:  0.001,
            0.999: 0.0001,
        },
    }, []string{"type", "table"})
    // 注册观察者
    prometheus.MustRegister(vector) // 两次调用builder会panic

    // errCounterVec 记录错误数
    // histogram 也行
    // active query
    return func(next orm.Handler) orm.Handler {
        return func(ctx context.Context, qc *orm.QueryContent) *orm.QueryResult {
            startTime := time.Now()
            defer func() {
                // 执行时间
                vector.WithLabelValues(qc.Type, qc.Model.TableName).
                    Observe(float64(time.Since(startTime).Milliseconds()))
            }()
            return next(ctx, qc)
        }
    }
}
// orm/select.go
package orm

// ...

func (s *Selector[T]) Build() (*Query, error) {
    //var sb strings.Builder
    s.sb = &strings.Builder{} // 指针类型一定要初始化

    if s.model == nil {
        var err error
        s.model, err = s.db.r.Get(new(T)) // new(T) 生成T的指针
        if err != nil {
            return nil, err
        }
    }

    sb := s.sb

    // ...
}

func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
    var err error
    s.model, err = s.db.r.Get(new(T)) // new(T) 生成T的指针
    if err != nil {
        return nil, err
    }

    root := s.getHandler

    for i := len(s.db.mdls) - 1; i >= 0; i-- {
        // 遍历并使用中间件
        root = s.db.mdls[i](root)
    }
    res := root(ctx, &QueryContent{
        Type:    "SELECT",
        Builder: s,
        Model:   s.model,
    })
    if res.Result != nil {
        return res.Result.(*T), res.Err
    }
    return nil, res.Err
}
// orm/insert.go

package orm
func (i *Inserter[T]) Build() (*Query, error) {
    if len(i.values) == 0 {
        return nil, errs.ErrInsertZeroRow
    }
    //var sb strings.Builder
    i.sb.WriteString("INSERT INTO ")
    var err error
    if i.model == nil {
        // 拿到元数据
        i.model, err = i.db.r.Get(i.values[0])
        if err != nil {
            return nil, err
        }
    }
    // 拼接表名
    // ...
}

func (i *Inserter[T]) Exec(ctx context.Context) Result {
    var err error
    i.model, err = i.db.r.Get(new(T))
    if err != nil {
        return Result{
            err: err,
        }
    }

    root := i.execHandler
    // ...
    res := root(ctx, &QueryContent{
        Type:    "INSERT",
        Builder: i,
        Model:   i.model,
    })
    // ...
}
// orm/middleware.go
package orm

import (
    "context"
    "my_framework/orm/model"
)

type QueryContent struct {
    // 查询类型, 标记增删改查
    Type string
    // 代表的是查询本身
    Builder QueryBuilder
    //
    Model *model.Model
}

// ...

集成测试

6. 集成测试:起步与 MySQL 的增删改查


// … 跳过

原生查询详解


  转载请注明: malred-blog Go 实战训练营

 上一篇
珠峰架构2021 珠峰架构2021
预习课(架构)2021 第一期 Vue3 架构课任务 1:1.vue3 变化介绍 任务2:2.vue3架构组织monorepo 项目结构 任务3:3.根据需要打包的信息进行打包 pnpm i typescript rollup rollup
2023-09-11
下一篇 
安卓项目下载问题解决 安卓项目下载问题解决
在项目的 settings.gradle 里加入 pluginManagement { repositories { // 就是这个 maven { url 'https://maven.aliy
2023-08-29
  目录