Skip to content

Implementation into openpi #8

@rohan-bansal

Description

@rohan-bansal

Hey there, this is a great work!

What would it take to implement this into the openpi repository? I've currently created a new model file in src/openpi/models based off of BaseModel, except it calls Pi0Infererence.forward after processing observations in the sample_actions() function instead of proceeding with denoising steps.

    @override
    @torch.no_grad()
    def sample_actions(self, device, observation: _model.Observation, 
                    noise=None, num_steps=10) -> torch.Tensor:
        
        bsize = observation.state.shape[0]
        if noise is None:
            actions_shape = (bsize, self.config.action_horizon, self.config.action_dim)
            noise = self.sample_noise(actions_shape, device)

        # use pi0_pytorch preprocessing func
        images, img_masks, lang_tokens, lang_masks, state = self._preprocess_observation(observation, train=False)
        

        images = images[:self.num_views]
        images_converted = []
        for img in images:
            img = img[0]
            img = img.permute(1, 2, 0)  # Convert to (H, W, C)
            images_converted.append(img)
        
        images_stacked = torch.stack(images_converted, dim=0)  # (num_views, H, W, C)
        images_bf16 = images_stacked.to(dtype=torch.bfloat16, device='cuda')
        
        state_bf16 = state[0].to(dtype=torch.bfloat16, device='cuda')
        noise_input = noise[0].to(dtype=torch.bfloat16, device='cuda')
        
        # Run inference
        output_actions = self.inference_engine.forward(
            images_bf16,
            state_bf16,
            noise_input
        )
        
        # (1, action_horizon, action_dim)
        return output_actions.unsqueeze(0).to(dtype=torch.float32)

I'm noticing problems in simulation rollout (LIBERO), where robot motion is unintelligible. What am I doing wrong here?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions