From a5eb781faea1da06979f77269712a9f5c609606b Mon Sep 17 00:00:00 2001 From: lrita Date: Fri, 25 Aug 2017 19:53:26 +0800 Subject: [PATCH 1/5] add MySQL backend implementant and support all methods Signed-off-by: lrita --- .travis.yml | 4 + README.md | 29 +- script/libkv.sql | 21 + store/mysql/mysql.go | 879 ++++++++++++++++++++++++++++++++++++++ store/mysql/mysql_test.go | 83 ++++ store/store.go | 4 + 6 files changed, 1006 insertions(+), 14 deletions(-) create mode 100644 script/libkv.sql create mode 100644 store/mysql/mysql.go create mode 100644 store/mysql/mysql_test.go diff --git a/.travis.yml b/.travis.yml index a7a3bcff..076c2c4b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,9 @@ language: go go: - 1.7.1 +services: + - mysql + # let us have speedy Docker-based Travis workers sudo: false @@ -18,6 +21,7 @@ before_script: - script/travis_consul.sh 0.6.3 - script/travis_etcd.sh 3.0.0 - script/travis_zk.sh 3.5.1-alpha + - mysql -u root < script/libkv.sql script: - ./consul agent -server -bootstrap -advertise=127.0.0.1 -data-dir /tmp/consul -config-file=./config.json 1>/dev/null & diff --git a/README.md b/README.md index ff2cc446..f40eb382 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ For example, you can use it to store your metadata or for service discovery to r You can also easily implement a generic *Leader Election* on top of it (see the [docker/leadership](https://github.com/docker/leadership) repository). -As of now, `libkv` offers support for `Consul`, `Etcd`, `Zookeeper` (**Distributed** store) and `BoltDB` (**Local** store). +As of now, `libkv` offers support for `Consul`, `Etcd`, `Zookeeper` (**Distributed** store), `MySQL` and `BoltDB` (**Local** store). ## Usage @@ -34,6 +34,7 @@ You can find examples of usage for `libkv` under in `docs/examples.go`. Optional - Etcd versions >= `2.0` because it uses the new `coreos/etcd/client`, this might change in the future as the support for `APIv3` comes along and adds more capabilities. - Zookeeper versions >= `3.4.5`. Although this might work with previous version but this remains untested as of now. - Boltdb, which shouldn't be subject to any version dependencies. +- MySQL versions >= `5.1.73`. ## Interface @@ -62,19 +63,19 @@ Backend drivers in `libkv` are generally divided between **local drivers** and * Local drivers are usually used in complement to the distributed drivers to store informations that only needs to be available locally. -| Calls | Consul | Etcd | Zookeeper | BoltDB | -|-----------------------|:----------:|:------:|:-----------:|:--------:| -| Put | X | X | X | X | -| Get | X | X | X | X | -| Delete | X | X | X | X | -| Exists | X | X | X | X | -| Watch | X | X | X | | -| WatchTree | X | X | X | | -| NewLock (Lock/Unlock) | X | X | X | | -| List | X | X | X | X | -| DeleteTree | X | X | X | X | -| AtomicPut | X | X | X | X | -| Close | X | X | X | X | +| Calls | Consul | Etcd | Zookeeper | BoltDB | MySQL | +|-----------------------|:----------:|:------:|:-----------:|:--------:|:-------:| +| Put | X | X | X | X | X | +| Get | X | X | X | X | X | +| Delete | X | X | X | X | X | +| Exists | X | X | X | X | X | +| Watch | X | X | X | | X | +| WatchTree | X | X | X | | X | +| NewLock (Lock/Unlock) | X | X | X | | X | +| List | X | X | X | X | X | +| DeleteTree | X | X | X | X | X | +| AtomicPut | X | X | X | X | X | +| Close | X | X | X | X | X | ## Limitations diff --git a/script/libkv.sql b/script/libkv.sql new file mode 100644 index 00000000..3f3e1977 --- /dev/null +++ b/script/libkv.sql @@ -0,0 +1,21 @@ +CREATE DATABASE IF NOT EXISTS `libkv`; +USE `libkv`; +CREATE TABLE IF NOT EXISTS `libkv` ( + `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT, + `field0` VARCHAR(127) NOT NULL, + `field1` VARCHAR(127) NOT NULL, + `field2` VARCHAR(127) NOT NULL, + `field3` VARCHAR(127) NOT NULL, + `field4` VARCHAR(127) NOT NULL, + `field5` VARCHAR(127) NOT NULL, + `field6` VARCHAR(127) NOT NULL, + `field7` VARCHAR(127) NOT NULL, + `lock_session` VARCHAR(64) NOT NULL DEFAULT '', + `last_index` BIGINT UNSIGNED NOT NULL, + `value` LONGTEXT NOT NULL, + `create_at` DATETIME NOT NULL, + `update_at` DATETIME NOT NULL, + PRIMARY KEY (`id`), + UNIQUE KEY `by_field` (`field0`,`field1`,`field2`, + `field3`,`field4`,`field5`,`field6`,`field7`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin; diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go new file mode 100644 index 00000000..5f6d1895 --- /dev/null +++ b/store/mysql/mysql.go @@ -0,0 +1,879 @@ +package mysql + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "math/rand" + "os" + "path" + "strings" + "sync" + "time" + + "github.com/docker/libkv" + "github.com/docker/libkv/store" + + "github.com/go-sql-driver/mysql" +) + +var ( + // ErrInvalidCountEndpoint is thrown when there are + // multiple endpoints or no enough endpoint specified for MySQL. + ErrInvalidCountEndpoint = errors.New("mysql only support one endpoint") + + // ErrPutMissing is thrown when put failed. + ErrPutMissing = errors.New("mysql put missing") + + // DefaultWatchWaitTime is how long we block for at a + // time to check if the watched key has changed. This + // affects the minimum time it takes to cancel a watch. + // You can modify this at the init stage. + DefaultWatchWaitTime = time.Second + + // ErrAbortTryLock is thrown when a user stops trying to seek the lock + // by sending a signal to the stop chan, this is used to verify if the + // operation succeeded + ErrAbortTryLock = errors.New("lock operation aborted") + + // ErrLockLost is thrown when the lock has been lost + ErrLockLost = errors.New("lock lost") + + // MaxFields represents the count of key fields does the implementant supported. + // 'a' has 1 field, 'a/b' has 2 fields, 'a/b/c' has 3 fields, and so on. Each field + // has 127 bytes at mostly. + // When you want to expend the max fields, you can change the mysql table and change + // this constants, and no more code need to be changed. + // Caveat: It's use UNIQUE KEY indexing, but the innodb default max index key length + // is 3072. Hence, if you want more fields, you should reduce each field's length. + MaxFields = 8 + + // rnd is used to generate session string + rnd = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +const ( + defaultTimeout = time.Second * 10 + defaultLockTTL = time.Second * 15 + defaultTTLPeriod = 3 +) + +// MySQL is a implementant of store.Store. At first, it's need to create the mysql +// database and table manually. There is a model "table.sql" in this package. The +// database's name and table's name is modifiable. +// You can use store.Config.Database, store.Config.Table, store.Config.Username to indicate, +// and the store.Config.Password is optional. Secondly, there only support 10 fields +// for the stored key(MaxFields), and it's scalability. +type MySQL struct { + db *sql.DB + table string + timeout time.Duration +} + +// Register registers mysql to libkv +func Register() { + libkv.AddStore(store.MYSQL, New) +} + +// New creates a new MySQL client given a list of endpoints +func New(endpoints []string, opts *store.Config) (store.Store, error) { + if len(endpoints) != 1 { + return nil, ErrInvalidCountEndpoint + } + + var passward string + if opts.Password != "" { + passward = ":" + opts.Password + } + + timeout := defaultTimeout + if opts.ConnectionTimeout != timeout { + timeout = opts.ConnectionTimeout + } + + db, err := sql.Open("mysql", + fmt.Sprintf("%s%s@tcp(%s)/%s?charset=utf8&interpolateParams=true&parseTime=True&loc=Local&timeout=%s", + opts.Username, passward, endpoints[0], opts.Database, timeout)) + if err != nil { + return nil, err + } + + return &MySQL{ + db: db, + table: opts.Table, + timeout: timeout, + }, nil +} + +func normalize(key string) string { + return strings.TrimPrefix(path.Clean(key), "/") +} + +func whereCond(key string, exact bool) string { + fields := strings.SplitN(normalize(key), "/", MaxFields) + tokens := make([]string, 0, MaxFields) + for i := range fields { + tokens = append(tokens, fmt.Sprintf("`field%d`=?", i)) + } + if l := len(tokens); exact && l < MaxFields { + for i := l; i < MaxFields; i++ { + tokens = append(tokens, fmt.Sprintf("`field%d`=?", i)) + } + } + return strings.Join(tokens, " AND ") +} + +func eachFields() string { + fields := make([]string, 0, MaxFields) + for i := 0; i < MaxFields; i++ { + fields = append(fields, fmt.Sprintf("`field%d`", i)) + } + return strings.Join(fields, ", ") +} + +func splitFields(key string, exact bool) []string { + tokens := make([]string, 0, MaxFields) + fields := strings.SplitN(normalize(key), "/", MaxFields) + tokens = append(tokens, fields...) + if l := len(fields); exact && l < MaxFields { + for i := l; i < MaxFields; i++ { + tokens = append(tokens, "") + } + } + return tokens +} + +// session generate a 64 bytes string to represent current session. +func session() string { + hn, _ := os.Hostname() + if l := len(hn); l > 30 { + hn = hn[l-30:] + } + return fmt.Sprintf("%s@%016X@%016X", hn, time.Now().UnixNano(), rnd.Int63()) +} + +// Close the MySQL connection +func (m *MySQL) Close() { + m.db.Close() +} + +// Get the value at "key", returns the last modified index +// to use in conjunction to CAS calls +func (m *MySQL) Get(key string) (*store.KVPair, error) { + args := make([]interface{}, 0, MaxFields) + for _, field := range splitFields(key, true) { + args = append(args, field) + } + + row := m.db.QueryRow( + fmt.Sprintf("SELECT `last_index`, `value` FROM `%s` WHERE %s LIMIT 1;", + m.table, whereCond(key, true)), args...) + + var ( + value []byte + index uint64 + ) + + if err := row.Scan(&index, &value); err != nil { + if err == sql.ErrNoRows { + err = store.ErrKeyNotFound + } + return nil, err + } + + return &store.KVPair{Key: key, Value: value, LastIndex: index}, nil +} + +// Put a value at "key". We cannot guarantee the systime on each machine are +// the same, hence it not support TTL. +func (m *MySQL) Put(key string, value []byte, _ *store.WriteOptions) error { + now := time.Now() + args := make([]interface{}, 0, MaxFields+6) + for _, field := range splitFields(key, true) { + args = append(args, field) + } + args = append(args, uint64(1), value, now, now, value, now) + + result, err := m.db.Exec( + fmt.Sprintf("INSERT INTO `%s` (%s, `last_index`, `value`, `create_at`, `update_at`) VALUES(%s?,?,?,?) ON DUPLICATE KEY UPDATE `last_index`=`last_index`+1, `value`=?, `update_at`=?;", + m.table, eachFields(), strings.Repeat("?,", MaxFields)), args...) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + + if affected == 0 { + return ErrPutMissing + } + return nil +} + +// Exists checks that the key exists inside the store +func (m *MySQL) Exists(key string) (bool, error) { + _, err := m.Get(key) + if err != nil { + if err == store.ErrKeyNotFound { + return false, nil + } + return false, err + } + return true, nil +} + +// Delete a value at "key" +func (m *MySQL) Delete(key string) error { + args := make([]interface{}, 0, MaxFields) + for _, field := range splitFields(key, true) { + args = append(args, field) + } + + result, err := m.db.Exec( + fmt.Sprintf("DELETE FROM `%s` WHERE %s;", m.table, whereCond(key, true)), + args...) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + + if affected == 0 { + return store.ErrKeyNotFound + } + return nil +} + +// List child nodes of a given directory +func (m *MySQL) List(directory string) ([]*store.KVPair, error) { + args := make([]interface{}, 0, MaxFields) + for _, field := range splitFields(directory, false) { + args = append(args, field) + } + + rows, err := m.db.Query( + fmt.Sprintf("SELECT %s, `last_index`, `value` FROM `%s` WHERE %s;", + eachFields(), m.table, whereCond(directory, false)), + args...) + if err != nil { + return nil, err + } + + defer rows.Close() + + var pairs []*store.KVPair + + for rows.Next() { + var ( + fields = make([]string, MaxFields) + inf = make([]interface{}, 0, MaxFields+2) + value []byte + index uint64 + ) + + for i := 0; i < len(fields); i++ { + inf = append(inf, &fields[i]) + } + inf = append(inf, &index, &value) + + if err := rows.Scan(inf...); err != nil { + return pairs, err + } + + pairs = append(pairs, &store.KVPair{ + Key: path.Join(fields...), + Value: value, + LastIndex: index, + }) + } + + if err := rows.Err(); err != nil { + return pairs, err + } + + if len(pairs) == 0 { + return nil, store.ErrKeyNotFound + } + + return pairs, nil +} + +// DeleteTree deletes a range of keys under a given directory +func (m *MySQL) DeleteTree(directory string) error { + args := make([]interface{}, 0, MaxFields) + for _, field := range splitFields(directory, false) { + args = append(args, field) + } + + result, err := m.db.Exec( + fmt.Sprintf("DELETE FROM `%s` WHERE %s;", m.table, whereCond(directory, false)), + args...) + if err != nil { + return err + } + + affected, err := result.RowsAffected() + if err != nil { + return err + } + + if affected == 0 { + return store.ErrKeyNotFound + } + return nil +} + +// Watch for changes on a "key" +// It returns a channel that will receive changes or pass +// on errors. Upon creation, the current value will first +// be sent to the channel. Providing a non-nil stopCh can +// be used to stop watching. +func (m *MySQL) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) { + // Get the key first, and check the key is exist. + pair, err := m.Get(key) + if err != nil { + return nil, err + } + + watchCh := make(chan *store.KVPair, 1) + lastIndex := pair.LastIndex + watchCh <- pair + + go func() { + defer close(watchCh) + + tick := time.NewTicker(DefaultWatchWaitTime) + defer tick.Stop() + + for { + // Check if we should quit + select { + case <-tick.C: + case <-stopCh: + return + } + + // Get the key + pair, err := m.Get(key) + if err != nil && err != store.ErrKeyNotFound { + // keep the same behavior with other backend implementant. + return + } else if err == nil { + // If LastIndex didn't change then it means `Get` returned + // because of the WaitTime and the key didn't changed. + if lastIndex == pair.LastIndex { + continue + } + lastIndex = pair.LastIndex + select { + case watchCh <- pair: + case <-stopCh: + return + } + } + // else { + // // the key has been deleted. + // // Nothing to do with this, keep the + // // same behaivor with other backend. + // } + } + }() + + return watchCh, nil +} + +// WatchTree watches for changes on a "directory" +// It returns a channel that will receive changes or pass +// on errors. Upon creating a watch, the current childs values +// will be sent to the channel. Providing a non-nil stopCh can +// be used to stop watching. +func (m *MySQL) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) { + directory = normalize(directory) + list, err := m.List(directory) + if err != nil { + return nil, err + } + + // record each children's index + indice := make(map[string]uint64) + watchCh := make(chan []*store.KVPair, 1) + for _, p := range list { + indice[p.Key] = p.LastIndex + } + + watchCh <- list + + go func() { + tick := time.NewTicker(DefaultWatchWaitTime) + defer func() { + tick.Stop() + close(watchCh) + }() + + for { + // Check if we should quit + select { + case <-tick.C: + case <-stopCh: + return + } + + // Get all the childrens + list, err := m.List(directory) + if err != nil { + return + } + + var ( + changed bool + pairs []*store.KVPair + exists = make(map[string]struct{}) + ) + + for _, p := range list { + exists[p.Key] = struct{}{} + if p.Key == directory { + continue + } + pairs = append(pairs, p) + // If LastIndex didn't change then it means `Get` returned + // because of the WaitTime and the child keys didn't change. + if p.LastIndex != indice[p.Key] { + changed = true + } + indice[p.Key] = p.LastIndex + } + + // find someone has been deleted + for key := range indice { + if _, ok := exists[key]; !ok { + changed = true + delete(indice, key) + } + } + + if changed { + select { + case watchCh <- pairs: + case <-stopCh: + return + } + } + } + }() + + return watchCh, nil +} + +// AtomicPut put a value at "key" if the key has not been +// modified in the meantime, throws an error if this is the case +func (m *MySQL) AtomicPut(key string, value []byte, previous *store.KVPair, + _ *store.WriteOptions) (bool, *store.KVPair, error) { + var ( + result sql.Result + err error + ) + + now := time.Now() + if previous == nil { + args := make([]interface{}, 0, MaxFields+4) + for _, field := range splitFields(key, true) { + args = append(args, field) + } + args = append(args, uint64(1), value, now, now) + result, err = m.db.Exec( + fmt.Sprintf("INSERT INTO `%s` (%s, `last_index`, `value`, `create_at`, `update_at`) VALUES(%s?,?,?,?);", + m.table, eachFields(), strings.Repeat("?,", MaxFields)), + args...) + } else { + args := make([]interface{}, 0, MaxFields+3) + args = append(args, value, now) + for _, field := range splitFields(key, true) { + args = append(args, field) + } + args = append(args, previous.LastIndex) + result, err = m.db.Exec( + fmt.Sprintf("UPDATE `%s` SET `value`=?, `last_index`=`last_index`+1, `update_at`=? WHERE %s AND `last_index`=?;", + m.table, whereCond(key, true)), + args...) + } + + if err != nil { + if merr, ok := err.(*mysql.MySQLError); ok && merr.Number == 1062 { + return false, nil, store.ErrKeyExists + } + return false, nil, err + } + + affected, err := result.RowsAffected() + if err != nil { + return false, nil, err + } + + if affected == 0 { + return false, nil, store.ErrKeyModified + } + + pair, err := m.Get(key) + if err != nil { + return false, nil, err + } + + return true, pair, nil +} + +// AtomicDelete deletes a value at "key" if the key has not +// been modified in the meantime, throws an error if this is the case +func (m *MySQL) AtomicDelete(key string, previous *store.KVPair) (ok bool, err error) { + if previous == nil { + return false, store.ErrPreviousNotSpecified + } + + // Extra Get operation to check on the key + if _, err := m.Get(key); err != nil { + return false, err + } + + args := make([]interface{}, 0, MaxFields+1) + for _, field := range splitFields(key, true) { + args = append(args, field) + } + args = append(args, previous.LastIndex) + + result, err := m.db.Exec( + fmt.Sprintf("DELETE FROM `%s` WHERE %s AND `last_index`=?;", + m.table, whereCond(key, true)), + args...) + if err != nil { + return false, err + } + + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + + if affected == 0 { + return false, store.ErrKeyModified + } + + return true, nil +} + +// NewLock creates a lock for a given key. +// The returned Locker is not held and must be acquired +// with `.Lock`. The Value is optional. +func (m *MySQL) NewLock(key string, options *store.LockOptions) (store.Locker, error) { + var ( + value []byte + ttl = defaultLockTTL + renewCh = make(chan struct{}) + ) + + if options != nil { + if options.TTL != 0 { + ttl = options.TTL + } + if options.Value != nil { + value = options.Value + } + if options.RenewLock != nil { + renewCh = options.RenewLock + } + } + + return &mysqlLock{ + key: normalize(key), + session: session(), + value: value, + ttl: ttl, + renewCh: renewCh, + m: m, + }, nil +} + +type mysqlLock struct { + mu sync.Mutex + wg sync.WaitGroup + key string + session string + value []byte + ttl time.Duration + m *MySQL + unlockCh chan struct{} + renewCh chan struct{} +} + +func (l *mysqlLock) acquire(lastIndex uint64, expired bool) (index uint64, ok bool, err error) { + var ( + tx *sql.Tx + result sql.Result + affected int64 + session string + ) + + if tx, err = l.m.db.Begin(); err != nil { + return + } + + defer func() { + if err == nil { + err = tx.Commit() + } else { + tx.Rollback() + } + }() + + args := make([]interface{}, 0, MaxFields) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + + // lock the mysql row + row := tx.QueryRow( + fmt.Sprintf("SELECT `lock_session`, `last_index` FROM `%s` WHERE %s LIMIT 1 FOR UPDATE;", + l.m.table, whereCond(l.key, true)), + args...) + + if err = row.Scan(&session, &index); err != nil && err != sql.ErrNoRows { + return + } + + now := time.Now() + + switch session { + case "": + // the lock key is not exist, so create one + args := make([]interface{}, 0, MaxFields+5) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + args = append(args, l.session, uint64(1), l.value, now, now) + result, err = tx.Exec( + fmt.Sprintf("INSERT INTO `%s` (%s,`lock_session`, `last_index`, `value`, `create_at`, `update_at`) VALUES(%s?,?,?,?,?);", + l.m.table, eachFields(), strings.Repeat("?,", MaxFields)), + args...) + case l.session: + // the current session is same with us, we fetch this lock directly + args := make([]interface{}, 0, MaxFields+2) + args = append(args, l.value, now) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + result, err = tx.Exec( + fmt.Sprintf("UPDATE `%s` SET `last_index`=`last_index`+1, `value`=?, `update_at`=? WHERE %s;", + l.m.table, whereCond(l.key, true)), + args...) + default: + if !expired || lastIndex != index { + // someone is helding the lock + return + } + // someone lost this lock + args := make([]interface{}, 0, MaxFields+3) + args = append(args, l.session, l.value, now) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + result, err = tx.Exec( + fmt.Sprintf("UPDATE `%s` SET `lock_session`=?, `last_index`=`last_index`+1, `value`=?, `update_at`=? WHERE %s;", + l.m.table, whereCond(l.key, true)), + args...) + } + + if err != nil { + if merr, ok := err.(*mysql.MySQLError); ok && merr.Number == 1062 { + err = nil // ignore insert failed on duplicate entry + } + return + } + + affected, err = result.RowsAffected() + if err != nil { + return + } + + if affected != 0 { + ok = true + index++ + } + return +} + +func (l *mysqlLock) renewLock() (err error) { + var ( + tx *sql.Tx + affected int64 + session string + result sql.Result + ) + + tx, err = l.m.db.Begin() + if err != nil { + return + } + + defer func() { + if err == nil { + err = tx.Commit() + } else { + tx.Rollback() + } + }() + + args := make([]interface{}, 0, MaxFields) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + + row := tx.QueryRow( + fmt.Sprintf("SELECT `lock_session` FROM `%s` WHERE %s LIMIT 1 FOR UPDATE;", + l.m.table, whereCond(l.key, true)), + args...) + if err = row.Scan(&session); err != nil { + return + } + + if session != l.session { + return ErrLockLost + } + + args = args[:0] + args = append(args, time.Now()) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + + result, err = tx.Exec( + fmt.Sprintf("UPDATE `%s` SET `last_index`=`last_index`+1, `update_at`=? WHERE %s;", + l.m.table, whereCond(l.key, true)), + args...) + if err != nil { + return + } + affected, err = result.RowsAffected() + if err != nil { + return + } + if affected == 0 { + err = ErrLockLost + } + return +} + +func (l *mysqlLock) holdLock(lockHeld chan struct{}, stopLocking, unlock <-chan struct{}) { + tick := time.NewTicker(l.ttl / defaultTTLPeriod) + + defer func() { + close(lockHeld) + tick.Stop() + l.wg.Done() + }() + + for { + select { + case <-tick.C: + if err := l.renewLock(); err != nil && err != driver.ErrBadConn { + return + } + case <-unlock: + return + case <-stopLocking: + return + } + } +} + +// Lock attempts to acquire the lock and blocks while +// doing so. It returns a channel that is closed if our +// lock is lost or if an error occurs +func (l *mysqlLock) Lock(stopChan chan struct{}) (<-chan struct{}, error) { + l.mu.Lock() + defer l.mu.Unlock() + + var ( + expired bool + count int + lastIndex uint64 + ttl = l.ttl / defaultTTLPeriod + ) + + // this lock.Lock is already invoked + if l.unlockCh != nil { + return nil, store.ErrCannotLock + } + + lockHeld := make(chan struct{}) + unlockCh := make(chan struct{}) + tick := time.NewTimer(0) + + defer tick.Stop() + + for { + select { + case <-tick.C: + case <-stopChan: + return nil, ErrAbortTryLock + } + + index, ok, err := l.acquire(lastIndex, expired) + if err == nil { + if ok { + l.wg.Add(1) + l.unlockCh = unlockCh + go l.holdLock(lockHeld, l.renewCh, unlockCh) + return lockHeld, nil + } else if index == lastIndex { + count++ + if count >= defaultTTLPeriod-1 { + expired = true + } + } else { + lastIndex = index + expired = false + } + } else if err != nil && err != driver.ErrBadConn { + return nil, err + } + tick.Reset(ttl) + } +} + +// Unlock the "key". +func (l *mysqlLock) Unlock() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.unlockCh != nil { + close(l.unlockCh) + l.wg.Wait() + l.unlockCh = nil + + // delete the lock key + args := make([]interface{}, 0, MaxFields+1) + for _, field := range splitFields(l.key, true) { + args = append(args, field) + } + args = append(args, l.session) + result, err := l.m.db.Exec( + fmt.Sprintf("DELETE FROM `%s` WHERE %s AND `lock_session`=?;", + l.m.table, whereCond(l.key, true)), + args...) + if err != nil { + return err + } + affected, err := result.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return ErrLockLost + } + } + return nil +} diff --git a/store/mysql/mysql_test.go b/store/mysql/mysql_test.go new file mode 100644 index 00000000..852d25b4 --- /dev/null +++ b/store/mysql/mysql_test.go @@ -0,0 +1,83 @@ +package mysql + +import ( + "testing" + "time" + + "github.com/docker/libkv" + "github.com/docker/libkv/store" + "github.com/docker/libkv/testutils" + "github.com/stretchr/testify/assert" +) + +var ( + testEndpoint = "localhost:3306" + testUser = "root" + testPassword = "" + testDatabase = "libkv" + testTable = "libkv" +) + +func init() { + DefaultWatchWaitTime = 100 * time.Millisecond +} + +func makeMySQLClient(t *testing.T) store.Store { + kv, err := New([]string{testEndpoint}, &store.Config{ + Username: testUser, + Password: testPassword, + Database: testDatabase, + Table: testTable, + }) + + if err != nil { + t.Fatalf("cannot create store: %v", err) + } + + return kv +} + +func TestRegister(t *testing.T) { + Register() + ss, err := libkv.NewStore(store.MYSQL, []string{testEndpoint}, &store.Config{ + Username: testUser, + Password: testPassword, + Database: testDatabase, + Table: testTable, + }) + assert.NoError(t, err) + assert.NotNil(t, ss) + + if _, ok := ss.(*MySQL); !ok { + t.Fatal("Error registering and initializing mysql") + } + + ss.Close() +} + +func TestMySQLStore(t *testing.T) { + kv := makeMySQLClient(t) + defer kv.Close() + lockKV := makeMySQLClient(t) + defer lockKV.Close() + + testutils.RunTestCommon(t, kv) + testutils.RunTestAtomic(t, kv) + testutils.RunTestWatch(t, kv) + testutils.RunTestLock(t, kv) + testutils.RunTestLockTTL(t, kv, lockKV) + testutils.RunCleanup(t, kv) +} + +func TestMySQLStoreExtra(t *testing.T) { + kv := makeMySQLClient(t) + defer kv.Close() + + ss, err := New([]string{"xx", "yy"}, nil) + assert.Equal(t, ErrInvalidCountEndpoint, err) + assert.Nil(t, ss) + + ok, err := kv.AtomicDelete("a/b/c", nil) + assert.Equal(t, store.ErrPreviousNotSpecified, err) + assert.False(t, ok) +} diff --git a/store/store.go b/store/store.go index 7a4850c0..aceffbc7 100644 --- a/store/store.go +++ b/store/store.go @@ -18,6 +18,8 @@ const ( ZK Backend = "zk" // BOLTDB backend BOLTDB Backend = "boltdb" + // MYSQL backend + MYSQL Backend = "mysql" ) var ( @@ -48,6 +50,8 @@ type Config struct { PersistConnection bool Username string Password string + Database string + Table string } // ClientTLSConfig contains data for a Client TLS configuration in the form From d542bf834636d8fbacf5621ccb0ddeb1e8ca3094 Mon Sep 17 00:00:00 2001 From: lrita Date: Thu, 31 Aug 2017 20:45:04 +0800 Subject: [PATCH 2/5] make Watch/WatchTree method can watch non-exist key We will get nil from the watching channel if the key isn't exist or has been deleted. Signed-off-by: lrita --- store/mysql/mysql.go | 45 ++++++++------- store/mysql/mysql_test.go | 112 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 21 deletions(-) diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go index 5f6d1895..d9c24401 100644 --- a/store/mysql/mysql.go +++ b/store/mysql/mysql.go @@ -337,12 +337,15 @@ func (m *MySQL) DeleteTree(directory string) error { func (m *MySQL) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, error) { // Get the key first, and check the key is exist. pair, err := m.Get(key) - if err != nil { + if err != nil && err != store.ErrKeyNotFound { return nil, err } watchCh := make(chan *store.KVPair, 1) - lastIndex := pair.LastIndex + lastIndex := uint64(0) + if pair != nil { + lastIndex = pair.LastIndex + } watchCh <- pair go func() { @@ -364,24 +367,24 @@ func (m *MySQL) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, if err != nil && err != store.ErrKeyNotFound { // keep the same behavior with other backend implementant. return - } else if err == nil { - // If LastIndex didn't change then it means `Get` returned - // because of the WaitTime and the key didn't changed. - if lastIndex == pair.LastIndex { - continue - } - lastIndex = pair.LastIndex - select { - case watchCh <- pair: - case <-stopCh: - return - } } - // else { - // // the key has been deleted. - // // Nothing to do with this, keep the - // // same behaivor with other backend. - // } + + index := uint64(0) + if pair != nil { + index = pair.LastIndex + } + + // If index didn't change then it means `Get` returned + // because of the WaitTime and the key didn't changed. + if lastIndex == index { + continue + } + lastIndex = index + select { + case watchCh <- pair: + case <-stopCh: + return + } } }() @@ -396,7 +399,7 @@ func (m *MySQL) Watch(key string, stopCh <-chan struct{}) (<-chan *store.KVPair, func (m *MySQL) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*store.KVPair, error) { directory = normalize(directory) list, err := m.List(directory) - if err != nil { + if err != nil && err != store.ErrKeyNotFound { return nil, err } @@ -426,7 +429,7 @@ func (m *MySQL) WatchTree(directory string, stopCh <-chan struct{}) (<-chan []*s // Get all the childrens list, err := m.List(directory) - if err != nil { + if err != nil && err != store.ErrKeyNotFound { return } diff --git a/store/mysql/mysql_test.go b/store/mysql/mysql_test.go index 852d25b4..7b9dce02 100644 --- a/store/mysql/mysql_test.go +++ b/store/mysql/mysql_test.go @@ -80,4 +80,116 @@ func TestMySQLStoreExtra(t *testing.T) { ok, err := kv.AtomicDelete("a/b/c", nil) assert.Equal(t, store.ErrPreviousNotSpecified, err) assert.False(t, ok) + + testWatchNonExistKey(t, kv) + testWatchTreeNonExistKey(t, kv) +} + +func testWatchNonExistKey(t *testing.T, kv store.Store) { + nonexist := "test/watch/nonexist" + value0 := []byte("hello world") + value1 := []byte("hello world!!!") + stopCh := make(chan struct{}) + + watchCh, err := kv.Watch(nonexist, stopCh) + if !assert.NoError(t, err) { + return + } + + go func() { + for i := 0; i < 3; i++ { + var err error + time.Sleep(250 * time.Millisecond) + switch i { + case 0: + err = kv.Put(nonexist, value0, nil) + case 1: + err = kv.Put(nonexist, value1, nil) + case 2: + err = kv.Delete(nonexist) + } + assert.NoError(t, err) + } + }() + + eventCount := 0 + for { + select { + case pair := <-watchCh: + switch eventCount { + case 0: + assert.Nil(t, pair, "first must be nil, because of we watching a non-exit key") + case 1: + if assert.NotNil(t, pair) { + assert.Equal(t, value0, pair.Value) + } + case 2: + if assert.NotNil(t, pair) { + assert.Equal(t, value1, pair.Value) + } + case 3: + assert.Nil(t, pair, "last must be nil, because of the key has been deleted") + close(stopCh) + return + } + eventCount++ + case <-time.After(4 * time.Second): + t.Fatal("Timeout reached") + } + } +} + +func testWatchTreeNonExistKey(t *testing.T, kv store.Store) { + nonexist := "test/watchtree/nonexist" + key := "test/watchtree/nonexist/testkey" + value0 := []byte("hello world") + value1 := []byte("hello world!!!") + stopCh := make(chan struct{}) + + watchCh, err := kv.WatchTree(nonexist, stopCh) + if !assert.NoError(t, err) { + return + } + + go func() { + for i := 0; i < 3; i++ { + var err error + time.Sleep(250 * time.Millisecond) + switch i { + case 0: + err = kv.Put(key, value0, nil) + case 1: + err = kv.Put(key, value1, nil) + case 2: + err = kv.Delete(key) + } + assert.NoError(t, err) + } + }() + + eventCount := 0 + for { + select { + case pairs := <-watchCh: + switch eventCount { + case 0: + assert.Nil(t, pairs, "first must be nil, because of we watching a non-exit key") + case 1: + if assert.NotNil(t, pairs) { + assert.Equal(t, value0, pairs[0].Value) + } + case 2: + if assert.NotNil(t, pairs) { + assert.Equal(t, value1, pairs[0].Value) + } + case 3: + assert.Nil(t, pairs, "last must be nil, because of the key has been deleted") + close(stopCh) + return + } + eventCount++ + case <-time.After(4 * time.Second): + t.Fatal("Timeout reached") + } + } } From 3cf22708d5e566ed73e88a6b28e3c588b3488eae Mon Sep 17 00:00:00 2001 From: lrita Date: Mon, 11 Nov 2019 15:10:30 +0800 Subject: [PATCH 3/5] Add Read/Write-Timeout --- store/mysql/mysql.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go index d9c24401..7dc75524 100644 --- a/store/mysql/mysql.go +++ b/store/mysql/mysql.go @@ -93,8 +93,8 @@ func New(endpoints []string, opts *store.Config) (store.Store, error) { } db, err := sql.Open("mysql", - fmt.Sprintf("%s%s@tcp(%s)/%s?charset=utf8&interpolateParams=true&parseTime=True&loc=Local&timeout=%s", - opts.Username, passward, endpoints[0], opts.Database, timeout)) + fmt.Sprintf("%s%s@tcp(%s)/%s?charset=utf8&interpolateParams=true&parseTime=True&loc=Local&timeout=%s&readTimeout=%s&writeTimeout=%s", + opts.Username, passward, endpoints[0], opts.Database, timeout, timeout, timeout)) if err != nil { return nil, err } From ff1871aadf40a442516b3cdf06803f0c40572d18 Mon Sep 17 00:00:00 2001 From: lrita Date: Tue, 12 Nov 2019 15:36:30 +0800 Subject: [PATCH 4/5] ignore mysql.ErrInvalidConn at lock acquring --- store/mysql/mysql.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go index 7dc75524..33c445f4 100644 --- a/store/mysql/mysql.go +++ b/store/mysql/mysql.go @@ -572,6 +572,7 @@ func (m *MySQL) AtomicDelete(key string, previous *store.KVPair) (ok bool, err e // NewLock creates a lock for a given key. // The returned Locker is not held and must be acquired // with `.Lock`. The Value is optional. +// https://dev.mysql.com/doc/refman/8.0/en/innodb-transaction-isolation-levels.html func (m *MySQL) NewLock(key string, options *store.LockOptions) (store.Locker, error) { var ( value []byte @@ -841,7 +842,7 @@ func (l *mysqlLock) Lock(stopChan chan struct{}) (<-chan struct{}, error) { lastIndex = index expired = false } - } else if err != nil && err != driver.ErrBadConn { + } else if err != nil && err != driver.ErrBadConn && err != mysql.ErrInvalidConn { return nil, err } tick.Reset(ttl) From 12a10fd9661c9bf282f799c31e48d71148da3bdd Mon Sep 17 00:00:00 2001 From: lrita Date: Wed, 19 Aug 2020 11:20:29 +0800 Subject: [PATCH 5/5] add rejectReadOnly=true in db's dsn --- store/mysql/mysql.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go index 33c445f4..dfe19fd1 100644 --- a/store/mysql/mysql.go +++ b/store/mysql/mysql.go @@ -93,7 +93,7 @@ func New(endpoints []string, opts *store.Config) (store.Store, error) { } db, err := sql.Open("mysql", - fmt.Sprintf("%s%s@tcp(%s)/%s?charset=utf8&interpolateParams=true&parseTime=True&loc=Local&timeout=%s&readTimeout=%s&writeTimeout=%s", + fmt.Sprintf("%s%s@tcp(%s)/%s?charset=utf8&interpolateParams=true&parseTime=True&loc=Local&timeout=%s&readTimeout=%s&writeTimeout=%s&rejectReadOnly=true", opts.Username, passward, endpoints[0], opts.Database, timeout, timeout, timeout)) if err != nil { return nil, err