Progress Bar

Due to the project requirements, we need to develop a set of CLI tools so that users can upload large files for Model Training through the CLI, please refer to the flowchart above. The first step will be to verify the user with the API Server, and after the verification, we will start to upload the data to AWS S3 or other storage space. In addition to the upload process, the current upload progress (speed, progress and remaining time) should be uploaded to the API Server, and finally the user can see the upload progress data in real time through GraphQL Subscription in the Web UI.

For the CLI progress upload, we use an open source package cheggaaa/pb, which I believe is not familiar to anyone who has written Go Language. Although this package can help to display the progress bar in Terminal, there are some interfaces that are not provided, such as real-time speed, upload progress and remaining time. This article teaches you how to implement these data, and share the problems you will encounter during the process.

Read the upload progress display

The example provided via cheggaaa/pb is as follows:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
package main

import (
    "crypto/rand"
    "io"
    "io/ioutil"
    "log"

    "github.com/cheggaaa/pb/v3"
)

func main() {

    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
}

You can clearly see that you can start uploading the simulation progress by using io.Copy Then you need to read the current progress by goroutine and upload it to the API Server.

Calculating upload progress and time remaining

The pb v3 version only opens a few public information, such as the starting progress time and how many bits of data have been uploaded so far.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
package main

import (
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "time"

    "github.com/cheggaaa/pb/v3"
)

func main() {
    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    go func(bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            // Case to print current time
            case <-d.C:
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain duration:", remainDur)

                // TODO: upload progress and remain duration to api server
            }
        }
    }(bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
}

Using time.NewTicker, we calculate the current progress data every two seconds and upload it to API Server, from the uploaded data and time used, we can calculate the current speed. Of course, this is not very accurate, because the time is calculated from the beginning of the upload to now (total uploaded data / current time spent).

Using Channel to end uploading

After doing the above functions, it is not difficult to find a problem, the goroutine will not stop, but will still calculate the progress every two seconds, then we need to notify the goroutine to end through a Channel.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package main

import (
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "time"

    "github.com/cheggaaa/pb/v3"
)

func main() {
    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    finishCh := make(chan struct{})
    go func(bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            case <-finishCh:
                d.Stop()
                log.Println("finished")
                return
            // Case to print current time
            case <-d.C:
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain suration:", remainDur)
            }
        }
    }(bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
    close(finishCh)
}

First declare a finishCh := make(chan struct{}) to notify the goroutine to jump out of the loop, everyone pay attention to see, the last is used to close the Channel, if it is with the method below:

1
finishCh <- strunct{}{}

It is possible that the switch case will arrive at the same time, making it impossible to jump out of the loop, and closing the channel directly will ensure that case <-finishCh keeps getting empty data, thus achieving the need to jump out of the loop.

Integrating Graceful Shutdown

Finally, let’s see how to integrate Graceful Shutdown, when the user presses ctrl + c to stop the upload and change the status to stopped. Next, let’s see how to add Graceful Shutdown:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
package main

import (
    "context"
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "os"
    "os/signal"
    "syscall"
    "time"

    "github.com/cheggaaa/pb/v3"
)

func withContextFunc(ctx context.Context, f func()) context.Context {
    ctx, cancel := context.WithCancel(ctx)
    go func() {
        c := make(chan os.Signal, 1)
        signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
        defer signal.Stop(c)

        select {
        case <-ctx.Done():
        case <-c:
            f()
            cancel()
        }
    }()

    return ctx
}

func main() {

    ctx := withContextFunc(
        context.Background(),
        func() {
            // clear machine field
            log.Println("interrupt received, terminating process")
        },
    )

    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    finishCh := make(chan struct{})
    go func(ctx context.Context, bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            case <-ctx.Done():
                d.Stop()
                log.Println("interrupt received")
                return
            case <-finishCh:
                d.Stop()
                log.Println("finished")
                return
            // Case to print current time
            case <-d.C:
                if ctx.Err() != nil {
                    return
                }
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain suration:", remainDur)
            }
        }
    }(ctx, bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if _, err := io.Copy(writer, barReader); err != nil {
        log.Fatal(err)
    }
    // finish bar
    bar.Finish()
    close(finishCh)
}

Notify can detect whether there is a system signal to close the CLI program through context and signal. This time you can do the corresponding things, in the code you need to accept more Done Channel, because in Select multiple ctx.Done() channels, so it may also happen at the same time, so you need to judge the Err error message of conetxt in another switch case, if it is not equal to nil then it is received signal, and then If it is not equal to nil, then the signal is received, and then return, so that the goroutine will not continue in the background. After you run the above program, press ctrl + c to see the message below normally:

1
2
3
4
5
^C
2021/05/21 12:29:25 interrupt received, terminating process
2021/05/21 12:29:25 interrupt received
^C
signal: interrupt

You can see that you have to press ctrl + c once to end the program, the reason for this is that io.Reader is still uploading and has not stopped, and the system first interrupt signal has been used up by the program, then the solution is to modify the program underneath

1
2
3
4
5
barReader := bar.NewProxyReader(reader)
// copy from proxy reader
if _, err := io.Copy(writer, barReader); err != nil {
    log.Fatal(err)
}

io.Copy supports context interrupt

io.Copy needs to support context interrupt, but we can only start with reader, first look at the original Reader interface:

1
2
3
type Reader interface {
    Read(p []byte) (n int, err error)
}

Now write your own func to support context functions:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
type readerFunc func(p []byte) (n int, err error)

func (r readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
func copy(ctx context.Context, dst io.Writer, src io.Reader) error {
    _, err := io.Copy(dst, readerFunc(func(p []byte) (int, error) {
        select {
        case <-ctx.Done():
            return 0, ctx.Err()
        default:
            return src.Read(p)
        }
    }))
    return err
}

Since io.Reader will divide the whole file into several chunks to avoid Memory from reading too big file directly and bursting, make sure that no context interruption message is received before each chunk is uploaded, so that the uploading behavior can be stopped. The overall code is as follows:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
package main

import (
    "context"
    "crypto/rand"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "os"
    "os/signal"
    "syscall"
    "time"

    "github.com/cheggaaa/pb/v3"
)

type readerFunc func(p []byte) (n int, err error)

func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }

func copy(ctx context.Context, dst io.Writer, src io.Reader) error {
    _, err := io.Copy(dst, readerFunc(func(p []byte) (int, error) {
        select {
        case <-ctx.Done():
            return 0, ctx.Err()
        default:
            return src.Read(p)
        }
    }))
    return err
}

func withContextFunc(ctx context.Context, f func()) context.Context {
    ctx, cancel := context.WithCancel(ctx)
    go func() {
        c := make(chan os.Signal, 1)
        signal.Notify(c, syscall.SIGINT, syscall.SIGTERM)
        defer signal.Stop(c)

        select {
        case <-ctx.Done():
        case <-c:
            f()
            cancel()
        }
    }()

    return ctx
}

func main() {

    ctx := withContextFunc(
        context.Background(),
        func() {
            // clear machine field
            log.Println("interrupt received, terminating process")
        },
    )

    var limit int64 = 1024 * 1024 * 10000
    // we will copy 10 Gb from /dev/rand to /dev/null
    reader := io.LimitReader(rand.Reader, limit)
    writer := ioutil.Discard

    // start new bar
    bar := pb.Full.Start64(limit)
    finishCh := make(chan struct{})
    go func(bar *pb.ProgressBar) {
        d := time.NewTicker(2 * time.Second)
        startTime := bar.StartTime()
        // Using for loop
        for {
            // Select statement
            select {
            case <-ctx.Done():
                log.Println("stop to get current process")
                return
            case <-finishCh:
                d.Stop()
                log.Println("finished")
                return
            // Case to print current time
            case <-d.C:
                if !bar.IsStarted() {
                    continue
                }
                currentTime := time.Now()
                dur := currentTime.Sub(startTime)
                lastSpeed := float64(bar.Current()) / dur.Seconds()
                remain := float64(bar.Total() - bar.Current())
                remainDur := time.Duration(remain/lastSpeed) * time.Second
                fmt.Println("Progress:", float32(bar.Current())/float32(bar.Total())*100)
                fmt.Println("last speed:", lastSpeed/1024/1024)
                fmt.Println("remain suration:", remainDur)
            }
        }
    }(bar)
    // create proxy reader
    barReader := bar.NewProxyReader(reader)
    // copy from proxy reader
    if err := copy(ctx, writer, barReader); err != nil {
        log.Println("cancel upload data:", err.Error())
    }
    // finish bar
    bar.Finish()
    close(finishCh)
    time.Sleep(1 * time.Second)
}