aboutsummaryrefslogtreecommitdiffstats
blob: 72a4572d9a023ce550e30a2e875d344e90e3f39a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
#!/usr/bin/env python3

# Copyright      2017 Johns Hopkins University (Author: Daniel Povey)
#                2017 Johns Hopkins University (Author: Daniel Garcia-Romero)
#                2017 David Snyder
# Apache 2.0

# This script, which is used in getting training examples, decides
# which examples will come from which recordings, and at what point
# during the training.

# You call it as (e.g.)
#
#  allocate_egs.py --min-frames-per-chunk=50 --max-frames-per-chunk=200 \
#   --frames-per-iter=1000000 --num-repeats=60 --num-archives=169 \
#   --num-jobs=24  exp/xvector_a/egs/temp/utt2len.train exp/xvector_a/egs
#
# The program outputs certain things to the temp directory (e.g.,
# exp/xvector_a/egs/temp) that will enable you to dump the chunks for xvector
# training.  What we'll eventually be doing is invoking the following program
# with something like the following args:
#
#  nnet3-xvector-get-egs [options] exp/xvector_a/temp/ranges.1 \
#    scp:data/train/feats.scp ark:exp/xvector_a/egs/egs_temp.1.ark \
#    ark:exp/xvector_a/egs/egs_temp.2.ark ark:exp/xvector_a/egs/egs_temp.3.ark
#
# where exp/xvector_a/temp/ranges.1 contains something like the following:
#
#   utt1  0  1  0   65 0
#   utt1  6  7  160 50 0
#   utt2  ...
#
# where each line is interpreted as follows:
#  <source-utterance> <relative-archive-index> <absolute-archive-index> \
#    <start-frame-index> <num-frames> <spkr-label>
#
# Note: <relative-archive-index> is the zero-based offset of the archive-index
# within the subset of archives that a particular ranges file corresponds to;
# and <absolute-archive-index> is the 1-based numeric index of the destination
# archive among the entire list of archives, which will form part of the
# archive's filename (e.g. egs/egs.<absolute-archive-index>.ark);
# <absolute-archive-index> is only kept for debug purposes so you can see which
# archive each line corresponds to.
#
# For each line of the ranges file, we specify an eg containing a chunk of data
# from a given utterane, the corresponding speaker label, and the output
# archive.  The list of archives corresponding to ranges.n will be written to
# output.n, so in exp/xvector_a/temp/outputs.1 we'd have:
#
#  ark:exp/xvector_a/egs/egs_temp.1.ark ark:exp/xvector_a/egs/egs_temp.2.ark \
#    ark:exp/xvector_a/egs/egs_temp.3.ark
#
# The number of these files will equal 'num-jobs'.  If you add up the
# word-counts of all the outputs.* files you'll get 'num-archives'.  The number
# of frames in each archive will be about the --frames-per-iter.
#
# This program will also output to the temp directory a file called
# archive_chunk_length which tells you the frame-length associated with
# each archive, e.g.,
# 1   60
# 2   120
# the format is:  <archive-index> <num-frames>.  The <num-frames> will always
# be in the range [min-frames-per-chunk, max-frames-per-chunk].


# We're using python 3.x style print but want it to work in python 2.x.
from __future__ import print_function
import re, os, argparse, sys, math, warnings, random

def get_args():
    parser = argparse.ArgumentParser(description="Writes ranges.*, outputs.* and archive_chunk_lengths files "
                                 "in preparation for dumping egs for xvector training.",
                                 epilog="Called by sid/nnet3/xvector/get_egs.sh")
    parser.add_argument("--prefix", type=str, default="",
                   help="Adds a prefix to the output files. This is used to distinguish between the train "
                   "and diagnostic files.")
    parser.add_argument("--num-repeats", type=int, default=10, help="Number of times each speaker repeats within an archive.")
    parser.add_argument("--min-frames-per-chunk", type=int, default=50,
                    help="Minimum number of frames-per-chunk used for any archive")
    parser.add_argument("--max-frames-per-chunk", type=int, default=300,
                    help="Maximum number of frames-per-chunk used for any archive")
    parser.add_argument("--randomize-chunk-length", type=str,
                    help="If true, randomly pick a chunk length in [min-frames-per-chunk, max-frames-per-chunk]."
                    "If false, the chunk length varies from min-frames-per-chunk to max-frames-per-chunk"
                    "according to a geometric sequence.",
                    default="true", choices = ["false", "true"])
    parser.add_argument("--frames-per-iter", type=int, default=1000000,
                    help="Target number of frames for each archive")
    parser.add_argument("--num-archives", type=int, default=-1,
                    help="Number of archives to write");
    parser.add_argument("--num-jobs", type=int, default=-1,
                    help="Number of jobs we're going to use to write the archives; the ranges.* "
                    "and outputs.* files are indexed by job.  Must be <= the --num-archives option.");
    parser.add_argument("--seed", type=int, default=123,
                    help="Seed for random number generator")
    parser.add_argument("--num-pdfs", type=int, default=-1,
                    help="Num pdfs")

    # now the positional arguments
    parser.add_argument("--utt2len-filename", type=str, required=True,
                    help="utt2len file of the features to be used as input (format is: "
                    "<utterance-id> <num-frames>)");
    parser.add_argument("--utt2int-filename", type=str, required=True,
                    help="utt2int file of the features to be used as input (format is: "
                    "<utterance-id> <id>)");
    parser.add_argument("--egs-dir", type=str, required=True,
                    help="Name of egs directory, e.g. exp/xvector_a/egs");

    print(' '.join(sys.argv), file=sys.stderr)
    print(sys.argv, file=sys.stderr)
    args = parser.parse_args()
    args = process_args(args)
    return args

def process_args(args):
    if args.num_repeats < 1:
        raise Exception("--num-repeats should have a minimum value of 1")
    if not os.path.exists(args.utt2int_filename):
        raise Exception("This script expects --utt2int-filename to exist")
    if not os.path.exists(args.utt2len_filename):
        raise Exception("This script expects --utt2len-filename to exist")
    if args.min_frames_per_chunk <= 1:
        raise Exception("--min-frames-per-chunk is invalid.")
    if args.max_frames_per_chunk < args.min_frames_per_chunk:
        raise Exception("--max-frames-per-chunk is invalid.")
    if args.frames_per_iter < 1000:
        raise Exception("--frames-per-iter is invalid.")
    if args.num_archives < 1:
        raise Exception("--num-archives is invalid")
    if args.num_jobs > args.num_archives:
        raise Exception("--num-jobs is invalid (must not exceed num-archives)")
    return args

# Create utt2len
def get_utt2len(utt2len_filename):
    utt2len = {}
    f = open(utt2len_filename, "r")
    if f is None:
        sys.exit("Error opening utt2len file " + str(utt2len_filename))
    utt_ids = []
    lengths = []
    for line in f:
        tokens = line.split()
        if len(tokens) != 2:
            sys.exit("bad line in utt2len file " + line)
        utt2len[tokens[0]] = int(tokens[1])
    f.close()
    return utt2len
    # Done utt2len

# Handle utt2int, create spk2utt, spks
def get_labels(utt2int_filename):
    f = open(utt2int_filename, "r")
    if f is None:
        sys.exit("Error opening utt2int file " + str(utt2int_filename))
    spk2utt = {}
    utt2spk = {}
    for line in f:
        tokens = line.split()
        if len(tokens) != 2:
            sys.exit("bad line in utt2int file " + line)
        spk = int(tokens[1])
        utt = tokens[0]
        utt2spk[utt] = spk
        if spk not in spk2utt:
            spk2utt[spk] = [utt]
        else:
            spk2utt[spk].append(utt)
    spks = spk2utt.keys()
    f.close()
    return spks, spk2utt, utt2spk
    # Done utt2int


# this function returns a random integer utterance index, limited to utterances
# above a minimum length in frames, with probability proportional to its length.
def get_random_utt(spkr, spk2utt, min_length):
    this_utts = spk2utt[spkr]
    this_num_utts = len(this_utts)
    i = random.randint(0, this_num_utts-1)
    utt = this_utts[i]
    return utt

def random_chunk_length(min_frames_per_chunk, max_frames_per_chunk):
    ans = random.randint(min_frames_per_chunk, max_frames_per_chunk)
    return ans

# This function returns an integer in the range
# [min-frames-per-chunk, max-frames-per-chunk] according to a geometric
# sequence. For example, suppose min-frames-per-chunk is 50,
# max-frames-per-chunk is 200, and args.num_archives is 3. Then the
# lengths for archives 0, 1, and 2 will be 50, 100, and 200.
def deterministic_chunk_length(archive_id, num_archives, min_frames_per_chunk, max_frames_per_chunk):
  if max_frames_per_chunk == min_frames_per_chunk:
    return max_frames_per_chunk
  elif num_archives == 1:
    return int(max_frames_per_chunk);
  else:
    return int(math.pow(float(max_frames_per_chunk) /
                     min_frames_per_chunk, float(archive_id) /
                     (num_archives-1)) * min_frames_per_chunk + 0.5)



# given an utterance length utt_length (in frames) and two desired chunk lengths
# (length1 and length2) whose sum is <= utt_length,
# this function randomly picks the starting points of the chunks for you.
# the chunks may appear randomly in either order.
def get_random_offset(utt_length, length):
    if length > utt_length:
        sys.exit("code error: length > utt-length")
    free_length = utt_length - length

    offset = random.randint(0, free_length)
    return offset


def main():
    args = get_args()
    if not os.path.exists(args.egs_dir + "/temp"):
        os.makedirs(args.egs_dir + "/temp")
    random.seed(args.seed)
    utt2len = get_utt2len(args.utt2len_filename)
    spks, spk2utt, utt2spk = get_labels(args.utt2int_filename)
    if args.num_pdfs == -1:
        args.num_pdfs = max(spks) + 1

    # archive_chunk_lengths is an mapping from archive id to the number of
    # frames in examples of that archive.
    archive_chunk_lengths = []
    # all_egs contains 2-tuples of the form (utt-id, offset)
    all_egs= []

    prefix = ""
    if args.prefix != "":
        prefix = args.prefix + "_"

    info_f = open(args.egs_dir + "/temp/" + prefix + "archive_chunk_lengths", "w")
    if info_f is None:
        sys.exit(str("Error opening file {0}/temp/" + prefix + "archive_chunk_lengths").format(args.egs_dir));
    for archive_index in range(args.num_archives):
        print("Processing archive {0}".format(archive_index + 1))
        if args.randomize_chunk_length == "true":
            # don't constrain the lengths to be the same
            length = random_chunk_length(args.min_frames_per_chunk, args.max_frames_per_chunk)
        else:
            length = deterministic_chunk_length(archive_index, args.num_archives, args.min_frames_per_chunk, args.max_frames_per_chunk);
        print("{0} {1}".format(archive_index + 1, length), file=info_f)
        archive_chunk_lengths.append(length)
        this_num_egs = int((args.frames_per_iter / length) + 1)
        this_egs = [ ] # A 2-tuple of the form (utt-id, start-frame)
        spkrs = args.num_repeats * list(spk2utt.keys())
        random.shuffle(spkrs)
        for n in range(this_num_egs):
            if len(spkrs) == 0:
                print("Ran out of speakers for archive {0}".format(archive_index + 1))
                break
            spkr = spkrs.pop()
            utt = get_random_utt(spkr, spk2utt, length)
            utt_len = utt2len[utt]
            offset = get_random_offset(utt_len, length)
            this_egs.append( (utt, offset) )
        all_egs.append(this_egs)
    info_f.close()

    # work out how many archives we assign to each job in an equitable way.
    num_archives_per_job = [ 0 ] * args.num_jobs
    for i in range(0, args.num_archives):
        num_archives_per_job[i % args.num_jobs]  = num_archives_per_job[i % args.num_jobs] + 1

    pdf2num = {}
    cur_archive = 0
    for job in range(args.num_jobs):
        this_ranges = []
        this_archives_for_job = []
        this_num_archives = num_archives_per_job[job]

        for i in range(0, this_num_archives):
            this_archives_for_job.append(cur_archive)
            for (utterance_index, offset) in all_egs[cur_archive]:
                this_ranges.append( (utterance_index, i, offset) )
            cur_archive = cur_archive + 1

        f = open(args.egs_dir + "/temp/" + prefix + "ranges." + str(job + 1), "w")
        if f is None:
            sys.exit("Error opening file " + args.egs_dir + "/temp/" + prefix + "ranges." + str(job + 1))
        for (utterance_index, i, offset) in sorted(this_ranges):
            archive_index = this_archives_for_job[i]
            print("{0} {1} {2} {3} {4} {5}".format(utterance_index,
                                           i,
                                           archive_index + 1,
                                           offset,
                                           archive_chunk_lengths[archive_index],
                                           utt2spk[utterance_index]),
              file=f)
            if utt2spk[utterance_index] in pdf2num:
                 pdf2num[utt2spk[utterance_index]] += 1
            else:
                pdf2num[utt2spk[utterance_index]] = 1
        f.close()


        f = open(args.egs_dir + "/temp/" + prefix + "outputs." + str(job + 1), "w")
        if f is None:
            sys.exit("Error opening file " + args.egs_dir + "/temp/" + prefix + "outputs." + str(job + 1))
        print( " ".join([ str("{0}/" + prefix + "egs_temp.{1}.ark").format(args.egs_dir, n + 1) for n in this_archives_for_job ]),
           file=f)
        f.close()

    f = open(args.egs_dir + "/" + prefix + "pdf2num", "w")
    nums = []
    for k in range(0, args.num_pdfs):
        if k in pdf2num:
          nums.append(pdf2num[k])
        else:
          nums.append(0)

    print(" ".join(map(str, nums)), file=f)
    f.close()

    print("allocate_egs.py: finished generating " + prefix + "ranges.* and " + prefix + "outputs.* files")

if __name__ == "__main__":
    main()