Skip to content
README.md 2.01 KiB
Newer Older
Tobias Strauß's avatar
Tobias Strauß committed
# Uni Rostock PyTorch Framework

The `uros_pf` framework intents to reduce the setup time for new machine 
learning projects based on pytorch. 
It unifies standard procedures, such as configuring, input/output
procedures, training, logging etc., that are used in all projects.

## Basic principles

- Files in the `uros_pf` package are of general interest. They must
not be adapted to a certain project.
- Project specific source code goes into the `scenario` folder.
- Data / configuration files / model checkpoints etc. are stored in a 
separate workdir folder.
- The input processor, the model and all hyperparameter shall be configured
by the config file (a `yaml`-file). It is usually stored in the workdir and 
passed to the trainer by `-cn config_name` (`config_name` without extension `.yaml`).

## Sample project

A working configuration file (named `ag_linear_config.yaml`) is
```yaml
builder:
    input: "scenario.ag_news_corpus.ag_ip.AGInputProcessor"
    model: "scenario.ag_news_corpus.ag_simple_model.SimpleModel"
trainer:
    epochs: 10
input:
    feature_size: 1000
    train_file: "data/ag_news_corpus/train.csv"
    val_file: "data/ag_news_corpus/test.csv"
    batch_size: 10
    samples_per_epoch: 12000
model:
    num_of_classes: 4
    loss_fn: "torch.nn.MSELoss"
    metric_fns: "uros_pf.metrics.accuracy.Accuracy"
    module_cls: "scenario.ag_news_corpus.ag_least_square_module.AGLeastSquareModule"
    optimizer: "torch.optim.SGD"
    lr: 0.01
    module:
        feature_size: ${input.feature_size}
        num_of_classes: ${model.num_of_classes}

```
The structure of the workdir is: 
```commandline
├── data
│   ├── ag_news_corpus
│   │   ├── test.csv
│   │   ├── train.csv
├── config
│   ├── ag_linear_config.yaml
```
whereas the data is taken from 
[mhjabreel's github account](https://github.com/mhjabreel/CharCnn_Keras/tree/master/data/ag_news_csv).

To run the example project just execute 
`python3 path/to/src/uros_pf/trainer/trainer.py -cn ag_linear_config`
from the work dir.