polardbxoperator/pkg/binlogtool/cmd/flashback.go

468 lines
13 KiB
Go

/*
Copyright 2022 Alibaba Group Holding Limited.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cmd
import (
"bufio"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"regexp"
"strconv"
"strings"
"time"
"github.com/shopspring/decimal"
"github.com/spf13/cobra"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/bitmap"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/utils"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/binlog"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/binlog/event"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/binlog/meta"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/binlog/rows"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/binlog/spec"
"github.com/alibaba/polardbx-operator/pkg/binlogtool/binlog/str"
)
var (
flashbackChecksum string
flashbackFiles []string
flashbackDuration time.Duration
flashbackSchema string
flashbackTable string
)
func init() {
flashbackCmd.Flags().DurationVarP(&flashbackDuration, "back", "b", 5*time.Minute, "duration before now to flashback to")
flashbackCmd.Flags().StringVar(&flashbackSchema, "schema", "", "target schema (regex)")
flashbackCmd.Flags().StringVar(&flashbackTable, "table", "", "target table (regex)")
flashbackCmd.Flags().StringVar(&flashbackChecksum, "checksum", "crc32", "binlog checksum algorithm (ignored for binary log version v1, v3 and v4 after 3.6.1)")
rootCmd.AddCommand(flashbackCmd)
}
var (
flashbackTime time.Time
flashbackBinlogFiles []meta.BinlogFile
flashbackSchemaRe *regexp.Regexp
flashbackTableRe *regexp.Regexp
)
var flashbackCmd = &cobra.Command{
Use: "flashback [flags] file [files...]",
Short: "Generate SQL statements to rollback delete operations. (experimental)",
Long: "Generate SQL statements to rollback delete operations. (experimental)",
PreRun: wrap(func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("please specify a binlog file")
}
flashbackFiles = args
var err error
flashbackBinlogFiles, err = meta.ParseBinlogFilesAndSortByIndex(flashbackFiles...)
if err != nil {
return err
}
flashbackTime = time.Now().Add(-flashbackDuration)
if flashbackSchema != "" {
flashbackSchemaRe, err = regexp.Compile(wrapRegex(flashbackSchema))
if err != nil {
return fmt.Errorf("unable to parse regex: %w", err)
}
}
if flashbackTable != "" {
flashbackTableRe, err = regexp.Compile(wrapRegex(flashbackTable))
if err != nil {
return fmt.Errorf("unable to parse regex: %w", err)
}
}
return nil
}),
Run: wrap(func(cmd *cobra.Command, args []string) error {
flashbackTimestamp := uint32(flashbackTime.Unix())
// Build binlog file scanners.
scanners := make([]binlog.LogEventScanCloser, 0, len(flashbackBinlogFiles))
for i := range flashbackBinlogFiles {
bf := flashbackBinlogFiles[i]
s := binlog.NewLazyLogEventScanCloser(
func() (io.ReadCloser, error) {
f, err := os.Open(bf.File)
if err != nil {
return nil, err
}
return utils.NewSeekableBufferReader(f), nil
},
0,
binlog.WithChecksumAlgorithm(flashbackChecksum),
binlog.WithBinlogFile(bf.File),
binlog.WithLogEventHeaderFilter(func(header event.LogEventHeader) bool {
return header.EventTimestamp() >= flashbackTimestamp
}),
binlog.WithInterestedLogEventTypes(utils.ConvertAndConcatenateSlice([]string{"rows", "tablemap"}, guessEventType)...),
)
scanners = append(scanners, s)
}
sc := binlog.NewMultiLogEventScanCloser(scanners...)
defer sc.Close()
// Table context and row event mutate.
tables := make(map[uint64]*event.TableMapEvent)
results := make([]string, 0)
s := binlog.NewFilterLogEventScanner(
binlog.NewMutateLogEventScanner(
binlog.NewFilterLogEventScanner(sc,
// Filter and record table.
func(off binlog.EventOffset, ev event.LogEvent) bool {
if ev.EventHeader().EventTypeCode() == spec.TABLE_MAP_EVENT {
if tableMapEvent, ok := ev.EventData().(*event.TableMapEvent); ok {
if flashbackSchemaRe != nil && !flashbackSchemaRe.MatchString(tableMapEvent.Schema.String()) {
tables[tableMapEvent.TableID] = nil
return false
} else if flashbackTableRe != nil && !flashbackTableRe.MatchString(tableMapEvent.Table.String()) {
tables[tableMapEvent.TableID] = nil
return false
} else {
tables[tableMapEvent.TableID] = tableMapEvent
return true
}
}
}
return true
},
// Filter rows event.
func(off binlog.EventOffset, ev event.LogEvent) bool {
if rowsEvent, ok := ev.EventData().(event.RowsEventVariants); ok {
table, ok := tables[rowsEvent.AsRowsEvent().TableID]
// Table not recorded or table not filtered.
return !ok || table != nil
}
return true
},
),
// Decode delete rows events.
func(off binlog.EventOffset, ev event.LogEvent) (event.LogEvent, error) {
// errTableFilteredOut should never happen
switch ev.EventHeader().EventTypeCode() {
case spec.PRE_GA_DELETE_ROWS_EVENT, spec.DELETE_ROWS_EVENT_V1, spec.DELETE_ROWS_EVENT_V2:
return decodeRowsEvents(tables, ev)
default:
return ev, nil
}
}),
// Clean table if it's an end statement. Also generate result for rows event.
func(off binlog.EventOffset, ev event.LogEvent) bool {
if rowsEvent, ok := ev.EventData().(event.RowsEventVariants); ok {
if deleteRowsEvent, ok := rowsEvent.(*rows.DecodedDeleteRowsEvent); ok {
table := tables[rowsEvent.AsRowsEvent().TableID]
results = append(results, newRollbackStatement(table, deleteRowsEvent)...)
}
if rowsEvent.AsRowsEvent().Flags&spec.STMT_END_F > 0 {
tables = make(map[uint64]*event.TableMapEvent)
}
}
return true
},
)
// Run to end.
for {
_, _, err := s.Next()
if err != nil {
if err == binlog.EOF {
break
}
return err
}
}
// Print result (reverse).
w := bufio.NewWriter(os.Stdout)
defer w.Flush()
for i := len(results) - 1; i >= 0; i-- {
_, _ = fmt.Fprintln(w, results[i])
}
return nil
}),
}
// Copied from go-mysql driver
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
func appendDateTime(buf []byte, t time.Time) ([]byte, error) {
year, month, day := t.Date()
hour, min, sec := t.Clock()
nsec := t.Nanosecond()
if year < 1 || year > 9999 {
return buf, errors.New("year is not in the range [1, 9999]: " + strconv.Itoa(year)) // use errors.New instead of fmt.Errorf to avoid year escape to heap
}
year100 := year / 100
year1 := year % 100
var localBuf [len("2006-01-02T15:04:05.999999999")]byte // does not escape
localBuf[0], localBuf[1], localBuf[2], localBuf[3] = digits10[year100], digits01[year100], digits10[year1], digits01[year1]
localBuf[4] = '-'
localBuf[5], localBuf[6] = digits10[month], digits01[month]
localBuf[7] = '-'
localBuf[8], localBuf[9] = digits10[day], digits01[day]
if hour == 0 && min == 0 && sec == 0 && nsec == 0 {
return append(buf, localBuf[:10]...), nil
}
localBuf[10] = ' '
localBuf[11], localBuf[12] = digits10[hour], digits01[hour]
localBuf[13] = ':'
localBuf[14], localBuf[15] = digits10[min], digits01[min]
localBuf[16] = ':'
localBuf[17], localBuf[18] = digits10[sec], digits01[sec]
if nsec == 0 {
return append(buf, localBuf[:19]...), nil
}
nsec100000000 := nsec / 100000000
nsec1000000 := (nsec / 1000000) % 100
nsec10000 := (nsec / 10000) % 100
nsec100 := (nsec / 100) % 100
nsec1 := nsec % 100
localBuf[19] = '.'
// milli second
localBuf[20], localBuf[21], localBuf[22] =
digits01[nsec100000000], digits10[nsec1000000], digits01[nsec1000000]
// micro second
localBuf[23], localBuf[24], localBuf[25] =
digits10[nsec10000], digits01[nsec10000], digits10[nsec100]
// nano second
localBuf[26], localBuf[27], localBuf[28] =
digits01[nsec100], digits10[nsec1], digits01[nsec1]
// trim trailing zeros
n := len(localBuf)
for n > 0 && localBuf[n-1] == '0' {
n--
}
return append(buf, localBuf[:n]...), nil
}
func reserveBuffer(buf []byte, appendSize int) []byte {
newSize := len(buf) + appendSize
if cap(buf) < newSize {
// Grow buffer exponentially
newBuf := make([]byte, len(buf)*2+appendSize)
copy(newBuf, buf)
buf = newBuf
}
return buf[:newSize]
}
func escapeBytesQuotes(buf, v []byte) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for _, c := range v {
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
func escapeStringQuotes(buf []byte, v string) []byte {
pos := len(buf)
buf = reserveBuffer(buf, len(v)*2)
for i := 0; i < len(v); i++ {
c := v[i]
if c == '\'' {
buf[pos] = '\''
buf[pos+1] = '\''
pos += 2
} else {
buf[pos] = c
pos++
}
}
return buf[:pos]
}
func interpolateParams(query string, args []driver.Value) (string, error) {
// Number of ? should be same to len(args)
if strings.Count(query, "?") != len(args) {
return "", driver.ErrSkip
}
buf := make([]byte, 0, 1024)
buf = buf[:0]
argPos := 0
var err error
for i := 0; i < len(query); i++ {
q := strings.IndexByte(query[i:], '?')
if q == -1 {
buf = append(buf, query[i:]...)
break
}
buf = append(buf, query[i:i+q]...)
i += q
arg := args[argPos]
argPos++
if arg == nil {
buf = append(buf, "NULL"...)
continue
}
switch v := arg.(type) {
case int8:
buf = strconv.AppendInt(buf, int64(v), 10)
case int16:
buf = strconv.AppendInt(buf, int64(v), 10)
case int32:
buf = strconv.AppendInt(buf, int64(v), 10)
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint8:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint16:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint32:
buf = strconv.AppendUint(buf, uint64(v), 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float32:
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
if v {
buf = append(buf, '1')
} else {
buf = append(buf, '0')
}
case time.Time:
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v)
if err != nil {
return "", err
}
buf = append(buf, '\'')
}
case json.RawMessage:
buf = append(buf, '\'')
buf = escapeBytesQuotes(buf, v)
buf = append(buf, '\'')
case str.Str:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, '\'')
buf = escapeBytesQuotes(buf, v)
buf = append(buf, '\'')
}
case str.Blob:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
buf = escapeBytesQuotes(buf, v)
buf = append(buf, '\'')
}
case []byte:
if v == nil {
buf = append(buf, "NULL"...)
} else {
buf = append(buf, "_binary'"...)
buf = escapeBytesQuotes(buf, v)
buf = append(buf, '\'')
}
case string:
buf = append(buf, '\'')
buf = escapeStringQuotes(buf, v)
buf = append(buf, '\'')
case bitmap.Bitmap:
buf = append(buf, '\'')
buf = append(buf, v.String()...)
buf = append(buf, '\'')
case decimal.Decimal:
buf = append(buf, v.String()...)
case time.Duration:
buf = append(buf, '\'')
buf = append(buf, fmt.Sprintf("%d:%d:%d.%d", v/time.Hour, v/time.Minute, v/time.Second, v/time.Microsecond*100000)...)
buf = append(buf, '\'')
default:
return "", driver.ErrSkip
}
}
if argPos != len(args) {
return "", driver.ErrSkip
}
return string(buf), nil
}
func repeat(s string, n int) []string {
r := make([]string, 0, n)
for i := 0; i < n; i++ {
r = append(r, s)
}
return r
}
func newRollbackStatement(table *event.TableMapEvent, deleteRows *rows.DecodedDeleteRowsEvent) []string {
statements := make([]string, 0, len(deleteRows.DecodedRows))
for _, row := range deleteRows.DecodedRows {
s := fmt.Sprintf("INSERT INTO `%s`.`%s` VALUES (\n%s\n);", table.Schema.String(), table.Table.String(),
strings.Join(repeat("\t?", len(row)), ", \n"))
values := utils.ConvertSlice(row, func(t rows.Column) driver.Value {
return t.Value
})
stmt, err := interpolateParams(s, values)
if err != nil {
panic("unable to interpolate param: " + err.Error())
}
statements = append(statements, stmt)
}
return statements
}