From 4aeed26a201eeba72826acd142db9c7ef20c4ab9 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Jul 2011 11:41:54 -0700 Subject: [PATCH 1/6] If db isn't setup, show helpful tips and don't crash. --- mysql_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/mysql_test.go b/mysql_test.go index 99c78f3..b7e0337 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -10,18 +10,21 @@ import ( "os" "rand" "strconv" + "sync" "testing" ) -const ( - // Testing credentials, run the following on server client prior to running: - // create database gomysql_test; - // create database gomysql_test2; - // create database gomysql_test3; - // create user gomysql_test@localhost identified by 'abc123'; - // grant all privileges on gomysql_test.* to gomysql_test@localhost; - // grant all privileges on gomysql_test2.* to gomysql_test@localhost; +const instructions = `To run the GoMySQL tests, run the following on the server first: + + create database gomysql_test; + create database gomysql_test2; + create database gomysql_test3; + create user gomysql_test@localhost identified by 'abc123'; + grant all privileges on gomysql_test.* to gomysql_test@localhost; + grant all privileges on gomysql_test2.* to gomysql_test@localhost; +` +const ( // Testing settings TEST_HOST = "localhost" TEST_PORT = "3306" @@ -49,8 +52,10 @@ const ( ) var ( - db *Client - err os.Error + db *Client + err os.Error + checkOnce sync.Once + skipTests bool ) type SimpleRow struct { @@ -61,8 +66,40 @@ type SimpleRow struct { Date string } +func verifyConnections() { + db, err = DialTCP(TEST_HOST, TEST_USER, TEST_PASSWD, TEST_DBNAME) + if db != nil { + db.Close() + } + if err != nil { + skipTests = true + os.Stderr.Write([]byte(instructions)) + return + } + db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) + if db != nil { + db.Close() + } + if err != nil { + skipTests = true + os.Stderr.Write([]byte(instructions)) + return + } +} + +func skipTest(t *testing.T) bool { + checkOnce.Do(verifyConnections) + if skipTests { + t.Logf("skipping test; see instructions") + } + return skipTests +} + // Test connect to server via TCP func TestDialTCP(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running DialTCP test to %s:%s", TEST_HOST, TEST_PORT) db, err = DialTCP(TEST_HOST, TEST_USER, TEST_PASSWD, TEST_DBNAME) if err != nil { @@ -78,6 +115,9 @@ func TestDialTCP(t *testing.T) { // Test connect to server via Unix socket func TestDialUnix(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running DialUnix test to %s", TEST_SOCK) db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) if err != nil { @@ -93,6 +133,9 @@ func TestDialUnix(t *testing.T) { // Test connect to server with unprivileged database func TestDialUnixUnpriv(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running DialUnix test to unprivileged database %s", TEST_DBNAMEUP) db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAMEUP) if err != nil { @@ -108,6 +151,9 @@ func TestDialUnixUnpriv(t *testing.T) { // Test connect to server with nonexistant database func TestDialUnixNonex(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running DialUnix test to nonexistant database %s", TEST_DBNAMEBAD) db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAMEBAD) if err != nil { @@ -123,6 +169,9 @@ func TestDialUnixNonex(t *testing.T) { // Test connect with bad password func TestDialUnixBadPass(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running DialUnix test with bad password") db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_BAD_PASSWD, TEST_DBNAME) if err != nil { @@ -138,6 +187,9 @@ func TestDialUnixBadPass(t *testing.T) { // Test queries on a simple table (create database, select, insert, update, drop database) func TestSimple(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running simple table tests") db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) if err != nil { @@ -267,6 +319,9 @@ func TestSimple(t *testing.T) { // Test queries on a simple table (create database, select, insert, update, drop database) using a statement func TestSimpleStatement(t *testing.T) { + if skipTest(t) { + return + } t.Logf("Running simple table statement tests") db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) if err != nil { From bb62395891c37ed9810acc7b565b11bc83335d80 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Jul 2011 11:42:41 -0700 Subject: [PATCH 2/6] Remove some unnecessary type assertions in type switch. --- statement.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/statement.go b/statement.go index d155f78..d14a18e 100644 --- a/statement.go +++ b/statement.go @@ -123,7 +123,7 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) { var t FieldType var d []byte // Switch on type - switch param.(type) { + switch p := param.(type) { // Nil case nil: t = FIELD_TYPE_NULL @@ -134,7 +134,7 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) { } else { t = FIELD_TYPE_LONGLONG } - d = itob(param.(int)) + d = itob(p) // Uint case uint: if strconv.IntSize == 32 { @@ -142,57 +142,57 @@ func (s *Statement) BindParams(params ...interface{}) (err os.Error) { } else { t = FIELD_TYPE_LONGLONG } - d = uitob(param.(uint)) + d = uitob(p) // Int8 case int8: t = FIELD_TYPE_TINY - d = []byte{byte(param.(int8))} + d = []byte{byte(p)} // Uint8 case uint8: t = FIELD_TYPE_TINY - d = []byte{param.(uint8)} + d = []byte{p} // Int16 case int16: t = FIELD_TYPE_SHORT - d = i16tob(param.(int16)) + d = i16tob(p) // Uint16 case uint16: t = FIELD_TYPE_SHORT - d = ui16tob(param.(uint16)) + d = ui16tob(p) // Int32 case int32: t = FIELD_TYPE_LONG - d = i32tob(param.(int32)) + d = i32tob(p) // Uint32 case uint32: t = FIELD_TYPE_LONG - d = ui32tob(param.(uint32)) + d = ui32tob(p) // Int64 case int64: t = FIELD_TYPE_LONGLONG - d = i64tob(param.(int64)) + d = i64tob(p) // Uint64 case uint64: t = FIELD_TYPE_LONGLONG - d = ui64tob(param.(uint64)) + d = ui64tob(p) // Float32 case float32: t = FIELD_TYPE_FLOAT - d = f32tob(param.(float32)) + d = f32tob(p) // Float64 case float64: t = FIELD_TYPE_DOUBLE - d = f64tob(param.(float64)) + d = f64tob(p) // String case string: t = FIELD_TYPE_STRING - d = lcbtob(uint64(len(param.(string)))) - d = append(d, []byte(param.(string))...) + d = lcbtob(uint64(len(p))) + d = append(d, []byte(p)...) // Byte array case []byte: t = FIELD_TYPE_BLOB - d = lcbtob(uint64(len(param.([]byte)))) - d = append(d, param.([]byte)...) + d = lcbtob(uint64(len(p))) + d = append(d, p...) // Other types default: return &ClientError{CR_UNSUPPORTED_PARAM_TYPE, s.c.fmtError(CR_UNSUPPORTED_PARAM_TYPE_STR, reflect.ValueOf(param).Type(), k)} From 63776d20cdef31ec35079673803f89d5a44cd09d Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Jul 2011 12:15:30 -0700 Subject: [PATCH 3/6] Remove more type assertions. --- statement.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/statement.go b/statement.go index d14a18e..12b1c47 100644 --- a/statement.go +++ b/statement.go @@ -694,32 +694,32 @@ func (s *Statement) getResult(types packetType) (eof bool, err os.Error) { // Log read result s.c.log(1, "Reading result packet from server") // Get result packet - p, err := s.c.r.readPacket(types) + pr, err := s.c.r.readPacket(types) if err != nil { return } // Process result packet - switch p.(type) { + switch p := pr.(type) { default: err = &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} case *packetOK: - err = handleOK(p.(*packetOK), s.c, &s.AffectedRows, &s.LastInsertId, &s.Warnings) + err = handleOK(p, s.c, &s.AffectedRows, &s.LastInsertId, &s.Warnings) case *packetError: - err = handleError(p.(*packetError), s.c) + err = handleError(p, s.c) case *packetEOF: eof = true - err = handleEOF(p.(*packetEOF), s.c) + err = handleEOF(p, s.c) case *packetPrepareOK: - err = handlePrepareOK(p.(*packetPrepareOK), s.c, s) + err = handlePrepareOK(p, s.c, s) case *packetParameter: - err = handleParam(p.(*packetParameter), s.c) + err = handleParam(p, s.c) case *packetField: - err = handleField(p.(*packetField), s.c, s.result) + err = handleField(p, s.c, s.result) case *packetResultSet: s.result = &Result{c: s.c} - err = handleResultSet(p.(*packetResultSet), s.c, s.result) + err = handleResultSet(p, s.c, s.result) case *packetRowBinary: - err = handleBinaryRow(p.(*packetRowBinary), s.c, s.result) + err = handleBinaryRow(p, s.c, s.result) } return } From e8fe63895b8de1aa313a1c871615507f3822aafe Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Jul 2011 12:59:20 -0700 Subject: [PATCH 4/6] Add .gitignore. --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..43c07ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +*~ +*.out +_test* +_obj* +_go* From f016c76bb28cd177220fef27b7be7d30445262c1 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Jul 2011 12:59:37 -0700 Subject: [PATCH 5/6] Add UseResult method on Statement. --- mysql_test.go | 182 ++++++++++++++++++++++++++++++++------------------ result.go | 19 +++++- statement.go | 17 +++++ 3 files changed, 150 insertions(+), 68 deletions(-) diff --git a/mysql_test.go b/mysql_test.go index b7e0337..a0d862b 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -45,6 +45,7 @@ const ( UPDATE_SIMPLE = "UPDATE simple SET `text` = '%s', `datetime` = NOW() WHERE id = %d" UPDATE_SIMPLE_STMT = "UPDATE simple SET `text` = ?, `datetime` = NOW() WHERE id = ?" DROP_SIMPLE = "DROP TABLE `simple`" + DROP_SIMPLE_MAYBE = "DROP TABLE IF EXISTS `simple`" // All types table queries CREATE_ALLTYPES = "CREATE TABLE `all_types` (`id` SERIAL NOT NULL, `tiny_int` TINYINT NOT NULL, `tiny_uint` TINYINT UNSIGNED NOT NULL, `small_int` SMALLINT NOT NULL, `small_uint` SMALLINT UNSIGNED NOT NULL, `medium_int` MEDIUMINT NOT NULL, `medium_uint` MEDIUMINT UNSIGNED NOT NULL, `int` INT NOT NULL, `uint` INT UNSIGNED NOT NULL, `big_int` BIGINT NOT NULL, `big_uint` BIGINT UNSIGNED NOT NULL, `decimal` DECIMAL(10,4) NOT NULL, `float` FLOAT NOT NULL, `double` DOUBLE NOT NULL, `real` REAL NOT NULL, `bit` BIT(32) NOT NULL, `boolean` BOOLEAN NOT NULL, `date` DATE NOT NULL, `datetime` DATETIME NOT NULL, `timestamp` TIMESTAMP NOT NULL, `time` TIME NOT NULL, `year` YEAR NOT NULL, `char` CHAR(32) NOT NULL, `varchar` VARCHAR(32) NOT NULL, `tiny_text` TINYTEXT NOT NULL, `text` TEXT NOT NULL, `medium_text` MEDIUMTEXT NOT NULL, `long_text` LONGTEXT NOT NULL, `binary` BINARY(32) NOT NULL, `var_binary` VARBINARY(32) NOT NULL, `tiny_blob` TINYBLOB NOT NULL, `medium_blob` MEDIUMBLOB NOT NULL, `blob` BLOB NOT NULL, `long_blob` LONGBLOB NOT NULL, `enum` ENUM('a','b','c','d','e') NOT NULL, `set` SET('a','b','c','d','e') NOT NULL, `geometry` GEOMETRY NOT NULL) ENGINE = InnoDB CHARACTER SET utf8 COLLATE utf8_unicode_ci COMMENT = 'GoMySQL Test Suite All Types Table'" @@ -196,14 +197,15 @@ func TestSimple(t *testing.T) { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Create table") + db.Query(DROP_SIMPLE_MAYBE) err = db.Query(CREATE_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Insert 1000 records") rowMap := make(map[uint64][]string) for i := 0; i < 1000; i++ { @@ -216,21 +218,21 @@ func TestSimple(t *testing.T) { row := []string{fmt.Sprintf("%d", num), str1, str2} rowMap[db.LastInsertId] = row } - + t.Logf("Select inserted data") err = db.Query(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Use result") res, err := db.UseResult() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Validate inserted data") for { row := res.FetchRow() @@ -239,19 +241,23 @@ func TestSimple(t *testing.T) { } id := row[0].(uint64) num, str1, str2 := strconv.Itoa64(row[1].(int64)), row[2].(string), string(row[3].([]byte)) - if rowMap[id][0] != num || rowMap[id][1] != str1 || rowMap[id][2] != str2 { + expectRow, ok := rowMap[id] + if !ok { + t.Fatalf("read unexpected row number %d", id) + } + if expectRow[0] != num || expectRow[1] != str1 || expectRow[2] != str2 { t.Logf("String from database doesn't match local string") t.Fail() } } - + t.Logf("Free result") err = res.Free() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Update some records") for i := uint64(0); i < 1000; i += 5 { rowMap[i+1][2] = randString(256) @@ -265,21 +271,21 @@ func TestSimple(t *testing.T) { t.Fail() } } - + t.Logf("Select updated data") err = db.Query(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Store result") res, err = db.StoreResult() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Validate updated data") for { row := res.FetchRow() @@ -294,7 +300,7 @@ func TestSimple(t *testing.T) { t.Fail() } } - + t.Logf("Free result") err = res.Free() if err != nil { @@ -308,7 +314,7 @@ func TestSimple(t *testing.T) { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Close connection") err = db.Close() if err != nil { @@ -317,82 +323,79 @@ func TestSimple(t *testing.T) { } } -// Test queries on a simple table (create database, select, insert, update, drop database) using a statement -func TestSimpleStatement(t *testing.T) { - if skipTest(t) { - return - } - t.Logf("Running simple table statement tests") - db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) - if err != nil { - t.Logf("Error %s", err) - t.Fail() - } - - t.Logf("Init statement") +func insert1000Records(t *testing.T, db *Client) map[uint64][]string { stmt, err := db.InitStmt() if err != nil { - t.Logf("Error %s", err) - t.Fail() + t.Fatalf("InitStmt: %v", err) } - - t.Logf("Prepare create table") - err = stmt.Prepare(CREATE_SIMPLE) - if err != nil { - t.Logf("Error %s", err) - t.Fail() - } - - t.Logf("Execute create table") - err = stmt.Execute() - if err != nil { - t.Logf("Error %s", err) - t.Fail() - } - - t.Logf("Prepare insert") + err = stmt.Prepare(INSERT_SIMPLE_STMT) if err != nil { - t.Logf("Error %s", err) - t.Fail() + t.Logf("Prepare insert: %v", err) } - + t.Logf("Insert 1000 records") rowMap := make(map[uint64][]string) for i := 0; i < 1000; i++ { num, str1, str2 := rand.Int(), randString(32), randString(128) err = stmt.BindParams(num, str1, str2) if err != nil { - t.Logf("Error %s", err) - t.Fail() + t.Fatalf("Error %s", err) } err = stmt.Execute() if err != nil { - t.Logf("Error %s", err) - t.Fail() + t.Fatalf("Error %s", err) } row := []string{fmt.Sprintf("%d", num), str1, str2} rowMap[stmt.LastInsertId] = row } - + return rowMap +} + +// Test queries on a simple table (create database, select, insert, update, drop database) using a statement +func TestSimpleStatement(t *testing.T) { + if skipTest(t) { + return + } + t.Logf("Running simple table statement tests") + db, err = DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) + if err != nil { + t.Logf("Error %s", err) + t.Fail() + } + + db.Query(DROP_SIMPLE_MAYBE) + err := db.Query(CREATE_SIMPLE) + if err != nil { + t.Fatalf("create table: %v", err) + } + defer db.Query(DROP_SIMPLE) + + rowMap := insert1000Records(t, db) + + stmt, err := db.InitStmt() + if err != nil { + t.Fatalf("InitStmt: %v", err) + } + t.Logf("Prepare select") err = stmt.Prepare(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute select") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Bind result") row := SimpleRow{} stmt.BindResult(&row.Id, &row.Number, &row.String, &row.Text, &row.Date) - + t.Logf("Validate inserted data") for { eof, err := stmt.Fetch() @@ -408,21 +411,21 @@ func TestSimpleStatement(t *testing.T) { t.Fail() } } - + t.Logf("Reset statement") err = stmt.Reset() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Prepare update") err = stmt.Prepare(UPDATE_SIMPLE_STMT) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Update some records") for i := uint64(0); i < 1000; i += 5 { rowMap[i+1][2] = randString(256) @@ -437,21 +440,21 @@ func TestSimpleStatement(t *testing.T) { t.Fail() } } - + t.Logf("Prepare select updated") err = stmt.Prepare(SELECT_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute select updated") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Validate updated data") for { eof, err := stmt.Fetch() @@ -467,35 +470,35 @@ func TestSimpleStatement(t *testing.T) { t.Fail() } } - + t.Logf("Free result") err = stmt.FreeResult() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Prepare drop") err = stmt.Prepare(DROP_SIMPLE) if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Execute drop") err = stmt.Execute() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Close statement") err = stmt.Close() if err != nil { t.Logf("Error %s", err) t.Fail() } - + t.Logf("Close connection") err = db.Close() if err != nil { @@ -504,6 +507,53 @@ func TestSimpleStatement(t *testing.T) { } } +func TestStatementUseResult(t *testing.T) { + if skipTest(t) { + return + } + db, err := DialUnix(TEST_SOCK, TEST_USER, TEST_PASSWD, TEST_DBNAME) + if err != nil { + t.Fatalf("dial error: %v", err) + } + defer db.Close() + + db.Query(DROP_SIMPLE_MAYBE) + err = db.Query(CREATE_SIMPLE) + if err != nil { + t.Fatalf("create table: %v", err) + } + defer db.Query(DROP_SIMPLE) + + insert1000Records(t, db) + stmt, err := db.Prepare(SELECT_SIMPLE) + if err != nil { + t.Fatalf("Prepare select: %v", err) + } + err = stmt.Execute() + if err != nil { + t.Fatalf("Execute: %v", err) + } + res, err := stmt.UseResult() + if err != nil { + t.Fatalf("UseResult: %v", err) + } + nRows := 0 + for { + row := res.FetchRow() + if row == nil { + break + } + nRows++ + } + if nRows != 1000 { + t.Errorf("expected 1000 rows; got %d", nRows) + } + err = res.Free() + if err != nil { + t.Logf("Free result: %s", err) + } +} + // Benchmark connect/handshake via TCP func BenchmarkDialTCP(b *testing.B) { for i := 0; i < b.N; i++ { diff --git a/result.go b/result.go index ed8a758..31281cf 100644 --- a/result.go +++ b/result.go @@ -12,6 +12,10 @@ type Result struct { // Pointer to the client c *Client + // if non-nil, the Result came from a Statement, not + // via Client.Query. + s *Statement + // Fields fieldCount uint64 fieldPos uint64 @@ -71,6 +75,13 @@ func (r *Result) RowCount() uint64 { return 0 } +func (r *Result) getRow() (eof bool, err os.Error) { + if r.s != nil { + return r.s.getRow() + } + return r.c.getRow() +} + // Fetch a row func (r *Result) FetchRow() Row { // Stored result @@ -85,7 +96,7 @@ func (r *Result) FetchRow() Row { // Used result if r.mode == RESULT_USED { if r.allRead == false { - eof, err := r.c.getRow() + eof, err := r.getRow() if err != nil { return nil } @@ -123,6 +134,10 @@ func (r *Result) FetchRows() []Row { // Free the result func (r *Result) Free() (err os.Error) { - err = r.c.FreeResult() + if r.s != nil { + err = r.s.FreeResult() + } else { + err = r.c.FreeResult() + } return } diff --git a/statement.go b/statement.go index 12b1c47..58aa33e 100644 --- a/statement.go +++ b/statement.go @@ -472,6 +472,23 @@ func (s *Statement) Fetch() (eof bool, err os.Error) { return } +// Use result +func (s *Statement) UseResult() (*Result, os.Error) { + // Log use result + s.c.log(1, "=== Begin use result ===") + // Check prepared + if !s.prepared { + return nil, &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR} + } + // Check if result already used/stored + if s.result.mode != RESULT_UNUSED { + return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} + } + s.result.mode = RESULT_USED + s.result.s = s // tell the result that we own it + return s.result, nil +} + // Store result func (s *Statement) StoreResult() (err os.Error) { // Auto reconnect From 0002d16577e29e1e99b7f0bf72e8347db7b7af71 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 4 Jul 2011 14:54:09 -0700 Subject: [PATCH 6/6] Return an error in UseResult if there's no result set. --- statement.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/statement.go b/statement.go index 58aa33e..209d0af 100644 --- a/statement.go +++ b/statement.go @@ -480,6 +480,10 @@ func (s *Statement) UseResult() (*Result, os.Error) { if !s.prepared { return nil, &ClientError{CR_NO_PREPARE_STMT, CR_NO_PREPARE_STMT_STR} } + // Check result + if !s.checkResult() { + return nil, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} + } // Check if result already used/stored if s.result.mode != RESULT_UNUSED { return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR}