Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Uni Rostock PyTorch Framework
The `uros_pf` framework intents to reduce the setup time for new machine
learning projects based on pytorch.
It unifies standard procedures, such as configuring, input/output
procedures, training, logging etc., that are used in all projects.
## Basic principles
- Files in the `uros_pf` package are of general interest. They must
not be adapted to a certain project.
- Project specific source code goes into the `scenario` folder.
- Data / configuration files / model checkpoints etc. are stored in a
separate workdir folder.
- The input processor, the model and all hyperparameter shall be configured
by the config file (a `yaml`-file). It is usually stored in the workdir and
passed to the trainer by `-cn config_name` (`config_name` without extension `.yaml`).
## Sample project
A working configuration file (named `ag_linear_config.yaml`) is
```yaml
builder:
input: "scenario.ag_news_corpus.ag_ip.AGInputProcessor"
model: "scenario.ag_news_corpus.ag_simple_model.SimpleModel"
trainer:
epochs: 10
input:
feature_size: 1000
train_file: "data/ag_news_corpus/train.csv"
val_file: "data/ag_news_corpus/test.csv"
batch_size: 10
samples_per_epoch: 12000
model:
num_of_classes: 4
loss_fn: "torch.nn.MSELoss"
metric_fns: "uros_pf.metrics.accuracy.Accuracy"
module_cls: "scenario.ag_news_corpus.ag_least_square_module.AGLeastSquareModule"
optimizer: "torch.optim.SGD"
lr: 0.01
module:
feature_size: ${input.feature_size}
num_of_classes: ${model.num_of_classes}
```
The structure of the workdir is:
```commandline
├── data
│ ├── ag_news_corpus
│ │ ├── test.csv
│ │ ├── train.csv
├── config
│ ├── ag_linear_config.yaml
```
whereas the data is taken from
[mhjabreel's github account](https://github.com/mhjabreel/CharCnn_Keras/tree/master/data/ag_news_csv).
To run the example project just execute
`python3 path/to/src/uros_pf/trainer/trainer.py -cn ag_linear_config`
from the work dir.