
Network training
First, we create a MultiLayerNetwork using the preceding MultiLayerConfiguration. Then we initialize the network and start the training on the training set:
MultiLayerNetwork model = new MultiLayerNetwork(MLPconf);
model.init();
log.info("Train model....");
for( int i=0; i<numEpochs; i++ ){
model.fit(trainingDataIt);
}
In the preceding code block, we start training the model by invoking the model.fit() on the training set (trainingDataIt in our case). Now we will discuss how we prepared the training and test set. Well, for reading the training set or test set that are in an inappropriate format (features are numeric and labels are integers), I have created a method called readCSVDataset():
private static DataSetIterator readCSVDataset(String csvFileClasspath, int batchSize,
int labelIndex, int numClasses) throws IOException, InterruptedException {
RecordReader rr = new CSVRecordReader();
File input = new File(csvFileClasspath);
rr.initialize(new FileSplit(input));
DataSetIterator iterator = new RecordReaderDataSetIterator(rr, batchSize, labelIndex, numClasses);
return iterator;
}
If you see the previous code block, you can realize that it is basically a wrapper that reads the data in CSV format, and then the RecordReaderDataSetIterator() method converts the record reader as a dataset iterator. Technically, RecordReaderDataSetIterator() is the main constructor for classification. It takes the following parameters:
- RecordReader: This is the RecordReader that provides the source of the data
- batchSize: Batch size (that is, number of examples) for the output DataSet objects
- labelIndex: The index of the label writable (usually an IntWritable) as obtained by recordReader.next()
- numPossibleLabels: The number of classes (possible labels) for classification
This will then convert the input class index (at position labelIndex, with integer values 0 to numPossibleLabels-1, inclusive) to the appropriate one-hot output/labels representation. So let's see how to proceed. First, we show the path of training and test sets:
String trainPath = "data/Titanic_Train.csv";
String testPath = "data/Titanic_Test.csv";
int labelIndex = 7; // First 7 features are followed by the labels in integer
int numClasses = 2; // number of classes to be predicted -i.e survived or not-survived
int numEpochs = 1000; // Number of training eopich
int seed = 123; // Randome seed for reproducibilty
int numInputs = labelIndex; // Number of inputs in input layer
int numOutputs = numClasses; // Number of classes to be predicted by the network
int batchSizeTraining = 128;
Now let's prepare the data we want to use for training:
DataSetIterator trainingDataIt = readCSVDataset(trainPath, batchSizeTraining, labelIndex, numClasses);
Next, let's prepare the data we want to classify:
int batchSizeTest = 128;
DataSetIterator testDataIt = readCSVDataset(testPath, batchSizeTest, labelIndex, numClasses);
Fantastic! We have managed to prepare the training and test DataSetIterator. Remember, we will be following nearly the same approach to prepare the training and test sets for other problems too.