diff --git a/commands/fetch-awesomepoc.go b/commands/fetch-awesomepoc.go index 16a625f..2700b62 100644 --- a/commands/fetch-awesomepoc.go +++ b/commands/fetch-awesomepoc.go @@ -30,7 +30,7 @@ func fetchAwesomePoc(_ *cobra.Command, _ []string) (err error) { return xerrors.Errorf("Failed to SetLogger. err: %w", err) } - driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{}) + driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{BatchSize: viper.GetInt("batch-size")}) if err != nil { if xerrors.Is(err, db.ErrDBLocked) { return xerrors.Errorf("Failed to open DB. Close DB connection before fetching. err: %w", err) diff --git a/commands/fetch-exploitdb.go b/commands/fetch-exploitdb.go index 8a84864..73fa0c8 100644 --- a/commands/fetch-exploitdb.go +++ b/commands/fetch-exploitdb.go @@ -30,7 +30,7 @@ func fetchExploitDB(_ *cobra.Command, _ []string) (err error) { return xerrors.Errorf("Failed to SetLogger. err: %w", err) } - driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{}) + driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{BatchSize: viper.GetInt("batch-size")}) if err != nil { if xerrors.Is(err, db.ErrDBLocked) { return xerrors.Errorf("Failed to open DB. Close DB connection before fetching. err: %w", err) diff --git a/commands/fetch-githubrepos.go b/commands/fetch-githubrepos.go index 6abbfab..a5a9b85 100644 --- a/commands/fetch-githubrepos.go +++ b/commands/fetch-githubrepos.go @@ -36,7 +36,7 @@ func fetchGitHubRepos(_ *cobra.Command, _ []string) (err error) { return xerrors.Errorf("Failed to SetLogger. err: %w", err) } - driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{}) + driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{BatchSize: viper.GetInt("batch-size")}) if err != nil { if xerrors.Is(err, db.ErrDBLocked) { return xerrors.Errorf("Failed to open DB. Close DB connection before fetching. err: %w", err) diff --git a/commands/fetch-inthewild.go b/commands/fetch-inthewild.go index da57e79..504371f 100644 --- a/commands/fetch-inthewild.go +++ b/commands/fetch-inthewild.go @@ -30,7 +30,7 @@ func fetchInTheWild(_ *cobra.Command, _ []string) (err error) { return xerrors.Errorf("Failed to SetLogger. err: %w", err) } - driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{}) + driver, err := db.NewDB(viper.GetString("dbtype"), viper.GetString("dbpath"), viper.GetBool("debug-sql"), db.Option{BatchSize: viper.GetInt("batch-size")}) if err != nil { if xerrors.Is(err, db.ErrDBLocked) { return xerrors.Errorf("Failed to open DB. Close DB connection before fetching. err: %w", err) diff --git a/db/db.go b/db/db.go index e772a3c..3f64dac 100644 --- a/db/db.go +++ b/db/db.go @@ -30,6 +30,7 @@ type DB interface { // Option : type Option struct { RedisTimeout time.Duration + BatchSize int } // NewDB : diff --git a/db/rdb.go b/db/rdb.go index c78abea..fbfc5b9 100644 --- a/db/rdb.go +++ b/db/rdb.go @@ -12,7 +12,6 @@ import ( "github.com/cheggaaa/pb/v3" "github.com/glebarez/sqlite" "github.com/inconshreveable/log15" - "github.com/spf13/viper" "golang.org/x/xerrors" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -32,8 +31,9 @@ const ( // RDBDriver : type RDBDriver struct { - name string - conn *gorm.DB + name string + conn *gorm.DB + batchSize int } // https://github.com/mattn/go-sqlite3/blob/edc3bb69551dcfff02651f083b21f3366ea2f5ab/error.go#L18-L66 @@ -58,7 +58,9 @@ func (r *RDBDriver) Name() string { } // OpenDB opens Database -func (r *RDBDriver) OpenDB(dbType, dbPath string, debugSQL bool, _ Option) (err error) { +func (r *RDBDriver) OpenDB(dbType, dbPath string, debugSQL bool, option Option) (err error) { + r.batchSize = option.BatchSize + gormConfig := gorm.Config{ DisableForeignKeyConstraintWhenMigrating: true, Logger: logger.New( @@ -197,8 +199,7 @@ func (r *RDBDriver) deleteAndInsertExploit(exploitType models.ExploitType, explo tx.Commit() }() - batchSize := viper.GetInt("batch-size") - if batchSize < 1 { + if r.batchSize < 1 { return xerrors.New("Failed to set batch-size. err: batch-size option is not set properly") } @@ -210,7 +211,7 @@ func (r *RDBDriver) deleteAndInsertExploit(exploitType models.ExploitType, explo if len(oldIDs) > 0 { log15.Info("Deleting old Exploits") bar := pb.StartNew(len(oldIDs)) - for idx := range chunkSlice(len(oldIDs), batchSize) { + for idx := range chunkSlice(len(oldIDs), r.batchSize) { osIDs := []models.OffensiveSecurity{} if err := tx.Model(&models.OffensiveSecurity{}).Select("id").Where("exploit_id IN ?", oldIDs[idx.From:idx.To]).Find(&osIDs).Error; err != nil { return xerrors.Errorf("Failed to select old OffensiveSecurity: %w", err) @@ -252,7 +253,7 @@ func (r *RDBDriver) deleteAndInsertExploit(exploitType models.ExploitType, explo log15.Info("Inserting new Exploits") bar := pb.StartNew(len(exploits)) - for idx := range chunkSlice(len(exploits), batchSize) { + for idx := range chunkSlice(len(exploits), r.batchSize) { if err = tx.Create(exploits[idx.From:idx.To]).Error; err != nil { return xerrors.Errorf("Failed to insert. err: %w", err) } diff --git a/db/redis.go b/db/redis.go index 990f38e..35d7200 100644 --- a/db/redis.go +++ b/db/redis.go @@ -11,7 +11,6 @@ import ( "github.com/cheggaaa/pb/v3" "github.com/go-redis/redis/v8" "github.com/inconshreveable/log15" - "github.com/spf13/viper" "golang.org/x/xerrors" "github.com/vulsio/go-exploitdb/config" @@ -55,8 +54,9 @@ const ( // RedisDriver is Driver for Redis type RedisDriver struct { - name string - conn *redis.Client + name string + conn *redis.Client + batchSize int } // Name return db name @@ -66,6 +66,7 @@ func (r *RedisDriver) Name() string { // OpenDB opens Database func (r *RedisDriver) OpenDB(_, dbPath string, _ bool, option Option) error { + r.batchSize = option.BatchSize if err := r.connectRedis(dbPath, option); err != nil { return xerrors.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %w", dialectRedis, dbPath, err) } @@ -305,8 +306,7 @@ func (r *RedisDriver) GetExploitMultiByCveID(cveIDs []string) (map[string][]mode // InsertExploit : func (r *RedisDriver) InsertExploit(exploitType models.ExploitType, exploits []models.Exploit) (err error) { ctx := context.Background() - batchSize := viper.GetInt("batch-size") - if batchSize < 1 { + if r.batchSize < 1 { return xerrors.Errorf("Failed to set batch-size. err: batch-size option is not set properly") } @@ -326,7 +326,7 @@ func (r *RedisDriver) InsertExploit(exploitType models.ExploitType, exploits []m bar := pb.StartNew(len(exploits)) var noCveIDExploitCount, cveIDExploitCount int - for idx := range chunkSlice(len(exploits), batchSize) { + for idx := range chunkSlice(len(exploits), r.batchSize) { pipe := r.conn.Pipeline() for _, exploit := range exploits[idx.From:idx.To] { j, err := json.Marshal(exploit)