-
Notifications
You must be signed in to change notification settings - Fork 805
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/vla 2 #583
base: user/rcadene/2024_10_07_vla
Are you sure you want to change the base?
Feature/vla 2 #583
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Good catch on the bug and nicely done with adding observation.states as an input! If you could explain some design choices and the code works after testing/inference, this will be approved 👍
@@ -513,6 +514,49 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso | |||
return actions, (mu, log_sigma_x2) | |||
|
|||
|
|||
class ACTEncoderDecoder(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please share the motivation behind introducing the ACTEncoderDecoder
module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At first I wanted to see if the 0 reward is coming from small action decoder. I think we should compare different design choices for the action decoder and this makes it a bit easier for ACT (ACT decoder only vs ACT encoder decoder)
self.action_decoder = ACTDecoder(action_decoder_config) | ||
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"]) | ||
self.use_robot_state = "observation.state" in config.input_shapes | ||
if "act" in self.action_decoder_name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this related to the new ActEncoderDecoder?
To me this if-else block related to act looks a bit confusing, we could rather rename it completely or structure in this way
if "act" in self.action_decoder_name:
action_decoder_config = OmegaConf.create(config.action_decoder)
if self.action_decoder_name == "act_decoder":
# Use standalone ACTDecoder
self.action_decoder = ACTDecoder(action_decoder_config)
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.action_decoder["dim_model"])
else:
# Use ACTEncoderDecoder, decide whether to include the encoder
use_encoder = "decoder" not in self.action_decoder_name
self.action_decoder = ACTEncoderDecoder(action_decoder_config, use_encoder=use_encoder)
Or even better: if ActDecoder is equivalent to ActEncoderDecoder
with use_encoder=False
then we can remove the use of ActDecoder completely and leave only the option use_encoder for act to avoid confusion and redundancy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ACTEncoderDecoder is more general but I kept it to be able your old checkpoints without changing its keys. If this is not needed anymore (e.g. we trained a good VLA with this branch) we should keep only the ACTEncoderDecoder
Hi, really great job ! is this branch already trainable and reasoning, is it able to use datasetv2.0 or how can it be modified, I'm interested in it :) |
Hi @IlIllllll! Thanks for your interest. Both this PR and branch are work in progress. It is trainable. For now, we didn't test the compatibility with datasetv2.0. For the final PR we will make sure it is up-to-date with main. Meanwhile, if you have any ideas or suggestions, feel free to drop them here or in our discord channel ^^ |
@@ -56,6 +56,7 @@ def __init__( | |||
self, | |||
config: ACTConfig | None = None, | |||
dataset_stats: dict[str, dict[str, Tensor]] | None = None, | |||
**kwargs: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
**kwargs
prevents the loading of older ACT checkpoints that were trained on previous branches, as those checkpoints do not include support for **kwargs. However, this works when training and loading checkpoints within the same setup.
It throws the following error when I try to evaluate the old ACT chckpts:
File "/raid/dana/miniconda3/envs/lerobot/lib/python3.10/site-packages/huggingface_hub/hub_mixin.py", line 562, in from_pretrained for key, value in config.items(): AttributeError: 'ACTConfig' object has no attribute 'items'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess same will be for other policies. So, we need to address the backward compatibility issues.
I tried to add this:
for key, value in kwargs.items(): setattr(self, key, value)
to handle missing kwargs, it doesn't work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this error when loading checkpoints. So I fixed this with the iter() and items() for VLAConfig (https://github.com/mshukor/lerobot/blob/feature/vla_2/lerobot/common/policies/vla/configuration_vla.py#L190)
Are you sure this error is coming from **kwargs?
@@ -95,23 +98,23 @@ def make_policy( | |||
raise ValueError( | |||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided." | |||
) | |||
|
|||
precision = torch.float16 if "fp16" in hydra_cfg.precision else torch.float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here if you are loading older checkpoints, it throws an error.
omegaconf.errors.ConfigAttributeError: Key 'precision' is not in struct full_key: precision object_type=dict
So, I force it to be
precision=None
if I the older checkpoints didn't have this arg.
Again, backward compatibility issue. Do you think abstracting precision
from training config and leave it only inside vla related config will solve the issue?
What this does
How it was tested
How to checkout & try? (for the reviewer)