diff --git a/cmd/reclaim_rent/main.go b/cmd/reclaim_rent/main.go index 5ffff268..496907ad 100644 --- a/cmd/reclaim_rent/main.go +++ b/cmd/reclaim_rent/main.go @@ -37,6 +37,7 @@ func init() { reclaimRentCmd.Flags().StringP("keypair", "k", "~/.config/solana/id.json", "The wallet to use as fee payer for transactions") reclaimRentCmd.Flags().StringP("destination", "d", "", "The recipient of reclaimed rent (defaults to fee payer)") reclaimRentCmd.Flags().StringP("program", "p", claimable_tokens.ProgramID.String(), "The claimable tokens program ID") + reclaimRentCmd.Flags().StringP("created-after", "", "", "Filter accounts created after this date (RFC3339 format, e.g., 2024-01-01T00:00:00Z)") } func reclaimRent(cmd *cobra.Command, args []string) error { @@ -98,9 +99,23 @@ func reclaimRent(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to derive authority: %w", err) } + createdAfterFlag, err := cmd.Flags().GetString("created-after") + if err != nil { + return fmt.Errorf("failed to get created-after flag: %w", err) + } + var createdAfter *time.Time + if createdAfterFlag != "" { + parsed, err := time.Parse(time.RFC3339, createdAfterFlag) + if err != nil { + return fmt.Errorf("failed to parse created-after date (must be RFC3339 format, e.g., 2024-01-01T00:00:00Z): %w", err) + } + createdAfter = &parsed + fmt.Printf("Filtering accounts created after: %s\n", createdAfter.Format(time.RFC3339)) + } + offset := 0 - totalCount, err := getTokenAccountsCountFromDatabase(ctx, pool, mint) + totalCount, err := getTokenAccountsCountFromDatabase(ctx, pool, mint, createdAfter) if err != nil { return fmt.Errorf("failed to get token accounts count from database: %w", err) } @@ -108,7 +123,7 @@ func reclaimRent(cmd *cobra.Command, args []string) error { limit := 1000 for { - accounts, err := getTokenAccountsFromDatabase(ctx, pool, mint, limit, offset) + accounts, err := getTokenAccountsFromDatabase(ctx, pool, mint, limit, offset, createdAfter) if err != nil { return fmt.Errorf("failed to get token accounts from database: %w", err) } @@ -286,32 +301,62 @@ type DatabaseAccount struct { EthereumAddress string } -func getTokenAccountsFromDatabase(ctx context.Context, pool *pgxpool.Pool, mint solana.PublicKey, limit, offset int) ([]DatabaseAccount, error) { - sql := ` - SELECT bank_account AS account, ethereum_address - FROM user_bank_accounts - LIMIT $1 OFFSET $2 - ` - rows, err := pool.Query(ctx, sql, limit, offset) - if err != nil { - return nil, fmt.Errorf("failed to query token accounts: %w", err) - } - - accounts, err := pgx.CollectRows(rows, pgx.RowToStructByName[DatabaseAccount]) - if err != nil { - return nil, fmt.Errorf("failed to collect token accounts: %w", err) +func getTokenAccountsFromDatabase(ctx context.Context, pool *pgxpool.Pool, mint solana.PublicKey, limit, offset int, createdAfter *time.Time) ([]DatabaseAccount, error) { + var sql string + + if createdAfter != nil { + sql = ` + SELECT bank_account AS account, ethereum_address + FROM user_bank_accounts + WHERE created_at > $1 + LIMIT $2 OFFSET $3 + ` + rows, err := pool.Query(ctx, sql, *createdAfter, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query token accounts: %w", err) + } + accounts, err := pgx.CollectRows(rows, pgx.RowToStructByName[DatabaseAccount]) + if err != nil { + return nil, fmt.Errorf("failed to collect token accounts: %w", err) + } + return accounts, nil + } else { + sql = ` + SELECT bank_account AS account, ethereum_address + FROM user_bank_accounts + LIMIT $1 OFFSET $2 + ` + rows, err := pool.Query(ctx, sql, limit, offset) + if err != nil { + return nil, fmt.Errorf("failed to query token accounts: %w", err) + } + accounts, err := pgx.CollectRows(rows, pgx.RowToStructByName[DatabaseAccount]) + if err != nil { + return nil, fmt.Errorf("failed to collect token accounts: %w", err) + } + return accounts, nil } - - return accounts, nil } -func getTokenAccountsCountFromDatabase(ctx context.Context, pool *pgxpool.Pool, mint solana.PublicKey) (int, error) { - sql := ` - SELECT COUNT(*) - FROM user_bank_accounts - ` +func getTokenAccountsCountFromDatabase(ctx context.Context, pool *pgxpool.Pool, mint solana.PublicKey, createdAfter *time.Time) (int, error) { var count int - err := pool.QueryRow(ctx, sql).Scan(&count) + var err error + + if createdAfter != nil { + sql := ` + SELECT COUNT(*) + FROM user_bank_accounts + WHERE created_at > $1 + ` + err = pool.QueryRow(ctx, sql, *createdAfter).Scan(&count) + } else { + sql := ` + SELECT COUNT(*) + FROM user_bank_accounts + ` + err = pool.QueryRow(ctx, sql).Scan(&count) + } + if err != nil { return 0, fmt.Errorf("failed to query token accounts count: %w", err) }