07 第一周:Web 框架之 Server 与路由树
1. Web 框架概览:学习路线
2. Web 框架概览:Beego 框架分析
package beego
import "github.com/beego/beego/v2/server/web"
type UserController struct {
web.Controller
}
func (c *UserController) GetUser() {
c.Ctx.WriteString("你好,我是malred")
}
func (c *UserController) CreateUser() {
u := &User{}
err := c.Ctx.BindJSON(u)
if err != nil {
c.Ctx.WriteString(err.Error())
return
}
_ = c.Ctx.JSONResp(u)
}
type User struct {
Name string
}
package beego
import (
"github.com/beego/beego/v2/server/web"
"testing"
)
func TestUserController(t *testing.T) {
// 这个配置是beego独有的
web.BConfig.CopyRequestBody = true
c := &UserController{}
web.Router("/user", c, "get:GetUser")
web.Run(":8081")
}
作者认为可以将 response 放入 output,request 放入 input
不同 server 之间是隔离的
3. Web 框架概览:Gin 框架分析
package gin
import (
"github.com/gin-gonic/gin"
)
type UserController struct{}
func (c *UserController) GetUser(ctx *gin.Context) {
ctx.String(200, "hello world")
}
package gin
import (
"net/http"
"testing"
"github.com/gin-gonic/gin"
)
func TestUserController_GetUser(t *testing.T) {
g := gin.Default()
ctrl := &UserController{}
g.GET("/user", ctrl.GetUser)
g.POST("/user", func(ctx *gin.Context) {
ctx.String(http.StatusOK, "hello %s", "world")
})
g.GET("/static", func(ctx *gin.Context) {
// 读文件
// 写响应
})
_ = g.Run(":8082")
}
handle 接口是核心,作者认为没必要将 static 放入核心接口里,而是可以提供一个实现
4. Web 框架概览:Iris 框架分析
// v12版本需要go1.21
package iris
import (
"github.com/kataras/iris/v12"
"testing"
)
func TestHelloWorld(t *testing.T) {
app := iris.New()
app.Get("/", func(ctx iris.Context) {
_, _ = ctx.HTML("hello <strong>%s</strong>!", "world")
})
_ = app.Listen(":8083")
}
5. Web 框架概览:Echo 框架分析与对比总结
package echo
import (
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"net/http"
"testing"
)
func TestHelloWorld(t *testing.T) {
// echo instance
e := echo.New()
// Middleware
e.Use(middleware.Logger())
e.Use(middleware.Recover())
// Routers
e.GET("/", hello)
// Start server
e.Logger.Fatal(e.Start(":8084"))
}
// Handle
func hello(c echo.Context) error {
return c.String(http.StatusOK, "hello world")
}
作者认为要做隔离就再弄一个 echo 对象,而不是在一个 echo 里放 namespace 和普通 route
6. Server 详解与面试要点
既然有路由管理功能,就得有路由注册的功能
// context.go
package web
import "net/http"
type Context struct {
Req *http.Request
Resp http.ResponseWriter
}
// server_test.go
package web
import (
"fmt"
"net/http"
_ "net/http"
"testing"
)
func TestServer(t *testing.T) {
h := &HTTPServer{} // NewServer
h.AddRoute(http.MethodGet, "/user", func(ctx Context) {
fmt.Println("user")
})
handler1 := func(ctx Context) {
fmt.Println("1")
}
handler2 := func(ctx Context) {
fmt.Println("2")
}
// 使用注册一个的也能实现
h.AddRoute(http.MethodGet, "/user", func(ctx Context) {
handler1(ctx)
handler2(ctx)
})
// h:=&HTTPServer{} 才有该方法
h.Get("/user", func(ctx Context) {})
// h.AddRoutes(
// http.MethodGet, "/user", handler1, handler2,
// )
// 用法1 完全委托给http包
// 第二个参数 handler 是我们框架和http包的结合点
// http.ListenAndServe(":8081", h)
// http.ListenAndServeTLS(":443", "", "", h)
// 用法2 自己手动管理
h.Start(":8081")
}
// server.go
package web
import (
"net"
"net/http"
)
type HandleFunc func(ctx Context)
// 确保HTTPServer结构体一定实现了Server
var _ Server = &HTTPServer{}
type Server interface {
http.Handler
Start(addr string) error
// StartHttp() error
// AddRoute 需要增加路由注册的功能
// method 请求方法
// path 路由
// handleFunc 业务逻辑
AddRoute(method string, path string, handleFunc HandleFunc)
// 注册多个(没必要提供,用户可以自己根据AddRoute实现)
// AddRoutes(method string, path string, handleFunc ...HandleFunc)
}
// type HTTPSServer struct {
// HTTPServer // 装饰HTTPServer
// }
type HTTPServer struct {
// addr string // 可以创建时传递,而不是通过start传,也是可以的
}
// http.Handler的接口
// 核心方法 处理请求的入口
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 框架代码
ctx := &Context{
Req: req,
Resp: w,
}
// 接下来就是查找路由并执行命中的业务逻辑
h.serve(ctx)
}
func (h *HTTPServer) serve(ctx *Context) {
}
func (h *HTTPServer) Start(addr string) error {
// 也可以自己创建server
// http.Server{}
l, err := net.Listen("tcp", addr)
if err != nil {
return err
}
// 在这里可以让用户注册 after、start 等回调(生命周期)
// 比如往 admin 注册这个实例
// 或者执行一些业务所需的前置条件
return http.Serve(l, h)
}
func (h *HTTPServer) AddRoute(method string, path string, handleFunc HandleFunc) {
// 注册到路由树
panic("implement me")
}
// 接口定义的东西尽量简洁
// 提供多种实现
func (h *HTTPServer) Get(path string, handleFunc HandleFunc) {
h.AddRoute(http.MethodGet, path, handleFunc)
}
func (h *HTTPServer) Post(path string, handleFunc HandleFunc) {
h.AddRoute(http.MethodPost, path, handleFunc)
}
func (h *HTTPServer) Put(path string, handleFunc HandleFunc) {
h.AddRoute(http.MethodPut, path, handleFunc)
}
func (h *HTTPServer) Patch(path string, handleFunc HandleFunc) {
h.AddRoute(http.MethodPatch, path, handleFunc)
}
func (h *HTTPServer) Delete(path string, handleFunc HandleFunc) {
h.AddRoute(http.MethodDelete, path, handleFunc)
}
// func (h *HTTPServer) AddRoutes(method string, path string, handleFunc ...HandleFunc) {
// panic("implement me")
// }
func (h *HTTPServer) StartHttp(addr string) error {
return http.ListenAndServe(addr, h)
}
7. 路由树:Beego、Gin、Echo 实现与设计总结
gin 路由树的复杂在于要找最长公共前缀,然后重构树
8. 路由树:全静态匹配
// router.go
package web
// 用来支持对路由树的操作
// 代表路由树(森林)
type router struct {
// Beego Gin HTTP method 对应一棵树
// GET 一颗,POST 一颗 ……
// http method => 路由树根节点
trees map[string]*node
}
// type tree struct {
// root *node
// }
func newRouter() *router {
return &router{
trees: map[string]*node{},
}
}
func (r *router) AddRoute(method string, path string, handleFunc HandleFunc) {
// 注册到路由树
panic("implement me")
}
type node struct {
path string
// path => 子节点
children map[string]*node
// 需要一个代表用户注册的业务逻辑
handler HandleFunc
}
// server.go
package web
import (
"net"
"net/http"
)
type HandleFunc func(ctx Context)
// 确保HTTPServer结构体一定实现了Server
var _ Server = &HTTPServer{}
type Server interface {
http.Handler
Start(addr string) error
// StartHttp() error
// AddRoute 需要增加路由注册的功能
// method 请求方法
// path 路由
// handleFunc 业务逻辑
AddRoute(method string, path string, handleFunc HandleFunc)
// 注册多个(没必要提供,用户可以自己根据AddRoute实现)
// AddRoutes(method string, path string, handleFunc ...HandleFunc)
}
// type HTTPSServer struct {
// HTTPServer // 装饰HTTPServer
// }
type HTTPServer struct {
// addr string // 可以创建时传递,而不是通过start传,也是可以的
// router
*router // 因为router实现了AddRoute,所以也可以通过编译
}
// 初始化server
func NewHTTPServer() *HTTPServer {
return &HTTPServer{
router: newRouter(),
}
}
// func (h *HTTPServer) AddRoute(method string, path string, handleFunc HandleFunc) {
// // 注册到路由树
// panic("implement me")
// }
// ...
9. 路由树:TDD 起步
// router.go
package web
import "strings"
// ...
func (r *router) AddRoute(method string, path string, handleFunc HandleFunc) {
// 注册到路由树
root, ok := r.trees[method]
if !ok {
// 说明还没有根节点
root = &node{
path: "/",
}
// 处理方法
r.trees[method] = root
}
// /user/home -> 会被切成3段
path = path[1:] // /user -> user
// 切割这个 path
segs := strings.Split(path, "/")
for _, seg := range segs {
// 递归下去, 中途有节点不存在就创建
children := root.childOrCreate(seg)
// root向下(深)走
root = children
}
// 前面root已经走到/user/home, 所以这里添加处理方法
// 是对/user/home的处理
root.handler = handleFunc
}
func (n *node) childOrCreate(seg string) *node {
if n.children == nil {
n.children = map[string]*node{}
}
res, ok := n.children[seg]
if !ok {
// 新建
res = &node{
path: seg,
}
n.children[seg] = res
}
return res
}
type node struct {
// ...
}
// router_test.go
package web
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRouter_AddRoute(t *testing.T) {
// 1. 构造路由树
testRoutes := []struct {
method string
path string
}{
{
method: http.MethodGet,
path: "/user/home",
},
}
var mockHandler HandleFunc = func(ctx Context) {}
r := newRouter()
for _, route := range testRoutes {
r.AddRoute(route.method, route.path, mockHandler)
}
// 2. 验证路由树
wantRouter := &router{
trees: map[string]*node{
http.MethodGet: &node{
path: "/",
children: map[string]*node{
"user": &node{
path: "user",
children: map[string]*node{
"home": &node{
path: "home",
handler: mockHandler,
},
},
},
},
},
},
}
msg, ok := wantRouter.equal(r)
assert.True(t, ok, msg)
}
func (r *router) equal(y *router) (string, bool) {
for k, v := range r.trees {
dst, ok := y.trees[k]
if !ok {
return fmt.Sprintf("找不到对应的 http method"), false
}
msg, equal := v.equal(dst)
if !equal {
return msg, false
}
}
return "", true
}
func (n *node) equal(y *node) (string, bool) {
if n.path != y.path {
return fmt.Sprintf("节点路径不匹配"), false
}
if len(n.children) != len(y.children) {
return fmt.Sprintf("子节点数量不相等"), false
}
// 比较handler
nHandler := reflect.ValueOf(n.handler)
yHandler := reflect.ValueOf(y.handler)
if nHandler != yHandler {
return fmt.Sprintf("handler 不相等"), false
}
for path, c := range n.children {
dst, ok := y.children[path]
if !ok {
return fmt.Sprintf("子节点 %s 不存在", path), false
}
msg, ok := c.equal(dst)
if !ok {
return msg, false
}
}
return "", true
}
10. 路由树:静态匹配测试用例
// router_test.go
package web
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRouter_addRoute(t *testing.T) {
// 1. 构造路由树
testRoutes := []struct {
method string
path string
}{
{
method: http.MethodGet,
path: "/",
},
{
method: http.MethodGet,
path: "/user",
},
{
method: http.MethodGet,
path: "/user/home",
},
{
method: http.MethodGet,
path: "/order/detail",
},
{
method: http.MethodPost,
path: "/order/create",
},
{
method: http.MethodPost,
path: "/login",
},
// {
// method: http.MethodPost,
// path: "login",
// },
// {
// method: http.MethodPost,
// path: "login////",
// },
// {
// method: http.MethodPost,
// path: "//login//a//b",
// },
}
var mockHandler HandleFunc = func(ctx Context) {}
r := newRouter()
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
// 2. 验证路由树
wantRouter := &router{
trees: map[string]*node{
http.MethodGet: &node{
path: "/",
handler: mockHandler,
children: map[string]*node{
"user": &node{
path: "user",
handler: mockHandler,
children: map[string]*node{
"home": &node{
path: "home",
handler: mockHandler,
},
},
},
"order": &node{
path: "order",
children: map[string]*node{
"detail": &node{
path: "detail",
handler: mockHandler,
},
},
},
},
},
http.MethodPost: &node{
path: "/",
children: map[string]*node{
"order": &node{
path: "order",
children: map[string]*node{
"create": &node{
path: "create",
handler: mockHandler,
},
},
},
"login": &node{
path: "login",
handler: mockHandler,
},
},
},
},
}
msg, ok := wantRouter.equal(r)
assert.True(t, ok, msg)
r = newRouter()
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "", mockHandler)
}, "web: 路径不能为空字符串")
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "/a/b/c/", mockHandler)
}, "web: 路径必须以 / 结尾")
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "a/b", mockHandler)
}, "web: 路径必须以 / 开头")
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "a//b", mockHandler)
}, "web: 不能有连续的 /")
r = newRouter()
r.addRoute(http.MethodGet, "/", mockHandler)
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "/", mockHandler)
}, "web: 路由冲突, 重复注册: [/]")
r = newRouter()
r.addRoute(http.MethodGet, "/user/home", mockHandler)
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "/user/home", mockHandler)
}, "web: 路由冲突, 重复注册: [/user/home]")
}
func (r *router) equal(y *router) (string, bool) {
// ...
}
func (n *node) equal(y *node) (string, bool) {
// ...
}
package web
import (
"fmt"
"strings"
)
// 用来支持对路由树的操作
// 代表路由树(森林)
type router struct {
// Beego Gin HTTP method 对应一棵树
// GET 一颗,POST 一颗 ……
// http method => 路由树根节点
trees map[string]*node
}
// type tree struct {
// root *node
// }
func newRouter() *router {
// ...
}
// 加一些限制:
// path 必须以 / 开头, 不能以 / 结尾, 中间不能有连续的 //
func (r *router) addRoute(method string, path string, handleFunc HandleFunc) {
if path == "" {
panic("web: 路径不能为空字符串")
}
// 注册到路由树
root, ok := r.trees[method]
if !ok {
// 说明还没有根节点
root = &node{
path: "/",
}
// 处理方法
r.trees[method] = root
}
// 根节点特殊处理
if path == "/" {
if root.handler != nil {
// 根节点重复注册
panic("web: 路由冲突, 重复注册: [/]")
}
root.handler = handleFunc
return
}
// 开头不能没有 /
if path[0] != '/' {
panic("web: 路径必须以 / 开头")
}
// 结尾
if path != "/" && path[len(path)-1] == '/' {
panic("web: 路径不能以 / 结尾")
}
// 中间连续 //, 可以用string.contains("//")
// /user/home -> 会被切成3段
path = path[1:] // /user -> user
// 切割这个 path
segs := strings.Split(path, "/")
for _, seg := range segs {
if seg == "" {
panic("web: 不能有连续的 /")
}
// 递归下去, 中途有节点不存在就创建
children := root.childOrCreate(seg)
// root向下(深)走
root = children
}
// 前面root已经走到/user/home, 所以这里添加处理方法
// 是对/user/home的处理
if root.handler != nil {
// 普通路由重复注册
panic(fmt.Sprintf("web: 路由冲突, 重复注册: [/%s]", path))
}
root.handler = handleFunc
}
func (n *node) childOrCreate(seg string) *node {
// ...
}
type node struct {
// ...
}
将 AddRoute 首字母小写,则变成私有方法,只能用我们暴露的 Get、Post 等方法注册,就不用担心 method 传乱七八糟的字符串
而 handler 如果为 nil,则表示根本没注册,一般用户不会这样
11. 路由树:静态匹配之路由查找
// router_test.go
package web
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRouter_findRoute(t *testing.T) {
testRoutes := []struct {
method string
path string
}{
{
method: http.MethodGet,
path: "/",
},
{
method: http.MethodDelete,
path: "/",
},
{
method: http.MethodGet,
path: "/user",
},
{
method: http.MethodGet,
path: "/user/home",
},
{
method: http.MethodGet,
path: "/order/detail",
},
{
method: http.MethodPost,
path: "/order/create",
},
{
method: http.MethodPost,
path: "/login",
},
}
r := newRouter()
var mockHandler HandleFunc = func(ctx Context) {}
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
testCases := []struct {
name string
method string
path string
wantFound bool
wantNode *node
}{
{
// method不存在
name: "method not found",
method: http.MethodOptions,
path: "/order/detail",
},
{
// 完全命中
name: "order detail",
method: http.MethodGet,
path: "/order/detail",
wantFound: true,
wantNode: &node{
handler: mockHandler,
path: "detail",
},
},
{
// 命中了,但是没有handler
name: "order",
method: http.MethodGet,
path: "/order",
wantFound: true,
wantNode: &node{
path: "order",
// order/detail注册的child -> detail
children: map[string]*node{
"detail": &node{
handler: mockHandler,
path: "detail",
},
},
},
},
{
// 根节点
name: "root",
method: http.MethodDelete,
path: "/",
wantFound: true,
wantNode: &node{
path: "/",
handler: mockHandler,
},
},
{
name: "path not found",
method: http.MethodGet,
path: "/404",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
n, found := r.findRoute(tc.method, tc.path)
assert.Equal(t, tc.wantFound, found)
if !found {
return
}
msg, ok := tc.wantNode.equal(n)
assert.True(t, ok, msg)
})
}
}
// ...
// router.go
func (r *router) findRoute(method string, path string) (*node, bool) {
// 沿着树深度查找下去
root, ok := r.trees[method]
if !ok {
return nil, false
}
// 根节点
if path == "/" {
return root, true
}
// 把path前后的 / 去掉
path = strings.Trim(path, "/")
// 按照 / 切割
segs := strings.Split(path, "/")
for _, seg := range segs {
child, found := root.childOf(seg)
if !found {
return nil, false
}
root = child
}
// 表示有该节点,但是不一定有handler
return root, true
// return root, root.handler != nil
}
// 找到对应child
func (n *node) childOf(path string) (*node, bool) {
if n.children == nil {
return nil, false
}
child, ok := n.children[path]
return child, ok
}
12. 路由树:静态匹配之集成 Server【更多 IT 资料加微信 djy136928775638】
// server.go
func (h *HTTPServer) serve(ctx *Context) {
n, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
if !ok || n.handler == nil {
// 路由没命中,404
ctx.Resp.WriteHeader(404)
ctx.Resp.Write([]byte("NOT FOUND"))
return
}
n.handler(ctx)
}
13. 路由树:通配符匹配之路由注册【更多 IT 资料加微信 djy136928775638】
// router_test.go
package web
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRouter_addRoute(t *testing.T) {
// 1. 构造路由树
testRoutes := []struct {
method string
path string
}{
// {
// method: http.MethodGet,
// path: "/*",
// },
// {
// method: http.MethodGet,
// path: "/*/*",
// },
// {
// method: http.MethodGet,
// path: "/*/abc/*",
// },
// {
// method: http.MethodGet,
// path: "/*/abc",
// },
// ...
{
method: http.MethodGet,
path: "/order/*",
},
}
var mockHandler HandleFunc = func(ctx *Context) {}
r := newRouter()
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
// 2. 验证路由树
wantRouter := &router{
trees: map[string]*node{
http.MethodGet: &node{
path: "/",
handler: mockHandler,
children: map[string]*node{
"user": &node{
path: "user",
handler: mockHandler,
children: map[string]*node{
"home": &node{
path: "home",
handler: mockHandler,
},
},
},
"order": &node{
path: "order",
children: map[string]*node{
"detail": &node{
path: "detail",
handler: mockHandler,
},
},
starChild: &node{
path: "*",
handler: mockHandler,
},
},
},
},
http.MethodPost: &node{
// ...
},
},
}
// ...
}
func (n *node) equal(y *node) (string, bool) {
if n.path != y.path {
return fmt.Sprintf("节点路径不匹配"), false
}
if len(n.children) != len(y.children) {
return fmt.Sprintf("子节点数量不相等"), false
}
if n.starChild != nil {
msg, ok := n.starChild.equal(y.starChild)
if !ok {
return msg, ok
}
}
// ...
return "", true
}
// router.go
func (n *node) childOrCreate(seg string) *node {
// 通配符
if seg == "*" {
n.starChild = &node{
path: seg,
}
return n.starChild
}
// ...
}
type node struct {
path string
// 静态匹配的节点
// path => 子节点
children map[string]*node
// 通配符(/*)匹配的节点
starChild *node
// 需要一个代表用户注册的业务逻辑
handler HandleFunc
}
14. 路由树:通配符匹配之路由查找与测试【更多 IT 资料加微信 djy136928775638】
// router.go
func (n *node) childOf(path string) (*node, bool) {
// 优先考虑静态匹配,匹配不上再考虑通配符
if n.children == nil {
// return nil, false
return n.starChild, n.starChild != nil
}
child, ok := n.children[path]
if !ok {
return n.starChild, n.starChild != nil
}
return child, ok
}
// router_test.go
func TestRouter_findRoute(t *testing.T) {
testRoutes := []struct {
method string
path string
}{
// ...
{
method: http.MethodGet,
path: "/order/detail",
},
{
method: http.MethodGet,
path: "/order/*",
},
}
r := newRouter()
var mockHandler HandleFunc = func(ctx *Context) {}
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
testCases := []struct {
name string
method string
path string
wantFound bool
wantNode *node
}{
{
// 完全命中
name: "order detail",
method: http.MethodGet,
path: "/order/detail",
wantFound: true,
wantNode: &node{
handler: mockHandler,
path: "detail",
},
},
{
// 完全命中
name: "order star",
method: http.MethodGet,
path: "/order/abc",
wantFound: true,
wantNode: &node{
handler: mockHandler,
path: "*",
},
},
// ...
}
// ...
}
// server_test.go
package web
import (
"fmt"
_ "net/http"
"testing"
)
func TestServer(t *testing.T) {
// h := &HTTPServer{} // NewServer
h := NewHTTPServer()
h.Get("/order/detail", func(ctx *Context) {
ctx.Resp.Write([]byte("hello order detail"))
})
h.Get("/order/*", func(ctx *Context) {
ctx.Resp.Write([]byte(fmt.Sprintf("hello, %s", ctx.Req.URL.Path)))
})
h.Start(":8081")
}
15. 路由树:参数路径之基本注册和查找【更多 IT 资料加微信 djy136928775638】
// router_test.go
package web
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRouter_addRoute(t *testing.T) {
// 1. 构造路由树
testRoutes := []struct {
method string
path string
}{
// ...
// 路径参数
{
method: http.MethodGet,
path: "/order/detail/:id",
},
}
var mockHandler HandleFunc = func(ctx *Context) {}
r := newRouter()
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
// 2. 验证路由树
wantRouter := &router{
trees: map[string]*node{
http.MethodGet: &node{
path: "/",
handler: mockHandler,
children: map[string]*node{
// ...
"order": &node{
path: "order",
children: map[string]*node{
"detail": &node{
path: "detail",
handler: mockHandler,
paramChild: &node{
path: ":id",
handler: mockHandler,
},
},
},
starChild: &node{
path: "*",
handler: mockHandler,
},
},
},
},
http.MethodPost: &node{
// ...
},
},
}
// ...
}
func TestRouter_findRoute(t *testing.T) {
testRoutes := []struct {
method string
path string
}{
// ...
{
method: http.MethodPost,
path: "/login/:username",
},
}
r := newRouter()
var mockHandler HandleFunc = func(ctx *Context) {}
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
testCases := []struct {
name string
method string
path string
wantFound bool
wantNode *node
}{
// ...
{
// 路径参数匹配
name: "login username",
method: http.MethodPost,
path: "/login/wuif",
wantFound: true,
wantNode: &node{path: ":username", handler: mockHandler},
},
}
// ...
}
// ...
// router.go
type node struct {
path string
// 静态匹配的节点
// path => 子节点
children map[string]*node
// 通配符(/*)匹配的节点
starChild *node
// 路径参数
paramChild *node
// 需要一个代表用户注册的业务逻辑
handler HandleFunc
}
func (n *node) childOf(path string) (*node, bool) {
// 优先考虑静态匹配,匹配不上再考虑通配符
if n.children == nil {
if n.paramChild != nil {
return n.paramChild, true
}
// return nil, false
return n.starChild, n.starChild != nil
}
child, ok := n.children[path]
if !ok {
if n.paramChild != nil {
return n.paramChild, true
}
return n.starChild, n.starChild != nil
}
return child, ok
}
func (n *node) childOrCreate(seg string) *node {
if seg[0] == ':' {
n.paramChild = &node{
path: seg,
}
return n.paramChild
}
// ...
}
16. 路由树:参数路径之校验【更多 IT 资料加微信 djy136928775638】
// router_test.go
r = newRouter()
r.addRoute(http.MethodGet, "/a/*", mockHandler)
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "/a/:b", mockHandler)
}, "web: 不允许同时注册路径参数和通配符路径, 已经注册通配符路径")
r = newRouter()
r.addRoute(http.MethodGet, "/a/:b", mockHandler)
assert.Panics(t, func() {
r.addRoute(http.MethodGet, "/a/*", mockHandler)
}, "web: 不允许同时注册路径参数和通配符路径, 已经注册路径参数")
// router.go
func (n *node) childOrCreate(seg string) *node {
if seg[0] == ':' {
if n.starChild != nil {
panic("web: 不允许同时注册路径参数和通配符路径, 已经注册通配符路径")
}
n.paramChild = &node{
path: seg,
}
return n.paramChild
}
// 通配符
if seg == "*" {
if n.paramChild != nil {
panic("web: 不允许同时注册路径参数和通配符路径, 已经注册路径参数")
}
n.starChild = &node{
path: seg,
}
return n.starChild
}
// ...
return res
}
17. 路由树:参数路径之参数值【更多 IT 资料加微信 djy136928775638】
用户希望直接获取路径参数,我们将 node 和路径参数封装到一个对象并返回
// router.go
type matchInfo struct {
n *node
pathParams map[string]string
}
func (r *router) findRoute(method string, path string) (*matchInfo, bool) {
// ...
// 按照 / 切割
segs := strings.Split(path, "/")
var pathParams map[string]string
for _, seg := range segs {
child, paramChild, found := root.childOf(seg)
if !found {
return nil, false
}
//命中了路径参数
if paramChild {
if pathParams == nil {
pathParams = make(map[string]string)
}
// path 是 :id 这种形式
pathParams[child.path[1:]] = seg
}
root = child
}
// 表示有该节点,但是不一定有handler
return &matchInfo{n: root, pathParams: pathParams}, true
// return root, root.handler != nil
}
// 找到对应child
// 参数1:找到的子节点
// 参数2:找到的child是不是路径参数
// 参数3:是否找到
func (n *node) childOf(path string) (*node, bool, bool) {
// 优先考虑静态匹配,匹配不上再考虑通配符
if n.children == nil {
if n.paramChild != nil {
return n.paramChild, true, true
}
// return nil, false
return n.starChild, false, n.starChild != nil
}
child, ok := n.children[path]
if !ok {
if n.paramChild != nil {
return n.paramChild, true, true
}
return n.starChild, false, n.starChild != nil
}
return child, false, ok
}
// router_test.go
func TestRouter_findRoute(t *testing.T) {
testRoutes := []struct {
method string
path string
}{
// ...
}
r := newRouter()
var mockHandler HandleFunc = func(ctx *Context) {}
for _, route := range testRoutes {
r.addRoute(route.method, route.path, mockHandler)
}
testCases := []struct {
name string
method string
path string
wantFound bool
info *matchInfo
}{
{
// method不存在
name: "method not found",
method: http.MethodOptions,
path: "/order/detail",
},
{
// 完全命中
name: "order detail",
method: http.MethodGet,
path: "/order/detail",
wantFound: true,
info: &matchInfo{
n: &node{
handler: mockHandler,
path: "detail",
},
},
},
{
// 完全命中
name: "order star",
method: http.MethodGet,
path: "/order/abc",
wantFound: true,
info: &matchInfo{
n: &node{
handler: mockHandler,
path: "*",
},
},
},
{
// 命中了,但是没有handler
name: "order",
method: http.MethodGet,
path: "/order",
wantFound: true,
info: &matchInfo{
n: &node{
path: "order",
// order/detail注册的child -> detail
children: map[string]*node{
"detail": &node{
handler: mockHandler,
path: "detail",
},
},
},
},
},
{
// 根节点
name: "root",
method: http.MethodDelete,
path: "/",
wantFound: true,
info: &matchInfo{
n: &node{
path: "/",
handler: mockHandler,
},
},
},
{
name: "path not found",
method: http.MethodGet,
path: "/404",
},
{
// 路径参数匹配
name: "login username",
method: http.MethodPost,
path: "/login/wuif",
wantFound: true,
info: &matchInfo{
n: &node{path: ":username", handler: mockHandler},
pathParams: map[string]string{"username": "wuif"},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
info, found := r.findRoute(tc.method, tc.path)
assert.Equal(t, tc.wantFound, found)
if !found {
return
}
assert.Equal(t, tc.info.pathParams, info.pathParams)
msg, ok := tc.info.n.equal(info.n)
assert.True(t, ok, msg)
})
}
}
// context.go
package web
import "net/http"
type Context struct {
Req *http.Request
Resp http.ResponseWriter
PathParams map[string]string
}
// server.go
func (h *HTTPServer) serve(ctx *Context) {
info, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
if !ok || info.n.handler == nil {
// 路由没命中,404
ctx.Resp.WriteHeader(404)
ctx.Resp.Write([]byte("NOT FOUND"))
return
}
ctx.PathParams = info.pathParams
info.n.handler(ctx)
}
18. 路由树总结与面试要点【更多 IT 资料加微信 djy136928775638】
如果注册了非常多的路由,最长公共前缀(gin)的实现性能高一点,但是实现很麻烦
20.第一周作业:实现一棵路由树[选学]【更多 IT 资料加微信 djy136928775638】
21.第一周路由树作业讲解[选学]【更多 IT 资料加微信 djy136928775638】
它设置了 nodeType,如果查找路由时,所有的都找不到,还加一个判断是否有通配符
name(正则表达式)
08 第二周:Web 框架之 Context 与 AOP 方案
context
1. Context 简介【更多 IT 资料加微信 djy136928775638】
parseForm 可以多次调用,只会解析一次
2. Context:Beego Context 设计分析【更多 IT 资料加微信 djy136928775638】
3. Context:Gin Context 设计分析
4. Context:Echo 和 Iris 的 Context 设计分析
5. Context:处理输入输出总结
6. Context:处理输入之 Body 输入
如果引入小众需求会影响框架的复杂度,或者导致性能等受挫,就不要考虑
7. Context:处理输入之表单输入
parseForm只解析表单数据
// context.go
func (c *Context) FormValue(key string) (string, error) {
err := c.Req.ParseForm()
if err != nil {
return "", err
}
//vals, ok := c.Req.Form[key]
//if !ok {
// return "", errors.New("web: key 不存在")
//}
//return vals[0], nil
return c.Req.FormValue(key), nil
}
8. Context:处理输入之查询参数、路径参数和 StringValue
func (c *Context) QueryValue(key string) (string, error) {
if c.queryValues == nil {
// Query每次会调用parseQuery,而form只调用一次,所以query需要缓存
// 缓存
c.queryValues = c.Req.URL.Query()
}
vals, ok := c.queryValues[key]
if !ok {
return "", errors.New("web: key 不存在")
}
return vals[0], nil
}
func (c *Context) QueryValueV1(key string) StringValue {
if c.queryValues == nil {
// Query每次会调用parseQuery,而form只调用一次,所以query需要缓存
// 缓存
c.queryValues = c.Req.URL.Query()
}
vals, ok := c.queryValues[key]
if !ok {
return StringValue{
err: errors.New("web: key 不存在"),
}
}
return StringValue{
val: vals[0],
}
}
func (c *Context) PathValue(key string) (string, error) {
val, ok := c.PathParams[key]
if !ok {
return "", errors.New("web: key 不存在")
}
return val, nil
}
func (c *Context) PathValueV1(key string) StringValue {
val, ok := c.PathParams[key]
if !ok {
return StringValue{
err: errors.New("web: key 不存在"),
}
}
return StringValue{val: val}
}
type StringValue struct {
val string
err error
}
func (s StringValue) AsInt64() (int64, error) {
if s.err != nil {
return 0, s.err
}
return strconv.ParseInt(s.val, 10, 64)
}
// server_test.go
h.Post("/values/v1/:id", func(ctx *Context) {
// 这种就比较方便
id, err := ctx.PathValueV1("id").AsInt64()
if err != nil {
ctx.Resp.WriteHeader(400)
ctx.Resp.Write([]byte("id 输入错误"))
return
}
ctx.Resp.Write([]byte(fmt.Sprintf("hello, %d", id)))
})
h.Post("/values/:id", func(ctx *Context) {
idStr, err := ctx.PathValue("id")
if err != nil {
ctx.Resp.WriteHeader(400)
ctx.Resp.Write([]byte("id 输入错误"))
return
}
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
ctx.Resp.WriteHeader(400)
ctx.Resp.Write([]byte("id 输入错误"))
return
}
ctx.Resp.Write([]byte(fmt.Sprintf("hello, %d", id)))
})
9. Context:处理输出
10. Context 总结与面试要点
go的泛型功能比较有限
aop
11. AOP 简介与不同框架设计概览
12. AOP 设计方案:Middleware
// middleware.go
package web
// Middleware 函数式的责任链模式(洋葱模式)
type Middleware func(next HandleFunc) HandleFunc
// web框架AOP方案在不同语言有不同叫法
// Middleware Handler Chain Filter Filter-Chain
// Interceptor Wrapper
//type MiddlewareV1 interface {
// Invoke(next HandleFunc) HandleFunc
//}
//
//type Interceptor interface {
// Before(ctx *Context)
// After(ctx *Context)
// Surround(ctx *Context)
//}
//type Chain []HandleFunc
//
//type HandleFuncV1 func(ctx *Context) (next bool)
//
//type ChainV1 struct {
// handlers []HandleFuncV1
//}
//
//func (c ChainV1) Run(ctx *Context) {
// for _, h := range c.handlers {
// next := h(ctx)
// // 这种是中断执行
// if !next {
// return
// }
// }
//}
// server.go
type HTTPServer struct {
// addr string // 可以创建时传递,而不是通过start传,也是可以的
// router
router // 因为router实现了addRoute,所以也可以通过编译
mdls []Middleware
}
// http.Handler的接口
// 核心方法 处理请求的入口
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 框架代码
ctx := &Context{
Req: req,
Resp: w,
}
// 最后一个handler是这个
root := h.serve
// 利用最后一个,从后往前回溯,组装链条
// 后一个作为前一个的next 构造链条
for i := len(h.mdls) - 1; i >= 0; i-- {
root = h.mdls[i](root)
}
// 这里执行就是从前往后的
root(ctx)
// 接下来就是查找路由并执行命中的业务逻辑
//h.serve(ctx)
}
无侵入式比较体现代码功底,但是有可能牺牲一部分性能
别的方式, 链式扩散成网状
13. Middleware:AccessLog
![go:build](1ea4c724.png
// web/accesslog/middleware.go
package accesslog
import (
"github.com/goccy/go-json"
"web_framework/web"
)
type MiddlewareBuilder struct {
logFunc func(log string)
}
func (m *MiddlewareBuilder) LogFunc(fn func(log string)) *MiddlewareBuilder {
m.logFunc = fn
return m
}
func (m MiddlewareBuilder) Build() web.Middleware {
return func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
// 在这里记录请求
defer func() {
l := accessLog{
Host: ctx.Req.Host,
Route: ctx.MatchedRoute,
HTTPMethod: ctx.Req.Method,
Path: ctx.Req.URL.Path,
}
data, _ := json.Marshal(l)
m.logFunc(string(data))
}()
next(ctx)
}
}
}
type accessLog struct {
Host string `json:"host,omitempty"`
// 命中的路由
Route string `json:"route,omitempty"`
HTTPMethod string `json:"http_method,omitempty"`
Path string `json:"path,omitempty"`
}
// router.go
func (r *router) addRoute(method string, path string, handleFunc HandleFunc) {
// ...
// 根节点特殊处理
if path == "/" {
// ...
root.handler = handleFunc
root.route = "/"
return
}
// ...
root.handler = handleFunc
root.route = path
}
type node struct {
// /a/b/* -> /a/b/*
route string
// 当前路由 /a/* -> *
path string
// ...
}
// middleware_e2e_test.go
//go:build e2e
package accesslog
import (
"fmt"
"testing"
"web_framework/web"
)
func TestMiddlewareBuilderE2E(t *testing.T) {
builder := MiddlewareBuilder{}
mdls := builder.LogFunc(func(log string) {
fmt.Println(log)
}).Build()
server := web.NewHTTPServer(web.ServerWithMiddleware(mdls))
server.Get("/a/b/*", func(ctx *web.Context) {
ctx.Resp.Write([]byte("hello, it's me"))
})
server.Start(":8081")
}
// middleware_test.go
package accesslog
import (
"fmt"
"net/http"
"testing"
"web_framework/web"
)
func TestMiddlewareBuilder(t *testing.T) {
builder := MiddlewareBuilder{}
mdls := builder.LogFunc(func(log string) {
fmt.Println(log)
}).Build()
server := web.NewHTTPServer(web.ServerWithMiddleware(mdls))
server.Post("/a/b/*", func(ctx *web.Context) {
fmt.Println("hello, it's me")
})
req, err := http.NewRequest(http.MethodPost, "/a/b/c", nil)
if err != nil {
t.Fatal(err)
}
server.ServeHTTP(nil, req)
}
// server.go
func (h *HTTPServer) serve(ctx *Context) {
info, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
if !ok || info.n.handler == nil {
// ...
}
ctx.PathParams = info.pathParams
ctx.MatchedRoute = info.n.route
info.n.handler(ctx)
}
// context.go
type Context struct {
// ...
MatchedRoute string
//cookieSameSite http.SameSite
}
14. Middleware:Trace 简介和 OpenTelemetry
一种是用户传数据进来,怎么来的无所谓(编程类接口),一种是用户传配置文件的路径,我们去读取和解析,这种是强耦合
// web/middlewares/opentelemetry/middleware.go
package opentelemetry
import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"web_framework/web"
)
const instrumentationName = "web_framework/middlewares/opentelemetry"
type MiddlewareBuilder struct {
Tracer trace.Tracer
}
//func NewMiddlewareBuilder(tracer trace.Tracer) *MiddlewareBuilder {
// return &MiddlewareBuilder{tracer}
//}
func (m MiddlewareBuilder) Build() web.Middleware {
if m.Tracer == nil {
m.Tracer = otel.GetTracerProvider().Tracer(instrumentationName)
}
return func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
reqCtx := ctx.Req.Context()
// 尝试和客户端的trace结合
reqCtx = otel.GetTextMapPropagator().Extract(reqCtx, propagation.HeaderCarrier(ctx.Req.Header))
// 创建span,这里的spanName(参数2)应该是命中的路由名称,但是现在还不知道
_, span := m.Tracer.Start(reqCtx, "unknown")
defer span.End()
// 根据业务自己加
span.SetAttributes(attribute.String("http.method", ctx.Req.Method))
span.SetAttributes(attribute.String("http.url", ctx.Req.URL.String()))
// schema -> http/https
span.SetAttributes(attribute.String("http.schema", ctx.Req.URL.Scheme))
span.SetAttributes(attribute.String("http.host", ctx.Req.Host))
// 下一步
next(ctx)
// 执行完next,才有值
span.SetName(ctx.MatchedRoute)
}
}
}
15. Middleware:OpenTelemetry 测试
// middleware_e2e_test.go
//go:build e2e
package opentelemetry
import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/zipkin"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semcove "go.opentelemetry.io/otel/semconv/v1.10.0"
"log"
"os"
"testing"
"time"
"web_framework/web"
)
func TestMiddlewareBuilder_Build(t *testing.T) {
tracer := otel.GetTracerProvider().Tracer(instrumentationName)
builder := MiddlewareBuilder{
Tracer: tracer,
}
server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))
server.Get("/user", func(ctx *web.Context) {
c, span := tracer.Start(ctx.Req.Context(), "first_layer")
defer span.End()
//defer func() {
// // 执行完next,才有值
// span.SetName(ctx.MatchedRoute)
// // 响应码
// span.SetAttributes(attribute.Int("http.status", ctx.RespStatusCode))
// span.End()
//}()
secondC, second := tracer.Start(c, "second_layer")
time.Sleep(time.Second)
_, third1 := tracer.Start(secondC, "third_layer_1")
time.Sleep(100 * time.Millisecond)
third1.End()
_, third2 := tracer.Start(secondC, "third_layer_2")
time.Sleep(300 * time.Millisecond)
third2.End()
second.End()
_, first := tracer.Start(ctx.Req.Context(), "first_layer_1")
defer first.End()
//ctx.Resp.Write([]byte("hello,world"))
time.Sleep(100 * time.Millisecond)
ctx.RespJson(202, User{
Name: "Tom",
})
})
initZipkin(t)
// 一定要加冒号
server.Start(":8081")
}
type User struct {
Name string `json:"name"`
}
func initZipkin(t *testing.T) {
exporter, err := zipkin.New(
"http://tyy:9411/api/v2/spans",
zipkin.WithLogger(log.New(os.Stderr, "opentelemetry-demo", log.Lmicroseconds|log.Ldate)),
)
if err != nil {
t.Fatal(err)
}
batcher := sdktrace.NewBatchSpanProcessor(exporter)
tp := sdktrace.NewTracerProvider(
sdktrace.WithSpanProcessor(batcher),
sdktrace.WithResource(resource.NewWithAttributes(
semcove.SchemaURL,
semcove.ServiceNameKey.String("opentelemetry-demo"),
)),
)
otel.SetTracerProvider(tp)
}
// server.go
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// ...
// 这里执行就是从前往后的
// 最后把RespData和RespStatusCode刷新到响应
var m Middleware = func(next HandleFunc) HandleFunc {
return func(ctx *Context) {
// 这里就设置了RespData和RespStatusCode
next(ctx)
h.flushResp(ctx)
}
}
root = m(root)
root(ctx)
// 接下来就是查找路由并执行命中的业务逻辑
//h.serve(ctx)
}
func (h *HTTPServer) flushResp(ctx *Context) {
if ctx.RespStatusCode != 0 {
ctx.Resp.WriteHeader(ctx.RespStatusCode)
}
n, err := ctx.Resp.Write(ctx.RespData)
if err != nil || n != len(ctx.RespData) {
//log.Fatalln("写入响应失败")
h.log("写入响应失败 %v\n", err)
}
}
func (h *HTTPServer) serve(ctx *Context) {
info, ok := h.findRoute(ctx.Req.Method, ctx.Req.URL.Path)
if !ok || info.n.handler == nil {
// 路由没命中,404
//ctx.Resp.WriteHeader(404)
//ctx.Resp.Write([]byte("NOT FOUND"))
ctx.RespData = []byte("NOT FOUND")
ctx.RespStatusCode = 404
return
}
ctx.PathParams = info.pathParams
ctx.MatchedRoute = info.n.route
info.n.handler(ctx)
}
func NewHTTPServer(opts ...HTTPServerOption) *HTTPServer {
res := &HTTPServer{
router: newRouter(),
log: func(msg string, args ...any) {
fmt.Printf(msg, args...)
},
}
for _, opt := range opts {
opt(res)
}
return res
}
type HTTPServer struct {
// addr string // 可以创建时传递,而不是通过start传,也是可以的
// router
router // 因为router实现了addRoute,所以也可以通过编译
mdls []Middleware
log func(msg string, args ...any)
}
func (c *Context) RespJson(code int, val any) error {
data, err := json.Marshal(val)
if err != nil {
return err
}
//c.Resp.WriteHeader(code) // 状态码
//// 参数1,写了多少数据
//n, err := c.Resp.Write(data)
//if n != len(data) {
// return errors.New("web: 未写入全部数据")
//}
c.RespData = data
c.RespStatusCode = code
return err
}
type Context struct {
Req *http.Request
// 如果用了Resp,就绕开了RespData和RespStatusCode,
// 部分middleware可能无法运作
Resp http.ResponseWriter
PathParams map[string]string
// 主要是为了给middleware用
RespData []byte
RespStatusCode int
queryValues url.Values
MatchedRoute string
//cookieSameSite http.SameSite
}
// web/middlewares/opentelemetry/middleware.go
package opentelemetry
import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"web_framework/web"
)
const instrumentationName = "web_framework/middlewares/opentelemetry"
type MiddlewareBuilder struct {
Tracer trace.Tracer
}
//func NewMiddlewareBuilder(tracer trace.Tracer) *MiddlewareBuilder {
// return &MiddlewareBuilder{tracer}
//}
func (m MiddlewareBuilder) Build() web.Middleware {
if m.Tracer == nil {
m.Tracer = otel.GetTracerProvider().Tracer(instrumentationName)
}
return func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
reqCtx := ctx.Req.Context()
// 尝试和客户端的trace结合
reqCtx = otel.GetTextMapPropagator().Extract(reqCtx, propagation.HeaderCarrier(ctx.Req.Header))
// 创建span,这里的spanName(参数2)应该是命中的路由名称,但是现在还不知道
reqCtx, span := m.Tracer.Start(reqCtx, "unknown")
defer span.End()
// 根据业务自己加
span.SetAttributes(attribute.String("http.method", ctx.Req.Method))
span.SetAttributes(attribute.String("http.url", ctx.Req.URL.String()))
// schema -> http/https
span.SetAttributes(attribute.String("http.schema", ctx.Req.URL.Scheme))
span.SetAttributes(attribute.String("http.host", ctx.Req.Host))
// reqCtx创建完trace,复制给ctx.Req的ctx
// 复制的操作耗性能,但是为了使用该ctx,没办法
ctx.Req = ctx.Req.WithContext(reqCtx)
// 下一步
next(ctx)
// 执行完next,才有值
span.SetName(ctx.MatchedRoute)
// 响应码
span.SetAttributes(attribute.Int("http.status", ctx.RespStatusCode))
}
}
}
16. Middleware:OpenTelemetry 总结
17. Prometheus 详解
18. Middleware:Prometheus
// web/middlewares/prometheus/middleware_test.go
//go:build e2e
package prometheus
import (
"github.com/prometheus/client_golang/prometheus/promhttp"
"math/rand"
"net/http"
"testing"
"time"
"web_framework/web"
)
func TestMiddlewareBuilder_Build(t *testing.T) {
builder := MiddlewareBuilder{
Namespace: "malred",
Subsystem: "web",
Name: "http_response",
}
server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))
server.Get("/user", func(ctx *web.Context) {
val := rand.Intn(1000) + 1
time.Sleep(time.Duration(val) * time.Millisecond)
ctx.RespJson(202, User{Name: "Tom"})
})
// 启动端口让prometheus暴露数据
go func() {
http.Handle("/metrics", promhttp.Handler())
http.ListenAndServe(":8082", nil)
}()
server.Start(":8081")
}
type User struct {
Name string `json:"name"`
}
// web/middlewares/prometheus/middleware.go
package prometheus
import (
"github.com/prometheus/client_golang/prometheus"
"strconv"
"time"
"web_framework/web"
)
// builder方便扩展
type MiddlewareBuilder struct {
Namespace string
Subsystem string
Name string
Help string
}
func (m MiddlewareBuilder) Build() web.Middleware {
vector := prometheus.NewSummaryVec(prometheus.SummaryOpts{
Namespace: m.Namespace,
Subsystem: m.Subsystem,
Name: m.Name,
Help: m.Help,
Objectives: map[float64]float64{
0.5: 0.01,
0.75: 0.01,
0.90: 0.01,
0.99: 0.001,
0.999: 0.0001,
},
}, []string{"pattern", "method", "status"})
// 注册观察者
prometheus.MustRegister(vector) // 两次调用builder会panic
return func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
startTime := time.Now()
defer func() {
duration := time.Now().Sub(startTime).Milliseconds()
pattern := ctx.MatchedRoute
if pattern == "" {
pattern = "unknown"
}
// Observe 表示响应时间
vector.WithLabelValues(pattern, ctx.Req.Method,
strconv.Itoa(ctx.RespStatusCode)).
Observe(float64(duration))
}()
next(ctx)
}
}
}
19. Middleware 例子:错误页面
// middleware_test.go
package errhdl
import (
"net/http"
"testing"
"web_framework/web"
)
func TestNewMiddlewareBuilder(t *testing.T) {
builder := NewMiddlewareBuilder()
builder.AddCode(http.StatusNotFound, []byte(`
<html>
<body>
<h1>NOT FOUND</h1>
</body>
</html>
`)).
AddCode(http.StatusBadRequest, []byte(`
<html>
<body>
<h1>BAD REQUEST</h1>
</body>
</html>
`))
server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))
server.Start(":8081")
}
// web/middlewares/errhdl/middleware.go
package errhdl
import "web_framework/web"
type MiddlewareBuilder struct {
// 这种设计只能返回固定值
// 不能动态渲染
resp map[int][]byte
}
func NewMiddlewareBuilder() *MiddlewareBuilder {
return &MiddlewareBuilder{
resp: map[int][]byte{},
}
}
func (m *MiddlewareBuilder) AddCode(status int, data []byte) *MiddlewareBuilder {
m.resp[status] = data
return m
}
func (m MiddlewareBuilder) Build() web.Middleware {
return func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
next(ctx)
resp, ok := m.resp[ctx.RespStatusCode]
if ok {
// 篡改结果
ctx.RespData = resp
}
}
}
}
20. Middleware 例子:从 panic 中恢复
// middleware_test.go
package recover
import (
"fmt"
"testing"
"web_framework/web"
)
func TestMiddlewareBuilder_Build(t *testing.T) {
builder := MiddlewareBuilder{
Data: []byte("你panic了"),
StatusCode: 500,
Log: func(ctx *web.Context) {
fmt.Printf("panic 路径: %s", ctx.Req.URL.String())
},
}
server := web.NewHTTPServer(web.ServerWithMiddleware(builder.Build()))
server.Get("/user", func(ctx *web.Context) {
panic("发生panic")
})
server.Start(":8081")
}
// web/middlewares/recover/middleware.go
package recover
import "web_framework/web"
type MiddlewareBuilder struct {
StatusCode int
Data []byte
//Log func(err any)
Log func(ctx *web.Context)
//Log func(stack string)
}
func (m MiddlewareBuilder) Build() web.Middleware {
return func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
defer func() {
if err := recover(); err != nil {
ctx.RespData = m.Data
ctx.RespStatusCode = m.StatusCode
m.Log(ctx)
}
}()
next(ctx)
}
}
}
21. Middleware 总结和面试
第二周作业
09 第三周:Web 框架之页面渲染、文件处理与 Session
页面渲染
1. 页面渲染:模板引擎接口定义
2. 页面渲染:Template 语法
text/template在中间件中常用,比如grpc代码生成、谷歌依赖注入的wire、go不支持动态生成类型,所以动态代理一般用代码生成实现
package template_demo
import (
"bytes"
"fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"html/template"
"testing"
)
func TestHelloWorld(t *testing.T) {
type User struct {
Name string
}
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`Hello, {{.Name}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, User{Name: "Tom"})
require.NoError(t, err)
assert.Equal(t, `Hello, Tom`, buffer.String())
}
func TestMapData(t *testing.T) {
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`Hello, {{.Name}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, map[string]string{"Name": "Tom"})
require.NoError(t, err)
assert.Equal(t, `Hello, Tom`, buffer.String())
}
func TestSlice(t *testing.T) {
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`Hello, {{index . 0}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, []string{"Tom"})
require.NoError(t, err)
assert.Equal(t, `Hello, Tom`, buffer.String())
}
func TestBasic(t *testing.T) {
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`Hello, {{.}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, 123)
require.NoError(t, err)
assert.Equal(t, `Hello, 123`, buffer.String())
}
func TestInvoke(t *testing.T) {
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`打印数字: {{printf "%.2f" 1.2345}}\n切片长度: {{len .Slice}}\n{{.Hello "Tom" "Jerry"}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, FnCall{Slice: []string{"a", "b"}})
require.NoError(t, err)
assert.Equal(t, `打印数字: 1.23\n切片长度: 2\nHello, Tom-Jerry`, buffer.String())
}
type FnCall struct {
Slice []string
}
func (f FnCall) Hello(first string, last string) string {
return fmt.Sprintf("Hello, %s-%s", first, last)
}
func TestFor(t *testing.T) {
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`
{{- range $idx, $ele := .Slice}}
{{- .}}
{{$idx}}-{{$ele}}
{{end}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, FnCall{Slice: []string{"a", "b"}})
require.NoError(t, err)
assert.Equal(t, `a
0-a
b
1-b
`, buffer.String())
}
func TestForLoop(t *testing.T) {
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`
{{- range $idx, $ele := .}}
{{- $idx}}
{{- end}}`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, make([]int, 10))
require.NoError(t, err)
assert.Equal(t, `0123456789`, buffer.String())
}
func TestIfElse(t *testing.T) {
type User struct {
Age int
}
tpl := template.New("hello-world")
tpl, err := tpl.Parse(`
{{- if and (gt .Age 0) (le .Age 6) -}}
儿童: (0, 6]
{{- else if and (gt .Age 6) (le .Age 18) -}}
少年: (6. 18]
{{- else -}}
成人: >18
{{- end -}}
`)
require.NoError(t, err)
buffer := &bytes.Buffer{}
err = tpl.Execute(buffer, User{Age: 14})
require.NoError(t, err)
assert.Equal(t, `少年: (6. 18]`, buffer.String())
}
3. 页面渲染:GoTemplateEngin 实现、面试要点总结
模板引擎涉及编译原理(解释执行)
// context.go
type Context struct {
// ...
tplEngine TemplateEngine
//cookieSameSite http.SameSite
}
func (c *Context) Render(tplName string, data any) error {
var err error
c.RespData, err = c.tplEngine.Render(c.Req.Context(), tplName, data)
if err != nil {
c.RespStatusCode = http.StatusInternalServerError
return err
}
c.RespStatusCode = http.StatusOK
return nil
}
// server.go
type HTTPServer struct {
// addr string // 可以创建时传递,而不是通过start传,也是可以的
// router
router // 因为router实现了addRoute,所以也可以通过编译
mdls []Middleware
log func(msg string, args ...any)
tplEngine TemplateEngine
}
func ServerWithTemplateEngine(tplEngin TemplateEngine) HTTPServerOption {
return func(server *HTTPServer) {
server.tplEngine = tplEngin
}
}
// http.Handler的接口
// 核心方法 处理请求的入口
func (h *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 框架代码
ctx := &Context{
Req: req,
Resp: w,
tplEngine: h.tplEngine,
}
// ...
}
// template_e2e_test.go
//go:build e2e
package web
import (
"github.com/stretchr/testify/require"
"html/template"
"log"
"testing"
)
func TestLoginPage(t *testing.T) {
tpl, err := template.ParseGlob("testdata/tpls/*.gohtml")
require.NoError(t, err)
engine := &GoTemplateEngine{
T: tpl,
}
h := NewHTTPServer(ServerWithTemplateEngine(engine))
h.Get("/login", func(ctx *Context) {
err := ctx.Render("login.gohtml", nil)
if err != nil {
log.Println(err)
}
})
h.Start(":8081")
}
// template.go
package web
import (
"bytes"
"context"
"html/template"
)
type TemplateEngine interface {
// Render 渲染页面
// tplName 模板的名字,按名索引
// data 渲染页面用的数据
Render(ctx context.Context, tplName string, data any) ([]byte, error)
// 渲染页面,数据写入writer
//Render(ctx context.Context, tplName string, data any,writer io.Writer) ( error)
// 用这个也行,但是会耦合
//Render(ctx Context)
}
type GoTemplateEngine struct {
T *template.Template
}
func (g *GoTemplateEngine) Render(ctx context.Context, tplName string, data any) ([]byte, error) {
bs := &bytes.Buffer{}
err := g.T.ExecuteTemplate(bs, tplName, data)
return bs.Bytes(), err
}
文件处理
4. 文件处理:文件基本操作
package file_demo
import (
"fmt"
"github.com/stretchr/testify/require"
"os"
"testing"
)
func TestFile(t *testing.T) {
// working dir
// D:\go\projs\web_framework\web\file_demo <nil>
fmt.Println(os.Getwd())
f, err := os.Open("testdata/my_file.txt")
require.NoError(t, err)
data := make([]byte, 64)
n, err := f.Read(data)
require.NoError(t, err)
// 读取了多少
fmt.Println(n)
n, err = f.WriteString("hello")
// bad file descriptor 不可写
fmt.Println(err)
// 不可写入
require.Error(t, err) // Access is denied.
// 写入了多少
fmt.Println(n)
f.Close()
f, err = os.OpenFile("testdata/my_file.txt", os.O_APPEND|os.O_WRONLY, os.ModeAppend)
require.NoError(t, err)
n, err = f.WriteString("hello")
require.NoError(t, err)
fmt.Println(n)
f.Close()
f, err = os.Create("testdata/my_file_copy.txt")
require.NoError(t, err)
n, err = f.WriteString("hello, world")
require.NoError(t, err)
fmt.Println(n)
f.Close()
}
5. 文件处理:文件上传
// file.go
package web
import (
"io"
"mime/multipart"
"net/http"
"os"
"strings"
)
type FileUploader struct {
FileField string
// 返回文件保存路径
// 如果框架来设名称要考虑重名问题(可以使用UUID解决)
DataPathFunc func(file *multipart.FileHeader) string
}
func (u FileUploader) Handle() HandleFunc {
if u.FileField==""{
u.FileField="file"
}
if u.DataPathFunc==nil{
// 设置默认值
u.DataPathFunc= func(file *multipart.FileHeader) string {
return filepath.Join("testdata", "upload", file.Filename)
}
}
return func(ctx *Context) {
// 上传文件的逻辑
// 1. 读取文件内容
/*
// A FileHeader describes a file part of a multipart request.
type FileHeader struct {
Filename string
Header textproto.MIMEHeader
Size int64
content []byte
tmpfile string
tmpoff int64
tmpshared bool
}
*/
file, fileHeader, err := ctx.Req.FormFile(u.FileField)
if err != nil {
ctx.RespStatusCode = 500
ctx.RespData = []byte("上传失败 " + err.Error())
return
}
defer file.Close()
// 2. 计算目标路径
// 将目标路径计算的逻辑交给用户
dst := u.DataPathFunc(fileHeader)
// 可以尝试把dst里不存在的目录都创建,防止报错
index := strings.LastIndex(dst, "\\")
err = os.MkdirAll(dst[:index], os.ModePerm)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("上传失败 " + err.Error())
return
}
// os.O_WRONLY 可写
// os.O_TRUNC 存在就清空
// os.O_CREATE 不存在就创建
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0o666)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("上传失败 " + err.Error())
return
}
defer dstFile.Close()
// 3. 保存文件
// buf(参数3)指定每次传输多大一段, 会影响性能
// 要考虑复用, 如果传nil, 有默认提供
_, err = io.CopyBuffer(dstFile, file, nil)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("上传失败 " + err.Error())
return
}
// 4. 返回响应
ctx.RespStatusCode = http.StatusOK
ctx.RespData = []byte("上传成功")
}
}
// file_e2e_test.go
//go:build e2e
package web
import (
"github.com/stretchr/testify/require"
"html/template"
"log"
"mime/multipart"
"path/filepath"
"testing"
)
func TestFileUpload(t *testing.T) {
tpl, err := template.ParseGlob("testdata/tpls/*.gohtml")
require.NoError(t, err)
engine := &GoTemplateEngine{
T: tpl,
}
h := NewHTTPServer(ServerWithTemplateEngine(engine))
h.Get("/upload", func(ctx *Context) {
err := ctx.Render("upload.gohtml", nil)
if err != nil {
log.Println(err)
}
})
fu := FileUploader{
// 表单form的name属性
FileField: "myfile",
DataPathFunc: func(file *multipart.FileHeader) string {
return filepath.Join("testdata", "upload", file.Filename)
},
}
h.Post("/upload", fu.Handle())
h.Start(":8081")
}
// testdata/upload.gohtml
<html>
<body>
<form action="/upload" method="post" enctype="multipart/form-data">
<input type="file" name="myfile">
<button type="submit">上传</button>
</form>
</body>
</html>
6. 文件处理:文件下载
// file_e2e_test.go
func TestFileDownload(t *testing.T) {
h := NewHTTPServer()
fd := FileDownloader{
Dir: filepath.Join("testdata", "download"),
}
h.Get("/download", fd.Handle())
h.Start(":8081")
}
// file.go
type FileDownloader struct {
Dir string
}
func (d FileDownloader) Handle() HandleFunc {
return func(ctx *Context) {
// xxx?file=xxx
filename, err := ctx.QueryValue("file")
if err != nil {
ctx.RespStatusCode = http.StatusBadRequest
ctx.RespData = []byte("找不到目标文件")
return
}
// 安全性提升(一点点)
filename = filepath.Clean(filename)
dst := filepath.Join(d.Dir, filename)
// 转绝对路径(相对路径可能被拿到私密文件)
//dst,err=filepath.Abs(dst)
//if strings.Contains(d.Dir,filename){
//
//}
fn := filepath.Base(dst)
header := ctx.Resp.Header()
header.Set("Content-Disposition", "attachment;filename="+fn)
header.Set("Content-Description", "File Transfer")
// 同样二进制文件
header.Set("Content-Type", "application/octet-stream")
header.Set("Content-Transfer-Encoding", "binary")
// 不会缓存
header.Set("Expires", "0")
header.Set("Cache-Control", "must-revalidate")
header.Set("Pragma", "public")
http.ServeFile(ctx.Resp, ctx.Req, dst)
}
}
现在的文件下载没有缓存, 而且没有对大文件的处理(分片[需要前后端配合])
7. 文件处理:静态资源处理、面试要点总结
type StaticResourceHandlerOption func(handler *StaticResourceHandler)
type StaticResourceHandler struct {
dir string
// 设置type防止有些浏览器无法解析文件
// 根据文件后缀名匹配type
extensionContextTypeMap map[string]string
cache *lru.Cache // 缓存
maxSize int // 大文件不缓存
}
func NewStaticResourceHandler(dir string, opts ...StaticResourceHandlerOption) (*StaticResourceHandler, error) {
// New(总共缓存多少个k-y)
cache, err := lru.New(1000)
if err != nil {
return nil, err
}
res := &StaticResourceHandler{
dir: dir,
cache: cache,
maxSize: 1024 * 1024 * 10, // 10MB
extensionContextTypeMap: map[string]string{
"jpeg": "image/jpeg",
"jpg": "image/jpg",
"jpe": "image/jpeg",
"png": "image/png",
"pdf": "image/pdf",
},
}
for _, opt := range opts {
opt(res)
}
return res, nil
}
func StaticWithMaxFileSize(maxSize int) StaticResourceHandlerOption {
return func(handler *StaticResourceHandler) {
handler.maxSize = maxSize
}
}
func StaticWithCache(c *lru.Cache) StaticResourceHandlerOption {
return func(handler *StaticResourceHandler) {
handler.cache = c
}
}
func StaticWithMoreExtension(extMap map[string]string) StaticResourceHandlerOption {
return func(h *StaticResourceHandler) {
for ext, contenType := range extMap {
h.extensionContextTypeMap[ext] = contenType
}
}
}
func (s *StaticResourceHandler) Handle(ctx *Context) {
// 1. 拿到目标文件名
file, err := ctx.PathValue("file")
if err != nil {
ctx.RespStatusCode = http.StatusBadRequest
ctx.RespData = []byte("请求路径错误")
return
}
dst := filepath.Join(s.dir, file)
// jpg/txt/video/audio ...
ext := filepath.Ext(dst)[1:]
header := ctx.Resp.Header()
if data, ok := s.cache.Get(dst); ok {
header.Set("Content-Type", s.extensionContextTypeMap[ext])
header.Set("Content-Length", strconv.Itoa(len(data.([]byte))))
header.Set("Web-Msg", "from cache")
ctx.RespData = data.([]byte)
ctx.RespStatusCode = 200
return
}
// 2. 定位到目标文件,读取
//dst := filepath.Join(s.dir, file)
data, err := ioutil.ReadFile(dst)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("服务器系统错误")
return
}
// 缓存
if len(data) <= s.maxSize {
s.cache.Add(dst, data)
}
// 3. 返回
header.Set("Content-Type", s.extensionContextTypeMap[ext])
header.Set("Content-Length", strconv.Itoa(len(data)))
ctx.RespData = data
ctx.RespStatusCode = 200
}
// test
func TestStaticResourceHandler(t *testing.T) {
h := NewHTTPServer()
s, err := NewStaticResourceHandler(filepath.Join("testdata", "static"))
require.NoError(t, err)
// localhost:8081/static/xxx.jpg
h.Get("/static/:file", s.Handle)
h.Start(":8081")
}
Session
8. Session:概念与不同框架的 Session 设计分析
9. Session:接口设计
// session/types.go
package session
import (
"context"
"net/http"
)
// 管理session本身
type Store interface {
// Session对应的id谁来指定 ? 要不要让session内部生成id
// 要不要在接口维度上设置超时时间
Generate(ctx context.Context,id string)(Session,error)
Refresh(ctx context.Context,id string)error
Remove(ctx context.Context,id string)error
Get(ctx context.Context,id string)(Session,error)
}
type Session interface {
Get(ctx context.Context,key string)(any,error)
Set(ctx context.Context,key string,val string)error
ID()string
}
type Propagator interface {
Inject(id string,writer http.ResponseWriter)error
Extract(req *http.Request) (string,error)
Remove(writer http.ResponseWriter)
}
10. Session:用户使用示例和 Manager 设计
// types_test.go
package session
import (
"net/http"
"testing"
"web_framework/web"
)
func TestSession(t *testing.T) {
// 简单的登录检验
var m Manager
server := web.NewHTTPServer(web.ServerWithMiddleware(func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
if ctx.Req.URL.Path == "/login" {
// 放行,让用户登录
next(ctx)
return
}
_, err := m.GetSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusUnauthorized
ctx.RespData = []byte("请重新登录")
return
}
// 刷新session过期时间
_ = m.RefreshSession(ctx)
next(ctx)
}
}))
server.Post("/login", func(ctx *web.Context) {
// 在session之前校验用户名密码
sess, err := m.InitSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("登录失败")
return
}
err = sess.Set(ctx.Req.Context(), "nickname", "xiaoming")
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("登录失败")
return
}
ctx.RespStatusCode = http.StatusOK
ctx.RespData = []byte("登录成功")
return
})
server.Post("/logout", func(ctx *web.Context) {
// 清理数据
err := m.RemoveSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("退出失败")
return
}
ctx.RespStatusCode = http.StatusOK
ctx.RespData = []byte("退出成功")
})
server.Get("/user", func(ctx *web.Context) {
sess, err := m.GetSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusUnauthorized
ctx.RespData = []byte("请重新登录")
return
}
sess.Get(ctx.Req.Context(), "nickname")
})
server.Start(":8081")
}
// types.go
package session
import (
"context"
"net/http"
)
// 管理session本身
type Store interface {
// Session对应的id谁来指定 ? 要不要让session内部生成id
// 要不要在接口维度上设置超时时间
Generate(ctx context.Context, id string) (Session, error)
Refresh(ctx context.Context, id string) error
Remove(ctx context.Context, id string) error
Get(ctx context.Context, id string) (Session, error)
}
type Session interface {
Get(ctx context.Context, key string) (any, error)
Set(ctx context.Context, key string, val string) error
ID() string
}
type Propagator interface {
Inject(id string, writer http.ResponseWriter) error
Extract(req *http.Request) (string, error)
Remove(writer http.ResponseWriter) error
}
// manager.go
package session
import (
"github.com/google/uuid"
"web_framework/web"
)
type Manager struct {
Propagator
Store
}
func (m *Manager) GetSession(ctx *web.Context) (Session, error) {
sessId, err := m.Extract(ctx.Req)
if err != nil {
return nil, err
}
return m.Get(ctx.Req.Context(), sessId)
}
func (m *Manager) InitSession(ctx *web.Context) (Session, error) {
id := uuid.New().String()
sess, err := m.Generate(ctx.Req.Context(), id)
if err != nil {
return nil, err
}
// 注入进去响应
err = m.Inject(id, ctx.Resp)
return sess, err
}
func (m *Manager) RemoveSession(ctx *web.Context) error {
sess, err := m.GetSession(ctx)
if err != nil {
return err
}
err = m.Store.Remove(ctx.Req.Context(), sess.ID())
if err != nil {
return err
}
err = m.Propagator.Remove(ctx.Resp)
if err != nil {
return err
}
return err
}
func (m *Manager) RefreshSession(ctx *web.Context) error {
sess, err := m.GetSession(ctx)
if err != nil {
return err
}
// 这里假设session的id不会变
return m.Refresh(ctx.Req.Context(), sess.ID())
}
11. Session:web.Context 缓存 Session
// context.go
type Context struct {
// ...
UserValues map[string]any // 缓存
}
// session/manager.go
type Manager struct {
Propagator
Store
CtxSessKey string // 在ctx缓存中的key
}
func (m *Manager) GetSession(ctx *web.Context) (Session, error) {
if ctx.UserValues == nil {
ctx.UserValues = make(map[string]any, 1)
}
// 从缓存中拿
val, ok := ctx.UserValues[m.CtxSessKey]
if ok {
return val.(Session), nil
}
sessId, err := m.Extract(ctx.Req)
if err != nil {
return nil, err
}
sess, err := m.Get(ctx.Req.Context(), sessId)
if err != nil {
return nil, err
}
// 设置缓存
ctx.UserValues[m.CtxSessKey] = sess
return sess, err
}
12. Session:基于内存的实现
// momery/session.go
package momery
import (
"context"
"errors"
"fmt"
cache "github.com/patrickmn/go-cache"
"sync"
"time"
"web_framework/web/session"
)
var (
errKeyNotFound = errors.New("session: 找不到key")
errSessionNotFound = errors.New("session: 找不到session")
)
type Store struct {
mutex sync.RWMutex
sessions *cache.Cache
expiration time.Duration
}
func NewStore(expiration time.Duration) *Store {
return &Store{
// 缓存过期时间,检查间隔
sessions: cache.New(expiration, time.Second),
expiration: expiration,
}
}
func (s Store) Generate(ctx context.Context, id string) (session.Session, error) {
// 防止在generate的时候有refresh
s.mutex.Lock()
defer s.mutex.Unlock()
sess := &Session{
id: id,
}
// 存起来
s.sessions.Set(id, sess, s.expiration)
return sess, nil
}
func (s Store) Refresh(ctx context.Context, id string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
val, ok := s.sessions.Get(id)
if !ok {
return fmt.Errorf("session: 该 id 对应的 session 不存在 %s", id)
}
s.sessions.Set(id, val, s.expiration)
return nil
}
func (s Store) Remove(ctx context.Context, id string) error {
s.mutex.Lock()
defer s.mutex.Unlock()
s.sessions.Delete(id)
return nil
}
func (s Store) Get(ctx context.Context, id string) (session.Session, error) {
s.mutex.RLock()
defer s.mutex.RUnlock()
sess, ok := s.sessions.Get(id)
if !ok {
return nil, errSessionNotFound
}
return sess.(*Session), nil
}
type Session struct {
id string
expiration time.Duration
// 控制性强
//mutex sync.RWMutex
//values map[string]any
values sync.Map
}
// ctrl+I 快捷实现接口
func (s Session) Get(ctx context.Context, key string) (any, error) {
val, ok := s.values.Load(key)
if !ok {
//return nil, fmt.Errorf("%w, key %s", errKeyNotFound,key)
return nil, errKeyNotFound
}
return val, nil
}
// 和第三方交互可以考虑加context
func (s Session) Set(ctx context.Context, key string, val string) error {
s.values.Store(key, val)
return nil
}
func (s Session) ID() string {
// 只读不用太考虑线程安全的问题
return s.id
}
13. Session:基于 Redis 的实现
// session_e2e_test.go
//go:build e2e
package redis
import (
"context"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestStore_Generate(t *testing.T) {
s := newStore()
ctx := context.Background()
id := "sess_test_id"
sess, err := s.Generate(ctx, id)
require.NoError(t, err)
defer s.Remove(ctx, id)
err = sess.Set(ctx, "key1", "123")
require.NoError(t, err)
val, err := sess.Get(ctx, "key1")
require.NoError(t, err)
assert.Equal(t, "123", val)
}
func newStore() *Store {
rc := redis.NewClient(&redis.Options{
Addr: "localhost:6379",
// Password: "redis",
})
return NewStore(rc)
}
// redis/session.go
package redis
import (
"context"
"errors"
"fmt"
"github.com/redis/go-redis/v9"
"time"
"web_framework/web/session"
)
var (
errSessionNotFound = errors.New("session: id 对应的 session 不存在")
)
type StoreOption func(store *Store)
// hset
//
// sessionid: key: value
//
// map[string]map[string]string
type Store struct {
prefix string
client redis.Cmdable
expiration time.Duration
}
func NewStore(client redis.Cmdable, opts ...StoreOption) *Store {
res := &Store{
expiration: time.Minute * 15,
client: client,
prefix: "sessid",
}
for _, opt := range opts {
opt(res)
}
return res
}
// StoreWithPrefix 用户可以设置自己的prefix
func StoreWithPrefix(prefix string) StoreOption {
return func(store *Store) {
store.prefix = prefix
}
}
func (s Store) Generate(ctx context.Context, id string) (session.Session, error) {
key := redisKey(s.prefix, id)
// 如果是用lua脚本,就线程安全
_, err := s.client.HSet(ctx, key, id, id).Result()
if err != nil {
return nil, err
}
_, err = s.client.Expire(ctx, key, s.expiration).Result()
if err != nil {
return nil, err
}
return &Session{
id: id,
key: key,
client: s.client,
}, nil
}
func (s Store) Refresh(ctx context.Context, id string) error {
key := redisKey(s.prefix, id)
ok, err := s.client.Expire(ctx, key, s.expiration).Result()
if err != nil {
return err
}
if !ok {
return errSessionNotFound
}
return nil
}
func (s Store) Remove(ctx context.Context, id string) error {
key := redisKey(s.prefix, id)
_, err := s.client.Del(ctx, key).Result()
return err
//if err != nil {
// return err
//}
// 代表id对应的session不存在
//if cnt == 0 {
//}
}
func (s Store) Get(ctx context.Context, id string) (session.Session, error) {
// 自由决策要不要提取把session存储的用户数据全部提取
// 1. 都不拿 // 2. 只拿高频数据(热点数据) // 3. 都拿
key := redisKey(s.prefix, id)
cnt, err := s.client.Exists(ctx, key).Result()
if err != nil {
return nil, err
}
if cnt != 1 {
return nil, errSessionNotFound
}
return &Session{
id: id,
key: key,
client: s.client,
}, nil
}
type Session struct {
prefix string
key string
id string
client redis.Cmdable
}
func (s Session) Get(ctx context.Context, key string) (any, error) {
// const lua=`
//if redis.call("exists", KEYS[1])
//then
// return redis.call("hset", KEYS[1], ARGV[1], ARGV[2])
//else
// return -1
//end
//`
val, err := s.client.HGet(ctx, s.key, key).Result()
return val, err
}
func (s Session) Set(ctx context.Context, key string, val string) error {
// KEYS[1] => s.id
const lua = `
if redis.call("exists", KEYS[1])
then
return redis.call("hset", KEYS[1], ARGV[1], ARGV[2])
else
return -1
end
`
res, err := s.client.Eval(ctx, lua, []string{s.key}, key, val).Int()
if err != nil {
return err
}
if res < 0 {
return errSessionNotFound
}
return nil
}
func (s Session) ID() string {
return s.id
}
func redisKey(prefix, id string) string {
return fmt.Sprintf("%s-%s", prefix, id)
}
14. Session:基于 Cookie 的实现
// cookie/propagator
package cookie
import "net/http"
type Option func(p *Propagator)
type Propagator struct {
cookieName string
cookieOption func(c *http.Cookie)
}
func NewPropagator() *Propagator {
return &Propagator{
cookieName: "sessid",
cookieOption: func(c *http.Cookie) {
},
}
}
// 推荐设置cookie的属性: Domain/Secure/SameSite/HttpOnly
func WithCookieName(name string) Option {
return func(p *Propagator) {
p.cookieName = name
}
}
func (p *Propagator) Inject(id string, writer http.ResponseWriter) error {
c := &http.Cookie{
Name: p.cookieName,
Value: id,
}
// 用户可以通过这个对cookie进行操作
p.cookieOption(c)
http.SetCookie(writer, c)
return nil
}
func (p *Propagator) Extract(req *http.Request) (string, error) {
cookie, err := req.Cookie(p.cookieName)
if err != nil {
return "", err
}
return cookie.Value, nil
}
func (p *Propagator) Remove(writer http.ResponseWriter) error {
c := &http.Cookie{
Name: p.cookieName,
MaxAge: -1, // 设置为过期状态
}
http.SetCookie(writer, c)
return nil
}
15. Session:测试与面试要点总结
// test/types_test.go
package test
import (
"net/http"
"testing"
"time"
"web_framework/web"
"web_framework/web/session"
"web_framework/web/session/cookie"
"web_framework/web/session/memory"
)
func TestSession(t *testing.T) {
// 简单的登录检验
var m *session.Manager = &session.Manager{
Propagator: cookie.NewPropagator(),
Store: memory.NewStore(time.Minute * 15),
CtxSessKey: "sessKey",
}
server := web.NewHTTPServer(web.ServerWithMiddleware(func(next web.HandleFunc) web.HandleFunc {
return func(ctx *web.Context) {
if ctx.Req.URL.Path == "/login" {
// 放行,让用户登录
next(ctx)
return
}
_, err := m.GetSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusUnauthorized
ctx.RespData = []byte("请重新登录")
return
}
// 刷新session过期时间
_ = m.RefreshSession(ctx)
next(ctx)
}
}))
server.Post("/login", func(ctx *web.Context) {
// 在session之前校验用户名密码
sess, err := m.InitSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("登录失败")
return
}
err = sess.Set(ctx.Req.Context(), "nickname", "xiaoming")
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("登录失败")
return
}
ctx.RespStatusCode = http.StatusOK
ctx.RespData = []byte("登录成功")
return
})
server.Post("/logout", func(ctx *web.Context) {
// 清理数据
err := m.RemoveSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusInternalServerError
ctx.RespData = []byte("退出失败")
return
}
ctx.RespStatusCode = http.StatusOK
ctx.RespData = []byte("退出成功")
})
server.Get("/user", func(ctx *web.Context) {
sess, err := m.GetSession(ctx)
if err != nil {
ctx.RespStatusCode = http.StatusUnauthorized
ctx.RespData = []byte("请重新登录")
return
}
val, _ := sess.Get(ctx.Req.Context(), "nickname")
ctx.RespData = []byte(val.(string))
})
server.Start(":8081")
}
出bug了
10 第四周:ORM 框架之 SELECT 与元数据
ORM概览
1. ORM 学习路线图
2. ORM 框架概览:Beego ORM 分析
3. ORM 框架概览:GORM 和 Ent 分析
4. ORM 框架总结和面试要点
select
5. SELECT:Beego、GORM、Ent 的 SQL构造分析
6. SELECT:核心接口定义
// org/types.go
package orm
// context 包是Go 1.7 引入的标准库,
// 主要用于在goroutine 之间传递取消信号、超时时间、截止时间以及一些共享的值等。
import (
"context"
"database/sql"
)
// Querier 用于select
type Querier[T any] interface {
// 用不用指针(*)都可以
// 反射,aop之类的用指针方便点
// 指针可能存在内存逃逸问题
Get(ctx context.Context) (*T, error)
GetMulti(ctx context.Context) ([]*T, error)
}
// Executor 用于 INSERT DELETE UPDATE
type Executor interface {
Exec(ctx context.Context) (sql.Result, error)
}
type QueryBuilder interface {
Build() (*Query, error)
}
type Query struct {
SQL string
Args []any
}
7. SELECT:SELECT 语句规范、Selector 定义、FROM 语句实现
// select_test.go
package orm
import (
"database/sql"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSelector_Build(t *testing.T) {
testCases := []struct {
name string
builder QueryBuilder
wantQuery *Query
wantErr error
}{
{
name: "not from",
builder: &Selector[TestModel]{},
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel`;",
Args: nil,
},
},
{
name: "from",
builder: (&Selector[TestModel]{}).Table("`test_model`"),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty from",
builder: (&Selector[TestModel]{}),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel`;",
Args: nil,
},
},
{
name: "empty table",
builder: (&Selector[TestModel]{}).Table(""),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel`;",
Args: nil,
},
},
{
name: "db table",
builder: (&Selector[TestModel]{}).Table("`mybatis`.`test`"),
wantQuery: &Query{
SQL: "SELECT * FROM `mybatis`.`test`;",
Args: nil,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.builder.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
// select.go
package orm
import (
"context"
"reflect"
"strings"
)
type Selector[T any] struct {
table string
}
func (s *Selector[T]) Build() (*Query, error) {
var sb strings.Builder
sb.WriteString("SELECT * FROM ")
if s.table == "" {
// 反射获取表名
var t T
typ := reflect.TypeOf(t)
sb.WriteByte('`')
sb.WriteString(typ.Name())
sb.WriteByte('`')
} else {
//segs := strings.Split(s.table, ".")
//sb.WriteByte('`')
sb.WriteString(s.table)
//sb.WriteByte('`')
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
}, nil
}
// func (s *Selector[T]) From(table string) *Selector[T] {
// s.table = table
// return s
// }
func (s *Selector[T]) Table(table string) *Selector[T] {
s.table = table
return s
}
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
//TODO implement me
panic("implement me")
}
func (s *Selector[T]) GetMulti(ctx context.Context) ([]*T, error) {
//TODO implement me
panic("implement me")
}
8. SELECT:WHRER 语句、Expression 抽象和面试要点
// select.go
package orm
import (
"context"
"fmt"
"reflect"
"strings"
)
type Selector[T any] struct {
table string
where []Predicate
sb *strings.Builder
args []any
}
func (s *Selector[T]) Build() (*Query, error) {
//var sb strings.Builder
s.sb = &strings.Builder{} // 指针类型一定要初始化
sb := s.sb
sb.WriteString("SELECT * FROM ")
if s.table == "" {
// 反射获取表名
var t T
typ := reflect.TypeOf(t)
sb.WriteByte('`')
sb.WriteString(typ.Name())
sb.WriteByte('`')
} else {
//segs := strings.Split(s.table, ".")
//sb.WriteByte('`')
sb.WriteString(s.table)
//sb.WriteByte('`')
}
if len(s.where) > 0 {
sb.WriteString(" WHERE ")
p := s.where[0]
for i := 1; i < len(s.where); i++ {
p = p.And(s.where[i])
}
if err := s.buildExpression(p); err != nil {
return nil, err
}
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: s.args,
}, nil
}
func (s *Selector[T]) buildExpression(expr Expression) error {
switch exp := expr.(type) {
case nil:
case Predicate:
// 处理p
_, ok := exp.left.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.left); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
s.sb.WriteByte(' ')
s.sb.WriteString(exp.op.String())
s.sb.WriteByte(' ')
_, ok = exp.right.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.right); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
case Column:
s.sb.WriteByte('`')
s.sb.WriteString(exp.name)
s.sb.WriteByte('`')
//
case value:
s.sb.WriteByte('?')
s.addArg(exp.val)
//
default:
return fmt.Errorf("orm: 不支持的表达式类型 %v", expr)
}
return nil
}
func (s *Selector[T]) addArg(val any) *Selector[T] {
if s.args == nil {
s.args = make([]any, 0, 8)
}
s.args = append(s.args, val)
return s
}
func (s *Selector[T]) Where(ps ...Predicate) *Selector[T] {
s.where = ps
return s
}
// predicate.go
package orm
// 衍生类型
type op string
// 别名
//type op =string
const (
opEq op = "="
opNot op = "NOT"
opAnd op = "AND"
opOr op = "OR"
)
func (o op) String() string {
return string(o)
}
type Predicate struct {
left Expression
op op
right Expression
}
// Eq("id",12)
// Eq(sub,"id",12)
// Eq(sub.id,12)
// Eq("sub.id",12)
//func Eq(column string, arg any) Predicate {
// return Predicate{
// Column: column,
// Op: "=",
// Arg: arg,
// }
//}
type Column struct {
name string
}
func C(name string) Column {
return Column{name: name}
}
// C("id").Eq(12)
// sub.C("id").Eq(12)
func (c Column) Eq(arg any) Predicate {
return Predicate{
left: c,
op: opEq,
right: value{val: arg},
}
}
func (Column) expr() {}
// Not Not(C("name").Eq("Tom"))
func Not(p Predicate) Predicate {
return Predicate{
op: opNot,
right: p,
}
}
// And C("id").Eq(12).And(C("name").Eq("Tom"))
func (left Predicate) And(right Predicate) Predicate {
return Predicate{
left: left,
op: opAnd,
right: right,
}
}
// And C("id").Eq(12).Or(C("name").Eq("Tom"))
func (left Predicate) Or(right Predicate) Predicate {
return Predicate{
left: left,
op: opOr,
right: right,
}
}
func (Predicate) expr() {
}
// Expression 是一个标记接口,代表表达式
type Expression interface {
expr()
}
type value struct {
val any
}
func (value) expr() {}
// select_test.go
package orm
import (
"database/sql"
"github.com/stretchr/testify/assert"
"testing"
)
func TestSelector_Build(t *testing.T) {
testCases := []struct {
name string
builder QueryBuilder
wantQuery *Query
wantErr error
}{
// ...
{
name: "where",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18)),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel` WHERE `Age` = ?;",
Args: []any{18},
},
},
{
name: "empty where",
builder: (&Selector[TestModel]{}).Where(),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel`;",
Args: nil,
},
},
{
name: "not",
builder: (&Selector[TestModel]{}).Where(Not(C("Age").Eq(18))),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel` WHERE NOT (`Age` = ?);",
Args: []any{18},
},
},
{
name: "and",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel` WHERE (`Age` = ?) AND (`FirstName` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "and",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `TestModel` WHERE (`Age` = ?) OR (`FirstName` = ?);",
Args: []any{18, "Tom"},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.builder.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
元数据
9. 元数据简介
10. 元数据:反射-读字段
构建中间件必备: 反射 unsafe技术 抽象语法树ast 模板(代码生成等)
// field_test.go
package reflect
import (
"errors"
"github.com/stretchr/testify/assert"
"testing"
)
func TestIterateFields(t *testing.T) {
type User struct {
Name string
age int
}
testCases := []struct {
name string
entity any
wantErr error
wantRes map[string]any
}{
{
name: "struct",
entity: User{
Name: "Tom",
age: 18,
},
wantRes: map[string]any{
"Name": "Tom",
// 私有的无法反射获取,用0值填充
"age": 0,
},
},
{
name: "pointer",
entity: &User{
Name: "Tom",
age: 18,
},
wantRes: map[string]any{
"Name": "Tom",
// 私有的无法反射获取,用0值填充
"age": 0,
},
},
{
name: "pointer",
entity: 18,
wantErr: errors.New("不支持类型"),
},
{
name: "multiple pointer",
entity: func() **User {
res := &User{
Name: "Tom",
age: 18,
}
return &res
}(),
wantRes: map[string]any{
"Name": "Tom",
// 私有的无法反射获取,用0值填充
"age": 0,
},
},
{
name: "multiple pointer",
entity: nil,
wantErr: errors.New("不支持 nil"),
},
{
name: "multiple pointer",
// 给nil转为User类型,通过了nil和非struct检查
entity: (*User)(nil),
wantErr: errors.New("不支持零值"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, err := IterateFields(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantRes, res)
})
}
}
// fields.go
package reflect
import (
"errors"
"reflect"
)
// IterateFields 遍历字段
func IterateFields(entity any) (map[string]any, error) { // 作者习惯: 公共方法都有一个error返回值
if entity == nil {
return nil, errors.New("不支持 nil")
}
typ := reflect.TypeOf(entity)
val := reflect.ValueOf(entity)
if val.IsZero(){
// 零值取Field会报错
return nil,errors.New("不支持零值")
}
// 如果不是结构体
//if typ.Kind() != reflect.Struct {
// 指针
// for是为了防止多级指针
for typ.Kind() == reflect.Pointer {
typ = typ.Elem()
// val.Field(i) 的val必须是struct
val = val.Elem()
}
if typ.Kind() != reflect.Struct {
return nil, errors.New("不支持类型")
}
// struct有多少字段
numFields := typ.NumField()
res := make(map[string]any, numFields)
for i := 0; i < numFields; i++ {
// 字段类型
fieldType := typ.Field(i)
// 字段值
fieldVal := val.Field(i)
// public的字段才能被反射获取
if fieldType.IsExported() {
// 反射获取的值被封装为Value,.Interface()就是获取Value里真正意义上的值
res[fieldType.Name] = fieldVal.Interface()
} else {
// 获取不到的值就给零值
res[fieldType.Name] = reflect.Zero(fieldType.Type).Interface()
}
}
return res, nil
}
11. 元数据:反射-写字段
// fields_test.go
func TestSetField(t *testing.T) {
testCases := []struct {
name string
entity any
field string
newValue any
wantErr error
// 修改后的entity
wantEntity any
}{
{
name: "struct",
entity: User{Name: "Tom"},
field: "Name",
newValue: "Jerry",
wantErr: errors.New("不可修改字段"),
wantEntity: User{Name: "Tom"},
},
{
name: "pointer",
entity: &User{Name: "Tom"},
field: "Name",
newValue: "Jerry",
//wantErr: errors.New("不可修改字段"),
wantEntity: &User{Name: "Jerry"},
},
{
name: "pointer exported",
entity: &User{age: 18},
field: "age",
newValue: 16,
wantErr: errors.New("不可修改字段"),
wantEntity: &User{age: 18},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := SetField(tc.entity, tc.field, tc.newValue)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantEntity, tc.entity)
})
}
}
func TestSet(t *testing.T) {
var i = 0
// panic: unaddressable value
//reflect.ValueOf(i).Set(reflect.ValueOf(12))
ptr := &i
reflect.ValueOf(ptr).Elem().Set(reflect.ValueOf(12))
assert.Equal(t, 12, i)
}
// fields.go
func SetField(entity any, field string, newValue any) error {
val := reflect.ValueOf(entity)
// 拿到指针指向的结构体
for val.Type().Kind() == reflect.Pointer {
val = val.Elem()
}
fieldVal := val.FieldByName(field)
// 能否被修改
if !fieldVal.CanSet() {
return errors.New("不可修改字段")
}
fieldVal.Set(reflect.ValueOf(newValue))
return nil
}
12. 元数据:反射-方法
// func_call_test.go
package reflect
import (
"github.com/stretchr/testify/assert"
"my_framework/orm/reflect/types"
"reflect"
"testing"
)
func TestIterateFunc(t *testing.T) {
testCases := []struct {
name string
entity any
wantRes map[string]FuncInfo
wantErr error
}{
{
name: "struct",
entity: types.NewUser("Tom", 18),
wantRes: map[string]FuncInfo{
// 结构体只能访问定义在struct上的
"GetAge": {
Name: "GetAge",
InputTypes: []reflect.Type{reflect.TypeOf(types.User{})},
OutputTypes: []reflect.Type{reflect.TypeOf(0)},
Result: []any{18},
},
// 定义在指针上,无法获取到
//"ChangeName": {
// Name: "ChangeName",
// InputTypes: []reflect.Type{reflect.TypeOf("")},
//},
},
},
{
name: "struct",
entity: types.NewUserPtr("Tom", 18),
wantRes: map[string]FuncInfo{
"GetAge": {
Name: "GetAge",
InputTypes: []reflect.Type{reflect.TypeOf(&types.User{})},
OutputTypes: []reflect.Type{reflect.TypeOf(0)},
Result: []any{18},
},
// 定义在指针上,无法获取到
"ChangeName": {
Name: "ChangeName",
InputTypes: []reflect.Type{reflect.TypeOf(&types.User{}), reflect.TypeOf("")},
OutputTypes: []reflect.Type{},
Result: []any{},
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, err := IterateFunc(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantRes, res)
})
}
}
// func_call.go
package reflect
import "reflect"
func IterateFunc(entity any) (map[string]FuncInfo, error) {
typ := reflect.TypeOf(entity)
numMethod := typ.NumMethod()
res := make(map[string]FuncInfo, numMethod)
for i := 0; i < numMethod; i++ {
method := typ.Method(i)
fn := method.Func
numIn := fn.Type().NumIn()
input := make([]reflect.Type, 0, numIn)
inputVals := make([]reflect.Value, 0, numIn)
// 方法的第一个参数是该方法所属结构体
input = append(input, reflect.TypeOf(entity))
inputVals = append(inputVals, reflect.ValueOf(entity))
for i := 1; i < numIn; i++ {
fnInType := fn.Type().In(i)
input = append(input, fnInType)
inputVals = append(inputVals, reflect.Zero(fnInType))
}
numOut := fn.Type().NumOut()
output := make([]reflect.Type, 0, numOut)
for i := 0; i < numOut; i++ {
output = append(output, fn.Type().Out(i))
}
resVals := fn.Call(inputVals)
result := make([]any, 0, len(resVals))
for _, v := range resVals {
result = append(result, v.Interface())
}
res[method.Name] = FuncInfo{
Name: method.Name,
InputTypes: input,
OutputTypes: output,
Result: result,
}
}
return res, nil
}
type FuncInfo struct {
Name string
InputTypes []reflect.Type
OutputTypes []reflect.Type
Result []any
}
// orm/types/user.go
package types
import "fmt"
type User struct {
Name string
age int
}
func NewUser(name string, age int) User {
return User{
Name: name,
age: age,
}
}
func NewUserPtr(name string, age int) *User {
return &User{
Name: name,
age: age,
}
}
func (u User) GetAge() int {
return u.age
}
func (u *User) ChangeName(newName string) {
u.Name = newName
}
func (u User) private() {
fmt.Println("private")
}
13. 元数据:反射-遍历
// iterate_test.go
package reflect
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestIterateArray(t *testing.T) {
testCases := []struct {
name string
entity any
wantVals []any
wantErr error
}{
{
name: "arr",
entity: [3]int{1, 2, 3},
wantVals: []any{1, 2, 3},
},
{
name: "slice",
entity: []int{1, 2, 3},
wantVals: []any{1, 2, 3},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
vals, err := IterateArrayOrSlice(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantVals, vals)
})
}
}
func TestIterateMap(t *testing.T) {
tests := []struct {
name string
entity any
wantKeys []any
wantValues []any
wantErr error
}{
{
name: "map",
entity: map[string]string{
"A": "a",
"B": "b",
},
wantKeys: []any{"A", "B"},
wantValues: []any{"a", "b"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// go map 的遍历, key是无序的
keys, values, err := IterateMap(tt.entity)
assert.Equal(t, tt.wantErr, err)
if err != nil {
return
}
assert.EqualValues(t, tt.wantKeys, keys)
assert.EqualValues(t, tt.wantValues, values)
})
}
}
// iterate.go
package reflect
import "reflect"
func IterateArrayOrSlice(entity any) ([]any, error) {
val := reflect.ValueOf(entity)
res := make([]any, 0, val.Len())
for i := 0; i < val.Len(); i++ {
ele := val.Index(i)
res = append(res, ele.Interface())
}
return res, nil
}
// keys,vals,err
func IterateMap(entity any) ([]any, []any, error) {
val := reflect.ValueOf(entity)
resKeys := make([]any, 0, val.Len())
resValues := make([]any, 0, val.Len())
// 方式1
itr := val.MapRange()
for itr.Next() {
resKeys = append(resKeys, itr.Key().Interface())
resValues = append(resValues, itr.Value().Interface())
}
// 方式2
//keys := val.MapKeys()
//for _, key := range keys {
// v := val.MapIndex(key)
// resKeys = append(resKeys, key.Interface())
// resValues = append(resValues, v.Interface())
//}
return resKeys, resValues, nil
}
14. 元数据:反射的开源实例、面试要点总结
15. 元数据:反射解析模型
// orm/internal/errs/error.go
// internal 包里的文件不会被用户访问到
package errs
import (
"errors"
"fmt"
)
var (
ErrPointerOnly = errors.New("orm: 只支持指向结构体的一级指针")
)
func NewErrUnsupportedExpression(expr any) error {
return fmt.Errorf("orm: 不支持的表达式类型 %v", expr)
}
// model_test.go
package orm
import (
"github.com/stretchr/testify/assert"
"my_framework/orm/internal/errs"
"testing"
)
func Test_parseModel(t *testing.T) {
tests := []struct {
name string
entity any
wantModel *model
wantErr error
}{
{
name: "struct",
entity: TestModel{},
//wantModel: &model{
// tableName: "test_model",
// fields: map[string]*field{
// "Id": {
// colName: "id",
// },
// "FirstName": {
// colName: "first_name",
// },
// "LastName": {
// colName: "last_name",
// },
// "Age": {
// colName: "age",
// },
// },
//},
wantErr: errs.ErrPointerOnly,
},
{
name: "pointer",
entity: &TestModel{},
wantModel: &model{
tableName: "test_model",
fields: map[string]*field{
"Id": {
colName: "id",
},
"FirstName": {
colName: "first_name",
},
"LastName": {
colName: "last_name",
},
"Age": {
colName: "age",
},
},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m, err := parseModel(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantModel, m)
})
}
}
// model.go
package orm
import (
"my_framework/orm/internal/errs"
"reflect"
"unicode"
)
type model struct {
tableName string
fields map[string]*field
}
type field struct {
// 列名
colName string
}
func parseModel(entity any) (*model, error) {
typ := reflect.TypeOf(entity)
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
return nil, errs.ErrPointerOnly
}
// 指针
// 可以处理多级指针
//for typ.Kind() == reflect.Pointer {
// 只能处理一级指针
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
numField := typ.NumField()
fieldMap := make(map[string]*field, numField)
for i := 0; i < numField; i++ {
fd := typ.Field(i)
fieldMap[fd.Name] = &field{
colName: underscoreName(fd.Name),
}
}
return &model{
tableName: underscoreName(typ.Name()),
fields: fieldMap,
}, nil
}
// underscoreName 驼峰命名转下划线
func underscoreName(tableName string) string {
var buf []byte
for i, v := range tableName {
if unicode.IsUpper(v) {
if i != 0 {
buf = append(buf, '_')
}
buf = append(buf, byte(unicode.ToLower(v)))
} else {
buf = append(buf, byte(v))
}
}
return string(buf)
}
16. 元数据:利用元数据改造 Selector、元数据阶段总结
// select.go
package orm
import (
"context"
"my_framework/orm/internal/errs"
"strings"
)
type Selector[T any] struct {
table string
model *model
where []Predicate
sb *strings.Builder
args []any
}
func (s *Selector[T]) Build() (*Query, error) {
//var sb strings.Builder
s.sb = &strings.Builder{} // 指针类型一定要初始化
var err error
s.model, err = parseModel(new(T)) // new(T) 生成T的指针
if err != nil {
return nil, err
}
sb := s.sb
sb.WriteString("SELECT * FROM ")
if s.table == "" {
sb.WriteByte('`')
sb.WriteString(s.model.tableName)
sb.WriteByte('`')
} else {
//segs := strings.Split(s.table, ".")
//sb.WriteByte('`')
sb.WriteString(s.table)
//sb.WriteByte('`')
}
if len(s.where) > 0 {
// ...
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: s.args,
}, nil
}
func (s *Selector[T]) buildExpression(expr Expression) error {
switch exp := expr.(type) {
case nil:
case Predicate:
// ...
case Column:
fd, ok := s.model.fields[exp.name]
// 传入了错误的字段
if !ok {
return errs.NewErrUnknownField(exp.name)
}
s.sb.WriteByte('`')
s.sb.WriteString(fd.colName)
s.sb.WriteByte('`')
//
case value:
s.sb.WriteByte('?')
s.addArg(exp.val)
//
default:
return errs.NewErrUnsupportedExpression(expr)
}
return nil
}
// ...
// select_test.go
package orm
import (
"database/sql"
"github.com/stretchr/testify/assert"
"my_framework/orm/internal/errs"
"testing"
)
func TestSelector_Build(t *testing.T) {
testCases := []struct {
name string
builder QueryBuilder
wantQuery *Query
wantErr error
}{
{
name: "not from",
builder: &Selector[TestModel]{},
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "from",
builder: (&Selector[TestModel]{}).Table("`test_model`"),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty from",
builder: (&Selector[TestModel]{}),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty table",
builder: (&Selector[TestModel]{}).Table(""),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "db table",
builder: (&Selector[TestModel]{}).Table("`mybatis`.`test`"),
wantQuery: &Query{
SQL: "SELECT * FROM `mybatis`.`test`;",
Args: nil,
},
},
{
name: "where",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18)),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE `age` = ?;",
Args: []any{18},
},
},
{
name: "empty where",
builder: (&Selector[TestModel]{}).Where(),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "not",
builder: (&Selector[TestModel]{}).Where(Not(C("Age").Eq(18))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE NOT (`age` = ?);",
Args: []any{18},
},
},
{
name: "and",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`first_name` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "and",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) OR (`first_name` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "invalid column",
builder: (&Selector[TestModel]{}).Where(C("Age").Eq(18).Or(C("XXXX").Eq("Tom"))),
wantErr: errs.NewErrUnknownField("XXXX"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.builder.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
作业
18.第四周作业:DELETE 语句[选学]
19.第四周 DELETE 作业讲解[选学]
// builder.go
package orm
import (
"my_framework/orm/internal/errs"
"strings"
)
type builder struct {
sb *strings.Builder
model *model
args []any
}
func (s *builder ) buildExpression(expr Expression) error {
switch exp := expr.(type) {
case nil:
case Predicate:
// 处理p
_, ok := exp.left.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.left); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
s.sb.WriteByte(' ')
s.sb.WriteString(exp.op.String())
s.sb.WriteByte(' ')
_, ok = exp.right.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.right); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
case Column:
fd, ok := s.model.fields[exp.name]
// 传入了错误的字段
if !ok {
return errs.NewErrUnknownField(exp.name)
}
s.sb.WriteByte('`')
s.sb.WriteString(fd.colName)
s.sb.WriteByte('`')
//
case value:
s.sb.WriteByte('?')
s.addArg(exp.val)
//
default:
return errs.NewErrUnsupportedExpression(expr)
}
return nil
}
func (s *builder) addArg(val any) *builder {
if s.args == nil {
s.args = make([]any, 0, 8)
}
s.args = append(s.args, val)
return s
}
func (b *builder) buildPredicates(ps []Predicate) error {
p := ps[0]
for i := 1; i < len(ps); i++ {
p = p.And(ps[i])
}
return b.buildExpression(p)
}
// delete.go
package orm
import (
"strings"
)
type Deleter[T any] struct {
builder
where []Predicate
table string
}
func (d *Deleter[T]) Build() (*Query, error) {
//var sb strings.Builder
d.sb = &strings.Builder{} // 指针类型一定要初始化
var err error
d.model, err = parseModel(new(T)) // new(T) 生成T的指针
if err != nil {
return nil, err
}
sb := d.sb
sb.WriteString("DELETE * FROM ")
if d.table == "" {
sb.WriteByte('`')
sb.WriteString(d.model.tableName)
sb.WriteByte('`')
} else {
//segs := strings.Split(d.table, ".")
//sb.WriteByte('`')
sb.WriteString(d.table)
//sb.WriteByte('`')
}
if len(d.where) > 0 {
sb.WriteString(" WHERE ")
if err := d.buildPredicates(d.where); err != nil {
return nil, err
}
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: d.args,
}, nil
}
func (d *Deleter[T]) From(table string) *Deleter[T] {
return d
}
func (d *Deleter[T]) Where(predicate ...Predicate) *Deleter[T] {
return d
}
// delete_test.go
package orm
import (
"github.com/stretchr/testify/assert"
"testing"
)
func TestDeleter_Build(t *testing.T) {
testCases := []struct {
name string
builder QueryBuilder
wantQuery *Query
wantErr error
}{
{
name: "struct",
builder: &Deleter[TestModel]{},
wantQuery: &Query{
SQL: "DELETE * FROM `test_model`;",
Args: nil,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.builder.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
11 第五周:ORM 框架之元数据、SQL 编程与结果集处理
元数据
1. 元数据:注册中心
// select.go
type Selector[T any] struct {
builder
table string
where []Predicate
db *DB
}
func NewSelector[T any](db *DB) *Selector[T] {
return &Selector[T]{
db: db,
}
}
// select_test.go
package orm
import (
"database/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"my_framework/orm/internal/errs"
"testing"
)
func TestSelector_Build(t *testing.T) {
db, err := NewDB()
require.NoError(t, err)
testCases := []struct {
name string
builder QueryBuilder
wantQuery *Query
wantErr error
}{
{
name: "not from",
builder: NewSelector[TestModel](db),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "from",
builder: (NewSelector[TestModel](db)).Table("`test_model`"),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty from",
builder: (NewSelector[TestModel](db)),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty table",
builder: (NewSelector[TestModel](db)).Table(""),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "db table",
builder: (NewSelector[TestModel](db)).Table("`mybatis`.`test`"),
wantQuery: &Query{
SQL: "SELECT * FROM `mybatis`.`test`;",
Args: nil,
},
},
{
name: "where",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18)),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE `age` = ?;",
Args: []any{18},
},
},
{
name: "empty where",
builder: (NewSelector[TestModel](db)).Where(),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "not",
builder: (NewSelector[TestModel](db)).Where(Not(C("Age").Eq(18))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE NOT (`age` = ?);",
Args: []any{18},
},
},
{
name: "and",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`first_name` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "and",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) OR (`first_name` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "invalid column",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("XXXX").Eq("Tom"))),
wantErr: errs.NewErrUnknownField("XXXX"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.builder.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
// db.go
package orm
type DBOption func(db *DB)
type DB struct {
r *registry
}
func NewDB(opts ...DBOption) (*DB, error) {
res := &DB{
r: newRegistry(),
}
for _, opt := range opts {
opt(res)
}
return res, nil
}
func MustNewDB(opts ...DBOption) *DB {
res, err := NewDB(opts...)
if err != nil {
panic(err)
}
return res
}
// model_test.go
package orm
import (
"github.com/stretchr/testify/assert"
"my_framework/orm/internal/errs"
"reflect"
"testing"
)
func Test_parseModel(t *testing.T) {
tests := []struct {
name string
entity any
wantModel *model
wantErr error
}{
{
name: "struct",
entity: TestModel{},
wantErr: errs.ErrPointerOnly,
},
{
name: "pointer",
entity: &TestModel{},
wantModel: &model{
tableName: "test_model",
fields: map[string]*field{
"Id": {
colName: "id",
},
"FirstName": {
colName: "first_name",
},
"LastName": {
colName: "last_name",
},
"Age": {
colName: "age",
},
},
},
},
}
r := ®istry{}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m, err := r.parseModel(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantModel, m)
})
}
}
func TestRegistry_Get(t *testing.T) {
testCases := []struct {
name string
entity any
wantModel *model
wantErr error
cacheSize int
}{
{
name: "pointer",
entity: &TestModel{},
wantModel: &model{
tableName: "test_model",
fields: map[string]*field{
"Id": {
colName: "id",
},
"FirstName": {
colName: "first_name",
},
"LastName": {
colName: "last_name",
},
"Age": {
colName: "age",
},
},
},
cacheSize: 1,
},
}
r := newRegistry()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m, err := r.get(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantModel, m)
// 只检测了数量
assert.Equal(t, tc.cacheSize, len(r.models))
typ := reflect.TypeOf(tc.entity)
m, ok := r.models[typ]
assert.True(t, ok)
assert.Equal(t, tc.wantModel, m)
})
}
}
// model.go
// 元数据注册中心
type registry struct {
models map[reflect.Type]*model
}
func newRegistry() *registry {
return ®istry{
models: make(map[reflect.Type]*model, 64),
}
}
func (r *registry) get(val any) (*model, error) {
typ := reflect.TypeOf(val)
m, ok := r.models[typ]
if !ok {
var err error
m, err = r.parseModel(val)
if err != nil {
return nil, err
}
r.models[typ] = m
}
return m, nil
}
2. 元数据:注册中心并发问题
// model.go
func (r *registry) get(val any) (*model, error) {
typ := reflect.TypeOf(val)
m, ok := r.models.Load(typ)
if ok {
return m.(*model), nil
}
m, err := r.parseModel(val)
if err != nil {
return nil, err
}
r.models.Store(typ, m)
return m.(*model), nil
}
3. 元数据:标签自定义列名
// model_test.go
package orm
import (
"github.com/stretchr/testify/assert"
"my_framework/orm/internal/errs"
"reflect"
"testing"
)
func Test_parseModel(t *testing.T) {
tests := []struct {
name string
entity any
wantModel *model
wantErr error
}{
{
name: "struct",
entity: TestModel{},
wantErr: errs.ErrPointerOnly,
},
{
name: "pointer",
entity: &TestModel{},
wantModel: &model{
tableName: "test_model",
fields: map[string]*field{
"Id": {
colName: "id",
},
"FirstName": {
colName: "first_name",
},
"LastName": {
colName: "last_name",
},
"Age": {
colName: "age",
},
},
},
},
}
r := ®istry{}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m, err := r.parseModel(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantModel, m)
})
}
}
func TestRegistry_Get(t *testing.T) {
testCases := []struct {
name string
entity any
wantModel *model
wantErr error
}{
{
name: "pointer",
entity: &TestModel{},
wantModel: &model{
tableName: "test_model",
fields: map[string]*field{
"Id": {
colName: "id",
},
"FirstName": {
colName: "first_name",
},
"LastName": {
colName: "last_name",
},
"Age": {
colName: "age",
},
},
},
},
{
name: "tag",
entity: func() any {
type TagTable struct {
FirstName string `orm:"column=first_name_t"`
}
return &TagTable{}
}(),
wantModel: &model{
tableName: "tag_table",
fields: map[string]*field{
"FirstName": {
colName: "first_name_t",
},
},
},
},
{
name: "empty tag",
entity: func() any {
type TagTable struct {
FirstName string `orm:"column="`
}
return &TagTable{}
}(),
wantModel: &model{
tableName: "tag_table",
fields: map[string]*field{
"FirstName": {
colName: "first_name",
},
},
},
},
{
name: "column only",
entity: func() any {
type TagTable struct {
FirstName string `orm:"column"`
}
return &TagTable{}
}(),
wantErr: errs.NewErrInvalidTagContent("column"),
},
{
name: "invalid tag",
entity: func() any {
type TagTable struct {
FirstName string `orm:"abc=abc"`
}
return &TagTable{}
}(),
wantModel: &model {
tableName: "tag_table",
fields: map[string]*field{
"FirstName": {
colName: "first_name",
},
},
},
},
}
r := newRegistry()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m, err := r.get(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantModel, m)
typ := reflect.TypeOf(tc.entity)
cache, ok := r.models.Load(typ)
assert.True(t, ok)
assert.Equal(t, tc.wantModel, cache)
})
}
}
// model.go
// 解析创建model元数据
func (r *registry) parseModel(entity any) (*model, error) {
typ := reflect.TypeOf(entity)
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
return nil, errs.ErrPointerOnly
}
// 指针
// 可以处理多级指针
//for typ.Kind() == reflect.Pointer {
// 只能处理一级指针
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
numField := typ.NumField()
fieldMap := make(map[string]*field, numField)
for i := 0; i < numField; i++ {
fd := typ.Field(i)
// 取tag `xxx:xxx=xxx`
pair, err := r.parseTag(fd.Tag)
if err != nil {
return nil, err
}
// 取column: `orm:"column"=xxx`
colName := pair[tagKeyColumn]
if colName == "" {
// 用户没有设置,用字段名
colName = underscoreName(fd.Name)
}
fieldMap[fd.Name] = &field{
//colName: underscoreName(fd.Name),
colName: colName,
}
}
return &model{
tableName: underscoreName(typ.Name()),
fields: fieldMap,
}, nil
}
type User struct {
ID uint64 `orm:"column=id,xxx=bbb"`
}
func (r *registry) parseTag(tag reflect.StructTag) (map[string]string, error) {
ormTag, ok := tag.Lookup("orm")
if !ok {
return map[string]string{}, nil
}
// orm:"xxx=xxx,xxx=xxx"
pairs := strings.Split(ormTag, ",")
res := make(map[string]string, len(pairs))
for _, pair := range pairs {
segs := strings.Split(pair, "=")
if len(segs) != 2 {
return nil, errs.NewErrInvalidTagContent(pair)
}
key := segs[0]
val := segs[1]
res[key] = val
}
return res, nil
}
4. 元数据:接口自定义表名
// model_test.go
func TestRegistry_Get(t *testing.T) {
testCases := []struct {
name string
entity any
wantModel *model
wantErr error
}{
// ...
{
name: "custom table name",
entity: &CustomTableName{},
wantModel: &model{
tableName: "custom_table_name_t",
fields: map[string]*field{
"FirstName": {
colName: "first_name",
},
},
},
},
{
name: "custom table name ptr",
entity: &CustomTableNamePtr{},
wantModel: &model{
tableName: "custom_table_name_ptr_t",
fields: map[string]*field{
"FirstName": {
colName: "first_name",
},
},
},
},
{
name: "custom table name empty ptr",
entity: &CustomTableNameEmpty{},
wantModel: &model{
tableName: "custom_table_name_empty",
fields: map[string]*field{
"FirstName": {
colName: "first_name",
},
},
},
},
}
r := newRegistry()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m, err := r.get(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantModel, m)
typ := reflect.TypeOf(tc.entity)
cache, ok := r.models.Load(typ)
assert.True(t, ok)
assert.Equal(t, tc.wantModel, cache)
})
}
}
type CustomTableName struct {
FirstName string
}
func (c CustomTableName) TableName() string {
return "custom_table_name_t"
}
type CustomTableNamePtr struct {
FirstName string
}
func (c *CustomTableNamePtr) TableName() string {
return "custom_table_name_ptr_t"
}
type CustomTableNameEmpty struct {
FirstName string
}
func (c CustomTableNameEmpty) TableName() string {
return ""
}
// model.go
func (r *registry) parseModel(entity any) (*model, error) {
// ...
var tableName string
if tbl, ok := entity.(TableName); ok {
tableName = tbl.TableName()
}
if tableName == "" {
tableName = underscoreName(typ.Name())
}
return &model{
tableName: tableName,
fields: fieldMap,
}, nil
}
// types.go
type TableName interface {
TableName() string
}
5. 元数据:编程方式自定义表名和列名
// model.go
func (r *registry) Get(val any) (*Model, error) {
typ := reflect.TypeOf(val)
m, ok := r.models.Load(typ)
if ok {
return m.(*Model), nil
}
m, err := r.Register(val)
if err != nil {
return nil, err
}
return m.(*Model), nil
}
func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
typ := reflect.TypeOf(entity)
if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct {
return nil, errs.ErrPointerOnly
}
// 指针
// 可以处理多级指针
//for elemType.Kind() == reflect.Pointer {
// 只能处理一级指针
//if typ.Kind() == reflect.Pointer {
elemType := typ.Elem()
//}
numField := elemType.NumField()
fieldMap := make(map[string]*Field, numField)
for i := 0; i < numField; i++ {
fd := elemType.Field(i)
// 取tag `xxx:xxx=xxx`
pair, err := r.parseTag(fd.Tag)
if err != nil {
return nil, err
}
// 取column: `orm:"column"=xxx`
colName := pair[tagKeyColumn]
if colName == "" {
// 用户没有设置,用字段名
colName = underscoreName(fd.Name)
}
fieldMap[fd.Name] = &Field{
//colName: underscoreName(fd.Name),
colName: colName,
}
}
var tableName string
if tbl, ok := entity.(TableName); ok {
tableName = tbl.TableName()
}
if tableName == "" {
tableName = underscoreName(elemType.Name())
}
res := &Model{
tableName: tableName,
fields: fieldMap,
}
for _, opt := range opts {
err := opt(res)
if err != nil {
return nil, err
}
}
r.models.Store(typ, res)
return res, nil
}
// model_test.go
func TestModelWithTableName(t *testing.T) {
r := newRegistry()
m, err := r.Register(&TestModel{}, ModelWithTableName("test_model_ttt"))
require.NoError(t, err)
assert.Equal(t, "test_model_ttt", m.tableName)
}
func TestModelWithColName(t *testing.T) {
testCases := []struct {
name string
field string
colName string
wantColName string
wantErr error
}{
{
name: "column name",
field: "FirstName",
colName: "first_name_ccc",
wantColName: "first_name_ccc",
},
{
name: "invalid column name",
field: "XXX",
colName: "first_name_ccc",
wantErr: errs.NewErrUnknownField("XXX"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := newRegistry()
m, err := r.Register(&TestModel{}, ModelWithColumnName(tc.field, tc.colName))
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
fd, ok := m.fields[tc.field]
require.True(t, ok)
assert.Equal(t, tc.wantColName, fd.colName)
})
}
}
6. 元数据:总结与面试要点
sql编程
7. SQL 编程:增删改查
// crud_test.go
package sql_demo
import (
"context"
"database/sql"
"log"
"time"
//_ "github.com/mattn/go-sqlite3"
_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/require"
"testing"
)
func TestDB(t *testing.T) {
db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
defer db.Close()
err = db.Ping()
if err != nil {
return
}
// 可以用db
//sql.OpenDB()
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
// 除了select, 都是用的 ExecContext
//_, err = db.ExecContext(ctx, `DROP TABLE IF EXISTS test_model;`)
_, err = db.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS test_model(
id int AUTO_INCREMENT PRIMARY KEY,
first_name TEXT NOT NULL,
age int,
last_name TEXT NOT NULL
)
`)
//cancel()
require.NoError(t, err)
// INSERT 插入
// 使用?作为占位符,防止sql注入
//res, err := db.ExecContext(ctx, "INSERT INTO test_model(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?)",
res, err := db.ExecContext(ctx, "INSERT INTO test_model(`first_name`,`age`,`last_name`) VALUES (?,?,?)",
//1, "Tom", 18, "Jerry")
"Tom", 18, "Jerry")
// 全都执行完再cancel
//cancel()
require.NoError(t, err)
affected, err := res.RowsAffected()
require.NoError(t, err)
log.Println("受影响行数: ", affected)
lastId, err := res.LastInsertId()
require.NoError(t, err)
// auto_increment才会返回正确的值,如果不是主键自增或者插入失败就会返回0
log.Println("最后插入的id: ", lastId)
// SELECT 查询
row := db.QueryRowContext(ctx, "SELECT * FROM `test_model` WHERE `id`=?", 1)
require.NoError(t, row.Err())
tm := TestModel{}
//// 如果不放心,就把select *换成字段名,然后这里传入字段对应的指针&tm.xxx
err = row.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
require.NoError(t, err)
log.Println(tm)
row = db.QueryRowContext(ctx, "SELECT * FROM `test_model` WHERE `id`=?", 100)
// 查询不到
//log.Println(row.Err())
//require.Error(t, sql.ErrNoRows, row.Err())
err = row.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
log.Println(err)
require.Error(t, sql.ErrNoRows, err)
rows, err := db.QueryContext(ctx, "SELECT * FROM `test_model` WHERE `id`>?", 1)
for rows.Next() {
tm = TestModel{}
err = rows.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
require.NoError(t, err)
log.Println(tm)
}
// 最后再用
defer cancel()
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
8. SQL 编程:Valuer 和 Scanner 接口
// json_test.go
package sql_demo
import (
"github.com/stretchr/testify/assert"
"testing"
)
type User struct {
Name string
}
func TestJsonColumn_Value(t *testing.T) {
js := JsonColumn[User]{Valid: true, Val: User{Name: "Tom"}}
value, err := js.Value()
assert.Nil(t, err)
assert.Equal(t, []byte(`{"Name":"Tom"}`), value)
js = JsonColumn[User]{}
value, err = js.Value()
assert.Nil(t, err)
assert.Nil(t, value)
}
func TestJsonColumn_Scan(t *testing.T) {
testCases := []struct {
name string
src any
wantErr error
wantVal User
valid bool
}{
{
name: "nil",
},
{
name: "string",
src: `{"Name":"Tom"}`,
wantVal: User{Name: "Tom"},
valid: true,
},
{
name: "bytes",
src: []byte(`{"Name":"Tom"}`),
wantVal: User{Name: "Tom"},
valid: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
js := &JsonColumn[User]{}
err := js.Scan(tc.src)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantVal, js.Val)
assert.Equal(t, tc.valid, js.Valid)
})
}
}
// json.go
package sql_demo
import (
"database/sql/driver"
"errors"
"github.com/goccy/go-json"
)
type JsonColumn[T any] struct {
Val T
// not NULL
Valid bool
}
func (j JsonColumn[T]) Value() (driver.Value, error) {
// 如果查出来是null
if !j.Valid {
return nil, nil
}
return json.Marshal(j.Val)
}
func (j *JsonColumn[T]) Scan(src any) error {
var bs []byte
switch data := src.(type) {
case string:
// 可以考虑额外处理空字符串
bs = []byte(data)
case []byte:
// 可以考虑处理[]byte{}
bs = data
case nil:
// 说明数据库里没数据(null)
return nil
default:
return errors.New("不支持类型")
}
err := json.Unmarshal(bs, &j.Val)
if err == nil {
j.Valid = true
}
return err
}
可以用这种方法给字段加解密,因为是在代码层面上,不会给数据库造成什么压力(如果是让数据库加密就会)
9. SQL 编程:事务与隔离级别
// crud_test.go
func TestTranslation(t *testing.T) {
db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
defer db.Close()
err = db.Ping()
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
tx, err := db.BeginTx(ctx, &sql.TxOptions{})
require.NoError(t, err)
res, err := tx.ExecContext(ctx, "INSERT INTO test_model(`first_name`,`age`,`last_name`) VALUES (?,?,?)",
//1, "Tom", 18, "Jerry")
"Tom", 18, "Jerry")
if err != nil {
// 回滚
err := tx.Rollback()
if err != nil {
log.Println(err)
return
}
}
// 提交事务
err = tx.Commit()
require.NoError(t, err)
affected, err := res.RowsAffected()
require.NoError(t, err)
log.Println("受影响行数: ", affected)
lastId, err := res.LastInsertId()
require.NoError(t, err)
// auto_increment才会返回正确的值,如果不是主键自增或者插入失败就会返回0
log.Println("最后插入的id: ", lastId)
}
10. SQL 编程:Prepare Statement
// crud_test.go
func TestPS(t *testing.T) {
db, err := sql.Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
defer db.Close()
err = db.Ping()
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
stmt, err := db.PrepareContext(ctx, "SELECT * FROM `test_model` WHERE `id`=?")
require.NoError(t, err)
rows, err := stmt.QueryContext(ctx, 1)
require.NoError(t, err)
for rows.Next() {
tm := TestModel{}
err = rows.Scan(&tm.Id, &tm.FirstName, &tm.Age, &tm.LastName)
require.NoError(t, err)
log.Println(tm)
}
// 如果是in,就无法复用
// 如果定义很多个prepareStatement,到达数据库上限会报错
//stmt, err = db.PrepareContext(ctx, "SELECT * FROM `test_model` WHERE `id` in (?,?)")
//stmt, err = db.PrepareContext(ctx, "SELECT * FROM `test_model` WHERE `id` in (?,?,?)")
// 整个应用关闭的时候调用
// 什么时候可以被关?有可能正在使用 -> 1.引用计数 2.一次性ps
stmt.Close()
}
11. SQL 编程:sqlmock 入门、SQL 编程面试要点
// sqlmock_test.go
package sql_demo
import (
"context"
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
"log"
"testing"
)
func TestSQLMock(t *testing.T) {
db, mock, err := sqlmock.New()
defer db.Close()
require.NoError(t, err)
mockRows := sqlmock.NewRows([]string{"id", "first_name"})
mockRows.AddRow(1, "Tom")
// 准备数据
// 正则表达式
mock.ExpectQuery("SELECT id,first_name FROM `user`.*").WillReturnRows(mockRows)
mock.ExpectQuery("SELECT id FROM `user`.*").WillReturnError(errors.New("mock error"))
// mock准备的顺序和执行的顺序要一致 一一对应
rows, err := db.QueryContext(context.Background(),
"SELECT id,first_name FROM `user` WHERE id=?", 1)
require.NoError(t, err)
for rows.Next() {
tm := TestModel{}
err = rows.Scan(&tm.Id, &tm.FirstName)
require.NoError(t, err)
log.Println(tm)
}
_, err = db.QueryContext(context.Background(),
"SELECT id FROM `user` WHERE id=?", 1)
require.Error(t, err)
}
结果集
12. 结果集处理:Open 与 OpenDB
// db.go
package orm
import "database/sql"
type DBOption func(db *DB)
// DB 是一个sql.DB的装饰器
type DB struct {
r *registry
db *sql.DB
}
func Open(driver string, dataSoruceName string, opts ...DBOption) (*DB, error) {
db, err := sql.Open(driver, dataSoruceName)
if err != nil {
return nil, err
}
return OpenDB(db, opts...)
}
func OpenDB(db *sql.DB, opts ...DBOption) (*DB, error) {
res := &DB{
r: newRegistry(),
db: db,
}
for _, opt := range opts {
opt(res)
}
return res, nil
}
func MustOpen(driver string, dataSoruceName string, opts ...DBOption) *DB {
res, err := Open(driver, dataSoruceName, opts...)
if err != nil {
panic(err)
}
return res
}
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
q, err := s.Build()
if err != nil {
return nil, err
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
}
func (s *Selector[T]) GetMulti(ctx context.Context) ([]*T, error) {
q, err := s.Build()
if err != nil {
return nil, err
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
for rows.Next() {
}
}
13. 结果集处理:发起查询异常情况
// select_test.go
func TestGet(t *testing.T) {
mockDB, mock, err := sqlmock.New()
require.NoError(t, err)
db, err := OpenDB(mockDB)
require.NoError(t, err)
// query error
mock.ExpectQuery("SELECT .*").WillReturnError(errors.New("query error"))
// no rows
rows := sqlmock.NewRows([]string{
"id", "first_name", "age", "last_name",
})
mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
testCases := []struct {
name string
s *Selector[TestModel]
wantRes *TestModel
wantErr error
}{
{
name: "invalid query",
s: NewSelector[TestModel](db).Where(C("xxx").Eq(1)),
wantErr: errs.NewErrUnknownField("xxx"),
},
{
name: "query error",
s: NewSelector[TestModel](db).Where(C("Id").Eq(1)),
wantErr: errors.New("query error"),
},
{
name: "no rows",
s: NewSelector[TestModel](db).Where(C("Id").Lt(1)),
wantErr: errors.New("orm: 没有数据"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, err := tc.s.Get(context.Background())
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantRes, res)
})
}
}
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
q, err := s.Build()
if err != nil {
return nil, err
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
//_, err = db.QueryContext(ctx, q.SQL, q.Args...)
// 查询错误
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, ErrNoRows
}
// select出来多少列
cs, err := rows.Columns()
if err != nil {
return nil, err
}
tp := new(T)
rows.Scan()
return nil, nil
}
14. 结果集处理:反射处理结果集
// model.go
type Field struct {
// 字段名
goName string
// 列名
colName string
// 列的类型
typ reflect.Type
}
func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
// ...
fieldMap := make(map[string]*Field, numField)
for i := 0; i < numField; i++ {
// ...
fieldMap[fd.Name] = &Field{
goName: fd.Name,
//colName: underscoreName(fd.Name),
colName: colName,
typ: fd.Type,
}
}
// ...
}
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
q, err := s.Build()
if err != nil {
return nil, err
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
//_, err = db.QueryContext(ctx, q.SQL, q.Args...)
// 查询错误
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, ErrNoRows
}
// select出来多少列
cs, err := rows.Columns()
if err != nil {
return nil, err
}
tp := new(T)
vals := make([]any, 0, len(cs))
for _, c := range cs {
// c是列名
for _, fd := range s.model.fields {
if fd.colName == c {
// 反射创建实例
// 这里创建的实例是原本类型的指针
val := reflect.New(fd.typ)
vals = append(vals, val.Interface())
}
}
}
rows.Scan(vals...)
// 把vals填入tp
tpValue := reflect.ValueOf(tp)
for i, c := range cs {
// c是列名
for _, fd := range s.model.fields {
if fd.colName == c {
tpValue.Elem().FieldByName(fd.goName).
Set(reflect.ValueOf(vals[i]).Elem())
}
}
}
return tp, nil
}
// select_test.go
func TestGet(t *testing.T) {
mockDB, mock, err := sqlmock.New()
require.NoError(t, err)
db, err := OpenDB(mockDB)
require.NoError(t, err)
// query error
mock.ExpectQuery("SELECT .*").WillReturnError(errors.New("query error"))
// no rows
rows := sqlmock.NewRows([]string{
"id", "first_name", "age", "last_name",
})
mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
// data
rows = sqlmock.NewRows([]string{
"id", "first_name", "age", "last_name",
})
rows.AddRow("1", "Tom", "18", "Jerry")
mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
testCases := []struct {
name string
s *Selector[TestModel]
wantRes *TestModel
wantErr error
}{
{
name: "invalid query",
s: NewSelector[TestModel](db).Where(C("xxx").Eq(1)),
wantErr: errs.NewErrUnknownField("xxx"),
},
{
name: "query error",
s: NewSelector[TestModel](db).Where(C("Id").Eq(1)),
wantErr: errors.New("query error"),
},
{
name: "no rows",
s: NewSelector[TestModel](db).Where(C("Id").Lt(1)),
wantErr: errors.New("orm: 没有数据"),
},
{
name: "data",
s: NewSelector[TestModel](db).Where(C("Id").Lt(1)),
wantRes: &TestModel{Id: 1, FirstName: "Tom", Age: 18, LastName: &sql.NullString{Valid: true, String: "Jerry"}},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, err := tc.s.Get(context.Background())
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantRes, res)
})
}
}
15. 结果集处理:代码优化与总结
// model.go
type Model struct {
tableName string
// 字段名-字段定义
fieldMap map[string]*Field
// 列名-字段定义
columnMap map[string]*Field
}
func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
// ...
fieldMap := make(map[string]*Field, numField)
columnMap := make(map[string]*Field, numField)
for i := 0; i < numField; i++ {
fd := elemType.Field(i)
// 取tag `xxx:xxx=xxx`
pair, err := r.parseTag(fd.Tag)
if err != nil {
return nil, err
}
// 取column: `orm:"column"=xxx`
colName := pair[tagKeyColumn]
if colName == "" {
// 用户没有设置,用字段名
colName = underscoreName(fd.Name)
}
fdData := &Field{
goName: fd.Name,
//colName: underscoreName(fd.Name),
colName: colName,
typ: fd.Type,
}
fieldMap[fd.Name] = fdData
columnMap[colName] = fdData
}
var tableName string
if tbl, ok := entity.(TableName); ok {
tableName = tbl.TableName()
}
if tableName == "" {
tableName = underscoreName(elemType.Name())
}
res := &Model{
tableName: tableName,
fieldMap: fieldMap,
columnMap: columnMap,
}
for _, opt := range opts {
err := opt(res)
if err != nil {
return nil, err
}
}
r.models.Store(typ, res)
return res, nil
}
// model_test.go
package orm
import (
"database/sql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"my_framework/orm/internal/errs"
"reflect"
"testing"
)
func Test_parseModel(t *testing.T) {
tests := []struct {
name string
entity any
wantModel *Model
fields []*Field
wantErr error
}{
{
name: "struct",
entity: TestModel{},
//wantModel: &Model{
// tableName: "test_model",
// fieldMap: map[string]*Field{
// "Id": {
// colName: "id",
// },
// "FirstName": {
// colName: "first_name",
// },
// "LastName": {
// colName: "last_name",
// },
// "Age": {
// colName: "age",
// },
// },
//},
wantErr: errs.ErrPointerOnly,
},
{
name: "pointer",
entity: &TestModel{},
wantModel: &Model{
tableName: "test_model",
fieldMap: map[string]*Field{},
},
fields: []*Field{
{
colName: "id",
goName: "Id",
typ: reflect.TypeOf(int64(0)),
},
{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
{
colName: "last_name",
goName: "LastName",
typ: reflect.TypeOf(&sql.NullString{}),
},
{
colName: "age",
goName: "Age",
typ: reflect.TypeOf(int8(0)),
},
},
},
}
r := ®istry{}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
m, err := r.Register(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
fieldMap := make(map[string]*Field)
columnMap := make(map[string]*Field)
for _, f := range tc.fields {
fieldMap[f.goName] = f
columnMap[f.colName] = f
}
tc.wantModel.fieldMap = fieldMap
tc.wantModel.columnMap = columnMap
assert.Equal(t, tc.wantModel, m)
})
}
}
func TestRegistry_Get(t *testing.T) {
testCases := []struct {
name string
entity any
wantModel *Model
fields []*Field
wantErr error
}{
{
name: "pointer",
entity: &TestModel{},
wantModel: &Model{
tableName: "test_model",
},
fields: []*Field{
{
colName: "id",
goName: "Id",
typ: reflect.TypeOf(int64(0)),
},
{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
{
colName: "last_name",
goName: "LastName",
typ: reflect.TypeOf(&sql.NullString{}),
},
{
colName: "age",
goName: "Age",
typ: reflect.TypeOf(int8(0)),
},
},
},
{
name: "tag",
entity: func() any {
type TagTable struct {
FirstName string `orm:"column=first_name_t"`
}
return &TagTable{}
}(),
fields: []*Field{
{
colName: "first_name_t",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
},
wantModel: &Model{
tableName: "tag_table",
},
},
{
name: "empty tag",
entity: func() any {
type TagTable struct {
FirstName string `orm:"column="`
}
return &TagTable{}
}(),
fields: []*Field{{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
},
wantModel: &Model{
tableName: "tag_table",
},
},
{
name: "column only",
entity: func() any {
type TagTable struct {
FirstName string `orm:"column"`
}
return &TagTable{}
}(),
wantErr: errs.NewErrInvalidTagContent("column"),
},
{
name: "invalid tag",
entity: func() any {
type TagTable struct {
FirstName string `orm:"abc=abc"`
}
return &TagTable{}
}(),
fields: []*Field{
{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
},
wantModel: &Model{
tableName: "tag_table",
},
},
{
name: "custom table name",
entity: &CustomTableName{},
fields: []*Field{
{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
},
wantModel: &Model{
tableName: "custom_table_name_t",
},
},
{
name: "custom table name ptr",
entity: &CustomTableNamePtr{},
fields: []*Field{
{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
},
wantModel: &Model{
tableName: "custom_table_name_ptr_t",
},
},
{
name: "custom table name empty ptr",
entity: &CustomTableNameEmpty{},
fields: []*Field{
{
colName: "first_name",
goName: "FirstName",
typ: reflect.TypeOf(""),
},
},
wantModel: &Model{
tableName: "custom_table_name_empty",
},
},
}
r := newRegistry()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m, err := r.Get(tc.entity)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
fieldMap := make(map[string]*Field)
columnMap := make(map[string]*Field)
for _, f := range tc.fields {
fieldMap[f.goName] = f
columnMap[f.colName] = f
}
tc.wantModel.fieldMap = fieldMap
tc.wantModel.columnMap = columnMap
assert.Equal(t, tc.wantModel, m)
typ := reflect.TypeOf(tc.entity)
cache, ok := r.models.Load(typ)
assert.True(t, ok)
assert.Equal(t, tc.wantModel, cache)
})
}
}
type CustomTableName struct {
FirstName string
}
func (c CustomTableName) TableName() string {
return "custom_table_name_t"
}
type CustomTableNamePtr struct {
FirstName string
}
func (c *CustomTableNamePtr) TableName() string {
return "custom_table_name_ptr_t"
}
type CustomTableNameEmpty struct {
FirstName string
}
func (c CustomTableNameEmpty) TableName() string {
return ""
}
func TestModelWithTableName(t *testing.T) {
r := newRegistry()
m, err := r.Register(&TestModel{}, ModelWithTableName("test_model_ttt"))
require.NoError(t, err)
assert.Equal(t, "test_model_ttt", m.tableName)
}
func TestModelWithColName(t *testing.T) {
testCases := []struct {
name string
field string
colName string
wantColName string
wantErr error
}{
{
name: "column name",
field: "FirstName",
colName: "first_name_ccc",
wantColName: "first_name_ccc",
},
{
name: "invalid column name",
field: "XXX",
colName: "first_name_ccc",
wantErr: errs.NewErrUnknownField("XXX"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := newRegistry()
m, err := r.Register(&TestModel{}, ModelWithColumnName(tc.field, tc.colName))
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
fd, ok := m.fieldMap[tc.field]
require.True(t, ok)
assert.Equal(t, tc.wantColName, fd.colName)
})
}
}
```
```go
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
// ...
tp := new(T)
vals := make([]any, 0, len(cs))
valElems := make([]reflect.Value, 0, len(cs))
for _, c := range cs {
// c是列名
fd, ok := s.model.columnMap[c]
if !ok {
return nil, errs.NewErrUnknownColumn(c)
}
// 反射创建实例
// 这里创建的实例是原本类型的指针
val := reflect.New(fd.typ)
vals = append(vals, val.Interface())
valElems = append(valElems, val.Elem())
}
err = rows.Scan(vals...)
if err != nil {
return nil, err
}
// 把vals填入tp
tpValue := reflect.ValueOf(tp)
for i, c := range cs {
fd, ok := s.model.columnMap[c]
if !ok {
return nil, errs.NewErrUnknownColumn(c)
}
tpValue.Elem().FieldByName(fd.goName).
//Set(reflect.ValueOf(vals[i]).Elem())
Set(valElems[i])
}
return tp, nil
}
```
```go
// select_test.go
package orm
import (
"context"
"database/sql"
"errors"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"my_framework/orm/internal/errs"
"testing"
)
func TestSelector_Build(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
testCases := []struct {
name string
builder QueryBuilder
wantQuery *Query
wantErr error
}{
{
name: "not from",
builder: NewSelector[TestModel](db),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "from",
builder: (NewSelector[TestModel](db)).Table("`test_model`"),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty from",
builder: (NewSelector[TestModel](db)),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "empty table",
builder: (NewSelector[TestModel](db)).Table(""),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "db table",
builder: (NewSelector[TestModel](db)).Table("`mybatis`.`test`"),
wantQuery: &Query{
SQL: "SELECT * FROM `mybatis`.`test`;",
Args: nil,
},
},
{
name: "where",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18)),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE `age` = ?;",
Args: []any{18},
},
},
{
name: "empty where",
builder: (NewSelector[TestModel](db)).Where(),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model`;",
Args: nil,
},
},
{
name: "not",
builder: (NewSelector[TestModel](db)).Where(Not(C("Age").Eq(18))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE NOT (`age` = ?);",
Args: []any{18},
},
},
{
name: "and",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).And(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) AND (`first_name` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "and",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("FirstName").Eq("Tom"))),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE (`age` = ?) OR (`first_name` = ?);",
Args: []any{18, "Tom"},
},
},
{
name: "invalid column",
builder: (NewSelector[TestModel](db)).Where(C("Age").Eq(18).Or(C("XXXX").Eq("Tom"))),
wantErr: errs.NewErrUnknownField("XXXX"),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.builder.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
func TestGet(t *testing.T) {
mockDB, mock, err := sqlmock.New()
require.NoError(t, err)
db, err := OpenDB(mockDB)
require.NoError(t, err)
// query error
mock.ExpectQuery("SELECT .*").WillReturnError(errors.New("query error"))
// no rows
rows := sqlmock.NewRows([]string{
"id", "first_name", "age", "last_name",
})
mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
// data
rows = sqlmock.NewRows([]string{
"id", "first_name", "age", "last_name",
})
rows.AddRow("1", "Tom", "18", "Jerry")
mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
// scan error
rows = sqlmock.NewRows([]string{
"id", "first_name", "age", "last_name",
})
rows.AddRow("abc", "Tom", "18", "Jerry")
mock.ExpectQuery("SELECT .*").WillReturnRows(rows)
testCases := []struct {
name string
s *Selector[TestModel]
wantRes *TestModel
wantErr error
}{
{
name: "invalid query",
s: NewSelector[TestModel](db).Where(C("xxx").Eq(1)),
wantErr: errs.NewErrUnknownField("xxx"),
},
{
name: "query error",
s: NewSelector[TestModel](db).Where(C("Id").Eq(1)),
wantErr: errors.New("query error"),
},
{
name: "no rows",
s: NewSelector[TestModel](db).Where(C("Id").Lt(1)),
wantErr: errors.New("orm: 没有数据"),
},
{
name: "data",
s: NewSelector[TestModel](db).Where(C("Id").Lt(1)),
wantRes: &TestModel{Id: 1, FirstName: "Tom", Age: 18, LastName: &sql.NullString{Valid: true, String: "Jerry"}},
},
//{
// name: "scan error",
// s: NewSelector[TestModel](db).Where(C("Id").Lt(1)),
// wantErr: errs.ErrNoRows,
//},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res, err := tc.s.Get(context.Background())
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantRes, res)
})
}
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
16. 加餐:Option 设计模式
package other
import "errors"
type MyStructOption func(myStruct *MyStruct)
type MyStructOptionErr func(myStruct *MyStruct) error
type MyStruct struct {
// 有两部分, 1.必须用户输入 2.可选
// 必传
id uint64
name string
// 可选
address string
// ...
field1 int
field2 int
}
func WithField1And2(field1 int, field2 int) MyStructOption {
return func(myStruct *MyStruct) {
myStruct.field1 = field1
myStruct.field2 = field2
}
}
func WithAddress(address string) MyStructOption {
return func(myStruct *MyStruct) {
myStruct.address = address
}
}
func WithField1And2V1(field1 int, field2 int) MyStructOptionErr {
return func(myStruct *MyStruct) error {
myStruct.field1 = field1
myStruct.field2 = field2
return nil
}
}
func WithAddressV1(address string) MyStructOptionErr {
return func(myStruct *MyStruct) error {
if address == "" {
return errors.New("地址不能为空")
}
myStruct.address = address
return nil
}
}
func WithAddressV2(address string) MyStructOption {
return func(myStruct *MyStruct) {
if address == "" {
panic(errors.New("地址不能为空"))
}
myStruct.address = address
}
}
// 参数包含用户所有必须传入的字段
func NewMyStruct(id uint64, name string, opts ...MyStructOption) *MyStruct {
// 必传部分
res := &MyStruct{
id: id,
name: name,
}
for _, opt := range opts {
opt(res)
}
return res
}
// 参数包含用户所有必须传入的字段
func NewMyStructV1(id uint64, name string, opts ...MyStructOptionErr) (*MyStruct, error) {
// 必传部分
res := &MyStruct{
id: id,
name: name,
}
for _, opt := range opts {
if err := opt(res); err != nil {
return nil, err
}
}
return res, nil
}
12 第六周:ORM 框架之结果集处理、SELECT 进阶与 INSERT
结果集处理
1. 结果集处理:unsafe 入门
// accessor.go
package unsafe
import (
"errors"
"reflect"
"unsafe"
)
type UnsafeAccessor struct {
fields map[string]FieldMeta
address unsafe.Pointer
}
func NewUnsafeAccessor(entity any) *UnsafeAccessor {
typ := reflect.TypeOf(entity)
typ = typ.Elem()
numField := typ.NumField()
fields := make(map[string]FieldMeta, numField)
for i := 0; i < numField; i++ {
fd := typ.Field(i)
fields[fd.Name] = FieldMeta{
Offset: fd.Offset,
typ: fd.Type,
}
}
val := reflect.ValueOf(entity)
return &UnsafeAccessor{
fields: fields,
address: val.UnsafePointer(),
}
}
func (a *UnsafeAccessor) Field(field string) (any, error) {
// 起始地址
//a.address
fd, ok := a.fields[field]
if !ok {
return nil, errors.New("非法字段")
}
// 字段起始地址
fdAddress := unsafe.Pointer(uintptr(a.address) + fd.Offset)
// 知道类型的情况
//return *(*int)(fdAddress), nil
// 不知道类型
// 根据地址转为对应类型的指针
return reflect.NewAt(fd.typ, fdAddress).Elem().Interface(), nil
}
func (a *UnsafeAccessor) SetField(field string, val any) error {
// 起始地址
//a.address
fd, ok := a.fields[field]
if !ok {
return errors.New("非法字段")
}
// 字段起始地址
fdAddress := unsafe.Pointer(uintptr(a.address) + fd.Offset)
// 知道类型
//*(*int)(fdAddress) = val.(int)
// 不知道类型
reflect.NewAt(fd.typ, fdAddress).Elem().Set(reflect.ValueOf(val))
return nil
}
type FieldMeta struct {
Offset uintptr
typ reflect.Type
}
// accessor_test.go
package unsafe
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
)
func TestUnsafeAccessor_Field(t *testing.T) {
type User struct {
Name string
Age int
}
accessor := NewUnsafeAccessor(&User{
Name: "Tom",
Age: 18,
})
field, err := accessor.Field("Age")
require.NoError(t, err)
assert.Equal(t, 18, field)
err = accessor.SetField("Age", 19)
require.NoError(t, err)
}
// iterate_fields_test.go
package unsafe
import "testing"
func TestPrintFieldOffset(t *testing.T) {
testCases := []struct {
name string
entity any
}{
{
name: "user",
entity: User{},
},
{
name: "user v1",
entity: UserV1{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
PrintFieldOffset(tc.entity)
})
}
}
type User struct {
// 0
Name string
// 16
Age int32
// 24
Alias []string
// 48
Address string
}
type UserV1 struct {
// 0
Name string
// 16
Age int32
// 20
AgeV1 int32
// 24
Alias []string
// 48
Address string
}
// iterate_fields.go
package unsafe
import "reflect"
func PrintFieldOffset(entity any) {
typ := reflect.TypeOf(entity)
numField := typ.NumField()
for i := 0; i < numField; i++ {
field := typ.Field(i)
println(field.Offset)
}
}
2. 结果集处理:unsafe 实现
// select.go
func (s *Selector[T]) GetV1(ctx context.Context) (*T, error) {
q, err := s.Build()
if err != nil {
return nil, err
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
//_, err = db.QueryContext(ctx, q.SQL, q.Args...)
// 查询错误
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, ErrNoRows
}
cs, err := rows.Columns()
if err != nil {
return nil, err
}
var vals []any
tp := new(T)
// 起始地址
address := reflect.ValueOf(tp).UnsafePointer()
for _, c := range cs {
// c是列名
fd, ok := s.model.columnMap[c]
if !ok {
return nil, errs.NewErrUnknownColumn(c)
}
// 需要计算偏移量
// 起始地址+偏移量
fdAddress := unsafe.Pointer(uintptr(address) + fd.offset)
// 在指定地址上创建原本类型的指针
// 得到的是指针类型
val := reflect.NewAt(fd.typ, fdAddress)
vals = append(vals, val.Interface())
}
err = rows.Scan(vals...)
return tp, err
}
// model.go
func (r *registry) Register(entity any, opts ...ModelOpt) (*Model, error) {
// ...
fieldMap := make(map[string]*Field, numField)
columnMap := make(map[string]*Field, numField)
for i := 0; i < numField; i++ {
fd := elemType.Field(i)
// 取tag `xxx:xxx=xxx`
pair, err := r.parseTag(fd.Tag)
if err != nil {
return nil, err
}
// 取column: `orm:"column"=xxx`
colName := pair[tagKeyColumn]
if colName == "" {
// 用户没有设置,用字段名
colName = underscoreName(fd.Name)
}
fdData := &Field{
goName: fd.Name,
//colName: underscoreName(fd.Name),
colName: colName,
typ: fd.Type,
offset: fd.Offset,
}
fieldMap[fd.Name] = fdData
columnMap[colName] = fdData
}
// ...
}
3. 结果集处理:valuer 重构与基准测试
// select.go
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
q, err := s.Build()
if err != nil {
return nil, err
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
//_, err = db.QueryContext(ctx, q.SQL, q.Args...)
// 查询错误
if err != nil {
return nil, err
}
if !rows.Next() {
return nil, ErrNoRows
}
tp := new(T)
val := s.db.creator(s.model, tp)
err = val.SetColumns(rows)
return tp, err
}
// db.go
type DB struct {
r model.Registry
db *sql.DB
creator valuer.Creator
}
func DBUseReflect() DBOption {
return func(db *DB) {
db.creator = valuer.NewReflectValue
}
}
func DBUseUnsafe() DBOption {
return func(db *DB) {
db.creator = valuer.NewUnsafeValue
}
}
func Open(driver string, dataSoruceName string, opts ...DBOption) (*DB, error) {
db, err := sql.Open(driver, dataSoruceName)
if err != nil {
return nil, err
}
return OpenDB(db, opts...)
}
func OpenDB(db *sql.DB, opts ...DBOption) (*DB, error) {
res := &DB{
r: model.NewRegistry(),
db: db,
creator: valuer.NewUnsafeValue,
}
for _, opt := range opts {
opt(res)
}
return res, nil
}
// value.go
package valuer
import (
"database/sql"
"my_framework/orm/model"
)
type Value interface {
SetColumns(rows *sql.Rows) error
}
type ValuerV1 interface {
SetColumns(ehntity any, rows *sql.Rows) error
}
type Creator func(model *model.Model, entity any) Value
// unsafe.go
package valuer
import (
"database/sql"
"my_framework/orm/internal/errs"
"my_framework/orm/model"
"reflect"
"unsafe"
)
type unsafeValue struct {
model *model.Model
// *T
val any
}
// 如果后续Creator的定义修改,这里会报错,就可以知道发生了改变
var _ Creator = NewUnsafeValue
func (u unsafeValue) SetColumns(rows *sql.Rows) error {
cs, err := rows.Columns()
if err != nil {
return err
}
var vals []any
// 起始地址
address := reflect.ValueOf(u.val).UnsafePointer()
for _, c := range cs {
// c是列名
fd, ok := u.model.ColumnMap[c]
if !ok {
return errs.NewErrUnknownColumn(c)
}
// 需要计算偏移量
// 起始地址+偏移量
fdAddress := unsafe.Pointer(uintptr(address) + fd.Offset)
// 在指定地址上创建原本类型的指针
// 得到的是指针类型
val := reflect.NewAt(fd.Typ, fdAddress)
vals = append(vals, val.Interface())
}
err = rows.Scan(vals...)
return err
}
func NewUnsafeValue(model *model.Model, val any) Value {
return unsafeValue{
model: model,
val: val,
}
}
// reflect.go
package valuer
import (
"database/sql"
"my_framework/orm/internal/errs"
"my_framework/orm/model"
"reflect"
)
type reflectValue struct {
model *model.Model
// 对应T的指针
val any
}
var _ Creator = NewReflectValue
func (r reflectValue) SetColumns(rows *sql.Rows) error {
// select出来多少列
cs, err := rows.Columns()
if err != nil {
return err
}
vals := make([]any, 0, len(cs))
valElems := make([]reflect.Value, 0, len(cs))
for _, c := range cs {
// c是列名
fd, ok := r.model.ColumnMap[c]
if !ok {
return errs.NewErrUnknownColumn(c)
}
// 反射创建实例
// 这里创建的实例是原本类型的指针
val := reflect.New(fd.Typ)
vals = append(vals, val.Interface())
valElems = append(valElems, val.Elem())
}
err = rows.Scan(vals...)
if err != nil {
return err
}
// 把vals填入tp
tpValue := reflect.ValueOf(r.val)
for i, c := range cs {
fd, ok := r.model.ColumnMap[c]
if !ok {
return errs.NewErrUnknownColumn(c)
}
tpValue.Elem().FieldByName(fd.GoName).
//Set(reflect.ValueOf(vals[i]).Elem())
Set(valElems[i])
}
return nil
}
func NewReflectValue(model *model.Model, val any) Value {
return reflectValue{
model: model,
val: val,
}
}
// reflect_test.go
package valuer
import (
"database/sql"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"my_framework/orm/model"
"testing"
)
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
func TestReflectValue_SetColumns(t *testing.T) {
testSetColumns(t, NewReflectValue)
}
func testSetColumns(t *testing.T, create Creator) {
testCases := []struct {
name string
// 指针
entity any
rows func() *sqlmock.Rows
wantErr error
wantEntity any
}{
{
name: "set columns",
entity: &TestModel{},
rows: func() *sqlmock.Rows {
rows := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"})
rows.AddRow("1", "Tom", "18", "Jerry")
return rows
},
wantEntity: &TestModel{
Id: 1,
FirstName: "Tom",
Age: 18,
LastName: &sql.NullString{Valid: true, String: "Jerry"},
},
},
{
// 部分列
name: "partial columns",
entity: &TestModel{},
rows: func() *sqlmock.Rows {
rows := sqlmock.NewRows([]string{"id", "first_name"})
rows.AddRow("1", "Tom")
return rows
},
wantEntity: &TestModel{
Id: 1,
FirstName: "Tom",
},
},
{
// 列的不同顺序
name: "order",
entity: &TestModel{},
rows: func() *sqlmock.Rows {
rows := sqlmock.NewRows([]string{"id", "last_name", "first_name", "age"})
rows.AddRow("1", "Jerry", "Tom", "18")
return rows
},
wantEntity: &TestModel{
Id: 1,
FirstName: "Tom",
Age: 18,
LastName: &sql.NullString{Valid: true, String: "Jerry"},
},
},
}
r := model.NewRegistry()
mockDB, mock, err := sqlmock.New()
require.NoError(t, err)
defer mockDB.Close()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 构造rows
mockRows := tc.rows()
mock.ExpectQuery("SELECT XX").WillReturnRows(mockRows)
rows, err := mockDB.Query("SELECT XX")
require.NoError(t, err)
rows.Next()
model, err := r.Get(tc.entity)
require.NoError(t, err)
val := create(model, tc.entity)
val.SetColumns(rows)
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
// 比较有没有设置好数据
assert.Equal(t, tc.wantEntity, tc.entity)
})
}
}
// unsafe_test.go
package valuer
import "testing"
func TestUnsafeValue_SetColumns(t *testing.T) {
testSetColumns(t, NewUnsafeValue)
}
// value_test.go
package valuer
import (
"database/sql/driver"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
"my_framework/orm/model"
"testing"
)
// go test -bench=BenchmarkSetColumns -benchtime=10000x -benchmem
// 单元测试已经确认结果正确,基准测试只需要测试调用
func BenchmarkSetColumns(b *testing.B) {
fn := func(b *testing.B, creator Creator) {
mockDB, mock, err := sqlmock.New()
require.NoError(b, err)
defer mockDB.Close()
// 构造rows,跑N次,要准备N条
mockRows := sqlmock.NewRows([]string{"id", "first_name", "age", "last_name"})
row := []driver.Value{"1", "Tom", "18", "Jerry"}
for i := 0; i < b.N; i++ {
mockRows.AddRow(row...)
}
mock.ExpectQuery("SELECT XX").WillReturnRows(mockRows)
rows, err := mockDB.Query("SELECT XX")
r := model.NewRegistry()
m, err := r.Get(&TestModel{})
require.NoError(b, err)
// 重置计时器
b.ResetTimer()
require.NoError(b, err)
for i := 0; i < b.N; i++ {
rows.Next()
val := creator(m, &TestModel{})
_ = val.SetColumns(rows)
}
}
b.Run("reflect", func(b *testing.B) {
fn(b, NewReflectValue)
})
b.Run("unsafe", func(b *testing.B) {
fn(b, NewUnsafeValue)
})
}
4. 结果集处理:总结与面试要点
// unsafe.go
package valuer
import (
"database/sql"
"my_framework/orm/internal/errs"
"my_framework/orm/model"
"reflect"
"unsafe"
)
type unsafeValue struct {
model *model.Model
// 起始地址
address unsafe.Pointer
}
// 如果后续Creator的定义修改,这里会报错,就可以知道发生了改变
var _ Creator = NewUnsafeValue
func (u unsafeValue) SetColumns(rows *sql.Rows) error {
cs, err := rows.Columns()
if err != nil {
return err
}
var vals []any
for _, c := range cs {
// c是列名
fd, ok := u.model.ColumnMap[c]
if !ok {
return errs.NewErrUnknownColumn(c)
}
// 需要计算偏移量
// 起始地址+偏移量
fdAddress := unsafe.Pointer(uintptr(u.address) + fd.Offset)
// 在指定地址上创建原本类型的指针
// 得到的是指针类型
val := reflect.NewAt(fd.Typ, fdAddress)
vals = append(vals, val.Interface())
}
err = rows.Scan(vals...)
return err
}
func NewUnsafeValue(model *model.Model, val any) Value {
address := reflect.ValueOf(val).UnsafePointer()
return unsafeValue{
model: model,
address: address,
}
}
SELECT进阶
5. SELECT 进阶:指定简单列
// select_test.go
func TestSelector_Select(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
//db := memoryDB(t)
tests := []struct {
name string
s QueryBuilder
wantQuery *Query
wantErr error
}{
{
name: "multiple columns",
//s: NewSelector[TestModel](db).Select("first_name", "last_name"),
s: NewSelector[TestModel](db).Select(C("FirstName"), C("LastName")),
wantQuery: &Query{
SQL: "SELECT `first_name`,`last_name` FROM `test_model`;",
},
},
{
name: "invalid columns",
//s: NewSelector[TestModel](db).Select("first_name", "last_name"),
s: NewSelector[TestModel](db).Select(C("Invalid")),
wantErr: errs.NewErrUnknownField("Invalid"),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.s.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
// select.go
// Selectable 是一个标记接口
// 代表的是查找的列或者聚合函数等
type Selectable interface {
selectable()
}
type Selector[T any] struct {
builder
table string
where []Predicate
db *DB
columns []Selectable
}
func (s *Selector[T]) Select(cols ...Selectable) *Selector[T] {
s.columns = cols
return s
}
func (s *Selector[T]) Build() (*Query, error) {
// ...
sb := s.sb
sb.WriteString("SELECT ")
// 是否指定列
if len(s.columns) > 0 {
for i, col := range s.columns {
println(i)
// 在添加字段前添加 ,
// 相对于在前一个col后加 ,
if i > 0 {
sb.WriteByte(',')
}
err = s.buildColumn(col.(Column))
if err != nil {
return nil, err
}
}
} else {
sb.WriteString("*")
}
sb.WriteString(` FROM `)
// ...
}
// build.go
func (s *builder) buildExpression(expr Expression) error {
switch exp := expr.(type) {
case nil:
case Predicate:
// 处理p
_, ok := exp.left.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.left); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
s.sb.WriteByte(' ')
s.sb.WriteString(exp.op.String())
s.sb.WriteByte(' ')
_, ok = exp.right.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.right); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
case Column:
return s.buildColumn(exp)
case value:
s.sb.WriteByte('?')
s.addArg(exp.val)
default:
return errs.NewErrUnsupportedExpression(expr)
}
return nil
}
func (s *builder) buildColumn(c Column) error {
fd, ok := s.model.FieldMap[c.name]
// 传入了错误的字段
if !ok {
return errs.NewErrUnknownField(c.name)
}
s.sb.WriteByte('`')
s.sb.WriteString(fd.ColName)
s.sb.WriteByte('`')
return nil
}
// column.go
package orm
type Column struct {
name string
}
func C(name string) Column {
return Column{name: name}
}
// Eq
// C("id").Eq(12)
// sub.C("id").Eq(12)
func (c Column) Eq(arg any) Predicate {
return Predicate{
left: c,
op: opEq,
right: value{val: arg},
}
}
func (c Column) Lt(arg any) Predicate {
return Predicate{
left: c,
op: opLt,
right: value{val: arg},
}
}
func (Column) expr() {}
func (c Column) selectable() {
}
6. SELECT 进阶:指定聚合函数
// aggregate.go
package orm
// Aggregate 聚合函数
// AVG("age"), SUM("age"), COUNT("age"), MAX("age"), MIN("age")
type Aggregate struct {
fn string
arg string
}
func (a Aggregate) selectable() {
}
func Avg(col string) Aggregate {
return Aggregate{
fn: "AVG",
arg: col,
}
}
func Sum(col string) Aggregate {
return Aggregate{
fn: "SUM",
arg: col,
}
}
func Count(col string) Aggregate {
return Aggregate{
fn: "COUNT",
arg: col,
}
}
func Max(col string) Aggregate {
return Aggregate{
fn: "MAX",
arg: col,
}
}
func Min(col string) Aggregate {
return Aggregate{
fn: "MIN",
arg: col,
}
}
// select_test.go
func TestSelector_Select(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
//db := memoryDB(t)
tests := []struct {
name string
s QueryBuilder
wantQuery *Query
wantErr error
}{
// ...
{
name: "avg",
s: NewSelector[TestModel](db).Select(Avg("Age")),
wantQuery: &Query{
SQL: "SELECT AVG(`age`) FROM `test_model`;",
},
},
{
name: "sum",
s: NewSelector[TestModel](db).Select(Sum("Age")),
wantQuery: &Query{
SQL: "SELECT SUM(`age`) FROM `test_model`;",
},
},
{
name: "sum invalid",
s: NewSelector[TestModel](db).Select(Sum("Invalid")),
wantErr: errs.NewErrUnknownField("Invalid"),
},
{
name: "sum",
s: NewSelector[TestModel](db).Select(Sum("Age"),Avg("Age")),
wantQuery: &Query{
SQL: "SELECT SUM(`age`),AVG(`age`) FROM `test_model`;",
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.s.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
// select.go
func (s *Selector[T]) Build() (*Query, error) {
// ...
sb := s.sb
sb.WriteString("SELECT ")
if err = s.buildColumns(); err != nil {
return nil, err
}
sb.WriteString(` FROM `)
// ...
}
func (s *Selector[T]) buildColumns() error {
if len(s.columns) == 0 {
s.sb.WriteString("*")
return nil
}
for i, col := range s.columns {
// 在添加字段前添加 ,
// 相对于在前一个col后加 ,
if i > 0 {
s.sb.WriteByte(',')
}
switch c := col.(type) {
case Column:
err := s.buildColumn(c.name)
if err != nil {
return err
}
case Aggregate:
// 如果是聚合函数
s.sb.WriteString(c.fn)
s.sb.WriteByte('(')
err := s.buildColumn(c.arg)
if err != nil {
return err
}
s.sb.WriteByte(')')
}
}
return nil
}
7. SELECT 进阶:原生表达式
// expression.go
package orm
// Expression 是一个标记接口,代表表达式
type Expression interface {
expr()
}
// RawExpr 代表原生表达式
type RawExpr struct {
raw string
args []any
}
func (r RawExpr) selectable() {
}
func (r RawExpr) expr() {
}
func Raw(expr string, args ...any) RawExpr {
return RawExpr{
raw: expr,
args: args,
}
}
func (r RawExpr) AsPredicate() Predicate{
return Predicate{
left: r,
}
}
8. SELECT 进阶:别名
// builder.go
func (s *builder) buildColumn(c Column) error {
fd, ok := s.model.FieldMap[c.name]
// 传入了错误的字段
if !ok {
return errs.NewErrUnknownField(c.name)
}
s.sb.WriteByte('`')
s.sb.WriteString(fd.ColName)
s.sb.WriteByte('`')
if c.alias != "" {
s.sb.WriteString(" AS `")
s.sb.WriteString(c.alias)
s.sb.WriteByte('`')
}
return nil
}
// column.go
type Column struct {
name string
alias string
}
// 不可变设计
func (c Column) As(alias string) Column {
return Column{
name: c.name,
alias: alias,
}
}
// select_test.go
func TestSelector_Select(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
require.NoError(t, err)
//db := memoryDB(t)
tests := []struct {
name string
s QueryBuilder
wantQuery *Query
wantErr error
}{
// ...
{
name: "column alias",
s: NewSelector[TestModel](db).Select(C("FirstName").As("my_name")),
wantQuery: &Query{
SQL: "SELECT `first_name` AS `my_name` FROM `test_model`;",
},
},
{
name: "column alias agg",
s: NewSelector[TestModel](db).Select(Avg("FirstName").As("my_name")),
wantQuery: &Query{
SQL: "SELECT AVG(`first_name`) AS `my_name` FROM `test_model`;",
},
},
{
name: "column alias in where",
s: NewSelector[TestModel](db).Where(C("Id").As("my_id").Eq(18)),
wantQuery: &Query{
SQL: "SELECT * FROM `test_model` WHERE `id` = ?;",
Args: []any{18},
},
},
}
for _, tc := range tests {
// ...
}
}
// aggregate.go
type Aggregate struct {
fn string
arg string
alias string
}
func (a Aggregate) As(alias string) Aggregate {
return Aggregate{
fn: a.fn,
arg: a.arg,
alias: alias,
}
}
INSERT
9. INSERT:INSERT 语句概览
10. INSERT:最简实现
// model.go
type Model struct {
TableName string
// 为了保证顺序遍历
Fields []*Field
// 字段名-字段定义
FieldMap map[string]*Field
// 列名-字段定义
ColumnMap map[string]*Field
}
func (r *registry) Registry(entity any, opts ...ModelOpt) (*Model, error) {
// ...
res := &Model{
TableName: tableName,
FieldMap: fieldMap,
ColumnMap: columnMap,
Fields: fields,
}
// ...
return res, nil
}
// insert
package orm
import (
"my_framework/orm/internal/errs"
"reflect"
"strings"
)
type Inserter[T any] struct {
values []*T
db *DB
}
func NewInserter[T any](db *DB) *Inserter[T] {
return &Inserter[T]{
db: db,
}
}
func (i *Inserter[T]) Values(vals ...*T) *Inserter[T] {
i.values = vals
return i
}
func (i Inserter[T]) Build() (*Query, error) {
if len(i.values) == 0 {
return nil, errs.ErrInsertZeroRow
}
var sb strings.Builder
sb.WriteString("INSERT INTO ")
// 拿到元数据
m, err := i.db.r.Get(i.values[0])
if err != nil {
return nil, err
}
// 拼接表名
sb.WriteByte('`')
sb.WriteString(m.TableName)
sb.WriteByte('`')
// 显示地指定列的顺序,不然很难处理
sb.WriteByte('(')
// 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
for idx, field := range m.Fields {
if idx > 0 {
sb.WriteByte(',')
}
sb.WriteByte('`')
sb.WriteString(field.ColName)
sb.WriteByte('`')
}
sb.WriteByte(')')
sb.WriteString(" VALUES ")
//sb.WriteByte('(')
args := make([]any, 0, len(i.values)*len(m.Fields))
for j, val := range i.values {
if j > 0 {
sb.WriteByte(',')
}
sb.WriteByte('(')
for idx, field := range m.Fields {
if idx > 0 {
sb.WriteByte(',')
}
sb.WriteByte('?')
// 读取参数
arg := reflect.ValueOf(val).Elem().
FieldByName(field.GoName).Interface()
args = append(args, arg)
}
sb.WriteByte(')')
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: args,
}, err
}
// insert_test.go
package orm
import (
"database/sql"
"github.com/stretchr/testify/assert"
"my_framework/orm/internal/errs"
"testing"
)
func TestInserter_Build(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
assert.NoError(t, err)
tests := []struct {
name string
i QueryBuilder
wantErr error
wantQuery *Query
}{
{
// 插入一行
name: "single insert",
i: NewInserter[TestModel](db).Values(&TestModel{
Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
}),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?);",
Args: []any{int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true}},
},
},
{
// 插入多行
name: "multiple insert",
i: NewInserter[TestModel](db).Values(
&TestModel{
Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
},
&TestModel{
Id: 13, FirstName: "Dummy", Age: 19, LastName: &sql.NullString{String: "Shan", Valid: true},
},
),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?),(?,?,?,?);",
Args: []any{
int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
int64(13), "Dummy", int8(19), &sql.NullString{String: "Shan", Valid: true},
},
},
},
{
// 插入多行
name: "no rowinsert",
i: NewInserter[TestModel](db).Values(),
wantErr: errs.ErrInsertZeroRow,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.i.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
11. INSERT:指定列
// insert.go
func (i *Inserter[T]) Columns(cols ...string) *Inserter[T] {
i.columns = cols
return i
}
func (i Inserter[T]) Build() (*Query, error) {
// ...
sb.WriteByte('(')
fields := m.Fields
// 用户指定了列
if len(i.columns) > 0 {
fields = make([]*model.Field, 0, len(i.columns))
for _, fd := range i.columns {
fdMeta, ok := m.FieldMap[fd]
// 传入错误的列
if !ok {
return nil, errs.NewErrUnknownField(fd)
}
fields = append(fields, fdMeta)
}
}
// 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
for idx, field := range fields {
if idx > 0 {
sb.WriteByte(',')
}
sb.WriteByte('`')
sb.WriteString(field.ColName)
sb.WriteByte('`')
}
sb.WriteByte(')')
sb.WriteString(" VALUES ")
//sb.WriteByte('(')
args := make([]any, 0, len(i.values)*len(fields))
for j, val := range i.values {
if j > 0 {
sb.WriteByte(',')
}
sb.WriteByte('(')
for idx, field := range fields {
// ...
}
sb.WriteByte(')')
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: args,
}, err
}
// insert_test.go
func TestInserter_Build(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
assert.NoError(t, err)
tests := []struct {
name string
i QueryBuilder
wantErr error
wantQuery *Query
}{
// ...
{
// 部分列
name: "partial columns",
i: NewInserter[TestModel](db).
Columns("Id","FirstName").
Values(
&TestModel{
Id: 12, FirstName: "Tom", Age: 18,
},
&TestModel{
Id: 13, FirstName: "Dummy", Age: 19,
},
),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`) VALUES (?,?),(?,?);",
Args: []any{
int64(12), "Tom",
int64(13), "Dummy",
},
},
},
}
for _, tc := range tests {
// ...
}
}
12. INSERT:UPSERT API 定义
// insert.go
type OnDuplicateKeyBuilder[T any] struct {
i *Inserter[T]
}
type OnDuplicateKey[T any] struct {
assigns []Assignable
}
type Assignable interface {
assign()
}
type Inserter[T any] struct {
values []*T
db *DB
columns []string
//onDuplicateKey []Assignable
onDuplicateKey *OnDuplicateKey[T]
}
//func (i *Inserter[T]) OnDuplicateKey(assigns ...Assignable) *Inserter[T] {
// i.onDuplicateKey = assigns
// return i
//}
func (o *OnDuplicateKeyBuilder[T]) Update(assigns ...Assignable) *Inserter[T] {
o.i.onDuplicateKey = &OnDuplicateKey[T]{
assigns: assigns,
}
return o.i
}
func (i *Inserter[T]) OnDuplicateKey() *OnDuplicateKeyBuilder[T] {
return &OnDuplicateKeyBuilder[T]{
i: i,
}
}
13. INSERT:MySQL UPSERT 基本实现
// insert.go
func (i Inserter[T]) Build() (*Query, error) {
// ...
sb.WriteString(" VALUES ")
//sb.WriteByte('(')
args := make([]any, 0, len(i.values)*len(fields))
for j, val := range i.values {
// ...
}
// on duplicate
if i.onDuplicateKey != nil {
sb.WriteString(" ON DUPLICATE KEY UPDATE ")
for idx, assign := range i.onDuplicateKey.assigns {
if idx > 0 {
sb.WriteByte(',')
}
switch a := assign.(type) {
case Assignment:
fd, ok := m.FieldMap[a.col]
if !ok {
return nil, errs.NewErrUnknownField(a.col)
}
sb.WriteByte('`')
sb.WriteString(fd.ColName)
sb.WriteByte('`')
sb.WriteString("=?")
args = append(args, a.val)
case Column:
fd, ok := m.FieldMap[a.name]
if !ok {
return nil, errs.NewErrUnknownField(a.name)
}
sb.WriteByte('`')
sb.WriteString(fd.ColName)
sb.WriteByte('`')
sb.WriteString("=VALUES(")
sb.WriteByte('`')
sb.WriteString(fd.ColName)
sb.WriteByte('`')
sb.WriteByte(')')
default:
return nil, errs.NewErrUnsupportedAssignable(assign)
}
}
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: args,
}, err
}
// assignment.go
package orm
type Assignment struct {
col string
val any
}
func (Assignment) assign() {
}
func Assign(col string, val any) Assignment {
return Assignment{
col: col, val: val,
}
}
// insert_test.go
func TestInserter_Build(t *testing.T) {
db, err := Open("mysql", "root:123456@tcp(127.0.0.1:3307)/mybatis")
assert.NoError(t, err)
tests := []struct {
name string
i QueryBuilder
wantErr error
wantQuery *Query
}{
// ...
{
name: "upsert",
i: NewInserter[TestModel](db).Values(&TestModel{
Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
}).OnDuplicateKey().Update(
Assign("FirstName", "Deng"),
Assign("Age", 19),
),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?)" +
" ON DUPLICATE KEY UPDATE `first_name`=?,`age`=?;",
Args: []any{
int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
"Deng", 19,
},
},
},
{
name: "update insert column",
i: NewInserter[TestModel](db).Values(
&TestModel{
Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
},
&TestModel{
Id: 13, FirstName: "Dummy", Age: 19, LastName: &sql.NullString{String: "Shan", Valid: true},
},
).OnDuplicateKey().Update(C("FirstName"), C("Age")),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?),(?,?,?,?)" +
" ON DUPLICATE KEY UPDATE `first_name`=VALUES(`first_name`),`age`=VALUES(`age`);",
Args: []any{
int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
int64(13), "Dummy", int8(19), &sql.NullString{String: "Shan", Valid: true},
},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.i.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
14. INSERT:方言抽象 Dialect
// dialect.go
package orm
import (
"my_framework/orm/internal/errs"
"strings"
)
type Dialect interface {
// quoter 就是为了解决引号问题
// Mysql
quoter() byte
buildOnDuplicateKey(sb strings.Builder, odk *OnDuplicateKey) error
}
type standardSQL struct {
}
func (s standardSQL) quoter() byte {
//TODO implement me
panic("implement me")
}
func (s standardSQL) buildOnDuplicateKey(sb strings.Builder, odk *OnDuplicateKey) error {
//TODO implement me
panic("implement me")
}
type mysqlDialect struct {
standardSQL
}
func (m mysqlDialect) quoter() byte {
return '`'
}
func (m mysqlDialect) buildOnDuplicateKey(sb strings.Builder, odk *OnDuplicateKey) error {
// ...
}
type sqliteDialect struct {
standardSQL
}
type postgreDialect struct {
standardSQL
}
15. INSERT:builder 抽象与重构
// build.go
package orm
import (
"my_framework/orm/internal/errs"
"my_framework/orm/model"
"strings"
)
type builder struct {
sb *strings.Builder
model *model.Model
args []any
dialect Dialect
quoter byte
}
func (b *builder) quote(name string) {
b.sb.WriteByte(b.quoter)
b.sb.WriteString(name)
b.sb.WriteByte(b.quoter)
}
func (s *builder) buildExpression(expr Expression) error {
switch exp := expr.(type) {
case nil:
case Predicate:
// 处理p
_, ok := exp.left.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.left); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
if exp.op != "" {
s.sb.WriteByte(' ')
s.sb.WriteString(exp.op.String())
s.sb.WriteByte(' ')
}
_, ok = exp.right.(Predicate)
if ok {
s.sb.WriteByte('(')
}
if err := s.buildExpression(exp.right); err != nil {
return err
}
if ok {
s.sb.WriteByte(')')
}
case Column:
// 防止where加as
exp.alias = ""
return s.buildColumn(exp)
case value:
s.sb.WriteByte('?')
s.addArg(exp.val)
case RawExpr:
s.sb.WriteByte('(')
s.sb.WriteString(exp.raw)
s.addArg(exp.args...)
s.sb.WriteByte(')')
default:
return errs.NewErrUnsupportedExpression(expr)
}
return nil
}
// ...
// insert.go
type Inserter[T any] struct {
builder
values []*T
db *DB
columns []string
//onDuplicateKey []Assignable
onDuplicateKey *OnDuplicateKey
}
func NewInserter[T any](db *DB) *Inserter[T] {
return &Inserter[T]{
builder: builder{
dialect: db.dialect,
quoter: db.dialect.quoter(),
sb: &strings.Builder{},
},
db: db,
}
}
func (i *Inserter[T]) Build() (*Query, error) {
if len(i.values) == 0 {
return nil, errs.ErrInsertZeroRow
}
//var sb strings.Builder
i.sb.WriteString("INSERT INTO ")
// 拿到元数据
var err error
i.model, err = i.db.r.Get(i.values[0])
if err != nil {
return nil, err
}
// 拼接表名
//sb.WriteByte('`')
//sb.WriteString(m.TableName)
//sb.WriteByte('`')
i.quote(i.model.TableName)
// 显示地指定列的顺序,不然很难处理
i.sb.WriteByte('(')
fields := i.model.Fields
// 用户指定了列
if len(i.columns) > 0 {
fields = make([]*model.Field, 0, len(i.columns))
for _, fd := range i.columns {
fdMeta, ok := i.model.FieldMap[fd]
// 传入错误的列
if !ok {
return nil, errs.NewErrUnknownField(fd)
}
fields = append(fields, fdMeta)
}
}
// 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
for idx, field := range fields {
if idx > 0 {
i.sb.WriteByte(',')
}
//i.sb.WriteByte('`')
//i.sb.WriteString(field.ColName)
//i.sb.WriteByte('`')
i.quote(field.ColName)
}
i.sb.WriteByte(')')
i.sb.WriteString(" VALUES ")
//sb.WriteByte('(')
i.args = make([]any, 0, len(i.values)*len(fields))
for j, val := range i.values {
if j > 0 {
i.sb.WriteByte(',')
}
i.sb.WriteByte('(')
for idx, field := range fields {
if idx > 0 {
i.sb.WriteByte(',')
}
i.sb.WriteByte('?')
// 读取参数
arg := reflect.ValueOf(val).Elem().
FieldByName(field.GoName).Interface()
//args = append(args, arg)
i.addArg(arg)
}
i.sb.WriteByte(')')
}
// on duplicate
if i.onDuplicateKey != nil {
err = i.dialect.buildOnDuplicateKey(&i.builder, i.onDuplicateKey)
if err != nil {
return nil, err
}
}
i.sb.WriteByte(';')
return &Query{
SQL: i.sb.String(),
Args: i.args,
}, err
}
// select.go
type Selector[T any] struct {
builder
table string
where []Predicate
db *DB
columns []Selectable
}
func NewSelector[T any](db *DB) *Selector[T] {
return &Selector[T]{
builder: builder{
dialect: db.dialect,
quoter: db.dialect.quoter(),
sb: &strings.Builder{},
},
db: db,
}
}
func (s *Selector[T]) Build() (*Query, error) {
//var sb strings.Builder
s.sb = &strings.Builder{} // 指针类型一定要初始化
var err error
s.model, err = s.db.r.Get(new(T)) // new(T) 生成T的指针
if err != nil {
return nil, err
}
sb := s.sb
sb.WriteString("SELECT ")
if err = s.buildColumns(); err != nil {
return nil, err
}
sb.WriteString(` FROM `)
if s.table == "" {
sb.WriteByte('`')
sb.WriteString(s.model.TableName)
sb.WriteByte('`')
} else {
//segs := strings.Split(s.table, ".")
//sb.WriteByte('`')
sb.WriteString(s.table)
//sb.WriteByte('`')
}
if len(s.where) > 0 {
sb.WriteString(" WHERE ")
if err := s.buildPredicates(s.where); err != nil {
return nil, err
}
}
sb.WriteByte(';')
return &Query{
SQL: sb.String(),
Args: s.args,
}, nil
}
// db.go
// DB 是一个sql.DB的装饰器
type DB struct {
r model.Registry
db *sql.DB
creator valuer.Creator
dialect Dialect
}
func DBWithDialect(d Dialect) DBOption {
return func(db *DB) {
db.dialect = d
}
}
func OpenDB(db *sql.DB, opts ...DBOption) (*DB, error) {
res := &DB{
r: model.NewRegistry(),
db: db,
creator: valuer.NewUnsafeValue,
dialect: DialectMysql,
}
for _, opt := range opts {
opt(res)
}
return res, nil
}
// dialect.go
var (
DialectMysql Dialect = mysqlDialect{}
DialectSqlLite Dialect = sqliteDialect{}
DialectPostgre Dialect = postgreDialect{}
)
func (m mysqlDialect) buildOnDuplicateKey(b *builder, odk *OnDuplicateKey) error {
b.sb.WriteString(" ON DUPLICATE KEY UPDATE ")
for idx, assign := range odk.assigns {
if idx > 0 {
b.sb.WriteByte(',')
}
switch a := assign.(type) {
case Assignment:
fd, ok := b.model.FieldMap[a.col]
if !ok {
return errs.NewErrUnknownField(a.col)
}
//b.sb.WriteByte('`')
//b.sb.WriteString(fd.ColName)
//b.sb.WriteByte('`')
b.quote(fd.ColName)
b.sb.WriteString("=?")
b.addArg(a.val)
case Column:
fd, ok := b.model.FieldMap[a.name]
if !ok {
return errs.NewErrUnknownField(a.name)
}
//b.sb.WriteByte('`')
//b.sb.WriteString(fd.ColName)
//b.sb.WriteByte('`')
b.quote(fd.ColName)
b.sb.WriteString("=VALUES(")
//b.sb.WriteByte('`')
//b.sb.WriteString(fd.ColName)
//b.sb.WriteByte('`')
b.quote(fd.ColName)
b.sb.WriteByte(')')
default:
return errs.NewErrUnsupportedAssignable(assign)
}
}
return nil
}
16. INSERT:SQLite UPSERT 实现、方言抽象局限性
// dialect.go
func (s sqliteDialect) quoter() byte {
return '`'
}
func (s sqliteDialect) buildOnDuplicateKey(b *builder, odk *OnDuplicateKey) error {
b.sb.WriteString(" ON CONFLICT(")
for i, col := range odk.conflictColumns {
if i > 0 {
b.sb.WriteByte(',')
}
err := b.buildColumn(Column{
name: col,
})
if err != nil {
return err
}
}
b.sb.WriteString(") DO UPDATE SET ")
for idx, assign := range odk.assigns {
if idx > 0 {
b.sb.WriteByte(',')
}
switch a := assign.(type) {
case Assignment:
fd, ok := b.model.FieldMap[a.col]
if !ok {
return errs.NewErrUnknownField(a.col)
}
//b.sb.WriteByte('`')
//b.sb.WriteString(fd.ColName)
//b.sb.WriteByte('`')
b.quote(fd.ColName)
b.sb.WriteString("=?")
b.addArg(a.val)
case Column:
fd, ok := b.model.FieldMap[a.name]
if !ok {
return errs.NewErrUnknownField(a.name)
}
//b.sb.WriteByte('`')
//b.sb.WriteString(fd.ColName)
//b.sb.WriteByte('`')
b.quote(fd.ColName)
b.sb.WriteString("=excluded.")
//b.sb.WriteByte('`')
//b.sb.WriteString(fd.ColName)
//b.sb.WriteByte('`')
b.quote(fd.ColName)
default:
return errs.NewErrUnsupportedAssignable(assign)
}
}
return nil
}
// insert_test.go
func TestSqlite_upsert_Build(t *testing.T) {
db, err := Open("sqlite",
"file:test.db?cache=shared&mode=memory",
DBWithDialect(DialectSqlLite))
assert.NoError(t, err)
tests := []struct {
name string
i QueryBuilder
wantErr error
wantQuery *Query
}{
{
name: "upsert",
i: NewInserter[TestModel](db).Values(&TestModel{
Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
}).OnDuplicateKey().ConflictColumns("FirstName", "Age").Update(
Assign("FirstName", "Deng"),
Assign("Age", 19),
),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?)" +
" ON CONFLICT(`first_name`,`age`) UPDATE `first_name`=excluded.`first_name`,`age`=excluded.`age`;",
Args: []any{
int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
"Deng", 19,
},
},
},
{
name: "update insert column",
i: NewInserter[TestModel](db).Values(
&TestModel{
Id: 12, FirstName: "Tom", Age: 18, LastName: &sql.NullString{String: "Jerry", Valid: true},
},
&TestModel{
Id: 13, FirstName: "Dummy", Age: 19, LastName: &sql.NullString{String: "Shan", Valid: true},
},
).OnDuplicateKey().Update(C("FirstName"), C("Age")),
wantQuery: &Query{
SQL: "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?),(?,?,?,?)" +
" ON DUPLICATE KEY UPDATE `first_name`=VALUES(`first_name`),`age`=VALUES(`age`);",
Args: []any{
int64(12), "Tom", int8(18), &sql.NullString{String: "Jerry", Valid: true},
int64(13), "Dummy", int8(19), &sql.NullString{String: "Shan", Valid: true},
},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
q, err := tc.i.Build()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.wantQuery, q)
})
}
}
17. INSERT:INSERT 执行
// types.go
type Executor interface {
//Exec(ctx context.Context) (sql.Result, error)
Exec(ctx context.Context) Result
}
// result.go
package orm
import "database/sql"
type Result struct {
err error
res sql.Result
}
func (r Result) LastInsertId() (int64, error) {
if r.err != nil {
return 0, r.err
}
return r.res.LastInsertId()
}
func (r Result) RowsAffected() (int64, error) {
if r.err != nil {
return 0, r.err
}
return r.res.RowsAffected()
}
func (r Result) Err() error {
return r.err
}
// insert_test.go
func TestInserter_Exec(t *testing.T) {
mockDB, mock, err := sqlmock.New()
require.NoError(t, err)
db, err := OpenDB(mockDB)
require.NoError(t, err)
tests := []struct {
name string
i *Inserter[TestModel]
//wantRes Result
wantErr error
affected int64
}{
{
name: "query error",
i: func() *Inserter[TestModel] {
return NewInserter[TestModel](db).Values(&TestModel{}).
Columns("Invalid")
}(),
wantErr: errs.NewErrUnknownField("Invalid"),
},
{
name: "db error",
i: func() *Inserter[TestModel] {
mock.ExpectExec("INSERT INTO .*").
WillReturnError(errors.New("db error"))
return NewInserter[TestModel](db).Values(&TestModel{})
}(),
wantErr: errors.New("db error"),
},
{
name: "exec",
i: func() *Inserter[TestModel] {
res := driver.RowsAffected(1)
mock.ExpectExec("INSERT INTO .*").
WillReturnResult(res)
return NewInserter[TestModel](db).Values(&TestModel{})
}(),
affected: 1,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
res := tc.i.Exec(context.Background())
affected, err := res.RowsAffected()
assert.Equal(t, tc.wantErr, err)
if err != nil {
return
}
assert.Equal(t, tc.affected, affected)
})
}
}
// insert.go
func (i *Inserter[T]) Exec(ctx context.Context) Result {
q, err := i.Build()
if err != nil {
return Result{
err: err,
}
}
res, err := i.db.db.Exec(q.SQL, q.Args...)
return Result{
err: err, res: res,
}
}
18. INSERT:unsafe 读取字段、总结与面试要点
// insert.go
func (i *Inserter[T]) BuildUnsafe() (*Query, error) {
if len(i.values) == 0 {
return nil, errs.ErrInsertZeroRow
}
//var sb strings.Builder
i.sb.WriteString("INSERT INTO ")
// 拿到元数据
var err error
i.model, err = i.db.r.Get(i.values[0])
if err != nil {
return nil, err
}
// 拼接表名
//sb.WriteByte('`')
//sb.WriteString(m.TableName)
//sb.WriteByte('`')
i.quote(i.model.TableName)
// 显示地指定列的顺序,不然很难处理
i.sb.WriteByte('(')
fields := i.model.Fields
// 用户指定了列
if len(i.columns) > 0 {
fields = make([]*model.Field, 0, len(i.columns))
for _, fd := range i.columns {
fdMeta, ok := i.model.FieldMap[fd]
// 传入错误的列
if !ok {
return nil, errs.NewErrUnknownField(fd)
}
fields = append(fields, fdMeta)
}
}
// 遍历map的顺序是随机的,不能遍历fieldmap或colmap拼接(x,x,x,x)
for idx, field := range fields {
if idx > 0 {
i.sb.WriteByte(',')
}
//i.sb.WriteByte('`')
//i.sb.WriteString(field.ColName)
//i.sb.WriteByte('`')
i.quote(field.ColName)
}
i.sb.WriteByte(')')
i.sb.WriteString(" VALUES ")
//sb.WriteByte('(')
i.args = make([]any, 0, len(i.values)*len(fields))
for j, val := range i.values {
if j > 0 {
i.sb.WriteByte(',')
}
i.sb.WriteByte('(')
v := i.db.creator(i.model, val)
for idx, field := range fields {
if idx > 0 {
i.sb.WriteByte(',')
}
i.sb.WriteByte('?')
// 读取参数
arg, err := v.Filed(field.GoName)
if err != nil {
return nil, err
}
//args = append(args, arg)
i.addArg(arg)
}
i.sb.WriteByte(')')
}
// on duplicate
if i.onDuplicateKey != nil {
err = i.dialect.buildOnDuplicateKey(&i.builder, i.onDuplicateKey)
if err != nil {
return nil, err
}
}
i.sb.WriteByte(';')
return &Query{
SQL: i.sb.String(),
Args: i.args,
}, err
}
// unsafe.go
func (u unsafeValue) Filed(name string) (any, error) {
fd, ok := u.model.FieldMap[name]
if !ok {
return nil, errs.NewErrUnknownField(name)
}
fdAddr := unsafe.Pointer(uintptr(u.address) + fd.Offset)
val := reflect.NewAt(fd.Typ, fdAddr)
return val.Elem().Interface(), nil
}
// reflect.go
package valuer
import (
"database/sql"
"my_framework/orm/internal/errs"
"my_framework/orm/model"
"reflect"
)
type reflectValue struct {
model *model.Model
// 对应T的指针
//val any
val reflect.Value
}
func (r reflectValue) Filed(name string) (any, error) {
//if _, ok := r.val.Type().FieldByName(name); !ok {
// return nil, errors.New("")
//}
//val:=r.val.FieldByName(name)
//if val==(reflect.Value{}){
// // err
//}
return r.val.FieldByName(name).Interface(), nil
}
var _ Creator = NewReflectValue
func (r reflectValue) SetColumns(rows *sql.Rows) error {
// ...
// 把vals填入tp
//tpValue := reflect.ValueOf(r.val)
tpValue := r.val
for i, c := range cs {
fd, ok := r.model.ColumnMap[c]
if !ok {
return errs.NewErrUnknownColumn(c)
}
tpValue.Elem().FieldByName(fd.GoName).
//Set(reflect.ValueOf(vals[i]).Elem())
Set(valElems[i])
}
return nil
}
func NewReflectValue(model *model.Model, val any) Value {
return reflectValue{
model: model,
val: reflect.ValueOf(val).Elem(),
}
}
// value.go
type Value interface {
Filed(name string) (any, error)
SetColumns(rows *sql.Rows) error
}
20.第六周作业:丰富 SELECT 语句[选学]【更多IT资料加微信djy136928775638】
21.第六周 SELECT 作业讲解[选学]【更多IT资料加微信djy136928775638】
13 第七周:ORM 框架之事务 API、AOP 方案与集成测试
事务api
1. 事务 API:不同框架设计分析、设计与实现
2. 事务 API:事务闭包 API、总结与面试要点
// orm/transaction.go
package orm
// ...
type Tx struct {
tx *sql.Tx
db *DB
// 给事务扩散方案
done bool
}
// ...
func (t *Tx) Commit() error {
t.done = true
return t.tx.Commit()
}
func (t *Tx) Rollback() error {
t.done = true
return t.tx.Rollback()
}
func (t *Tx) RollbackIfNotCommit() error {
t.done = true
err := t.tx.Rollback()
// 已经提交过
if err == sql.ErrTxDone {
return nil
}
return err
}
// orm/db.go
package orm
// ...
type txKey struct{}
func (db *DB) BeginTxV2(ctx context.Context, opts *sql.TxOptions) (context.Context, *Tx, error) {
val := ctx.Value(txKey{})
tx, ok := val.(*Tx)
// 存在事务并且从未提交
if ok && !tx.done {
return ctx, tx, nil
}
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return nil, nil, err
}
ctx = context.WithValue(ctx, txKey{}, tx)
return ctx, tx, nil
}
// 必须要已有事务
//func (db *DB) BeginTxV3(ctx context.Context, opts *sql.TxOptions) ( *Tx, error) {
// val := ctx.Value(txKey{})
// tx, ok := val.(*Tx)
// if ok {
// return tx, nil
// }
// return nil,errors.New("没有事务")
//}
func (db *DB) DoTx(
ctx context.Context,
fn func(ctx context.Context, tx *Tx) error,
opts *sql.TxOptions,
) (err error) {
tx, err := db.BeginTx(ctx, opts)
if err != nil {
return err
}
panicked := true
defer func() {
if panicked || err != nil {
e := tx.Rollback()
err = errs.NewErrFailedToRollbackTx(err, e, panicked)
} else {
err = tx.Commit()
}
}()
err = fn(ctx, tx)
panicked = false
return err
}
AOP 方案
3. AOP 方案:不同框架设计分析、方案总结
// orm/middleware.go
package orm
import "context"
type QueryContent struct {
// 查询类型, 标记增删改查
Type string
// 代表的是查询本身
Builder QueryBuilder
}
type QueryResult struct {
// Result 在不同查询下, 类型不同
// select *T/[]*T
// 其他就是 Result
Result any
// Err 是查询本身出的问题
Err error
}
type Handler func(ctx context.Context, qc *QueryContent) *QueryResult
type Middleware func(next Handler) Handler
4. AOP 方案:Middleware 接入与 querylog
//orm/middleware/querylog/middleware_test.go
package querylog
import (
"context"
"database/sql"
_ "github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"my_framework/orm"
"testing"
)
func TestMiddlewareBuilder_Build(t *testing.T) {
var query string
var args []any
m := (&MiddlewareBuilder{}).LogFunc(func(q string, as []any) {
query = q
args = as
})
db, err := orm.Open("mysql",
"root:123456@tcp(127.0.0.1:3307)/mybatis",
orm.DBWithMiddlewares(m.Build()))
require.NoError(t, err)
_, _ = orm.NewSelector[TestModel](db).Where(orm.C("Id").Eq(10)).Get(context.Background())
assert.Equal(t, "SELECT * FROM `test_model` WHERE `id` = ?;", query)
assert.Equal(t, []any{10}, args)
orm.NewInserter[TestModel](db).Values(&TestModel{Id: 19}).Exec(context.Background())
assert.Equal(t, "INSERT INTO `test_model`(`id`,`first_name`,`age`,`last_name`) VALUES (?,?,?,?);", query)
assert.Equal(t, []any{int64(19),"",int8(0), (*sql.NullString)(nil)}, args)
}
type TestModel struct {
Id int64
FirstName string
Age int8
LastName *sql.NullString
}
// orm/middleware/querylog/middleware.go
package querylog
import (
"context"
"log"
"my_framework/orm"
)
type MiddlewareBuilder struct {
logFunc func(query string, args []any)
//logFunc func(query string, args ...)
}
func NewMiddlewareBuilder() *MiddlewareBuilder {
return &MiddlewareBuilder{
// 用户自己提供日志打印
logFunc: func(query string, args []any) {
log.Printf("sql: %s, args: %v\n", query, args)
},
}
}
func (m *MiddlewareBuilder) LogFunc(fn func(query string, args []any)) *MiddlewareBuilder {
m.logFunc = fn
return m
}
func (m MiddlewareBuilder) Build() orm.Middleware {
return func(next orm.Handler) orm.Handler {
return func(ctx context.Context, qc *orm.QueryContent) *orm.QueryResult {
q, err := qc.Builder.Build()
if err != nil {
// 记录吗?
//log.Println("构造 SQL 出错", err)
return &orm.QueryResult{
Err: err,
}
}
//log.Println("sql: %s, args: %v \n", q.SQL, q.Args)
m.logFunc(q.SQL, q.Args)
res := next(ctx, qc)
return res
}
}
}
//orm/insert.go
package orm
// ...
func (i *Inserter[T]) Exec(ctx context.Context) Result {
root := i.execHandler
for j := len(i.db.mdls) - 1; j >= 0; j-- {
// 遍历并使用中间件
root = i.db.mdls[j](root)
}
res := root(ctx, &QueryContent{
Type: "INSERT",
Builder: i,
})
var sqlRes sql.Result
if res.Result != nil {
sqlRes = res.Result.(sql.Result)
}
return Result{
err: res.Err,
res: sqlRes,
}
}
var _ Handler = (&Inserter[int]{}).execHandler
func (i *Inserter[T]) execHandler(ctx context.Context, qc *QueryContent) *QueryResult {
q, err := i.Build()
if err != nil {
return &QueryResult{
Err: err,
Result: Result{
err: err,
},
}
}
res, err := i.db.db.Exec(q.SQL, q.Args...)
return &QueryResult{
Err: err,
Result: Result{
err: err,
res: res,
},
}
}
//orm/select.go
package orm
// ...
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
root := s.getHandler
for i := len(s.db.mdls)-1; i >= 0; i-- {
// 遍历并使用中间件
root = s.db.mdls[i](root)
}
res := root(ctx, &QueryContent{
Type: "SELECT",
Builder: s,
})
if res.Result != nil {
return res.Result.(*T), res.Err
}
return nil, res.Err
}
var _ Handler = (&Selector[any]{}).getHandler
func (s *Selector[T]) getHandler(ctx context.Context, qc *QueryContent) *QueryResult {
q, err := s.Build()
if err != nil {
return &QueryResult{
Err: err,
}
}
db := s.db.db
// 发起查询,并处理结果集
rows, err := db.QueryContext(ctx, q.SQL, q.Args...)
//_, err = db.QueryContext(ctx, q.SQL, q.Args...)
// 查询错误
if err != nil {
return &QueryResult{
Err: err,
}
}
if !rows.Next() {
return &QueryResult{
Err: ErrNoRows,
}
}
tp := new(T)
val := s.db.creator(s.model, tp)
err = val.SetColumns(rows)
return &QueryResult{
Result: tp,
Err: err,
}
}
// orm/db.go
package orm
type DB struct {
//r model.Registry
db *sql.DB
//creator valuer.Creator
//dialect Dialect
core
}
func DBWithMiddlewares(mdls ...Middleware) DBOption {
return func(db *DB) {
db.mdls = mdls
//db.mdls = append(db.mdls, mdls...)
}
}
// ...
// orm/core.go
package orm
import (
"my_framework/orm/internal/valuer"
"my_framework/orm/model"
)
type core struct {
model *model.Model
dialect Dialect
creator valuer.Creator
r model.Registry
mdls []Middleware
}
5. AOP 方案:Middleware 各种实现、总结与面试要点
// orm/middleware/opentelemetry/middleware.go
package opentelemetry
import (
"context"
"fmt"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"my_framework/orm"
)
const instrumentationName = "orm/middleware/opentelemetry/middleware.go"
type MiddlewareBuilder struct {
Tracer trace.Tracer
}
func (m MiddlewareBuilder) Build() orm.Middleware {
if m.Tracer == nil {
m.Tracer = otel.GetTracerProvider().Tracer(instrumentationName)
}
return func(next orm.Handler) orm.Handler {
return func(ctx context.Context, qc *orm.QueryContent) *orm.QueryResult {
// span name: select-test_model
// insert-test_model
tbl := qc.Model.TableName
spanCtx, span := m.Tracer.Start(ctx, fmt.Sprintf("%s-%s", qc.Type, tbl))
defer span.End()
q, _ := qc.Builder.Build()
if q != nil {
span.SetAttributes(attribute.String("sql", q.SQL))
}
span.SetAttributes(attribute.String("table", tbl))
span.SetAttributes(attribute.String("component", "orm"))
res := next(spanCtx, qc)
if res.Err != nil {
span.RecordError(res.Err)
}
return res
}
}
}
// orm/middleware/prometheus/middleware.go
package prometheus
import (
"context"
"github.com/prometheus/client_golang/prometheus"
"my_framework/orm"
"time"
)
type MiddlewareBuilder struct {
Namespace string
Subsystem string
Name string
Help string
}
func (m MiddlewareBuilder) Build() orm.Middleware {
vector := prometheus.NewSummaryVec(prometheus.SummaryOpts{
Namespace: m.Namespace,
Subsystem: m.Subsystem,
Name: m.Name,
Help: m.Help,
Objectives: map[float64]float64{
0.5: 0.01,
0.75: 0.01,
0.90: 0.01,
0.99: 0.001,
0.999: 0.0001,
},
}, []string{"type", "table"})
// 注册观察者
prometheus.MustRegister(vector) // 两次调用builder会panic
// errCounterVec 记录错误数
// histogram 也行
// active query
return func(next orm.Handler) orm.Handler {
return func(ctx context.Context, qc *orm.QueryContent) *orm.QueryResult {
startTime := time.Now()
defer func() {
// 执行时间
vector.WithLabelValues(qc.Type, qc.Model.TableName).
Observe(float64(time.Since(startTime).Milliseconds()))
}()
return next(ctx, qc)
}
}
}
// orm/select.go
package orm
// ...
func (s *Selector[T]) Build() (*Query, error) {
//var sb strings.Builder
s.sb = &strings.Builder{} // 指针类型一定要初始化
if s.model == nil {
var err error
s.model, err = s.db.r.Get(new(T)) // new(T) 生成T的指针
if err != nil {
return nil, err
}
}
sb := s.sb
// ...
}
func (s *Selector[T]) Get(ctx context.Context) (*T, error) {
var err error
s.model, err = s.db.r.Get(new(T)) // new(T) 生成T的指针
if err != nil {
return nil, err
}
root := s.getHandler
for i := len(s.db.mdls) - 1; i >= 0; i-- {
// 遍历并使用中间件
root = s.db.mdls[i](root)
}
res := root(ctx, &QueryContent{
Type: "SELECT",
Builder: s,
Model: s.model,
})
if res.Result != nil {
return res.Result.(*T), res.Err
}
return nil, res.Err
}
// orm/insert.go
package orm
func (i *Inserter[T]) Build() (*Query, error) {
if len(i.values) == 0 {
return nil, errs.ErrInsertZeroRow
}
//var sb strings.Builder
i.sb.WriteString("INSERT INTO ")
var err error
if i.model == nil {
// 拿到元数据
i.model, err = i.db.r.Get(i.values[0])
if err != nil {
return nil, err
}
}
// 拼接表名
// ...
}
func (i *Inserter[T]) Exec(ctx context.Context) Result {
var err error
i.model, err = i.db.r.Get(new(T))
if err != nil {
return Result{
err: err,
}
}
root := i.execHandler
// ...
res := root(ctx, &QueryContent{
Type: "INSERT",
Builder: i,
Model: i.model,
})
// ...
}
// orm/middleware.go
package orm
import (
"context"
"my_framework/orm/model"
)
type QueryContent struct {
// 查询类型, 标记增删改查
Type string
// 代表的是查询本身
Builder QueryBuilder
//
Model *model.Model
}
// ...
集成测试
6. 集成测试:起步与 MySQL 的增删改查
// … 跳过