Using Julia Flux to build a simple neural network

admin

Administrator
Staff member
I have a dataset of images (<a href="https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria" rel="nofollow noreferrer">https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria</a>), and I want to use a neural network to know if one picture is a uninfected cell or not.
So I arranged my data to get 4 variables :

Code:
X_tests, Y_tests, X_training, Y_training

Each of these variable is of type
Code:
Array{Array{Float64,1},1}

And I have a function to build a simple neural network (that comes from an example <a href="https://smist08.wordpress.com/2018/09/24/julia-flux-for-machine-learning/" rel="nofollow noreferrer">https://smist08.wordpress.com/2018/09/24/julia-flux-for-machine-learning/</a>):

Code:
function simple_nn(X_tests, Y_tests, X_training, Y_training)
    input = 100*100*3
    hl1 = 32
    m = Chain(
      Dense(input, 32, relu),
      Dense(32, 2),
      softmax) |&gt; gpu

    loss(x, y) = crossentropy(m(x), y)

    accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))

    dataset = [(X_training,Y_training)]
    evalcb = () -&gt; @show(loss(X_training, Y_training))
    opt = ADAM(params(m))

    Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))

    println("acc X,Y ", accuracy(X_training, Y_training))

    println("acc tX, tY ", accuracy(X_tests, Y_tests))
end

And I get this error after executing
Code:
simple_nn(X_tests, Y_tests, X_training, Y_training)
:

Code:
ERROR: DimensionMismatch("matrix A has dimensions (32,30000), vector B has length 2668")
...

The error is on this line :
Code:
Flux.train!(loss, dataset, opt, cb = throttle(evalcb, 10))

I don't know what the functions are doing, what argument they take, what they are returning and I can't find any documentation on the internet. I can only find examples.
So I have two questions : How can I make this work for my dataset? And Is there a documentation for Flux functions, like for sklearn? (like this for example : <a href="https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html" rel="nofollow noreferrer">https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html</a>)