-
Notifications
You must be signed in to change notification settings - Fork 4
/
create_configs.py
77 lines (65 loc) · 2.78 KB
/
create_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
import random
from pathlib import Path
import yaml
# dropout=.5, width=32, height=10
# droput=.2, width=128, height=10
# dropout=.2, width=32, height=20
# dropout=.2, width=32, height=2
def main(args):
Path(f'{args.dataset}_{args.metric}_{args.protected}').mkdir(exist_ok=True)
args.num_runs = int(args.num_runs)
args.protected = int(args.protected)
args.architecture_type = int(args.architecture_type)
hyperparam_dict = {
'num_deep': [2, 5, 10, 15],
'hid': [16, 32, 64, 128],
'dropout_p': [0.1, 0.2, 0.3, 0.5]
}
if args.architecture_type == -1:
hyperparam = {'num_deep': 10, 'hid': 32, 'dropout_p': 0.2}
elif args.architecture_type == 0:
hyperparam = {k: random.choice(v) for k, v in hyperparam_dict.items()}
elif args.architecture_type == 1:
hyperparam = {'num_deep': 10, 'hid': 32, 'dropout_p': 0.5}
elif args.architecture_type == 2:
hyperparam = {'num_deep': 10, 'hid': 64, 'dropout_p': 0.2}
elif args.architecture_type == 3:
hyperparam = {'num_deep': 10, 'hid': 16, 'dropout_p': 0.2}
elif args.architecture_type == 4:
hyperparam = {'num_deep': 20, 'hid': 32, 'dropout_p': 0.2}
elif args.architecture_type == 5:
hyperparam = {'num_deep': 2, 'hid': 32, 'dropout_p': 0.2}
for i in range(args.num_runs):
baselines_config = {
'experiment_name': f'{args.dataset}_{args.metric}_{args.protected}_{i}_baselines',
'dataset': args.dataset,
'protected': args.protected,
'modelpath': f'models/{args.dataset}_{args.architecture_type}_{i}_model.pt',
'metric': args.metric,
'models': [
'default',
'ROC',
'EqOdds',
'CalibEqOdds',
'random',
'adversarial'
],
'CalibEqOdds': {'cost_constraint': 'fpr'},
'random': {'num_trials': 201},
'adversarial': {'epochs': 16, 'critic_steps': 201, 'actor_steps': 101, 'batch_size': 64, 'lambda': 0.75},
'hyperparameters': hyperparam
}
with open(f'{args.dataset}_{args.metric}_{args.protected}/config_{args.dataset}_{args.metric}_{args.protected}_{i}_baselines.yaml', 'w') as fh:
yaml.dump(baselines_config, fh)
if __name__ == "__main__":
""" This is executed when run from the command line """
parser = argparse.ArgumentParser()
parser.add_argument("dataset", help="Which dataset")
parser.add_argument("metric", help="which metric")
parser.add_argument("protected", help="which protected")
parser.add_argument("num_runs", help="Number of runs")
parser.add_argument("architecture_type", help="Type of Architecture")
args = parser.parse_args()
main(args)