polardbxoperator/pkg/hpfs/task/task.go

352 lines
8.2 KiB
Go

/*
Copyright 2021 Alibaba Group Holding Limited.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package task
import (
"context"
"crypto/sha1"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"strings"
"time"
"github.com/google/uuid"
_ "modernc.org/sqlite"
)
const (
taskTableInitStmt = `CREATE TABLE IF NOT EXISTS task (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
gmt_created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
trace_id VARCHAR(128) NOT NULL UNIQUE,
operation VARCHAR(64) NOT NULL,
details TEXT NOT NULL,
checksum VARCHAR(128) NOT NULL,
progress INTEGER NOT NULL DEFAULT 0,
status INTEGER NOT NULL DEFAULT 0,
error_msg TEXT NOT NULL DEFAULT ''
)`
taskModifiedTimeTriggerInitStmt = `CREATE TRIGGER IF NOT EXISTS task_gmt_modified_trigger AFTER UPDATE ON task
BEGIN UPDATE task SET gmt_modified = CURRENT_TIMESTAMP WHERE id = NEW.id; END`
)
type TaskStatus int
const (
Pending TaskStatus = 0
Running TaskStatus = 1
Complete TaskStatus = 2
Error TaskStatus = 3
Canceling TaskStatus = 4
Cancel TaskStatus = 5
)
type Task struct {
Id int64
GmtCreated time.Time
GmtModified time.Time
TraceId string
Operation string
Details string
CheckSum string
Progress int
Status TaskStatus
ErrMsg string
}
func (t *Task) String() string {
b, _ := json.Marshal(t)
return string(b)
}
type Manager interface {
GetCheckSum(s string) string
CreateTask(task *Task) error
UpdateTaskStatus(task *Task) error
CasTaskStatus(task *Task, old TaskStatus) (bool, error)
GetTaskById(id int64) (*Task, error)
GetTaskByTraceId(traceId string) (*Task, error)
GetLastTaskByOperationAndCheckSum(op, checksum string) (*Task, error)
DeleteTask(task *Task) error
ListTasks(status ...TaskStatus) ([]Task, error)
Close() error
}
type taskManager struct {
dbFile string
db *sql.DB
}
func (t *taskManager) init() error {
db, err := t.getDb()
if err != nil {
return err
}
if _, err = db.Exec(taskTableInitStmt); err != nil {
return err
}
if _, err = db.Exec(taskModifiedTimeTriggerInitStmt); err != nil {
return err
}
return nil
}
func (t *taskManager) getDb() (*sql.DB, error) {
if t.db == nil {
db, err := sql.Open("sqlite", t.dbFile)
if err != nil {
return nil, err
}
t.db = db
}
return t.db, nil
}
func (t *taskManager) checksum(s string) string {
h := sha1.New()
_, _ = io.WriteString(h, s)
return fmt.Sprintf("%x", h.Sum(nil))
}
func (t *taskManager) GetCheckSum(s string) string {
return t.checksum(s)
}
func (t *taskManager) CreateTask(task *Task) error {
db, err := t.getDb()
if err != nil {
return err
}
if len(task.TraceId) == 0 {
task.TraceId = uuid.New().String()
}
task.Progress = 0
task.Status = Pending
task.CheckSum = t.checksum(task.Details)
rs, err := db.Exec(`INSERT INTO task (trace_id, operation, details, checksum, progress, status) VALUES (?, ?, ?, ?, ?, ?)`,
task.TraceId, task.Operation, task.Details, task.CheckSum, task.Progress, task.Status)
if err != nil {
return err
}
id, err := rs.LastInsertId()
if err != nil {
return err
}
task.Id = id
return nil
}
func (t *taskManager) UpdateTaskStatus(task *Task) error {
db, err := t.getDb()
if err != nil {
return err
}
rs, err := db.Exec(`UPDATE task SET status = ?, progress = ?, error_msg = ? WHERE id = ?`,
task.Status, task.Progress, task.ErrMsg, task.Id)
if err != nil {
return err
}
rowsAffected, err := rs.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return errors.New("task not found")
}
return nil
}
func (t *taskManager) CasTaskStatus(task *Task, old TaskStatus) (bool, error) {
db, err := t.getDb()
if err != nil {
return false, err
}
// Begin an RC transaction
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
})
if err != nil {
return false, err
}
rs, err := tx.Exec(`UPDATE task SET status = ?, progress = ?, error_msg = ? WHERE id = ? AND status = ?`,
task.Status, task.Progress, task.ErrMsg, task.Id, old)
if err != nil {
_ = tx.Commit()
return false, err
}
rowsAffected, err := rs.RowsAffected()
if err != nil {
_ = tx.Commit()
return false, err
}
if rowsAffected == 0 {
r := tx.QueryRow(`SELECT status FROM task WHERE id = ?`, task.Id)
if err := r.Scan(&task.Status); err != nil {
return false, err
}
_ = tx.Commit()
return false, nil
}
if err = tx.Commit(); err != nil {
return false, err
}
return true, err
}
func (t *taskManager) GetTaskById(id int64) (*Task, error) {
db, err := t.getDb()
if err != nil {
return nil, err
}
task := &Task{}
err = db.QueryRow(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task
WHERE id = ?`, id).Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return task, nil
}
func (t *taskManager) GetTaskByTraceId(traceId string) (*Task, error) {
db, err := t.getDb()
if err != nil {
return nil, err
}
task := &Task{}
err = db.QueryRow(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task
WHERE trace_id = ?`, traceId).Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return task, nil
}
func (t *taskManager) GetLastTaskByOperationAndCheckSum(op, checksum string) (*Task, error) {
db, err := t.getDb()
if err != nil {
return nil, err
}
task := &Task{}
err = db.QueryRow(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task
WHERE operation = ? AND checksum = ? ORDER BY gmt_created DESC LIMIT 1`, op, checksum).
Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return task, nil
}
func (t *taskManager) DeleteTask(task *Task) error {
db, err := t.getDb()
if err != nil {
return err
}
_, err = db.Exec(`DELETE FROM task WHERE id = ?`, task.Id)
return err
}
func (t *taskManager) ListTasks(status ...TaskStatus) ([]Task, error) {
if len(status) == 0 {
return make([]Task, 0), nil
}
db, err := t.getDb()
if err != nil {
return nil, err
}
values := make([]string, len(status))
for i := range status {
values[i] = strconv.FormatInt(int64(status[i]), 10)
}
stmt := fmt.Sprintf(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task WHERE status IN (%s)`,
strings.Join(values, ","))
rs, err := db.Query(stmt)
if err != nil {
return nil, err
}
defer rs.Close()
tasks := make([]Task, 0, 0)
for rs.Next() {
var task Task
err = rs.Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
return tasks, nil
}
func (t *taskManager) Close() error {
if t.db != nil {
return t.db.Close()
}
return nil
}
func newTaskManager(dbFile string) (*taskManager, error) {
tm := &taskManager{
dbFile: dbFile,
}
if err := tm.init(); err != nil {
defer tm.Close()
return nil, err
}
return tm, nil
}
func NewTaskManager(dbFile string) (Manager, error) {
return newTaskManager(dbFile)
}