Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

could provide a example recipe? #51

Open
gandolfxu opened this issue Feb 13, 2023 · 12 comments
Open

could provide a example recipe? #51

gandolfxu opened this issue Feb 13, 2023 · 12 comments

Comments

@gandolfxu
Copy link

  1. Based on an open source dataset
  2. Detailed training parameters
@gandolfxu
Copy link
Author

gandolfxu commented Feb 13, 2023

Is the following example right?

import librosa
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
from loguru import logger

from audio_diffusion_pytorch import DiffusionModel
from audio_diffusion_pytorch import UNetV0
from audio_diffusion_pytorch import VDiffusion
from audio_diffusion_pytorch import VSampler

LEN = 2 ** 18

class AudioDataset(Dataset):
    def __init__(self, fpath):
        self.file_list =[]
        for line in open(fpath):
            line = line.strip()
            if not line:
                continue
            self.file_list.append(line)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_file = self.file_list[idx]
        audio, fs = torchaudio.load(audio_file)
        transform = torchaudio.transforms.Resample(fs, 48000)
        audio = transform(audio)
        if audio.shape[1] > LEN:
            offset = np.random.randint(0, audio.shape[1] - LEN)
        else:
            offset = 0
        return audio[:, offset:offset + LEN]


def collate_fn(batch):
    bsz = len(batch)
    out = torch.zeros(bsz, 2, LEN)
    for i, x in enumerate(batch):
        out[i, :, :x.shape[1]] = x # torch.from_numpy(x)
    return out

model = DiffusionModel(
    net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
    in_channels=2, # U-Net: number of input/output (audio) channels
    channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
    factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
    items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
    attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
    attention_heads=8, # U-Net: number of attention heads per attention item
    attention_features=64, # U-Net: number of attention features per attention item
    diffusion_t=VDiffusion, # The diffusion method used
    sampler_t=VSampler, # The diffusion sampler used
)
model.to('cuda:0')


train_dataset = AudioDataset('data/train.list')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=8,
    pin_memory=True
)

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""))

for i in range(6):
    for audio in tqdm(train_dataloader, desc='epoch %d' % i):
        audio = audio.to('cuda:0')
        loss = model(audio)
        logger.info('loss = %f' % loss.item())
        loss.backward()
    torch.save(model.state_dict(), 'model_%d.pt' % i)

@flavioschneider
Copy link
Member

Looks like it's missing the optimizer, I'd suggest to follow a basic pytorch tutorial on how to setup the training loop.

@deepak-newzera
Copy link

deepak-newzera commented Feb 15, 2023

@flavioschneider, Can you please provide an example script on how to train the model? And also how to get a dataset to train the model?

@deepak-newzera
Copy link

Is the following example right?

import librosa
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset
from tqdm import tqdm
from loguru import logger

from audio_diffusion_pytorch import DiffusionModel
from audio_diffusion_pytorch import UNetV0
from audio_diffusion_pytorch import VDiffusion
from audio_diffusion_pytorch import VSampler

LEN = 2 ** 18

class AudioDataset(Dataset):
    def __init__(self, fpath):
        self.file_list =[]
        for line in open(fpath):
            line = line.strip()
            if not line:
                continue
            self.file_list.append(line)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        audio_file = self.file_list[idx]
        audio, fs = torchaudio.load(audio_file)
        transform = torchaudio.transforms.Resample(fs, 48000)
        audio = transform(audio)
        if audio.shape[1] > LEN:
            offset = np.random.randint(0, audio.shape[1] - LEN)
        else:
            offset = 0
        return audio[:, offset:offset + LEN]


def collate_fn(batch):
    bsz = len(batch)
    out = torch.zeros(bsz, 2, LEN)
    for i, x in enumerate(batch):
        out[i, :, :x.shape[1]] = x # torch.from_numpy(x)
    return out

model = DiffusionModel(
    net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
    in_channels=2, # U-Net: number of input/output (audio) channels
    channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
    factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
    items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
    attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
    attention_heads=8, # U-Net: number of attention heads per attention item
    attention_features=64, # U-Net: number of attention features per attention item
    diffusion_t=VDiffusion, # The diffusion method used
    sampler_t=VSampler, # The diffusion sampler used
)
model.to('cuda:0')


train_dataset = AudioDataset('data/train.list')
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=8,
    pin_memory=True
)

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""))

for i in range(6):
    for audio in tqdm(train_dataloader, desc='epoch %d' % i):
        audio = audio.to('cuda:0')
        loss = model(audio)
        logger.info('loss = %f' % loss.item())
        loss.backward()
    torch.save(model.state_dict(), 'model_%d.pt' % i)

@gandolfxu Did you get the correct script to train the model? If yes please help me out with that. And also let me know about the dataset that you are using.

@deepak-newzera
Copy link

@flavioschneider Please let me know what kind of dataset can be used to train the model and how it should be structured.

@jameshball
Copy link

Hi @deepak-newzera I've written a simple training script here: https://github.com/jameshball/audio-diffusion/blob/master/train.py

It uses the LibriSpeech dataset and downloads it when you start the script.

You might need to change the data path defined at the top, and setup or remove the weights and biases (wandb) logging.

I'm currently having an issue with training where I get NaN: #52 but at least this code should give you something to start with.

@deepak-newzera
Copy link

@jameshball Thanks for your help.
The dataset you provided seems to be a speech dataset. But I suppose it must be a music dataset, right?

@kitchWWW
Copy link

kitchWWW commented Mar 1, 2023

Hi @jameshball , seems like the training script you linked goes to a 404, do you have an updated link for this training loop somewhere you could share?

@jameshball
Copy link

I've made that repo private now but this is the version of the file I linked:

import torch
import torchaudio
import gc
import argparse
import os
from tqdm import tqdm
import wandb
from audio_diffusion_pytorch import DiffusionModel, UNetV0, VDiffusion, VSampler
from audio_data_pytorch import LibriSpeechDataset, AllTransform

SAMPLE_RATE = 16000
BATCH_SIZE = 12
NUM_SAMPLES = 2**18


def create_model():
    return DiffusionModel(
        net_t=UNetV0, # The model type used for diffusion (U-Net V0 in this case)
        in_channels=1, # U-Net: number of input/output (audio) channels
        channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
        factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
        items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
        attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
        attention_heads=8, # U-Net: number of attention heads per attention item
        attention_features=64, # U-Net: number of attention features per attention item
        diffusion_t=VDiffusion, # The diffusion method used
        sampler_t=VSampler, # The diffusion sampler used
    )


def main():
    args = parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    dataset = LibriSpeechDataset(
        root="E:/librispeech",
        transforms=AllTransform(
            random_crop_size=NUM_SAMPLES,
            mono=True,
        ),
    )

    print(f"Dataset length: {len(dataset)}")

    torchaudio.save("test.wav", dataset[0], SAMPLE_RATE)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )

    model = create_model().to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

    run_id = wandb.util.generate_id()
    if args.run_id is not None:
        run_id = args.run_id
    print(f"Run ID: {run_id}")

    wandb.init(project="audio-diffusion", resume=args.resume, id=run_id)

    epoch = 0
    step = 0

    if args.checkpoint is not None:
        checkpoint_path = args.checkpoint
    else:
        checkpoint_path = f"checkpoint-{run_id}.pt"

    if wandb.run.resumed:
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(wandb.restore(checkpoint_path))
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        step = epoch * len(dataloader)
    
    scaler = torch.cuda.amp.GradScaler()

    model.train()
    while epoch < 100:
        avg_loss = 0
        avg_loss_step = 0
        progress = tqdm(dataloader)
        for i, audio in enumerate(progress):
            optimizer.zero_grad()
            audio = audio.to(device)
            with torch.cuda.amp.autocast():
                loss = model(audio)
                avg_loss += loss.item()
                avg_loss_step += 1
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            progress.set_postfix(
                loss=loss.item(),
                epoch=epoch + i / len(dataloader),
            )

            if step % 500 == 0:
                # Turn noise into new audio sample with diffusion
                noise = torch.randn(1, 1, NUM_SAMPLES, device=device)
                with torch.cuda.amp.autocast():
                    sample = model.sample(noise, num_steps=100)

                torchaudio.save(f'test_generated_sound_{step}.wav', sample[0].cpu(), SAMPLE_RATE)
                del sample
                gc.collect()
                torch.cuda.empty_cache()

                wandb.log({
                    "step": step,
                    "epoch": epoch + i / len(dataloader),
                    "loss": avg_loss / avg_loss_step,
                    "generated_audio": wandb.Audio(f'test_generated_sound_{step}.wav', caption="Generated audio", sample_rate=SAMPLE_RATE),
                })
            
            if step % 100 == 0:
                wandb.log({
                    "step": step,
                    "epoch": epoch + i / len(dataloader),
                    "loss": avg_loss / avg_loss_step,
                })
                avg_loss = 0
                avg_loss_step = 0
            
            step += 1

        epoch += 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
        wandb.save(checkpoint_path)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--run_id", type=str, default=None)
    return parser.parse_args()


if __name__ == "__main__":
    main()

@jameshball
Copy link

It has some hardcoded paths and wandb code that you might need to remove but worked nicely

@deepak-newzera
Copy link

@jameshball Did you succeed in training the model? Is it producing some sensible outputs?
I am also willing to train using your script. But instead of LibriSpeechDataset, I would like to train the model on a set of wav files. If possible, can you guide how it can be done?

@dustyatx
Copy link

@jameshball did you successfully train? if so can you share what you did and learned?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants