• Najnowsze pytania
  • Bez odpowiedzi
  • Zadaj pytanie
  • Kategorie
  • Tagi
  • Zdobyte punkty
  • Ekipa ninja
  • IRC
  • FAQ
  • Regulamin
  • Książki warte uwagi

Java - deeplearning4j - EMNIST - słaby rezultat

+1 głos
89 wizyt
pytanie zadane 16 kwietnia w Java przez Tomasz Rogalski Obywatel (1,690 p.)
import java.io.File;
import java.io.FileOutputStream;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Learning {

    private static Logger log = LoggerFactory.getLogger(Learning.class);

    public static void main(String[] args) throws Exception {
        int outputNum = 10; // number of output classes
        int batchSize = 128; // batch size for each epoch
        int rngSeed = 123; // random number seed for reproducibility
        int numEpochs = 15; // number of epochs to perform
//        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
//        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
        DataSetIterator mnistTrain = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, true);
        DataSetIterator mnistTest = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, false);
//
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .iterations(5)
                .learningRate(0.006)
                .updater(Updater.NESTEROVS)
                .momentum(0.9)
                .regularization(true).l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(28 * 28)
                        .nOut(1800)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(1, new DenseLayer.Builder()
                        .nIn(1800)
                        .nOut(300)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(300)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .pretrain(false).backprop(true) //use backpropagation to adjust weights
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //print the score with every 1 iteration
        model.setListeners(new ScoreIterationListener(100));

        log.info("Train model....");
        for (int i = 0; i < numEpochs; i++) {
            model.fit(mnistTrain);
        }

        File tempFile = new File("C:/Users/MaFiGat/Desktop/test/EmnistDigits");
        tempFile.deleteOnExit();
        FileOutputStream fos = new FileOutputStream(tempFile);

        ModelSerializer.writeModel(model, fos, true);
        MultiLayerNetwork xxmodel = ModelSerializer.restoreMultiLayerNetwork(tempFile,false);
        System.out.println("Evaluate model....");
        Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
        while (mnistTest.hasNext()) {
            DataSet next = mnistTest.next();
            INDArray output = xxmodel.output(next.getFeatures()); //get the networks prediction
            eval.eval(next.getLabels(), output); //check the prediction against the true class
        }
         System.out.println(eval.stats());
        log.info("****************Example finished********************");

    }
}
//22:24:33,986 INFO  ~ Score at iteration 140500 is 2.302849665167757
//22:24:39,534 INFO  ~ Score at iteration 140600 is 2.311589049670205
//Evaluate model....
//
//Examples labeled as 0 classified by model as 2: 4 times
//Examples labeled as 0 classified by model as 3: 3996 times
//Examples labeled as 1 classified by model as 2: 19 times
//Examples labeled as 1 classified by model as 3: 3940 times
//Examples labeled as 1 classified by model as 8: 38 times
//Examples labeled as 1 classified by model as 9: 3 times
//Examples labeled as 2 classified by model as 2: 5 times
//Examples labeled as 2 classified by model as 3: 3995 times
//Examples labeled as 3 classified by model as 3: 3990 times
//Examples labeled as 3 classified by model as 8: 10 times
//Examples labeled as 4 classified by model as 3: 3985 times
//Examples labeled as 4 classified by model as 8: 15 times
//Examples labeled as 5 classified by model as 2: 15 times
//Examples labeled as 5 classified by model as 3: 3925 times
//Examples labeled as 5 classified by model as 5: 8 times
//Examples labeled as 5 classified by model as 8: 52 times
//Examples labeled as 6 classified by model as 2: 1 times
//Examples labeled as 6 classified by model as 3: 3994 times
//Examples labeled as 6 classified by model as 4: 1 times
//Examples labeled as 6 classified by model as 8: 4 times
//Examples labeled as 7 classified by model as 2: 1 times
//Examples labeled as 7 classified by model as 3: 3997 times
//Examples labeled as 7 classified by model as 8: 2 times
//Examples labeled as 8 classified by model as 3: 3997 times
//Examples labeled as 8 classified by model as 8: 3 times
//Examples labeled as 9 classified by model as 2: 2 times
//Examples labeled as 9 classified by model as 3: 3996 times
//Examples labeled as 9 classified by model as 8: 2 times
//
//Warning: 4 classes were never predicted by the model and were excluded from average precision
//Classes excluded from average precision: [0, 1, 6, 7]
//
//==========================Scores========================================
// # of classes:    10
// Accuracy:        0,1002
// Precision:       0,2051	(4 classes excluded from average)
// Recall:          0,1002
// F1 Score:        0,0317	(4 classes excluded from average)
//Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
//========================================================================
//------------------------------------------------------------------------
//BUILD SUCCESS
//------------------------------------------------------------------------
//Total time: 2:00:00.329s
//Finished at: Mon Apr 16 22:24:45 CEST 2018
//Final Memory: 12M/221M

Skąd taki słaby rezultat? Na bazie MNIST po 5 minutach mam bardzo dużą trafność a tu nic, lipa.

1 odpowiedź

0 głosów
odpowiedź 17 kwietnia przez Wiciorny Nałogowiec (44,360 p.)
Nie znam się na Deep Lerningu ani na  wyciąganiu danych z baz MNIST, ale byc może to wina " lokalnych ustawień maszyny" stąd względnie będą inne rezultaty?

Tez zależy gdzie ten rezultat jest inny: czy ogólnie z cmd, czy na bazie twojego środowiska np Intellij - wtedy wpływ na to bedzie mialo srodowisko jego cashe i zapisy, logi.... które często śmiecą
komentarz 17 kwietnia przez Tomasz Rogalski Obywatel (1,690 p.)

Problemem jest albo baza EMNIST abo złe parametry nauczania. Ewentualnie czas.

wtedy wpływ na to bedzie mialo srodowisko jego cashe i zapisy, logi.... które często śmiecą

Nie, to wykluczone. 

komentarz 17 kwietnia przez Wiciorny Nałogowiec (44,360 p.)
zaciekawił mnie ten temat w sumie poszperam co nie co
komentarz 17 kwietnia przez Tomasz Rogalski Obywatel (1,690 p.)
To - https://deeplearning4j.org/mnist-for-beginners

próbuje przerobic na EmnistDataSetIterator.Set.DIGITS

linie w moim zadaniu 32-35. W domu zamierzam zredukować baze i ilość klas wyjściowych aby liczenie szło szybciej bo aktualnie pracowanie nad tym jest ciężkie z racji ilości czasu potrzebnej na nauczanie.

Podobne pytania

0 głosów
1 odpowiedź 43 wizyt
pytanie zadane 4 maja 2017 w PHP, Symfony, Zend przez Maticzpl Użytkownik (590 p.)
0 głosów
2 odpowiedzi 84 wizyt
–1 głos
3 odpowiedzi 91 wizyt
Porady nie od parady
Zadając pytanie postaraj się o odpowiedni tytuł, kategorię oraz tagi.
Ciekawy innych porad? Odwiedź tę stronę!

48,556 zapytań

90,186 odpowiedzi

181,559 komentarzy

23,506 pasjonatów

Przeglądających: 171
Pasjonatów: 9 Gości: 162

Motyw:

Akcja Pajacyk

Pajacyk od wielu lat dożywia dzieci. Pomóż klikając w zielony brzuszek na stronie. Dziękujemy! ♡

Oto dwie polecane książki warte uwagi. Pełną listę znajdziesz tutaj.

...