diff --git a/driver/driver_options_test.go b/driver/driver_options_test.go index e2839a83d..a2874a353 100644 --- a/driver/driver_options_test.go +++ b/driver/driver_options_test.go @@ -266,8 +266,8 @@ func TestDriverOptions_namedValueChecker(t *testing.T) { } func createMockServer(t *testing.T) *testServer { - inMemProvider := server.NewInMemoryProvider() - require.NoError(t, inMemProvider.AddUser(*testUser, *testPassword)) + authHandler := server.NewInMemoryAuthenticationHandler() + require.NoError(t, authHandler.AddUser(*testUser, *testPassword)) defaultServer := server.NewDefaultServer() l, err := net.Listen("tcp", "127.0.0.1:3307") @@ -285,7 +285,7 @@ func createMockServer(t *testing.T) *testServer { } go func() { - co, err := s.NewCustomizedConn(conn, inMemProvider, handler) + co, err := s.NewCustomizedConn(conn, authHandler, handler) if err != nil { return } diff --git a/mysql/util.go b/mysql/util.go index e9837bf60..548899375 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -56,13 +56,13 @@ func CalcNativePassword(scramble, password []byte) []byte { return Xor(scrambleHash, stage1) } -// Xor modifies hash1 in-place with XOR against hash2 +// Xor returns a new slice with hash1 XOR hash2, wrapping hash2 if hash1 is longer. func Xor(hash1 []byte, hash2 []byte) []byte { - l := min(len(hash1), len(hash2)) - for i := range l { - hash1[i] ^= hash2[i] + result := make([]byte, len(hash1)) + for i := range hash1 { + result[i] = hash1[i] ^ hash2[i%len(hash2)] } - return hash1 + return result } // hash_stage1 = xor(reply, sha1(public_seed, hash_stage2)) diff --git a/server/auth.go b/server/auth.go index e675ed7ef..b0dc97552 100644 --- a/server/auth.go +++ b/server/auth.go @@ -30,15 +30,15 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err return c.serverConf.authProvider.Authenticate(c, authPluginName, clientAuthData) } -func (c *Conn) acquirePassword() error { - if c.credential.Password != "" { +func (c *Conn) acquireCredential() error { + if len(c.credential.Passwords) > 0 { return nil } - credential, found, err := c.credentialProvider.GetCredential(c.user) + credential, found, err := c.authHandler.GetCredential(c.user) if err != nil { return err } - if !found { + if !found || len(credential.Passwords) == 0 { return mysql.NewDefaultError(mysql.ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) } c.credential = credential @@ -67,18 +67,24 @@ func scrambleValidation(cached, nonce, scramble []byte) bool { func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential Credential) error { if len(clientAuthData) == 0 { - if credential.Password == "" { + if credential.hasEmptyPassword() { return nil } return ErrAccessDeniedNoPassword } - password, err := mysql.DecodePasswordHex(c.credential.Password) - if err != nil { - return ErrAccessDenied - } - if mysql.CompareNativePassword(clientAuthData, password, c.salt) { - return nil + for _, password := range credential.Passwords { + hash, err := credential.hashPassword(password) + if err != nil { + continue + } + decoded, err := mysql.DecodePasswordHex(hash) + if err != nil { + continue + } + if mysql.CompareNativePassword(clientAuthData, decoded, c.salt) { + return nil + } } return ErrAccessDenied } @@ -86,7 +92,7 @@ func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, credential C func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential Credential) error { // Empty passwords are not hashed, but sent as empty string if len(clientAuthData) == 0 { - if credential.Password == "" { + if credential.hasEmptyPassword() { return nil } return ErrAccessDeniedNoPassword @@ -112,12 +118,18 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C clientAuthData = clientAuthData[:l-1] } } - check, err := mysql.Check256HashingPassword([]byte(credential.Password), string(clientAuthData)) - if err != nil { - return err - } - if check { - return nil + for _, password := range credential.Passwords { + hash, err := credential.hashPassword(password) + if err != nil { + continue + } + check, err := mysql.Check256HashingPassword([]byte(hash), string(clientAuthData)) + if err != nil { + continue + } + if check { + return nil + } } return ErrAccessDenied } @@ -125,7 +137,7 @@ func (c *Conn) compareSha256PasswordAuthData(clientAuthData []byte, credential C func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { // Empty passwords are not hashed, but sent as empty string if len(clientAuthData) == 0 { - if c.credential.Password == "" { + if c.credential.hasEmptyPassword() { return nil } return ErrAccessDeniedNoPassword @@ -139,10 +151,8 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 return c.writeAuthMoreDataFastAuth() } - - return ErrAccessDenied } - // cache miss, do full auth + // cache miss or validation failed, do full auth if err := c.writeAuthMoreDataFullAuth(); err != nil { return err } diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go index 79ed4c34c..98b00260b 100644 --- a/server/auth_switch_response.go +++ b/server/auth_switch_response.go @@ -24,7 +24,7 @@ func (c *Conn) handleAuthSwitchResponse() error { } func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { - if err := c.acquirePassword(); err != nil { + if err := c.acquireCredential(); err != nil { return err } if tlsConn, ok := c.Conn.Conn.(*tls.Conn); ok { @@ -72,15 +72,21 @@ func (c *Conn) handleCachingSha2PasswordFullAuth(authData []byte) error { func (c *Conn) checkSha2CacheCredentials(clientAuthData []byte, credential Credential) error { if len(clientAuthData) == 0 { - if credential.Password == "" { + if credential.hasEmptyPassword() { return nil } return ErrAccessDeniedNoPassword } - match, err := auth.CheckHashingPassword([]byte(credential.Password), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD) - if match && err == nil { - return nil + for _, password := range credential.Passwords { + hash, err := credential.hashPassword(password) + if err != nil { + continue + } + match, err := auth.CheckHashingPassword([]byte(hash), string(clientAuthData), mysql.AUTH_CACHING_SHA2_PASSWORD) + if match && err == nil { + return nil + } } return ErrAccessDenied } diff --git a/server/auth_switch_response_test.go b/server/auth_switch_response_test.go index 9b77ef260..10f505f94 100644 --- a/server/auth_switch_response_test.go +++ b/server/auth_switch_response_test.go @@ -30,7 +30,7 @@ func TestCheckSha2CacheCredentials_EmptyPassword(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Conn{ - credential: Credential{Password: tt.serverPassword}, + credential: Credential{Passwords: []string{tt.serverPassword}}, } err := c.checkSha2CacheCredentials(tt.clientAuthData, c.credential) if tt.wantErr == nil { diff --git a/server/auth_test.go b/server/auth_test.go index a7f24227b..c86f9078a 100644 --- a/server/auth_test.go +++ b/server/auth_test.go @@ -37,7 +37,7 @@ func TestCompareNativePasswordAuthData_EmptyPassword(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Conn{ - credential: Credential{Password: tt.serverPassword}, + credential: Credential{Passwords: []string{tt.serverPassword}}, } err := c.compareNativePasswordAuthData(tt.clientAuthData, c.credential) if tt.wantErr == nil { @@ -73,7 +73,7 @@ func TestCompareSha256PasswordAuthData_EmptyPassword(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Conn{ - credential: Credential{Password: tt.serverPassword}, + credential: Credential{Passwords: []string{tt.serverPassword}}, } err := c.compareSha256PasswordAuthData(tt.clientAuthData, c.credential) if tt.wantErr == nil { @@ -109,7 +109,7 @@ func TestCompareCacheSha2PasswordAuthData_EmptyPassword(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &Conn{ - credential: Credential{Password: tt.serverPassword}, + credential: Credential{Passwords: []string{tt.serverPassword}}, } err := c.compareCacheSha2PasswordAuthData(tt.clientAuthData) if tt.wantErr == nil { diff --git a/server/authentication_handler.go b/server/authentication_handler.go new file mode 100644 index 000000000..941b76bce --- /dev/null +++ b/server/authentication_handler.go @@ -0,0 +1,127 @@ +package server + +import ( + "slices" + "sync" + + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/parser/auth" +) + +// AuthenticationHandler provides user credentials and authentication lifecycle hooks. +// +// # Important Note +// +// if the password in a third-party auth handler could be updated at runtime, we have to invalidate the caching +// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'. +type AuthenticationHandler interface { + // GetCredential returns the user credential (supports multiple valid passwords per user). + // Implementations must be safe for concurrent use. + GetCredential(username string) (credential Credential, found bool, err error) + + // OnAuthSuccess is called after successful authentication, before the OK packet. + // Return an error to reject the connection (error will be sent to client instead of OK). + // Return nil to proceed with sending the OK packet. + OnAuthSuccess(conn *Conn) error + + // OnAuthFailure is called after authentication fails, before the error packet. + // This is informational only - the connection will be closed regardless. + OnAuthFailure(conn *Conn, err error) +} + +func NewInMemoryAuthenticationHandler(defaultAuthMethod ...string) *InMemoryAuthenticationHandler { + d := mysql.AUTH_CACHING_SHA2_PASSWORD + if len(defaultAuthMethod) > 0 { + d = defaultAuthMethod[0] + } + return &InMemoryAuthenticationHandler{ + userPool: sync.Map{}, + defaultAuthMethod: d, + } +} + +// Credential holds authentication settings for a user. +// Passwords contains all valid raw passwords for the user. They are hashed on demand during comparison. +// If empty password authentication is allowed, Passwords must contain an empty string (e.g., []string{""}) +// rather than being a zero-length slice. A zero-length slice means no valid passwords are configured. +type Credential struct { + Passwords []string + AuthPluginName string +} + +// hashPassword computes the password hash for a given password using the credential's auth plugin. +func (c Credential) hashPassword(password string) (string, error) { + if password == "" { + return "", nil + } + + switch c.AuthPluginName { + case mysql.AUTH_NATIVE_PASSWORD: + return mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password))), nil + + case mysql.AUTH_CACHING_SHA2_PASSWORD: + return auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD), nil + + case mysql.AUTH_SHA256_PASSWORD: + return mysql.NewSha256PasswordHash(password) + + case mysql.AUTH_CLEAR_PASSWORD: + return password, nil + + default: + return "", errors.Errorf("unknown authentication plugin name '%s'", c.AuthPluginName) + } +} + +// hasEmptyPassword returns true if any password in the credential is empty. +func (c Credential) hasEmptyPassword() bool { + return slices.Contains(c.Passwords, "") +} + +// InMemoryAuthenticationHandler implements AuthenticationHandler with in-memory credential storage. +type InMemoryAuthenticationHandler struct { + userPool sync.Map // username -> Credential + defaultAuthMethod string +} + +func (h *InMemoryAuthenticationHandler) CheckUsername(username string) (found bool, err error) { + _, ok := h.userPool.Load(username) + return ok, nil +} + +func (h *InMemoryAuthenticationHandler) GetCredential(username string) (credential Credential, found bool, err error) { + v, ok := h.userPool.Load(username) + if !ok { + return Credential{}, false, nil + } + c, valid := v.(Credential) + if !valid { + return Credential{}, true, errors.Errorf("invalid credential") + } + return c, true, nil +} + +func (h *InMemoryAuthenticationHandler) AddUser(username, password string, optionalAuthPluginName ...string) error { + authPluginName := h.defaultAuthMethod + if len(optionalAuthPluginName) > 0 { + authPluginName = optionalAuthPluginName[0] + } + + if !isAuthMethodSupported(authPluginName) { + return errors.Errorf("unknown authentication plugin name '%s'", authPluginName) + } + + h.userPool.Store(username, Credential{ + Passwords: []string{password}, + AuthPluginName: authPluginName, + }) + return nil +} + +func (h *InMemoryAuthenticationHandler) OnAuthSuccess(conn *Conn) error { + return nil +} + +func (h *InMemoryAuthenticationHandler) OnAuthFailure(conn *Conn, err error) { +} diff --git a/server/authentication_handler_test.go b/server/authentication_handler_test.go new file mode 100644 index 000000000..8ac20ad21 --- /dev/null +++ b/server/authentication_handler_test.go @@ -0,0 +1,122 @@ +package server + +import ( + "database/sql" + "net" + "sync/atomic" + "testing" + "time" + + "github.com/go-mysql-org/go-mysql/mysql" + _ "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/stretchr/testify/require" +) + +type hookTrackingAuthenticationHandler struct { + *InMemoryAuthenticationHandler + onSuccessCalled atomic.Int32 + onFailureCalled atomic.Int32 + rejectOnSuccess bool +} + +func (h *hookTrackingAuthenticationHandler) OnAuthSuccess(conn *Conn) error { + h.onSuccessCalled.Add(1) + if h.rejectOnSuccess { + return errors.New("connection rejected by policy") + } + return nil +} + +func (h *hookTrackingAuthenticationHandler) OnAuthFailure(conn *Conn, err error) { + h.onFailureCalled.Add(1) +} + +func TestOnAuthSuccessCalled(t *testing.T) { + handler := &hookTrackingAuthenticationHandler{ + InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(mysql.AUTH_NATIVE_PASSWORD), + } + require.NoError(t, handler.AddUser("testuser", "testpass")) + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + go func() { + conn, _ := l.Accept() + co, _ := NewDefaultServer().NewCustomizedConn(conn, handler, &EmptyHandler{}) + if co != nil { + for co.HandleCommand() == nil { + } + } + }() + + db, err := sql.Open("mysql", "testuser:testpass@tcp("+l.Addr().String()+")/test") + require.NoError(t, err) + defer db.Close() + db.SetConnMaxLifetime(time.Second) + + require.NoError(t, db.Ping()) + require.Equal(t, int32(1), handler.onSuccessCalled.Load()) + require.Equal(t, int32(0), handler.onFailureCalled.Load()) +} + +func TestOnAuthSuccessCanReject(t *testing.T) { + handler := &hookTrackingAuthenticationHandler{ + InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(mysql.AUTH_NATIVE_PASSWORD), + rejectOnSuccess: true, + } + require.NoError(t, handler.AddUser("testuser", "testpass")) + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + go func() { + conn, _ := l.Accept() + co, _ := NewDefaultServer().NewCustomizedConn(conn, handler, &EmptyHandler{}) + if co != nil { + for co.HandleCommand() == nil { + } + } + }() + + db, err := sql.Open("mysql", "testuser:testpass@tcp("+l.Addr().String()+")/test") + require.NoError(t, err) + defer db.Close() + db.SetConnMaxLifetime(time.Second) + + err = db.Ping() + require.Error(t, err) + require.Contains(t, err.Error(), "connection rejected by policy") + require.Equal(t, int32(1), handler.onSuccessCalled.Load()) +} + +func TestOnAuthFailureCalled(t *testing.T) { + handler := &hookTrackingAuthenticationHandler{ + InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(mysql.AUTH_NATIVE_PASSWORD), + } + require.NoError(t, handler.AddUser("testuser", "testpass")) + + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer l.Close() + + go func() { + conn, _ := l.Accept() + co, _ := NewDefaultServer().NewCustomizedConn(conn, handler, &EmptyHandler{}) + if co != nil { + for co.HandleCommand() == nil { + } + } + }() + + db, err := sql.Open("mysql", "testuser:wrongpass@tcp("+l.Addr().String()+")/test") + require.NoError(t, err) + defer db.Close() + db.SetConnMaxLifetime(time.Second) + + require.Error(t, db.Ping()) + require.Equal(t, int32(0), handler.onSuccessCalled.Load()) + require.Equal(t, int32(1), handler.onFailureCalled.Load()) +} diff --git a/server/caching_sha2_cache_test.go b/server/caching_sha2_cache_test.go index c43736ae7..b13ce9341 100644 --- a/server/caching_sha2_cache_test.go +++ b/server/caching_sha2_cache_test.go @@ -25,50 +25,50 @@ import ( // than the second connection (cache hit). Remember to set the password for MySQL user otherwise it won't cache empty password. func TestCachingSha2Cache(t *testing.T) { remoteProvider := &RemoteThrottleProvider{ - InMemoryProvider: NewInMemoryProvider(), + InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(), } require.NoError(t, remoteProvider.AddUser(*testUser, *testPassword)) cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) // no TLS suite.Run(t, &cacheTestSuite{ - server: cacheServer, - credProvider: remoteProvider, - tlsPara: "false", + server: cacheServer, + authHandler: remoteProvider, + tlsPara: "false", }) } func TestCachingSha2CacheTLS(t *testing.T) { remoteProvider := &RemoteThrottleProvider{ - InMemoryProvider: NewInMemoryProvider(), + InMemoryAuthenticationHandler: NewInMemoryAuthenticationHandler(), } require.NoError(t, remoteProvider.AddUser(*testUser, *testPassword)) cacheServer := NewServer("8.0.12", mysql.DEFAULT_COLLATION_ID, mysql.AUTH_CACHING_SHA2_PASSWORD, test_keys.PubPem, tlsConf) // TLS suite.Run(t, &cacheTestSuite{ - server: cacheServer, - credProvider: remoteProvider, - tlsPara: "skip-verify", + server: cacheServer, + authHandler: remoteProvider, + tlsPara: "skip-verify", }) } type RemoteThrottleProvider struct { - *InMemoryProvider + *InMemoryAuthenticationHandler getCredCallCount atomic.Int64 } func (m *RemoteThrottleProvider) GetCredential(username string) (credential Credential, found bool, err error) { m.getCredCallCount.Add(1) - return m.InMemoryProvider.GetCredential(username) + return m.InMemoryAuthenticationHandler.GetCredential(username) } type cacheTestSuite struct { suite.Suite - server *Server - serverAddr string - credProvider CredentialProvider - tlsPara string + server *Server + serverAddr string + authHandler AuthenticationHandler + tlsPara string db *sql.DB @@ -107,7 +107,7 @@ func (s *cacheTestSuite) onAccept() { func (s *cacheTestSuite) onConn(conn net.Conn) { // co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) - co, err := s.server.NewCustomizedConn(conn, s.credProvider, &testCacheHandler{s}) + co, err := s.server.NewCustomizedConn(conn, s.authHandler, &testCacheHandler{s}) require.NoError(s.T(), err) for { err = co.HandleCommand() @@ -134,7 +134,7 @@ func (s *cacheTestSuite) TestCache() { require.NoError(s.T(), err) s.db.SetMaxIdleConns(4) s.runSelect() - got := s.credProvider.(*RemoteThrottleProvider).getCredCallCount.Load() + got := s.authHandler.(*RemoteThrottleProvider).getCredCallCount.Load() require.Equal(s.T(), int64(1), got) if s.db != nil { @@ -146,7 +146,7 @@ func (s *cacheTestSuite) TestCache() { require.NoError(s.T(), err) s.db.SetMaxIdleConns(4) s.runSelect() - got = s.credProvider.(*RemoteThrottleProvider).getCredCallCount.Load() + got = s.authHandler.(*RemoteThrottleProvider).getCredCallCount.Load() require.Equal(s.T(), int64(2), got) if s.db != nil { diff --git a/server/conn.go b/server/conn.go index 12e160bcb..1dd8cb77d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -24,7 +24,7 @@ type Conn struct { warnings uint16 salt []byte // should be 8 + 12 for auth-plugin-data-part-1 and auth-plugin-data-part-2 - credentialProvider CredentialProvider + authHandler AuthenticationHandler user string credential Credential cachingSha2FullAuth bool @@ -54,21 +54,21 @@ func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, err // NewCustomizedConn: create connection with customized server settings // // Deprecated: Use [Server.NewCustomizedConn] instead. -func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, h Handler) (*Conn, error) { - return serverConf.NewCustomizedConn(conn, p, h) +func NewCustomizedConn(conn net.Conn, serverConf *Server, authHandler AuthenticationHandler, h Handler) (*Conn, error) { + return serverConf.NewCustomizedConn(conn, authHandler, h) } // NewConn: create connection with default server settings func (s *Server) NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, error) { - p := NewInMemoryProvider() - if err := p.AddUser(user, password); err != nil { + authHandler := NewInMemoryAuthenticationHandler() + if err := authHandler.AddUser(user, password); err != nil { return nil, err } - return s.NewCustomizedConn(conn, p, h) + return s.NewCustomizedConn(conn, authHandler, h) } -func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handler) (*Conn, error) { +func (s *Server) NewCustomizedConn(conn net.Conn, authHandler AuthenticationHandler, h Handler) (*Conn, error) { var packetConn *packet.Conn if s.tlsConfig != nil { packetConn = packet.NewTLSConn(conn) @@ -77,13 +77,13 @@ func (s *Server) NewCustomizedConn(conn net.Conn, p CredentialProvider, h Handle } c := &Conn{ - Conn: packetConn, - serverConf: s, - credentialProvider: p, - h: h, - connectionID: atomic.AddUint32(&baseConnID, 1), - stmts: make(map[uint32]*Stmt), - salt: mysql.RandomBuf(20), + Conn: packetConn, + serverConf: s, + authHandler: authHandler, + h: h, + connectionID: atomic.AddUint32(&baseConnID, 1), + stmts: make(map[uint32]*Stmt), + salt: mysql.RandomBuf(20), } c.closed.Store(false) @@ -109,6 +109,12 @@ func (c *Conn) handshake() error { err = mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR, c.user, c.RemoteAddr().String(), mysql.MySQLErrName[usingPasswd]) } + c.authHandler.OnAuthFailure(c, err) + _ = c.writeError(err) + return err + } + + if err := c.authHandler.OnAuthSuccess(c); err != nil { _ = c.writeError(err) return err } diff --git a/server/credential_provider.go b/server/credential_provider.go deleted file mode 100644 index 552b62b9a..000000000 --- a/server/credential_provider.go +++ /dev/null @@ -1,112 +0,0 @@ -package server - -import ( - "sync" - - "github.com/go-mysql-org/go-mysql/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/tidb/pkg/parser/auth" -) - -// interface for user credential provider -// hint: can be extended for more functionality -// -// # Important Note -// -// if the password in a third-party credential provider could be updated at runtime, we have to invalidate the caching -// for 'caching_sha2_password' by calling 'func (s *Server)InvalidateCache(string, string)'. -type CredentialProvider interface { - // check if the user exists - CheckUsername(username string) (bool, error) - // get user credential - GetCredential(username string) (credential Credential, found bool, err error) -} - -func NewInMemoryProvider(defaultAuthMethod ...string) *InMemoryProvider { - d := mysql.AUTH_CACHING_SHA2_PASSWORD - if len(defaultAuthMethod) > 0 { - d = defaultAuthMethod[0] - } - return &InMemoryProvider{ - userPool: sync.Map{}, - defaultAuthMethod: d, - } -} - -type Credential struct { - Password string - AuthPluginName string -} - -func NewCredential(password string, authPluginName string) (Credential, error) { - c := Credential{ - AuthPluginName: authPluginName, - } - - if password == "" { - c.Password = "" - return c, nil - } - - switch c.AuthPluginName { - case mysql.AUTH_NATIVE_PASSWORD: - c.Password = mysql.EncodePasswordHex(mysql.NativePasswordHash([]byte(password))) - - case mysql.AUTH_CACHING_SHA2_PASSWORD: - c.Password = auth.NewHashPassword(password, mysql.AUTH_CACHING_SHA2_PASSWORD) - - case mysql.AUTH_SHA256_PASSWORD: - hash, err := mysql.NewSha256PasswordHash(password) - if err != nil { - return c, err - } - c.Password = hash - - case mysql.AUTH_CLEAR_PASSWORD: - c.Password = password - - default: - return c, errors.Errorf("unknown authentication plugin name '%s'", c.AuthPluginName) - } - return c, nil -} - -// implements an in memory credential provider -type InMemoryProvider struct { - userPool sync.Map // username -> password - defaultAuthMethod string -} - -func (m *InMemoryProvider) CheckUsername(username string) (found bool, err error) { - _, ok := m.userPool.Load(username) - return ok, nil -} - -func (m *InMemoryProvider) GetCredential(username string) (credential Credential, found bool, err error) { - v, ok := m.userPool.Load(username) - if !ok { - return Credential{}, false, nil - } - c, valid := v.(Credential) - if !valid { - return Credential{}, true, errors.Errorf("invalid credential") - } - return c, true, nil -} - -func (m *InMemoryProvider) AddUser(username, password string, optionalAuthPluginName ...string) error { - authPluginName := m.defaultAuthMethod - if len(optionalAuthPluginName) > 0 { - authPluginName = optionalAuthPluginName[0] - } - - c, err := NewCredential(password, authPluginName) - if err != nil { - return err - } - - m.userPool.Store(username, c) - return nil -} - -type Provider InMemoryProvider diff --git a/server/handshake_resp.go b/server/handshake_resp.go index c50e2a612..7430ff1f7 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -200,7 +200,7 @@ func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { func (c *Conn) handleAuthMatch() (bool, error) { // if the client responds the handshake with a different auth method, the server will send the AuthSwitchRequest packet // to the client to ask the client to switch. - if err := c.acquirePassword(); err != nil { + if err := c.acquireCredential(); err != nil { return false, err } diff --git a/server/server_test.go b/server/server_test.go index c1e827f05..9be2d3eef 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -60,32 +60,32 @@ func prepareServerConf() []*Server { func Test(t *testing.T) { // general tests - inMemProvider := NewInMemoryProvider() + authHandler := NewInMemoryAuthenticationHandler() servers := prepareServerConf() // no TLS for _, svr := range servers { - inMemProvider.userPool.Clear() - err := inMemProvider.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) + authHandler.userPool.Clear() + err := authHandler.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) require.NoError(t, err) suite.Run(t, &serverTestSuite{ - server: svr, - credProvider: inMemProvider, - tlsPara: "false", + server: svr, + authHandler: authHandler, + tlsPara: "false", }) } // TLS if server supports for _, svr := range servers { if svr.tlsConfig != nil { - inMemProvider.userPool.Clear() - err := inMemProvider.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) + authHandler.userPool.Clear() + err := authHandler.AddUser(*testUser, *testPassword, svr.defaultAuthMethod) require.NoError(t, err) suite.Run(t, &serverTestSuite{ - server: svr, - credProvider: inMemProvider, - tlsPara: "skip-verify", + server: svr, + authHandler: authHandler, + tlsPara: "skip-verify", }) } } @@ -93,8 +93,8 @@ func Test(t *testing.T) { type serverTestSuite struct { suite.Suite - server *Server - credProvider CredentialProvider + server *Server + authHandler AuthenticationHandler tlsPara string @@ -144,7 +144,7 @@ func (s *serverTestSuite) onAccept() { func (s *serverTestSuite) onConn(conn net.Conn) { // co, err := NewConn(conn, *testUser, *testPassword, &testHandler{s}) - co, err := s.server.NewCustomizedConn(conn, s.credProvider, &testHandler{s}) + co, err := s.server.NewCustomizedConn(conn, s.authHandler, &testHandler{s}) require.NoError(s.T(), err) // set SSL if defined for {