import java.io.File;
import java.io.FileOutputStream;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class Learning {
private static Logger log = LoggerFactory.getLogger(Learning.class);
public static void main(String[] args) throws Exception {
int outputNum = 10; // number of output classes
int batchSize = 128; // batch size for each epoch
int rngSeed = 123; // random number seed for reproducibility
int numEpochs = 15; // number of epochs to perform
// DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
// DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
DataSetIterator mnistTrain = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, true);
DataSetIterator mnistTest = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, false);
//
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(5)
.learningRate(0.006)
.updater(Updater.NESTEROVS)
.momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(28 * 28)
.nOut(1800)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new DenseLayer.Builder()
.nIn(1800)
.nOut(300)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.build())
.layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(300)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.build())
.pretrain(false).backprop(true) //use backpropagation to adjust weights
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//print the score with every 1 iteration
model.setListeners(new ScoreIterationListener(100));
log.info("Train model....");
for (int i = 0; i < numEpochs; i++) {
model.fit(mnistTrain);
}
File tempFile = new File("C:/Users/MaFiGat/Desktop/test/EmnistDigits");
tempFile.deleteOnExit();
FileOutputStream fos = new FileOutputStream(tempFile);
ModelSerializer.writeModel(model, fos, true);
MultiLayerNetwork xxmodel = ModelSerializer.restoreMultiLayerNetwork(tempFile,false);
System.out.println("Evaluate model....");
Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
while (mnistTest.hasNext()) {
DataSet next = mnistTest.next();
INDArray output = xxmodel.output(next.getFeatures()); //get the networks prediction
eval.eval(next.getLabels(), output); //check the prediction against the true class
}
System.out.println(eval.stats());
log.info("****************Example finished********************");
}
}
//22:24:33,986 INFO ~ Score at iteration 140500 is 2.302849665167757
//22:24:39,534 INFO ~ Score at iteration 140600 is 2.311589049670205
//Evaluate model....
//
//Examples labeled as 0 classified by model as 2: 4 times
//Examples labeled as 0 classified by model as 3: 3996 times
//Examples labeled as 1 classified by model as 2: 19 times
//Examples labeled as 1 classified by model as 3: 3940 times
//Examples labeled as 1 classified by model as 8: 38 times
//Examples labeled as 1 classified by model as 9: 3 times
//Examples labeled as 2 classified by model as 2: 5 times
//Examples labeled as 2 classified by model as 3: 3995 times
//Examples labeled as 3 classified by model as 3: 3990 times
//Examples labeled as 3 classified by model as 8: 10 times
//Examples labeled as 4 classified by model as 3: 3985 times
//Examples labeled as 4 classified by model as 8: 15 times
//Examples labeled as 5 classified by model as 2: 15 times
//Examples labeled as 5 classified by model as 3: 3925 times
//Examples labeled as 5 classified by model as 5: 8 times
//Examples labeled as 5 classified by model as 8: 52 times
//Examples labeled as 6 classified by model as 2: 1 times
//Examples labeled as 6 classified by model as 3: 3994 times
//Examples labeled as 6 classified by model as 4: 1 times
//Examples labeled as 6 classified by model as 8: 4 times
//Examples labeled as 7 classified by model as 2: 1 times
//Examples labeled as 7 classified by model as 3: 3997 times
//Examples labeled as 7 classified by model as 8: 2 times
//Examples labeled as 8 classified by model as 3: 3997 times
//Examples labeled as 8 classified by model as 8: 3 times
//Examples labeled as 9 classified by model as 2: 2 times
//Examples labeled as 9 classified by model as 3: 3996 times
//Examples labeled as 9 classified by model as 8: 2 times
//
//Warning: 4 classes were never predicted by the model and were excluded from average precision
//Classes excluded from average precision: [0, 1, 6, 7]
//
//==========================Scores========================================
// # of classes: 10
// Accuracy: 0,1002
// Precision: 0,2051 (4 classes excluded from average)
// Recall: 0,1002
// F1 Score: 0,0317 (4 classes excluded from average)
//Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
//========================================================================
//------------------------------------------------------------------------
//BUILD SUCCESS
//------------------------------------------------------------------------
//Total time: 2:00:00.329s
//Finished at: Mon Apr 16 22:24:45 CEST 2018
//Final Memory: 12M/221M
Skąd taki słaby rezultat? Na bazie MNIST po 5 minutach mam bardzo dużą trafność a tu nic, lipa.