-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
could provide a example recipe? #51
Comments
Is the following example right?
|
Looks like it's missing the optimizer, I'd suggest to follow a basic pytorch tutorial on how to setup the training loop. |
@flavioschneider, Can you please provide an example script on how to train the model? And also how to get a dataset to train the model? |
@gandolfxu Did you get the correct script to train the model? If yes please help me out with that. And also let me know about the dataset that you are using. |
@flavioschneider Please let me know what kind of dataset can be used to train the model and how it should be structured. |
Hi @deepak-newzera I've written a simple training script here: https://github.com/jameshball/audio-diffusion/blob/master/train.py It uses the LibriSpeech dataset and downloads it when you start the script. You might need to change the data path defined at the top, and setup or remove the weights and biases (wandb) logging. I'm currently having an issue with training where I get NaN: #52 but at least this code should give you something to start with. |
@jameshball Thanks for your help. |
Hi @jameshball , seems like the training script you linked goes to a 404, do you have an updated link for this training loop somewhere you could share? |
I've made that repo private now but this is the version of the file I linked: import torch
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
from audio_data_pytorch import LibriSpeechDataset, AllTransform
SAMPLE_RATE = 16000
BATCH_SIZE = 12
NUM_SAMPLES = 2**18
def create_model():
return DiffusionModel(
net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
in_channels=1, # U-Net: number of input/output (audio) channels
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
attention_heads=8, # U-Net: number of attention heads per attention item
attention_features=64, # U-Net: number of attention features per attention item
diffusion_t=VDiffusion, # The diffusion method used
sampler_t=VSampler, # The diffusion sampler used
)
def main():
args = parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
dataset = LibriSpeechDataset(
root="E:/librispeech",
transforms=AllTransform(
random_crop_size=NUM_SAMPLES,
mono=True,
),
)
print(f"Dataset length: {len(dataset)}")
torchaudio.save("test.wav", dataset[0], SAMPLE_RATE)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
pin_memory=True,
)
model = create_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
run_id = wandb.util.generate_id()
if args.run_id is not None:
run_id = args.run_id
print(f"Run ID: {run_id}")
wandb.init(project="audio-diffusion", resume=args.resume, id=run_id)
epoch = 0
step = 0
if args.checkpoint is not None:
checkpoint_path = args.checkpoint
else:
checkpoint_path = f"checkpoint-{run_id}.pt"
if wandb.run.resumed:
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(wandb.restore(checkpoint_path))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
step = epoch * len(dataloader)
scaler = torch.cuda.amp.GradScaler()
model.train()
while epoch < 100:
avg_loss = 0
avg_loss_step = 0
progress = tqdm(dataloader)
for i, audio in enumerate(progress):
optimizer.zero_grad()
audio = audio.to(device)
with torch.cuda.amp.autocast():
loss = model(audio)
avg_loss += loss.item()
avg_loss_step += 1
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
progress.set_postfix(
loss=loss.item(),
epoch=epoch + i / len(dataloader),
)
if step % 500 == 0:
# Turn noise into new audio sample with diffusion
noise = torch.randn(1, 1, NUM_SAMPLES, device=device)
with torch.cuda.amp.autocast():
sample = model.sample(noise, num_steps=100)
torchaudio.save(f'test_generated_sound_{step}.wav', sample[0].cpu(), SAMPLE_RATE)
del sample
gc.collect()
torch.cuda.empty_cache()
wandb.log({
"step": step,
"epoch": epoch + i / len(dataloader),
"loss": avg_loss / avg_loss_step,
"generated_audio": wandb.Audio(f'test_generated_sound_{step}.wav', caption="Generated audio", sample_rate=SAMPLE_RATE),
})
if step % 100 == 0:
wandb.log({
"step": step,
"epoch": epoch + i / len(dataloader),
"loss": avg_loss / avg_loss_step,
})
avg_loss = 0
avg_loss_step = 0
step += 1
epoch += 1
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
wandb.save(checkpoint_path)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--run_id", type=str, default=None)
return parser.parse_args()
if __name__ == "__main__":
main() |
It has some hardcoded paths and wandb code that you might need to remove but worked nicely |
@jameshball Did you succeed in training the model? Is it producing some sensible outputs? |
@jameshball did you successfully train? if so can you share what you did and learned? |
The text was updated successfully, but these errors were encountered: