Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/vla 2 #583

Open
wants to merge 33 commits into
base: user/rcadene/2024_10_07_vla
Choose a base branch
from

Conversation

mshukor
Copy link

@mshukor mshukor commented Dec 16, 2024

What this does

  • Fix the reward 0 bug (due to not normalizing the targets)
  • Support encoder and decoder for ACT
  • Support robot states as input to the action decoder
  • Some features related to loading hf models

How it was tested


ENV=aloha
ENV_TASK=AlohaTransferCube-v0
dataset_repo_id=lerobot/aloha_sim_transfer_cube_human


policy=vla
LR=3e-5 #1e-5
LR_SCHEDULER=
USE_AMP=true
PRECISION=fp16

ASYNC_ENV=false

FEAT_SELECT=all_generated


VLM=google/paligemma2-3b-pt-224
VLM_NAME=paligemma2_3b
VLM_DIM=2304
NUM_IMG_TOKENS=598

USE_PROMNPT_TEMPLATE=false

ACTION_DECODER=act_decoder

DIM_MODEL=512
LORA_R=4

PEFT_METHOD=lora


USE_ACTION_CONNECTOR=true

TASK_NAME=lerobot_${ENV}_transfer_cube_${policy}_${ACTION_DECODER}_${VLM_NAME}_${PEFT_METHOD}_feat_select_${FEAT_SELECT}

GPUS=1
EVAL_FREQ=5000 #51000 #10000 51000
OFFLINE_STEPS=100000 #25000 17000 12500 50000
TRAIN_BATCH_SIZE=8
EVAL_BATCH_SIZE=8

SAVE_FREQ=5000


MUJOCO_GL=egl python lerobot/scripts/train.py \
 hydra.job.name=base_distributed_aloha_transfer_cube \
 hydra.run.dir=$WORK/logs/lerobot/${TASK_NAME} \
 dataset_repo_id=$dataset_repo_id \
 policy=$policy \
 env=$ENV env.task=$ENV_TASK \
 training.offline_steps=$OFFLINE_STEPS training.batch_size=$TRAIN_BATCH_SIZE training.save_freq=$SAVE_FREQ \
 training.eval_freq=$EVAL_FREQ eval.n_episodes=50 eval.use_async_envs=$ASYNC_ENV eval.batch_size=$EVAL_BATCH_SIZE \
 training.lr=$LR training.lr_backbone=$LR \
 wandb.enable=false use_amp=$USE_AMP precision=$PRECISION \
 policy.vlm_backbone.feature_selection=$FEAT_SELECT policy.vlm_backbone.name=$VLM policy.action_decoder.dim_model=$DIM_MODEL \
 policy.use_prompt_template=$USE_PROMNPT_TEMPLATE  policy.num_img_tokens=$NUM_IMG_TOKENS policy.peft_config.r=$LORA_R policy.peft_method=$PEFT_METHOD \
 policy.use_action_connector=$USE_ACTION_CONNECTOR policy.vlm_backbone.hidden_size=$VLM_DIM  policy.action_decoder.name=$ACTION_DECODER 



How to checkout & try? (for the reviewer)

@danaaubakirova danaaubakirova self-requested a review December 17, 2024 18:08
Copy link
Collaborator

@danaaubakirova danaaubakirova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Good catch on the bug and nicely done with adding observation.states as an input! If you could explain some design choices and the code works after testing/inference, this will be approved 👍

@@ -513,6 +514,49 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
return actions, (mu, log_sigma_x2)


class ACTEncoderDecoder(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share the motivation behind introducing the ACTEncoderDecoder module?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At first I wanted to see if the 0 reward is coming from small action decoder. I think we should compare different design choices for the action decoder and this makes it a bit easier for ACT (ACT decoder only vs ACT encoder decoder)

self.action_decoder = ACTDecoder(action_decoder_config)
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"])
self.use_robot_state = "observation.state" in config.input_shapes
if "act" in self.action_decoder_name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the new ActEncoderDecoder?

To me this if-else block related to act looks a bit confusing, we could rather rename it completely or structure in this way

if "act" in self.action_decoder_name:
   action_decoder_config = OmegaConf.create(config.action_decoder)       
   if self.action_decoder_name == "act_decoder":
          # Use standalone ACTDecoder
          self.action_decoder = ACTDecoder(action_decoder_config)
          self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"])
      else:
          # Use ACTEncoderDecoder, decide whether to include the encoder
          use_encoder = "decoder" not in self.action_decoder_name
          self.action_decoder = ACTEncoderDecoder(action_decoder_config, use_encoder=use_encoder)

Or even better: if ActDecoder is equivalent to ActEncoderDecoder with use_encoder=False then we can remove the use of ActDecoder completely and leave only the option use_encoder for act to avoid confusion and redundancy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, ACTEncoderDecoder is more general but I kept it to be able your old checkpoints without changing its keys. If this is not needed anymore (e.g. we trained a good VLA with this branch) we should keep only the ACTEncoderDecoder

lerobot/configs/policy/vla.yaml Show resolved Hide resolved
lerobot/common/policies/vla/modeling_vla.py Show resolved Hide resolved
@IlIllllll
Copy link

Hi, really great job ! is this branch already trainable and reasoning, is it able to use datasetv2.0 or how can it be modified, I'm interested in it :)

@danaaubakirova
Copy link
Collaborator

Hi @IlIllllll! Thanks for your interest. Both this PR and branch are work in progress. It is trainable. For now, we didn't test the compatibility with datasetv2.0. For the final PR we will make sure it is up-to-date with main. Meanwhile, if you have any ideas or suggestions, feel free to drop them here or in our discord channel ^^

@@ -56,6 +56,7 @@ def __init__(
self,
config: ACTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
**kwargs: Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs prevents the loading of older ACT checkpoints that were trained on previous branches, as those checkpoints do not include support for **kwargs. However, this works when training and loading checkpoints within the same setup.

It throws the following error when I try to evaluate the old ACT chckpts:
File "/raid/dana/miniconda3/envs/lerobot/lib/python3.10/site-packages/huggingface_hub/hub_mixin.py", line 562, in from_pretrained for key, value in config.items(): AttributeError: 'ACTConfig' object has no attribute 'items'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess same will be for other policies. So, we need to address the backward compatibility issues.

I tried to add this:
for key, value in kwargs.items(): setattr(self, key, value)
to handle missing kwargs, it doesn't work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this error when loading checkpoints. So I fixed this with the iter() and items() for VLAConfig (https://github.com/mshukor/lerobot/blob/feature/vla_2/lerobot/common/policies/vla/configuration_vla.py#L190)

Are you sure this error is coming from **kwargs?

@@ -95,23 +98,23 @@ def make_policy(
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
)

precision = torch.float16 if "fp16" in hydra_cfg.precision else torch.float32
Copy link
Collaborator

@danaaubakirova danaaubakirova Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here if you are loading older checkpoints, it throws an error.

omegaconf.errors.ConfigAttributeError: Key 'precision' is not in struct full_key: precision object_type=dict

So, I force it to be

precision=None

if I the older checkpoints didn't have this arg.
Again, backward compatibility issue. Do you think abstracting precision from training config and leave it only inside vla related config will solve the issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants