Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

online mix noise audio data in training step #2622

Open
wants to merge 32 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
681f470
Remove comments check from alphabet
carlfm01 Jun 5, 2019
421243d
Remove sort from feeding
carlfm01 Jun 5, 2019
d08efad
Remove sort from evaluate tools
carlfm01 Jun 5, 2019
b0a14b5
Merge pull request #1 from carlfm01/master
carlfm01 Jun 29, 2019
ba1a587
Remove TF dependency
carlfm01 Jun 29, 2019
aebd08d
[ADD] mix noise audio
mychiux413 Dec 30, 2019
d255c3f
[FIX] add missing file decoded_augmentation.py
mychiux413 Dec 30, 2019
ec25136
mix noise works, but performance is bad
mychiux413 Dec 31, 2019
484134e
[MOD] use tf.Dataset to cache noise audio
mychiux413 Dec 31, 2019
4f24f08
rename decoded -> audio
mychiux413 Dec 31, 2019
1f57ece
[FIX] don't create tf.Dataset in other tf.Dataset's pipeline
mychiux413 Jan 2, 2020
66cc7c4
limit audio signal between +-1.0
mychiux413 Jan 13, 2020
b7eb0f4
[FIX] switch shuffle/map for memory cost, replace cache with prefetch…
mychiux413 Feb 11, 2020
ccae7cc
[MOD] limit the buffer size of .shuffle() to protect memory usage
mychiux413 Feb 17, 2020
8cc95f9
[ADD] bin/normalize_noise_audio.py
mychiux413 Feb 19, 2020
9e2648a
[MOD] mix noise into complete audio
mychiux413 Feb 21, 2020
2269514
[ADD] dev/test dataset can also mix noise [MOD] use SNR to balance no…
mychiux413 Mar 6, 2020
0b8147c
[ADD] use dbfs and SNR to determine the balance of audio/noise, add o…
mychiux413 Mar 16, 2020
42bc45b
[FIX] audiofile_to_features & samples_to_mfccs return 3 values now, a…
mychiux413 Mar 19, 2020
289722d
Fix issues.
Mar 29, 2020
9334e79
Save invalid files.
Mar 29, 2020
25736e0
Merge remote-tracking branch 'noiseaug/more-augment-options' into noi…
Mar 29, 2020
40b431b
Fix merging errors.
Mar 29, 2020
f7d1279
[FIX] replace tqdm with prograssbar [ADD] separate speech/noise mixin…
mychiux413 Mar 31, 2020
7792226
Merge branch 'no-sort' into more-augment-options
carlfm01 Apr 2, 2020
c4c3ced
Merge #f7d1279.
Apr 12, 2020
c151b1d
Merge branch 'master' into noisetest
Apr 17, 2020
c089b7f
Fix merge not detecting moved scripts.
Apr 17, 2020
491a4b0
Undo personal changes.
Apr 17, 2020
735cbbb
Merge branch 'master' of https://github.com/mozilla/DeepSpeech into n…
Apr 23, 2020
2fa91e8
To recover the incorrect merge
mychiux413 May 12, 2020
6b820bb
Merge pull request #1 from DanBmh/noiseaugmaster
mychiux413 May 14, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions util/audio_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import tensorflow as tf
import tensorflow.compat.v1 as tfv1
from tensorflow.python.ops import gen_audio_ops as contrib_audio
import os

def collect_noise_filenames(walk_dirs):
assert isinstance(walk_dirs, list)

for d in walk_dirs:
for dirpath, _, filenames in os.walk(d):
for filename in filenames:
if filename.endswith('.wav'):
yield os.path.join(dirpath, filename)

def noise_file_to_audio(noise_file):
samples = tf.io.read_file(noise_file)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
return decoded.audio

def augment_noise(audio,
noise_audio,
change_audio_db_max=0,
change_audio_db_min=-10,
change_noise_db_max=-15,
change_noise_db_min=-25):

decoded_audio_len = tf.shape(audio)[0]
noise_decoded_audio_len = tf.shape(noise_audio)[0]

multiply = tf.math.floordiv(decoded_audio_len, noise_decoded_audio_len) + 1
noise_audio_tile = tf.tile(noise_audio, [multiply, 1])

# now noise_decoded_len must > decoded_len
noise_decoded_audio_len = tf.shape(noise_audio_tile)[0]

mix_decoded_start_end_points = tfv1.random_uniform(
[2], minval=0, maxval=decoded_audio_len-1, dtype=tf.int32)
mix_decoded_start_point = tf.math.reduce_min(mix_decoded_start_end_points)
mix_decoded_end_point = tf.math.reduce_max(
mix_decoded_start_end_points) + 1
mix_decoded_width = mix_decoded_end_point - mix_decoded_start_point

left_zeros = tf.zeros(shape=[mix_decoded_start_point, 1])

mix_noise_decoded_start_point = tfv1.random_uniform(
[], minval=0, maxval=noise_decoded_audio_len - mix_decoded_width, dtype=tf.int32)
mix_noise_decoded_end_point = mix_noise_decoded_start_point + mix_decoded_width
extract_noise_decoded = noise_audio_tile[mix_noise_decoded_start_point:mix_noise_decoded_end_point, :]

right_zeros = tf.zeros(
shape=[decoded_audio_len - mix_decoded_end_point, 1])

mixed_noise = tf.concat(
[left_zeros, extract_noise_decoded, right_zeros], axis=0)

choosen_audio_db = tfv1.random_uniform(
[], minval=change_audio_db_min, maxval=change_audio_db_max)
audio_ratio = tf.math.pow(10.0, choosen_audio_db / 10)

choosen_noise_db = tfv1.random_uniform(
[], minval=change_noise_db_min, maxval=change_noise_db_max)
noise_ratio = tf.math.pow(10.0, choosen_noise_db / 10)
return tf.multiply(audio, audio_ratio) + tf.multiply(mixed_noise, noise_ratio)
41 changes: 36 additions & 5 deletions util/feeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from util.flags import FLAGS
from util.spectrogram_augmentations import augment_freq_time_mask, augment_dropout, augment_pitch_and_tempo, augment_speed_up
from util.audio import read_frames_from_file, vad_split, DEFAULT_FORMAT
from util.audio_augmentation import augment_noise, noise_file_to_audio, collect_noise_filenames


def read_csvs(csv_files):
Expand Down Expand Up @@ -64,11 +65,26 @@ def samples_to_mfccs(samples, sample_rate, train_phase=False):
return mfccs, tf.shape(input=mfccs)[0]


def audiofile_to_features(wav_filename, train_phase=False):
def audiofile_to_features(wav_filename, train_phase=False, noise_iterator=None):
samples = tf.io.read_file(wav_filename)
decoded = contrib_audio.decode_wav(samples, desired_channels=1)
features, features_len = samples_to_mfccs(decoded.audio, decoded.sample_rate, train_phase=train_phase)
audio = decoded.audio

# augment audio
if train_phase and noise_iterator:
audio = augment_noise(
audio,
noise_iterator.get_next(),
change_audio_db_max=FLAGS.audio_aug_mix_noise_max_audio_db,
change_audio_db_min=FLAGS.audio_aug_mix_noise_min_audio_db,
change_noise_db_max=FLAGS.audio_aug_mix_noise_max_noise_db,
change_noise_db_min=FLAGS.audio_aug_mix_noise_min_noise_db,
)


features, features_len = samples_to_mfccs(audio, decoded.sample_rate, train_phase=train_phase)

# augment features
if train_phase:
if FLAGS.data_aug_features_multiplicative > 0:
features = features*tf.random.normal(mean=1, stddev=FLAGS.data_aug_features_multiplicative, shape=tf.shape(features))
Expand All @@ -79,9 +95,9 @@ def audiofile_to_features(wav_filename, train_phase=False):
return features, features_len


def entry_to_features(wav_filename, transcript, train_phase):
def entry_to_features(wav_filename, transcript, train_phase, noise_iterator=None):
# https://bugs.python.org/issue32117
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase)
features, features_len = audiofile_to_features(wav_filename, train_phase=train_phase, noise_iterator=noise_iterator)
return wav_filename, features, features_len, tf.SparseTensor(*transcript)


Expand Down Expand Up @@ -120,7 +136,22 @@ def batch_fn(wav_filenames, features, features_len, transcripts):
return tf.data.Dataset.zip((wav_filenames, features, transcripts))

num_gpus = len(Config.available_devices)
process_fn = partial(entry_to_features, train_phase=train_phase)

if train_phase and FLAGS.audio_aug_mix_noise_walk_dirs:
# because we have to determine the shuffle size, so we could not use generator
noise_filenames = tf.convert_to_tensor(
list(collect_noise_filenames(FLAGS.audio_aug_mix_noise_walk_dirs.split(','))),
dtype=tf.string)
print(">>> Collect {} noise files for mixing audio".format(noise_filenames.shape[0]))
noise_dataset = (tf.data.Dataset.from_tensor_slices(noise_filenames)
.map(noise_file_to_audio, num_parallel_calls=tf.data.experimental.AUTOTUNE)
.shuffle(noise_filenames.shape[0])
mychiux413 marked this conversation as resolved.
Show resolved Hide resolved
.cache(FLAGS.audio_aug_mix_noise_cache)
.repeat())
noise_iterator = tf.compat.v1.data.make_one_shot_iterator(noise_dataset)
else:
noise_iterator = None
process_fn = partial(entry_to_features, train_phase=train_phase, noise_iterator=noise_iterator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach has a disadvantage:
Results of the augmentation will now get part of the cache. Therefore we miss the opportunity to get them re-augmented differently with every training loop (which could help avoiding over-fit).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The noise dataset has been .shuffle() + .repeat(), if the dataset = dataset.cache(cache_path) doesn't execute, it should mix the different noise on same speech file with every loop.
I have double-checked it, using tf.Print to inspect the speech/noise filenames for each epoch, and make sure that the noises were selected randomly.
If we review the audio on tensorboard, we will notice that the same noise file is always picked, which is due to the random seed, try changing the random_seed to get a different result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but how about caching training samples and noise separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't plan to cache noise, because mixing audio happens in entry_to_features() then converted to features, caching audio could cost 10 times memory than caching features, which should be less practical and should not speed up the training too much, so I made enable_cache=False when FLAG.train_augmentation_files is not null like other augmentation does.


dataset = (tf.data.Dataset.from_generator(generate_values,
output_types=(tf.string, (tf.int64, tf.int32, tf.int64)))
Expand Down
8 changes: 7 additions & 1 deletion util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def create_flags():
# Data Augmentation
# ================

f.DEFINE_string('audio_aug_mix_noise_walk_dirs', '', 'walk through wav dir, then mix noise wav into decoded audio')
f.DEFINE_string('audio_aug_mix_noise_cache', '', 'must cache noise audio data, or it will read audio file every training step')
f.DEFINE_float('audio_aug_mix_noise_max_noise_db', -25, 'to limit noise max volume')
f.DEFINE_float('audio_aug_mix_noise_min_noise_db', -50, 'to limit noise min volume')
f.DEFINE_float('audio_aug_mix_noise_max_audio_db', 0, 'to limit audio max volume')
f.DEFINE_float('audio_aug_mix_noise_min_audio_db', -10, 'to limit audio min volume')

f.DEFINE_float('data_aug_features_additive', 0, 'std of the Gaussian additive noise')
f.DEFINE_float('data_aug_features_multiplicative', 0, 'std of normal distribution around 1 for multiplicative noise')

Expand All @@ -42,7 +49,6 @@ def create_flags():
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_pitch', 1.2, 'max value of pitch scaling')
f.DEFINE_float('augmentation_pitch_and_tempo_scaling_max_tempo', 1.2, 'max vlaue of tempo scaling')


# Global Constants
# ================

Expand Down