• Najnowsze pytania
  • Bez odpowiedzi
  • Zadaj pytanie
  • Kategorie
  • Tagi
  • Zdobyte punkty
  • Ekipa ninja
  • IRC
  • FAQ
  • Regulamin
  • Książki warte uwagi

question-closed Java - deeplearning4j - EMNIST - słaby rezultat

VPS Starter Arubacloud
+1 głos
399 wizyt
pytanie zadane 16 kwietnia 2018 w Java przez Tomasz Rogalski Bywalec (2,800 p.)
zamknięte 27 kwietnia 2018 przez Tomasz Rogalski
import java.io.File;
import java.io.FileOutputStream;
import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Learning {

    private static Logger log = LoggerFactory.getLogger(Learning.class);

    public static void main(String[] args) throws Exception {
        int outputNum = 10; // number of output classes
        int batchSize = 128; // batch size for each epoch
        int rngSeed = 123; // random number seed for reproducibility
        int numEpochs = 15; // number of epochs to perform
//        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
//        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
        DataSetIterator mnistTrain = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, true);
        DataSetIterator mnistTest = new EmnistDataSetIterator(EmnistDataSetIterator.Set.DIGITS, batchSize, false);
//
        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .iterations(5)
                .learningRate(0.006)
                .updater(Updater.NESTEROVS)
                .momentum(0.9)
                .regularization(true).l2(1e-4)
                .list()
                .layer(0, new DenseLayer.Builder()
                        .nIn(28 * 28)
                        .nOut(1800)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(1, new DenseLayer.Builder()
                        .nIn(1800)
                        .nOut(300)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(2, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nIn(300)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .pretrain(false).backprop(true) //use backpropagation to adjust weights
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //print the score with every 1 iteration
        model.setListeners(new ScoreIterationListener(100));

        log.info("Train model....");
        for (int i = 0; i < numEpochs; i++) {
            model.fit(mnistTrain);
        }

        File tempFile = new File("C:/Users/MaFiGat/Desktop/test/EmnistDigits");
        tempFile.deleteOnExit();
        FileOutputStream fos = new FileOutputStream(tempFile);

        ModelSerializer.writeModel(model, fos, true);
        MultiLayerNetwork xxmodel = ModelSerializer.restoreMultiLayerNetwork(tempFile,false);
        System.out.println("Evaluate model....");
        Evaluation eval = new Evaluation(outputNum); //create an evaluation object with 10 possible classes
        while (mnistTest.hasNext()) {
            DataSet next = mnistTest.next();
            INDArray output = xxmodel.output(next.getFeatures()); //get the networks prediction
            eval.eval(next.getLabels(), output); //check the prediction against the true class
        }
         System.out.println(eval.stats());
        log.info("****************Example finished********************");

    }
}
//22:24:33,986 INFO  ~ Score at iteration 140500 is 2.302849665167757
//22:24:39,534 INFO  ~ Score at iteration 140600 is 2.311589049670205
//Evaluate model....
//
//Examples labeled as 0 classified by model as 2: 4 times
//Examples labeled as 0 classified by model as 3: 3996 times
//Examples labeled as 1 classified by model as 2: 19 times
//Examples labeled as 1 classified by model as 3: 3940 times
//Examples labeled as 1 classified by model as 8: 38 times
//Examples labeled as 1 classified by model as 9: 3 times
//Examples labeled as 2 classified by model as 2: 5 times
//Examples labeled as 2 classified by model as 3: 3995 times
//Examples labeled as 3 classified by model as 3: 3990 times
//Examples labeled as 3 classified by model as 8: 10 times
//Examples labeled as 4 classified by model as 3: 3985 times
//Examples labeled as 4 classified by model as 8: 15 times
//Examples labeled as 5 classified by model as 2: 15 times
//Examples labeled as 5 classified by model as 3: 3925 times
//Examples labeled as 5 classified by model as 5: 8 times
//Examples labeled as 5 classified by model as 8: 52 times
//Examples labeled as 6 classified by model as 2: 1 times
//Examples labeled as 6 classified by model as 3: 3994 times
//Examples labeled as 6 classified by model as 4: 1 times
//Examples labeled as 6 classified by model as 8: 4 times
//Examples labeled as 7 classified by model as 2: 1 times
//Examples labeled as 7 classified by model as 3: 3997 times
//Examples labeled as 7 classified by model as 8: 2 times
//Examples labeled as 8 classified by model as 3: 3997 times
//Examples labeled as 8 classified by model as 8: 3 times
//Examples labeled as 9 classified by model as 2: 2 times
//Examples labeled as 9 classified by model as 3: 3996 times
//Examples labeled as 9 classified by model as 8: 2 times
//
//Warning: 4 classes were never predicted by the model and were excluded from average precision
//Classes excluded from average precision: [0, 1, 6, 7]
//
//==========================Scores========================================
// # of classes:    10
// Accuracy:        0,1002
// Precision:       0,2051	(4 classes excluded from average)
// Recall:          0,1002
// F1 Score:        0,0317	(4 classes excluded from average)
//Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
//========================================================================
//------------------------------------------------------------------------
//BUILD SUCCESS
//------------------------------------------------------------------------
//Total time: 2:00:00.329s
//Finished at: Mon Apr 16 22:24:45 CEST 2018
//Final Memory: 12M/221M

Skąd taki słaby rezultat? Na bazie MNIST po 5 minutach mam bardzo dużą trafność a tu nic, lipa.

komentarz zamknięcia: Sam znalazłem rozwiązanie

2 odpowiedzi

+1 głos
odpowiedź 27 kwietnia 2018 przez Tomasz Rogalski Bywalec (2,800 p.)
Odpowiedzią na problem jest użycie innego iteratora. Próbowałem po prostu przerobić Iterator, który był tylko do tego jednego przykładu na pokaz. Rozwiązanie pojawiło się wraz ze znalezieniem tego filmu(serii):

https://www.youtube.com/watch?v=2lwsHKUrXMk&list=PL9iheGibFMtrJZJFrTIdd-QNQyjDX5v2X

I to rozwiązuje problem:)
0 głosów
odpowiedź 17 kwietnia 2018 przez Wiciorny Ekspert (269,120 p.)
Nie znam się na Deep Lerningu ani na  wyciąganiu danych z baz MNIST, ale byc może to wina " lokalnych ustawień maszyny" stąd względnie będą inne rezultaty?

Tez zależy gdzie ten rezultat jest inny: czy ogólnie z cmd, czy na bazie twojego środowiska np Intellij - wtedy wpływ na to bedzie mialo srodowisko jego cashe i zapisy, logi.... które często śmiecą
komentarz 17 kwietnia 2018 przez Tomasz Rogalski Bywalec (2,800 p.)

Problemem jest albo baza EMNIST abo złe parametry nauczania. Ewentualnie czas.

wtedy wpływ na to bedzie mialo srodowisko jego cashe i zapisy, logi.... które często śmiecą

Nie, to wykluczone. 

komentarz 17 kwietnia 2018 przez Wiciorny Ekspert (269,120 p.)
zaciekawił mnie ten temat w sumie poszperam co nie co
komentarz 17 kwietnia 2018 przez Tomasz Rogalski Bywalec (2,800 p.)
To - https://deeplearning4j.org/mnist-for-beginners

próbuje przerobic na EmnistDataSetIterator.Set.DIGITS

linie w moim zadaniu 32-35. W domu zamierzam zredukować baze i ilość klas wyjściowych aby liczenie szło szybciej bo aktualnie pracowanie nad tym jest ciężkie z racji ilości czasu potrzebnej na nauczanie.

Podobne pytania

0 głosów
1 odpowiedź 398 wizyt
0 głosów
1 odpowiedź 118 wizyt
pytanie zadane 4 maja 2017 w PHP przez niezalogowany
0 głosów
2 odpowiedzi 191 wizyt

92,453 zapytań

141,262 odpowiedzi

319,085 komentarzy

61,854 pasjonatów

Motyw:

Akcja Pajacyk

Pajacyk od wielu lat dożywia dzieci. Pomóż klikając w zielony brzuszek na stronie. Dziękujemy! ♡

Oto polecana książka warta uwagi.
Pełną listę książek znajdziesz tutaj.

Akademia Sekuraka

Akademia Sekuraka 2024 zapewnia dostęp do minimum 15 szkoleń online z bezpieczeństwa IT oraz dostęp także do materiałów z edycji Sekurak Academy z roku 2023!

Przy zakupie możecie skorzystać z kodu: pasja-akademia - użyjcie go w koszyku, a uzyskacie rabat -30% na bilety w wersji "Standard"! Więcej informacji na temat akademii 2024 znajdziecie tutaj. Dziękujemy ekipie Sekuraka za taką fajną zniżkę dla wszystkich Pasjonatów!

Akademia Sekuraka

Niedawno wystartował dodruk tej świetnej, rozchwytywanej książki (około 940 stron). Mamy dla Was kod: pasja (wpiszcie go w koszyku), dzięki któremu otrzymujemy 10% zniżki - dziękujemy zaprzyjaźnionej ekipie Sekuraka za taki bonus dla Pasjonatów! Książka to pierwszy tom z serii o ITsec, który łagodnie wprowadzi w świat bezpieczeństwa IT każdą osobę - warto, polecamy!

...