• Najnowsze pytania
  • Bez odpowiedzi
  • Zadaj pytanie
  • Kategorie
  • Tagi
  • Zdobyte punkty
  • Ekipa ninja
  • IRC
  • FAQ
  • Regulamin
  • Książki warte uwagi

question-closed Java - deeplearning4j - EMNIST - słaby rezultat

+1 głos
123 wizyt
pytanie zadane 16 kwietnia w Java przez Tomasz Rogalski Bywalec (2,060 p.)
zamknięte 27 kwietnia przez Tomasz Rogalski
import java.io.File;
import java.io.FileOutputStream;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Learning {

    private static Logger log = LoggerFactory.getLogger(Learning.class);

    public static void main(String[] args) throws Exception {
        int outputNum = 10; // number of output classes
        int batchSize = 128; // batch size for each epoch
        int rngSeed = 123; // random number seed for reproducibility
        int numEpochs = 15; // number of epochs to perform
//        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
//        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
        DataSetIterator mnistTrain = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, true);
        DataSetIterator mnistTest = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, false);
//
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .iterations(5)
                .learningRate(0.006)
                .updater(Updater.NESTEROVS)
                .momentum(0.9)
                .regularization(true).l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(28 * 28)
                        .nOut(1800)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(1, new DenseLayer.Builder()
                        .nIn(1800)
                        .nOut(300)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(300)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .pretrain(false).backprop(true) //use backpropagation to adjust weights
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //print the score with every 1 iteration
        model.setListeners(new ScoreIterationListener(100));

        log.info("Train model....");
        for (int i = 0; i < numEpochs; i++) {
            model.fit(mnistTrain);
        }

        File tempFile = new File("C:/Users/MaFiGat/Desktop/test/EmnistDigits");
        tempFile.deleteOnExit();
        FileOutputStream fos = new FileOutputStream(tempFile);

        ModelSerializer.writeModel(model, fos, true);
        MultiLayerNetwork xxmodel = ModelSerializer.restoreMultiLayerNetwork(tempFile,false);
        System.out.println("Evaluate model....");
        Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
        while (mnistTest.hasNext()) {
            DataSet next = mnistTest.next();
            INDArray output = xxmodel.output(next.getFeatures()); //get the networks prediction
            eval.eval(next.getLabels(), output); //check the prediction against the true class
        }
         System.out.println(eval.stats());
        log.info("****************Example finished********************");

    }
}
//22:24:33,986 INFO  ~ Score at iteration 140500 is 2.302849665167757
//22:24:39,534 INFO  ~ Score at iteration 140600 is 2.311589049670205
//Evaluate model....
//
//Examples labeled as 0 classified by model as 2: 4 times
//Examples labeled as 0 classified by model as 3: 3996 times
//Examples labeled as 1 classified by model as 2: 19 times
//Examples labeled as 1 classified by model as 3: 3940 times
//Examples labeled as 1 classified by model as 8: 38 times
//Examples labeled as 1 classified by model as 9: 3 times
//Examples labeled as 2 classified by model as 2: 5 times
//Examples labeled as 2 classified by model as 3: 3995 times
//Examples labeled as 3 classified by model as 3: 3990 times
//Examples labeled as 3 classified by model as 8: 10 times
//Examples labeled as 4 classified by model as 3: 3985 times
//Examples labeled as 4 classified by model as 8: 15 times
//Examples labeled as 5 classified by model as 2: 15 times
//Examples labeled as 5 classified by model as 3: 3925 times
//Examples labeled as 5 classified by model as 5: 8 times
//Examples labeled as 5 classified by model as 8: 52 times
//Examples labeled as 6 classified by model as 2: 1 times
//Examples labeled as 6 classified by model as 3: 3994 times
//Examples labeled as 6 classified by model as 4: 1 times
//Examples labeled as 6 classified by model as 8: 4 times
//Examples labeled as 7 classified by model as 2: 1 times
//Examples labeled as 7 classified by model as 3: 3997 times
//Examples labeled as 7 classified by model as 8: 2 times
//Examples labeled as 8 classified by model as 3: 3997 times
//Examples labeled as 8 classified by model as 8: 3 times
//Examples labeled as 9 classified by model as 2: 2 times
//Examples labeled as 9 classified by model as 3: 3996 times
//Examples labeled as 9 classified by model as 8: 2 times
//
//Warning: 4 classes were never predicted by the model and were excluded from average precision
//Classes excluded from average precision: [0, 1, 6, 7]
//
//==========================Scores========================================
// # of classes:    10
// Accuracy:        0,1002
// Precision:       0,2051	(4 classes excluded from average)
// Recall:          0,1002
// F1 Score:        0,0317	(4 classes excluded from average)
//Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
//========================================================================
//------------------------------------------------------------------------
//BUILD SUCCESS
//------------------------------------------------------------------------
//Total time: 2:00:00.329s
//Finished at: Mon Apr 16 22:24:45 CEST 2018
//Final Memory: 12M/221M

Skąd taki słaby rezultat? Na bazie MNIST po 5 minutach mam bardzo dużą trafność a tu nic, lipa.

komentarz zamknięcia: Sam znalazłem rozwiązanie

2 odpowiedzi

+1 głos
odpowiedź 27 kwietnia przez Tomasz Rogalski Bywalec (2,060 p.)
Odpowiedzią na problem jest użycie innego iteratora. Próbowałem po prostu przerobić Iterator, który był tylko do tego jednego przykładu na pokaz. Rozwiązanie pojawiło się wraz ze znalezieniem tego filmu(serii):

https://www.youtube.com/watch?v=2lwsHKUrXMk&list=PL9iheGibFMtrJZJFrTIdd-QNQyjDX5v2X

I to rozwiązuje problem:)
0 głosów
odpowiedź 17 kwietnia przez Wiciorny Maniak (53,000 p.)
Nie znam się na Deep Lerningu ani na  wyciąganiu danych z baz MNIST, ale byc może to wina " lokalnych ustawień maszyny" stąd względnie będą inne rezultaty?

Tez zależy gdzie ten rezultat jest inny: czy ogólnie z cmd, czy na bazie twojego środowiska np Intellij - wtedy wpływ na to bedzie mialo srodowisko jego cashe i zapisy, logi.... które często śmiecą
komentarz 17 kwietnia przez Tomasz Rogalski Bywalec (2,060 p.)

Problemem jest albo baza EMNIST abo złe parametry nauczania. Ewentualnie czas.

wtedy wpływ na to bedzie mialo srodowisko jego cashe i zapisy, logi.... które często śmiecą

Nie, to wykluczone. 

komentarz 17 kwietnia przez Wiciorny Maniak (53,000 p.)
zaciekawił mnie ten temat w sumie poszperam co nie co
komentarz 17 kwietnia przez Tomasz Rogalski Bywalec (2,060 p.)
To - https://deeplearning4j.org/mnist-for-beginners

próbuje przerobic na EmnistDataSetIterator.Set.DIGITS

linie w moim zadaniu 32-35. W domu zamierzam zredukować baze i ilość klas wyjściowych aby liczenie szło szybciej bo aktualnie pracowanie nad tym jest ciężkie z racji ilości czasu potrzebnej na nauczanie.

Podobne pytania

0 głosów
1 odpowiedź 45 wizyt
pytanie zadane 4 maja 2017 w PHP, Symfony, Zend przez Maticzpl Użytkownik (590 p.)
0 głosów
2 odpowiedzi 85 wizyt
–1 głos
3 odpowiedzi 91 wizyt
Porady nie od parady
Wynikowy wygląd pytania, odpowiedzi czy komentarza, różni się od tego zaprezentowanego w edytorze postów. Stosuj więc funkcję Podgląd posta znajdującą się pod edytorem, aby upewnić się, czy na pewno ostateczny rezultat ci odpowiada.Podgląd posta

51,917 zapytań

94,689 odpowiedzi

193,022 komentarzy

25,307 pasjonatów

Przeglądających: 282
Pasjonatów: 21 Gości: 261

Motyw:

Akcja Pajacyk

Pajacyk od wielu lat dożywia dzieci. Pomóż klikając w zielony brzuszek na stronie. Dziękujemy! ♡

Oto dwie polecane książki warte uwagi. Pełną listę znajdziesz tutaj.

...