Skip to content

Tuple indexed dataframe  #2

@luistelmocosta

Description

@luistelmocosta

Hey, Olga!
First, I would love to thank you for your amazing work. The article is so well written and explained and your code looks super clean. I have been trying to apply your work to a project I am working on but I am having some issues and I am not being able to adapt what you did.

I am trying to apply AL to my test_dataset only, so in the query_the_oracle function my unlabeled_idx are all the indices of the test_dataset. However, my index is a tuple (it refers to the coordinates of a huge image) . So, I adapted the code to something like this:

unlabeled_idx = dataset.ids[0:200]
idx_x = []
idx_y = []
for idx in dataset.ids[0:1000]:
    idx_x.append(idx[0])
    idx_y.append(idx[1])

idx_xmax = max(idx_x)
idx_ymax = max(idx_y)
#print(unlabeled_idx)
num_workers = 4
# Select a pool of samples to query from
if pool_size > 0:    
    pool_idx_x = random.sample(range(1, idx_xmax), pool_size)
    pool_idx_y = random.sample(range(1, idx_ymax), pool_size)
    pool_idx =list(zip(pool_idx_x,pool_idx_y))
    
    pool_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
                                          #sampler=SubsetRandomSampler(unlabeled_idx[pool_idx]))
                                          sampler=SubsetRandomSampler(pool_idx))
else:
    pool_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers,
                                          sampler=SubsetRandomSampler(unlabeled_idx)) 

However, when I try to run, I get the following error in the margin_query function:
TypeError: list indices must be integers or slices, not tuple

Do you have any idea how can I fix this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions