diff options
Diffstat (limited to 'modules/nosql/manager_redis.go')
-rw-r--r-- | modules/nosql/manager_redis.go | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/modules/nosql/manager_redis.go b/modules/nosql/manager_redis.go new file mode 100644 index 0000000000..7792a90112 --- /dev/null +++ b/modules/nosql/manager_redis.go @@ -0,0 +1,205 @@ +// Copyright 2020 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package nosql + +import ( + "crypto/tls" + "path" + "strconv" + "strings" + + "github.com/go-redis/redis/v7" +) + +var replacer = strings.NewReplacer("_", "", "-", "") + +// CloseRedisClient closes a redis client +func (m *Manager) CloseRedisClient(connection string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + client, ok := m.RedisConnections[connection] + if !ok { + connection = ToRedisURI(connection).String() + client, ok = m.RedisConnections[connection] + } + if !ok { + return nil + } + + client.count-- + if client.count > 0 { + return nil + } + + for _, name := range client.name { + delete(m.RedisConnections, name) + } + return client.UniversalClient.Close() +} + +// GetRedisClient gets a redis client for a particular connection +func (m *Manager) GetRedisClient(connection string) redis.UniversalClient { + m.mutex.Lock() + defer m.mutex.Unlock() + client, ok := m.RedisConnections[connection] + if ok { + client.count++ + return client + } + + uri := ToRedisURI(connection) + client, ok = m.RedisConnections[uri.String()] + if ok { + client.count++ + return client + } + client = &redisClientHolder{ + name: []string{connection, uri.String()}, + } + + opts := &redis.UniversalOptions{} + tlsConfig := &tls.Config{} + + // Handle username/password + if password, ok := uri.User.Password(); ok { + opts.Password = password + // Username does not appear to be handled by redis.Options + opts.Username = uri.User.Username() + } else if uri.User.Username() != "" { + // assume this is the password + opts.Password = uri.User.Username() + } + + // Now handle the uri query sets + for k, v := range uri.Query() { + switch replacer.Replace(strings.ToLower(k)) { + case "addr": + opts.Addrs = append(opts.Addrs, v...) + case "addrs": + opts.Addrs = append(opts.Addrs, strings.Split(v[0], ",")...) + case "username": + opts.Username = v[0] + case "password": + opts.Password = v[0] + case "database": + fallthrough + case "db": + opts.DB, _ = strconv.Atoi(v[0]) + case "maxretries": + opts.MaxRetries, _ = strconv.Atoi(v[0]) + case "minretrybackoff": + opts.MinRetryBackoff = valToTimeDuration(v) + case "maxretrybackoff": + opts.MaxRetryBackoff = valToTimeDuration(v) + case "timeout": + timeout := valToTimeDuration(v) + if timeout != 0 { + if opts.DialTimeout == 0 { + opts.DialTimeout = timeout + } + if opts.ReadTimeout == 0 { + opts.ReadTimeout = timeout + } + } + case "dialtimeout": + opts.DialTimeout = valToTimeDuration(v) + case "readtimeout": + opts.ReadTimeout = valToTimeDuration(v) + case "writetimeout": + opts.WriteTimeout = valToTimeDuration(v) + case "poolsize": + opts.PoolSize, _ = strconv.Atoi(v[0]) + case "minidleconns": + opts.MinIdleConns, _ = strconv.Atoi(v[0]) + case "pooltimeout": + opts.PoolTimeout = valToTimeDuration(v) + case "idletimeout": + opts.IdleTimeout = valToTimeDuration(v) + case "idlecheckfrequency": + opts.IdleCheckFrequency = valToTimeDuration(v) + case "maxredirects": + opts.MaxRedirects, _ = strconv.Atoi(v[0]) + case "readonly": + opts.ReadOnly, _ = strconv.ParseBool(v[0]) + case "routebylatency": + opts.RouteByLatency, _ = strconv.ParseBool(v[0]) + case "routerandomly": + opts.RouteRandomly, _ = strconv.ParseBool(v[0]) + case "sentinelmasterid": + fallthrough + case "mastername": + opts.MasterName = v[0] + case "skipverify": + fallthrough + case "insecureskipverify": + insecureSkipVerify, _ := strconv.ParseBool(v[0]) + tlsConfig.InsecureSkipVerify = insecureSkipVerify + case "clientname": + client.name = append(client.name, v[0]) + } + } + + switch uri.Scheme { + case "redis+sentinels": + fallthrough + case "rediss+sentinel": + opts.TLSConfig = tlsConfig + fallthrough + case "redis+sentinel": + if uri.Host != "" { + opts.Addrs = append(opts.Addrs, strings.Split(uri.Host, ",")...) + } + if uri.Path != "" { + if db, err := strconv.Atoi(uri.Path); err == nil { + opts.DB = db + } + } + + client.UniversalClient = redis.NewFailoverClient(opts.Failover()) + case "redis+clusters": + fallthrough + case "rediss+cluster": + opts.TLSConfig = tlsConfig + fallthrough + case "redis+cluster": + if uri.Host != "" { + opts.Addrs = append(opts.Addrs, strings.Split(uri.Host, ",")...) + } + if uri.Path != "" { + if db, err := strconv.Atoi(uri.Path); err == nil { + opts.DB = db + } + } + client.UniversalClient = redis.NewClusterClient(opts.Cluster()) + case "redis+socket": + simpleOpts := opts.Simple() + simpleOpts.Network = "unix" + simpleOpts.Addr = path.Join(uri.Host, uri.Path) + client.UniversalClient = redis.NewClient(simpleOpts) + case "rediss": + opts.TLSConfig = tlsConfig + fallthrough + case "redis": + if uri.Host != "" { + opts.Addrs = append(opts.Addrs, strings.Split(uri.Host, ",")...) + } + if uri.Path != "" { + if db, err := strconv.Atoi(uri.Path); err == nil { + opts.DB = db + } + } + client.UniversalClient = redis.NewClient(opts.Simple()) + default: + return nil + } + + for _, name := range client.name { + m.RedisConnections[name] = client + } + + client.count++ + + return client +} |