When I was looking at the TiDB source code recently, I found it very interesting that failpoint is used for fault injection. It uses code generation and code AST tree parsing and replacement to implement failpoint, I will also try to parse it to learn how to parse the AST tree to generate code.

I will also try to parse it and learn how to parse the AST tree to generate code. So this article will mainly look at the use of failpoint in detail and how it is implemented.

Preface

failpoint is a tool for injecting errors during testing, it is a Golang implementation of FreeBSD Failpoints. Usually we have various test scenarios to improve the stability of our system, but some scenarios are very difficult to simulate, e.g. random delays in a service or unavailability of a service in microservices; simulating scenarios of unstable player networks, dropped frames, excessive latency, etc. in game development.

So in order to test for these problems easily, failpoint was developed, which greatly simplifies our testing process and helps us to simulate various errors in various scenarios so that we can debug code bugs.

The main advantages for failpoint are.

  • there should not be any additional overhead in the failpoint related code.
  • it must not interfere with the normal functional logic and must not intrude into the functional code in any way.
  • failpoint code must be easy to read, easy to write and able to introduce compiler detection.
  • the final generated code must be readable
  • the line numbers of the functional logic code must not change in the generated code (to facilitate debugging)

Use

First we need to build using the source code.

1
2
3
4
git clone https://github.com/pingcap/failpoint.git
cd failpoint
make
ls bin/failpoint-ctl

Translate the binary failpoint-ctl for code conversion.

The failpoint can then be used inside the code to inject faults.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
package main

import "github.com/pingcap/failpoint"
import "fmt"

func test() {
    failpoint.Inject("testValue", func(v failpoint.Value) {
        fmt.Println(v)
    })
}

func main(){
    for i:=0;i<100;i++{
        test()
    }
}

We can see this when we go to the Inject method.

1
func Inject(fpname string, fpbody interface{}) {}

When failpoint is not enabled, it is just an empty implementation and does not have any impact on the performance of our business logic. When our service code is compiled and built, this piece of code is optimized inline, which is the zero cost fault injection principle implemented by failpoint.

Here we convert all the above test functions into usable fault injection code.

1
$ failpoint/bin/failpoint-ctl enable .

Call the compiled failpoint-ctl to rewrite the current code for conversion.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
)

func test() {
    if v, _err_ := failpoint.Eval(_curpkg_("testValue")); _err_ == nil {
        fmt.Println(v)
    }
}

func main() {
    for i := 0; i < 100; i++ {
        test()
    }
}

Below we perform the injection on the code.

1
2
3
$ GO_FAILPOINTS='main/testValue=2*return("abc")' go run main.go binding__failpoint_binding__.go 
abc
abc

In the above example 2 means that the injection will only be performed twice, and the argument in return("abc") corresponds to the v variable obtained in the injection function.

In addition to this we can set the probability of it taking effect.

1
2
3
4
5
$ GO_FAILPOINTS='main/testValue=5%return("abc")' go run main.go binding__failpoint_binding__.go 
abc
abc
abc
abc

The 5% in the above use case means that only 5% of the validity returns abc.

In addition to the simple example above, it can be used to generate more complex scenarios.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
    "math/rand"
)

func main() {
    failpoint.Label("outer")
    for i := 0; i < 100; i++ {
    failpoint.Label("inner")
        for j := 1; j < 1000; j++ {
            switch rand.Intn(j) + i {
            case j / 5:
                failpoint.Break()
            case j / 7:
                failpoint.Continue("outer")
            case j / 9:
                failpoint.Fallthrough()
            case j / 10:
                failpoint.Goto("outer")
            default:
                failpoint.Inject("failpoint-name", func(val failpoint.Value) {
                    fmt.Println("unit-test", val.(int))
                    if val == j/11 {
                        failpoint.Break("inner")
                    } else {
                        failpoint.Goto("outer")
                    }
                })
            }
        }
    } 
}

In this example, failpoint.Break, failpoint.Goto, failpoint.Continue and failpoint.Label are used to implement the code jumps and the final generated code.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func main() {
outer:
    for i := 0; i < 100; i++ {
    inner:
        for j := 1; j < 1000; j++ {
            switch rand.Intn(j) + i {
            case j / 5:
                break
            case j / 7:
                continue outer
            case j / 9:
                fallthrough
            case j / 10:
                goto outer
            default:
                if val, _err_ := failpoint.Eval(_curpkg_("failpoint-name")); _err_ == nil {
                    fmt.Println("unit-test", val.(int))
                    if val == j/11 {
                        break inner
                    } else {
                        goto outer
                    }
                }
            }
        }
    }
}

You can see that our failpoint code above is translated into the jump keyword in Go.

After testing, we can finally revert the code by disable.

1
$ failpoint/bin/failpoint-ctl disable .

Additional usage can be found in the official documentation at

https://github.com/pingcap/failpoint

Principle of implementation

Code injection

for an example

When using failpoint we build our failure track using a series of Marker functions that it provides.

1
2
3
4
5
6
7
8
func Inject(fpname string, fpblock func(val Value)) {}
func InjectContext(fpname string, ctx context.Context, fpblock func(val Value)) {}
func Break(label ...string) {}
func Goto(label string) {}
func Continue(label ...string) {}
func Fallthrough() {}
func Return(results ...interface{}) {}
func Label(label string) {}

This is then converted by failpoint-ctl to build an AST replacing marker stmt, which is converted to the final injector code as shown below.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
)

func test() {
    failpoint.Inject("testPanic", func(val failpoint.Value){
        fmt.Println(val)
    })
}

func main() {
    for i := 0; i < 100; i++ {
        test()
    }
}

After conversion

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
package main

import (
    "fmt"
    "github.com/pingcap/failpoint"
)

func test() {
    if val, _err_ := failpoint.Eval(_curpkg_("testPanic")); _err_ == nil {
        fmt.Println(val)
    }
}

func main() {
    for i := 0; i < 100; i++ {
        test()
    }
}

The failpoint-ctl conversion generates a binding__failpoint_binding__.go file with a _curpkg_ function to get the current package name, in addition to replacing the code content.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
package main

import "reflect"

type __failpointBindingType struct {pkgpath string}
var __failpointBindingCache = &__failpointBindingType{}

func init() {
    __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath()
}
func _curpkg_(name string) string {
    return  __failpointBindingCache.pkgpath + "/" + name
}

Getting the code AST tree

When we call failpoint-ctl for code conversion, we rewrite the code using Rewriter, a tool structure that detects Marker functions and rewrites them by traversing the code AST tree.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
type Rewriter struct {
    rewriteDir    string // 重写路径
    currentPath   string // 文件路径
    currentFile   *ast.File // 文件 AST 树
    currsetFset   *token.FileSet // FileSet
    failpointName string // import 中 failpoint 的导入重命名
    rewritten     bool // 是否重写完毕

    output io.Writer // 重定向输出
}

The execution of failpoint-ctl will call the RewriteFile method for code rewriting.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
func (r *Rewriter) RewriteFile(path string) (err error) {
    defer func() {
        if e := recover(); e != nil {
            err = fmt.Errorf("%s %v\n%s", r.currentPath, e, debug.Stack())
        }
    }()
    fset := token.NewFileSet()
    // 获取go文件AST树
    file, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
    if err != nil {
        return err
    }
    if len(file.Decls) < 1 {
        return nil
    }
    // 文件路径
    r.currentPath = path
    // 文件AST树
    r.currentFile = file
    // 文件 FileSet
    r.currsetFset = fset
    // 标记是否重写完毕
    r.rewritten = false
    // 获取 failpoint import 包
    var failpointImport *ast.ImportSpec
    for _, imp := range file.Imports {
        if strings.Trim(imp.Path.Value, "`\"") == packagePath {
            failpointImport = imp
            break
        }
    }
    if failpointImport == nil {
        panic("import path should be check before rewrite")
    }
    if failpointImport.Name != nil {
        r.failpointName = failpointImport.Name.Name
    } else {
        r.failpointName = packageName
    }
    // 遍历文件中的顶级声明:如type、函数、import、全局常量等
    for _, decl := range file.Decls {
        fn, ok := decl.(*ast.FuncDecl)
        if !ok {
            continue
        }
        // 遍历函数声明节点,将failpoint相关函数进行替换
        if err := r.rewriteFuncDecl(fn); err != nil {
            return err
        }
    }

    if !r.rewritten {
        return nil
    }

    if r.output != nil {
        return format.Node(r.output, fset, file)
    }
    // 生成 binding__failpoint_binding__ 代码
    found, err := isBindingFileExists(path)
    if err != nil {
        return err
    }
    // binding__failpoint_binding__.go文件不存在,那么重新生成一个
    if !found {
        err := writeBindingFile(path, file.Name.Name)
        if err != nil {
            return err
        }
    }
    // 将原文件改名,如:将main.go改名为main.go__failpoint_stash__
    // 用来做作为还原使用
    targetPath := path + failpointStashFileSuffix
    if err := os.Rename(path, targetPath); err != nil {
        return err
    }

    newFile, err := os.OpenFile(path, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, os.ModePerm)
    if err != nil {
        return err
    }
    defer newFile.Close()
    //将构造好的ast树重新生成代码文件
    return format.Node(newFile, fset, file)
}

This method first calls the Go-provided parser.ParseFile method to get the AST tree of the file, which uses a tree structure to represent the syntactic structure of the source code, with each node of the tree representing a structure in the source code. The AST tree is then traversed through the top declaration Decls slice, which is equivalent to traversing down from the top of the tree, a depth-first traversal.

After the traversal is complete, the entire file is rewritten by checking the binding__failpoint_binding__ file and calling format.Node after the source file has been backed up.

Code AST tree traversal to get Rewriter Perform node replacement

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
func (r *Rewriter) rewriteStmts(stmts []ast.Stmt) error {
    // 遍历函数体节点
    for i, block := range stmts {
        switch v := block.(type) {
        case *ast.DeclStmt:
            ... 
        // 包含单独的表达式语句
        case *ast.ExprStmt:
            call, ok := v.X.(*ast.CallExpr)
            if !ok {
                break
            }
            switch expr := call.Fun.(type) {
            // 函数定义
            case *ast.FuncLit:
                // 递归遍历函数
                err := r.rewriteFuncLit(expr)
                if err != nil {
                    return err
                }
            // 选择结构,类似于a.b的结构
            case *ast.SelectorExpr:
                // 获取函数调用的包名
                packageName, ok := expr.X.(*ast.Ident)
                // 包名是否等于 failpoint 包名
                if !ok || packageName.Name != r.failpointName {
                    break
                }
                // 通过 Marker 名获取 failpoint 的 Rewriter
                exprRewriter, found := exprRewriters[expr.Sel.Name]
                if !found {
                    break
                }
                // 对函数进行重写
                rewritten, stmt, err := exprRewriter(r, call)
                if err != nil {
                    return err
                }
                if !rewritten {
                    continue
                }
                // 获取重新生成好的if节点
                if ifStmt, ok := stmt.(*ast.IfStmt); ok {
                    err := r.rewriteIfStmt(ifStmt)
                    if err != nil {
                        return err
                    }
                }
                // 节点替换为重新生成好的if节点
                stmts[i] = stmt
                r.rewritten = true
            }

        case *ast.AssignStmt:
            ... 
        case *ast.GoStmt:
            ...
        case *ast.DeferStmt:
            ...
        case *ast.ReturnStmt: 
        ... 
        default:
            fmt.Printf("unsupported statement: %T in %s\n", v, r.pos(v.Pos()))
        }
    } 
    return nil
}

Here all functions are traversed in turn until the failpoint Marker declaration is found, and then the corresponding Rewriter is retrieved in exprRewriters by Marker name.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
var exprRewriters = map[string]exprRewriter{
    "Inject":        (*Rewriter).rewriteInject,
    "InjectContext": (*Rewriter).rewriteInjectContext,
    "Break":         (*Rewriter).rewriteBreak,
    "Continue":      (*Rewriter).rewriteContinue,
    "Label":         (*Rewriter).rewriteLabel,
    "Goto":          (*Rewriter).rewriteGoto,
    "Fallthrough":   (*Rewriter).rewriteFallthrough,
    "Return":        (*Rewriter).rewriteReturn,
}

Rewriter Rewrites

Our example above uses failpoint.Inject, so it is explained here using rewriteInject.

This method will eventually convert:

1
2
3
    failpoint.Inject("testPanic", func(val failpoint.Value){
        fmt.Println(val)
    })

to:

1
2
3
    if val, _err_ := failpoint.Eval(_curpkg_("testPanic")); _err_ == nil {
        fmt.Println(val)
    }

Here’s how the AST tree is constructed

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
func (r *Rewriter) rewriteInject(call *ast.CallExpr) (bool, ast.Stmt, error) {
    //判断函数failpoint.Inject调用是否合法
    if len(call.Args) != 2 {
        return false, nil, fmt.Errorf("failpoint.Inject: expect 2 arguments but got %v in %s", len(call.Args), r.pos(call.Pos()))
    } 
    // 获取第一个参数 “testPanic”
    fpname, ok := call.Args[0].(ast.Expr)
    if !ok {
        return false, nil, fmt.Errorf("failpoint.Inject: first argument expect a valid expression in %s", r.pos(call.Pos()))
    }

    // 获取第二个参数 func(val failpoint.Value){}
    ident, ok := call.Args[1].(*ast.Ident)
    // 判断第二个参数是否为空
    isNilFunc := ok && ident.Name == "nil"

    // 校验第二个参数是函数的情况,因为第二个函数参数可以为空
    // failpoint.Inject("failpoint-name", func(){...})
    // failpoint.Inject("failpoint-name", func(val failpoint.Value){...})
    fpbody, isFuncLit := call.Args[1].(*ast.FuncLit)
    if !isNilFunc && !isFuncLit {
        return false, nil, fmt.Errorf("failpoint.Inject: second argument expect closure in %s", r.pos(call.Pos()))
    }

    // 第二个参数是函数的情况
    if isFuncLit {
        if len(fpbody.Type.Params.List) > 1 {
            return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos()))
        }

        if len(fpbody.Type.Params.List) == 1 && len(fpbody.Type.Params.List[0].Names) > 1 {
            return false, nil, fmt.Errorf("failpoint.Inject: closure signature illegal in %s", r.pos(call.Pos()))
        }
    }
    //构建替换函数:_curpkg_("testPanic")
    fpnameExtendCall := &ast.CallExpr{
        Fun:  ast.NewIdent(extendPkgName),
        Args: []ast.Expr{fpname},
    }
    //构建函数 failpoint.Eval
    checkCall := &ast.CallExpr{
        Fun: &ast.SelectorExpr{
            X:   &ast.Ident{NamePos: call.Pos(), Name: r.failpointName},
            Sel: ast.NewIdent(evalFunction),
        },
        Args: []ast.Expr{fpnameExtendCall},
    }
    if isNilFunc || len(fpbody.Body.List) < 1 {
        return true, &ast.ExprStmt{X: checkCall}, nil
    }
    // 构建if代码块
    ifBody := &ast.BlockStmt{
        Lbrace: call.Pos(),
        List:   fpbody.Body.List,
        Rbrace: call.End(),
    }

    // 校验failpoint中的闭包函数是否是包含参数的
    // func(val failpoint.Value) {...}
    // func() {...}
    var argName *ast.Ident
    if len(fpbody.Type.Params.List) > 0 {
        arg := fpbody.Type.Params.List[0]
        selector, ok := arg.Type.(*ast.SelectorExpr)
        if !ok || selector.Sel.Name != "Value" || selector.X.(*ast.Ident).Name != r.failpointName {
            return false, nil, fmt.Errorf("failpoint.Inject: invalid signature in %s", r.pos(call.Pos()))
        }
        argName = arg.Names[0]
    } else {
        argName = ast.NewIdent("_")
    }
    // 构建 failpoint.Eval 的返回值
    err := ast.NewIdent("_err_")
    init := &ast.AssignStmt{
        Lhs: []ast.Expr{argName, err},
        Rhs: []ast.Expr{checkCall},
        Tok: token.DEFINE,
    }
    // 构建 if 的判断条件,也就是 _err_ == nil
    cond := &ast.BinaryExpr{
        X:  err,
        Op: token.EQL,
        Y:  ast.NewIdent("nil"),
    }
    // 构建完整 if 代码块
    stmt := &ast.IfStmt{
        If:   call.Pos(),
        Init: init,
        Cond: cond,
        Body: ifBody,
    }
    return true, stmt, nil
}

The comments above should be very detailed, so you can follow the comments to see the code.

failpoint execution

Building a failure scenario

Let’s say we have a 5% chance of this failure being triggered, then we could do this.

1
$ GO_FAILPOINTS='main/testValue=5%return("abc")' go run main.go binding__failpoint_binding__.go

The contents of the GO_FAILPOINTS variable declared above will be read at initialisation time and the corresponding mechanism will be registered and fault controlled at execution time according to the registered mechanism.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
func init() {
    failpoints.reg = make(map[string]*Failpoint)
    // 获取 GO_FAILPOINTS 变量
    if s := os.Getenv("GO_FAILPOINTS"); len(s) > 0 { 
        // 多个值使用;进行分割
        for _, fp := range strings.Split(s, ";") {
            fpTerms := strings.Split(fp, "=")
            if len(fpTerms) != 2 {
                fmt.Printf("bad failpoint %q\n", fp)
                os.Exit(1)
            }
            // 注册注入方案
            err := Enable(fpTerms[0], fpTerms[1])
            if err != nil {
                fmt.Printf("bad failpoint %s\n", err)
                os.Exit(1)
            }
        }
    }
    if s := os.Getenv("GO_FAILPOINTS_HTTP"); len(s) > 0 {
        if err := serve(s); err != nil {
            fmt.Println(err)
            os.Exit(1)
        }
    }
}

Enable will finally be called in the Enable method of the Failpoints structure.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
type Failpoints struct {
    mu  sync.RWMutex  //并发控制
    reg map[string]*Failpoint //故障方案表
}

Failpoint struct {
    mu       sync.RWMutex  //并发控制
    t        *terms
    waitChan chan struct{} // 用来做暂停
}

Enable will parse main/testValue=5%return("abc") into the reg map as a key-value, and the value will be parsed into the Failpoint structure.

The fault control scheme in the Failpoint structure is mainly stored in the term structure.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
type term struct {
    desc string //方案描述,这里是 5%return("abc")

    mods mod // 方案类型,是故障概率控制还是故障次数控制,这里是 5%
    act  actFunc // 故障行为,这里是 return
    val  interface{} // 注入故障的值,这里是 abc

    parent *terms
    fp     *Failpoint
}

We have used return above to enforce the fault, in addition to 6 others.

  • off: Take no action (does not trigger failpoint code)
  • return: Trigger failpoint with specified argument
  • sleep: Sleep the specified number of milliseconds
  • panic: Panic
  • break: Execute gdb and break into debugger
  • print: Print failpoint path for inject variable
  • pause: Pause will pause until the failpoint is disabled

The hierarchy of the entire Filpoint is as follows.

sobyte

Here we look at Enable

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
func (fp *Failpoint) Enable(inTerms string) error {
    t, err := newTerms(inTerms, fp)
    if err != nil {
        return err
    }
    fp.mu.Lock()
    fp.t = t
    fp.waitChan = make(chan struct{})
    fp.mu.Unlock()
    return nil
}

Enable is mainly a call to newTerms to build the terms structure

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
func newTerms(desc string, fp *Failpoint) (*terms, error) {
    // 解析传入的策略
    chain, err := parse(desc, fp)
    if err != nil {
        return nil, err
    }
    t := &terms{chain: chain, desc: desc}
    for _, c := range chain {
        c.parent = t
    }
    return t, nil
}

Parse the incoming policy by parse and construct terms to return.

fault execution

Eval` is executed when we run the fault code, and then determines whether the fault function will be executed based on whether err is returned.

The Eval function will call the Eval method of the Failpoints.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
func (fps *Failpoints) Eval(failpath string) (Value, error) {
    fps.mu.RLock()
    // 获取注册的 Failpoint
    fp, found := fps.reg[failpath]
    fps.mu.RUnlock()
    if !found {
        return nil, errors.Wrapf(ErrNotExist, "error on %s", failpath)
    }
    // 执行方案判断
    val, err := fp.Eval()
    if err != nil {
        return nil, errors.Wrapf(err, "error on %s", failpath)
    }
    return val, nil
}

The reg map called in the Eval method is the scheme registered in the init function we mentioned above, and its Eval method will be called when the Failpoint is retrieved

sobyte

The Eval method calls the eval method of terms to iterate through the chain []*term field, get the scheme set in it and call the allow method to check if it passes, then call the do method to perform the corresponding action.

Summary

In the above introduction we first learned how to use Failpoint to serve our code, and then how Failpoint can be used to implement fault injection by way of code injection. This included Go’s AST tree traversal modifications, as well as code generation, and also provided an idea of how we can provide some additional functionality in our own code when writing it in this way.