352 lines
8.2 KiB
Go
352 lines
8.2 KiB
Go
/*
|
|
Copyright 2021 Alibaba Group Holding Limited.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package task
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha1"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
const (
|
|
taskTableInitStmt = `CREATE TABLE IF NOT EXISTS task (
|
|
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
|
gmt_created TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
gmt_modified TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
trace_id VARCHAR(128) NOT NULL UNIQUE,
|
|
operation VARCHAR(64) NOT NULL,
|
|
details TEXT NOT NULL,
|
|
checksum VARCHAR(128) NOT NULL,
|
|
progress INTEGER NOT NULL DEFAULT 0,
|
|
status INTEGER NOT NULL DEFAULT 0,
|
|
error_msg TEXT NOT NULL DEFAULT ''
|
|
)`
|
|
|
|
taskModifiedTimeTriggerInitStmt = `CREATE TRIGGER IF NOT EXISTS task_gmt_modified_trigger AFTER UPDATE ON task
|
|
BEGIN UPDATE task SET gmt_modified = CURRENT_TIMESTAMP WHERE id = NEW.id; END`
|
|
)
|
|
|
|
type TaskStatus int
|
|
|
|
const (
|
|
Pending TaskStatus = 0
|
|
Running TaskStatus = 1
|
|
Complete TaskStatus = 2
|
|
Error TaskStatus = 3
|
|
Canceling TaskStatus = 4
|
|
Cancel TaskStatus = 5
|
|
)
|
|
|
|
type Task struct {
|
|
Id int64
|
|
GmtCreated time.Time
|
|
GmtModified time.Time
|
|
TraceId string
|
|
Operation string
|
|
Details string
|
|
CheckSum string
|
|
Progress int
|
|
Status TaskStatus
|
|
ErrMsg string
|
|
}
|
|
|
|
func (t *Task) String() string {
|
|
b, _ := json.Marshal(t)
|
|
return string(b)
|
|
}
|
|
|
|
type Manager interface {
|
|
GetCheckSum(s string) string
|
|
CreateTask(task *Task) error
|
|
UpdateTaskStatus(task *Task) error
|
|
CasTaskStatus(task *Task, old TaskStatus) (bool, error)
|
|
GetTaskById(id int64) (*Task, error)
|
|
GetTaskByTraceId(traceId string) (*Task, error)
|
|
GetLastTaskByOperationAndCheckSum(op, checksum string) (*Task, error)
|
|
DeleteTask(task *Task) error
|
|
ListTasks(status ...TaskStatus) ([]Task, error)
|
|
Close() error
|
|
}
|
|
|
|
type taskManager struct {
|
|
dbFile string
|
|
db *sql.DB
|
|
}
|
|
|
|
func (t *taskManager) init() error {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = db.Exec(taskTableInitStmt); err != nil {
|
|
return err
|
|
}
|
|
|
|
if _, err = db.Exec(taskModifiedTimeTriggerInitStmt); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *taskManager) getDb() (*sql.DB, error) {
|
|
if t.db == nil {
|
|
db, err := sql.Open("sqlite", t.dbFile)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.db = db
|
|
}
|
|
return t.db, nil
|
|
}
|
|
|
|
func (t *taskManager) checksum(s string) string {
|
|
h := sha1.New()
|
|
_, _ = io.WriteString(h, s)
|
|
return fmt.Sprintf("%x", h.Sum(nil))
|
|
}
|
|
|
|
func (t *taskManager) GetCheckSum(s string) string {
|
|
return t.checksum(s)
|
|
}
|
|
|
|
func (t *taskManager) CreateTask(task *Task) error {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(task.TraceId) == 0 {
|
|
task.TraceId = uuid.New().String()
|
|
}
|
|
task.Progress = 0
|
|
task.Status = Pending
|
|
task.CheckSum = t.checksum(task.Details)
|
|
|
|
rs, err := db.Exec(`INSERT INTO task (trace_id, operation, details, checksum, progress, status) VALUES (?, ?, ?, ?, ?, ?)`,
|
|
task.TraceId, task.Operation, task.Details, task.CheckSum, task.Progress, task.Status)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, err := rs.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
task.Id = id
|
|
|
|
return nil
|
|
}
|
|
|
|
func (t *taskManager) UpdateTaskStatus(task *Task) error {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rs, err := db.Exec(`UPDATE task SET status = ?, progress = ?, error_msg = ? WHERE id = ?`,
|
|
task.Status, task.Progress, task.ErrMsg, task.Id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rowsAffected, err := rs.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rowsAffected == 0 {
|
|
return errors.New("task not found")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *taskManager) CasTaskStatus(task *Task, old TaskStatus) (bool, error) {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// Begin an RC transaction
|
|
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
|
|
Isolation: sql.LevelReadCommitted,
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
rs, err := tx.Exec(`UPDATE task SET status = ?, progress = ?, error_msg = ? WHERE id = ? AND status = ?`,
|
|
task.Status, task.Progress, task.ErrMsg, task.Id, old)
|
|
if err != nil {
|
|
_ = tx.Commit()
|
|
return false, err
|
|
}
|
|
rowsAffected, err := rs.RowsAffected()
|
|
if err != nil {
|
|
_ = tx.Commit()
|
|
return false, err
|
|
}
|
|
if rowsAffected == 0 {
|
|
r := tx.QueryRow(`SELECT status FROM task WHERE id = ?`, task.Id)
|
|
if err := r.Scan(&task.Status); err != nil {
|
|
return false, err
|
|
}
|
|
_ = tx.Commit()
|
|
return false, nil
|
|
}
|
|
if err = tx.Commit(); err != nil {
|
|
return false, err
|
|
}
|
|
return true, err
|
|
}
|
|
|
|
func (t *taskManager) GetTaskById(id int64) (*Task, error) {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
task := &Task{}
|
|
err = db.QueryRow(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task
|
|
WHERE id = ?`, id).Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return task, nil
|
|
}
|
|
|
|
func (t *taskManager) GetTaskByTraceId(traceId string) (*Task, error) {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
task := &Task{}
|
|
err = db.QueryRow(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task
|
|
WHERE trace_id = ?`, traceId).Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return task, nil
|
|
}
|
|
|
|
func (t *taskManager) GetLastTaskByOperationAndCheckSum(op, checksum string) (*Task, error) {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
task := &Task{}
|
|
err = db.QueryRow(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task
|
|
WHERE operation = ? AND checksum = ? ORDER BY gmt_created DESC LIMIT 1`, op, checksum).
|
|
Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return task, nil
|
|
}
|
|
|
|
func (t *taskManager) DeleteTask(task *Task) error {
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = db.Exec(`DELETE FROM task WHERE id = ?`, task.Id)
|
|
return err
|
|
}
|
|
|
|
func (t *taskManager) ListTasks(status ...TaskStatus) ([]Task, error) {
|
|
if len(status) == 0 {
|
|
return make([]Task, 0), nil
|
|
}
|
|
|
|
db, err := t.getDb()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
values := make([]string, len(status))
|
|
for i := range status {
|
|
values[i] = strconv.FormatInt(int64(status[i]), 10)
|
|
}
|
|
|
|
stmt := fmt.Sprintf(`SELECT id, gmt_created, gmt_modified, trace_id, operation, details, checksum, progress, status, error_msg FROM task WHERE status IN (%s)`,
|
|
strings.Join(values, ","))
|
|
|
|
rs, err := db.Query(stmt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rs.Close()
|
|
|
|
tasks := make([]Task, 0, 0)
|
|
for rs.Next() {
|
|
var task Task
|
|
err = rs.Scan(&task.Id, &task.GmtCreated, &task.GmtModified, &task.TraceId, &task.Operation, &task.Details, &task.CheckSum, &task.Progress, &task.Status, &task.ErrMsg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
func (t *taskManager) Close() error {
|
|
if t.db != nil {
|
|
return t.db.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func newTaskManager(dbFile string) (*taskManager, error) {
|
|
tm := &taskManager{
|
|
dbFile: dbFile,
|
|
}
|
|
|
|
if err := tm.init(); err != nil {
|
|
defer tm.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return tm, nil
|
|
}
|
|
|
|
func NewTaskManager(dbFile string) (Manager, error) {
|
|
return newTaskManager(dbFile)
|
|
}
|