Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In train model session have error #2423

Open
necco717 opened this issue Dec 26, 2024 · 1 comment
Open

In train model session have error #2423

necco717 opened this issue Dec 26, 2024 · 1 comment

Comments

@necco717
Copy link

2024-12-26 08:29:24 | INFO | main | Use gpus: 0
2024-12-26 08:29:24 | INFO | main | Execute: "/usr/local/envs/rvc/bin/python" infer/modules/train/train.py -e "Necco" -sr 40k -f0 1 -bs 19 -g 0 -te 64 -se 5 -pg assets/pretrained_v2/f0G40k.pth -pd assets/pretrained_v2/f0D40k.pth -l 0 -c 0 -sw 0 -v v2
INFO:Necco:{'data': {'filter_length': 2048, 'hop_length': 400, 'max_wav_value': 32768.0, 'mel_fmax': None, 'mel_fmin': 0.0, 'n_mel_channels': 125, 'sampling_rate': 40000, 'win_length': 2048, 'training_files': './logs/Necco/filelist.txt'}, 'model': {'filter_channels': 768, 'gin_channels': 256, 'hidden_channels': 192, 'inter_channels': 192, 'kernel_size': 3, 'n_heads': 2, 'n_layers': 6, 'p_dropout': 0, 'resblock': '1', 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 'resblock_kernel_sizes': [3, 7, 11], 'spk_embed_dim': 109, 'upsample_initial_channel': 512, 'upsample_kernel_sizes': [16, 16, 4, 4], 'upsample_rates': [10, 10, 2, 2], 'use_spectral_norm': False}, 'train': {'batch_size': 19, 'betas': [0.8, 0.99], 'c_kl': 1.0, 'c_mel': 45, 'epochs': 20000, 'eps': 1e-09, 'fp16_run': True, 'init_lr_ratio': 1, 'learning_rate': 0.0001, 'log_interval': 200, 'lr_decay': 0.999875, 'seed': 1234, 'segment_size': 12800, 'warmup_epochs': 0}, 'model_dir': './logs/Necco', 'experiment_dir': './logs/Necco', 'save_every_epoch': 5, 'name': 'Necco', 'total_epoch': 64, 'pretrainG': 'assets/pretrained_v2/f0G40k.pth', 'pretrainD': 'assets/pretrained_v2/f0D40k.pth', 'version': 'v2', 'gpus': '0', 'sample_rate': '40k', 'if_f0': 1, 'if_latest': 0, 'save_every_weights': '0', 'if_cache_data_in_gpu': 0}
/usr/local/envs/rvc/lib/python3.10/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
WeightNorm.apply(module, name, dim)
DEBUG:infer.lib.infer_pack.models:gin_channels: 256, self.spk_embed_dim: 109
INFO:Necco:loaded pretrained assets/pretrained_v2/f0G40k.pth
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:231: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
torch.load(hps.pretrainG, map_location="cpu")["model"]
INFO:Necco:
INFO:Necco:loaded pretrained assets/pretrained_v2/f0D40k.pth
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:246: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
torch.load(hps.pretrainD, map_location="cpu")["model"]
INFO:Necco:
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:263: FutureWarning: torch.cuda.amp.GradScaler(args...) is deprecated. Please use torch.amp.GradScaler('cuda', args...) instead.
scaler = GradScaler(enabled=hps.train.fp16_run)
/content/Retrieval-based-Voice-Conversion-WebUI/infer/lib/train/data_utils.py:114: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
spec = torch.load(spec_filename)
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:429: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
with autocast(enabled=hps.train.fp16_run):
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:457: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
with autocast(enabled=False):
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:476: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
with autocast(enabled=False):
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:486: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
with autocast(enabled=hps.train.fp16_run):
/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py:489: FutureWarning: torch.cuda.amp.autocast(args...) is deprecated. Please use torch.amp.autocast('cuda', args...) instead.
with autocast(enabled=False):
/usr/local/envs/rvc/lib/python3.10/site-packages/torch/autograd/graph.py:825: UserWarning: Grad strides do not match bucket view strides. This may indicate grad was not created according to the gradient layout contract, or that the param's strides changed since DDP was constructed. This is not an error, but may impair performance.
grad.sizes() = [64, 1, 4], strides() = [4, 1, 1]
bucket_view.sizes() = [64, 1, 4], strides() = [4, 4, 1] (Triggered internally at ../torch/csrc/distributed/c10d/reducer.cpp:327.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
INFO:Necco:Train Epoch: 1 [0%]
INFO:Necco:[0, 0.0001]
INFO:Necco:loss_disc=4.441, loss_gen=2.692, loss_fm=0.571,loss_mel=27.202, loss_kl=9.000
DEBUG:matplotlib:matplotlib data path: /usr/local/envs/rvc/lib/python3.10/site-packages/matplotlib/mpl-data
DEBUG:matplotlib:CONFIGDIR=/root/.config/matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is linux
Process Process-1:
Traceback (most recent call last):
File "/usr/local/envs/rvc/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/local/envs/rvc/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py", line 268, in run
train_and_evaluate(
File "/content/Retrieval-based-Voice-Conversion-WebUI/infer/modules/train/train.py", line 545, in train_and_evaluate
"slice/mel_org": utils.plot_spectrogram_to_numpy(
File "/content/Retrieval-based-Voice-Conversion-WebUI/infer/lib/train/utils.py", line 239, in plot_spectrogram_to_numpy
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
ValueError: cannot reshape array of size 800000 into shape (200,1000,3)

@necco717
Copy link
Author

fixed:
/content/Retrieval-based-Voice-Conversion-WebUI/infer/lib/train/utils.py
plot_spectrogram_to_numpy

revise to:

def plot_spectrogram_to_numpy(spectrogram):
global MATPLOTLIB_FLAG
if not MATPLOTLIB_FLAG:
import matplotlib
matplotlib.use('Agg')
MATPLOTLIB_FLAG = True
import matplotlib.pyplot as plt
else:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()

fig.canvas.draw()
data = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
width, height = fig.canvas.get_width_height()
data = data.reshape(height, width, 4)
data = data[:,:,:3]
plt.close()
return data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant