Dex: NN Layers
Neural Networks
import plot
NN Prelude
def relu (input : Float) : Float =
select (input > 0.0) input 0.0
instance [Add a, Add b] Add (a & b)
add = \(a, b) (c, d). ( (a + c), (b + d))
sub = \(a, b) (c, d). ( (a - c), (b - d))
zero = (zero, zero)
instance [VSpace a, VSpace b] VSpace (a & b)
scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b)
data Layer inp:Type out:Type params:Type =
AsLayer {forward:(params -> inp -> out) & init:(Key -> params)}
def forward (l:Layer i o p) (p : p) (x : i): o =
(AsLayer l' ) = l
(getAt #forward l') p x
def init (l:Layer i o p) (k:Key) : p =
(AsLayer l') = l
(getAt #init l') k
Layers
Dense layer
def DenseParams (a:Type) (b:Type) : Type =
((a=>b=>Float) & (b=>Float))
def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) =
AsLayer {
forward = (\ ((weight, bias)) x .
for j. (bias.j + sum for i. weight.i.j * x.i)),
init = arb
}
CNN layer
def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type =
((outc=>inc=>Fin kh=>Fin kw=>Float) &
(outc=>Float))
def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float)
(kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) :
outc=>(Fin h)=>(Fin w)=>Float =
for o i j.
(i', j') = (ordinal i, ordinal j)
case (i' + kh) <= h && (j' + kw) <= w of
True ->
sum for (ki, kj, inp).
(di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)),
fromOrdinal (Fin w) (j' + (ordinal kj)))
x.inp.di.dj * kernel.o.inp.ki.kj
False -> zero
def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) :
Layer (inc=>(Fin h)=>(Fin w)=>Float)
(outc=>(Fin h)=>(Fin w)=>Float)
(CNNParams inc outc kw kh) =
AsLayer {
forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o),
init = arb
}
Pooling
def split (x: m=>v) : n=>o=>v =
for i j. x.((ordinal (i,j))@m)
def imtile (x: a=>b=>v) : n=>o=>p=>q=>v =
for kw kh w h. (split (split x).w.kw).h.kh
def meanpool (kh: Type) (kw: Type) (x : m=>n=> Float) : ( h=>w=> Float) =
out : (kh => kw => h => w => Float) = imtile x
mean for (i,j). out.i.j
Simple point classifier
[k1, k2] = splitKey $ newKey 1
x1 : Fin 100 => Float = arb k1
x2 : Fin 100 => Float = arb k2
y = for i. case ((x1.i > 0.0) && (x2.i > 0.0)) || ((x1.i < 0.0) && (x2.i < 0.0)) of
True -> 1
False -> 0
xs = for i. [x1.i, x2.i]
:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i
simple = \h1.
ndense1 = dense (Fin 2) h1
ndense2 = dense h1 (Fin 2)
AsLayer {
forward = (\ (dense1, dense2) x.
x1' = forward ndense1 dense1 x
x1 = for i. relu x1'.i
logsoftmax $ forward ndense2 dense2 x1),
init = (\key.
[k1, k2] = splitKey key
(init ndense1 k1, init ndense2 k2))
}
:t simple
((h1:Type)
-> Layer
((Fin 2) => Float32)
((Fin 2) => Float32)
( (((Fin 2) => h1 => Float32) & (h1 => Float32))
& ((h1 => (Fin 2) => Float32) & ((Fin 2) => Float32))))
Train a multiclass classifier with minibatch SGD
' minibatch * minibatches = batch
def trainClass [VSpace p] (model: Layer a (b=>Float) p)
(x: batch=>a)
(y: batch=>b)
(epochs : Type)
(minibatch : Type)
(minibatches : Type) :
(epochs => p & epochs => Float ) =
xs : minibatches => minibatch => a = split x
ys : minibatches => minibatch => b = split y
unzip $ withState (init model $ newKey 0) $ \params .
for _ : epochs.
loss = sum $ for b : minibatches.
(loss, gradfn) = vjp (\ params.
-sum for j.
result = forward model params xs.b.j
result.(ys.b.j)) (get params)
gparams = gradfn 1.0
params := (get params) - scaleVec (0.05 / (IToF 100)) gparams
loss
(get params, loss)
-- todo : Do I have to give minibatches as a param?
simple_model = simple (Fin 10)
(all_params,losses) = trainClass simple_model xs (for i. (y.i @ (Fin 2))) (Fin 500) (Fin 100) (Fin 1)
span = linspace (Fin 10) (-1.0) (1.0)
tests = for h : (Fin 50). for i . for j.
r = forward simple_model all_params.((ordinal h * 10)@_) [span.i, span.j]
[exp r.(1@_), exp r.(0@_), 0.0]
:html imseqshow tests
LeNet for image classification
H = 28
W = 28
Image = Fin 1 => Fin H => Fin W => Float
Class = Fin 10
lenet = \h1 h2 h3 .
ncnn1 = cnn (Fin 1) h1 3 3
ncnn2 = cnn h1 h2 3 3
Pooled = (h2 & Fin 7 & Fin 7)
ndense1 = dense Pooled h3
ndense2 = dense h3 Class
AsLayer {
forward = (\ (cnn1, cnn2, dense1, dense2) inp.
x:Image = inp
x1' = forward ncnn1 cnn1 x
x1 = for i j k. relu x1'.i.j.k
x2' = forward ncnn2 cnn2 x1
x2 = for i j k. relu x2'.i.j.k
x3 : (h2 => Fin 7 => Fin 7 => Float) = for c. meanpool (Fin 4) (Fin 4) x2.c
x4' = forward ndense1 dense1 for (i,j,k). x3.i.j.k
x4 = for i. relu x4'.i
logsoftmax $ forward ndense2 dense2 x4),
init = (\key.
[k1, k2, k3, k4] = splitKey key
(init ncnn1 k1, init ncnn2 k2,
init ndense1 k3, init ndense2 k4))
}
:t lenet
((h1:Type)
-> (h2:Type)
-> (h3:Type)
-> Layer
((Fin 1) => (Fin 28) => (Fin 28) => Float32)
((Fin 10) => Float32)
( ((h1 => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & (h1 => Float32))
& ( ((h2 => h1 => (Fin 3) => (Fin 3) => Float32) & (h2 => Float32))
& ( (((h2 & (Fin 7 & Fin 7)) => h3 => Float32) & (h3 => Float32))
& ((h3 => (Fin 10) => Float32) & ((Fin 10) => Float32))))))
Data Loading
Batch = Fin 5000
Full = Fin ((size Batch) * H * W)
def pixel (x:Char) : Float32 =
r = W8ToI x
IToF case r < 0 of
True -> (abs r) + 128
False -> r
def getIm : Batch => Image =
(AsList _ im) = unsafeIO do readFile "examples/mnist.bin"
raw = unsafeCastTable Full im
for b: Batch c: (Fin 1) i:(Fin W) j:(Fin H).
pixel raw.((ordinal (b, i, j)) @ Full)
def getLabel : Batch => Class =
(AsList _ im2) = unsafeIO do readFile "examples/labels.bin"
r = unsafeCastTable Batch im2
for i. (W8ToI r.i @ Class)
Training loop
Get binary files from:
wget https://github.com/srush/learns-dex/raw/main/mnist.bin
wget https://github.com/srush/learns-dex/raw/main/labels.bin
Comment out these lines
ims = getIm
labels = getLabel
small_ims = for i: (Fin 10). ims.((ordinal i)@_)
small_labels = for i: (Fin 10). labels.((ordinal i)@_)
:p small_labels
[ (5@Fin 10)
, (0@Fin 10)
, (4@Fin 10)
, (1@Fin 10)
, (9@Fin 10)
, (2@Fin 10)
, (1@Fin 10)
, (3@Fin 10)
, (1@Fin 10)
, (4@Fin 10) ]
Epochs = (Fin 5)
Minibatches = (Fin 1)
Minibatch = (Fin 10)
:t ims.(2@_)
((Fin 1) => (Fin 28) => (Fin 28) => Float32)
model = lenet (Fin 1) (Fin 1) (Fin 20)
init_param = (init model $ newKey 0)
:p forward model init_param (ims.(2@Batch))
[ -6193.269
, -2458.7317
, -1936.6416
, -3916.4082
, -4915.504
, 0.
, -2748.0789
, -3704.304
, -3521.0889
, -1654.1918 ]
Sanity check
:t (grad ((\x param. sum (forward model param x)) (ims.(2@_)))) init_param
( (((Fin 1) => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & ((Fin 1) => Float32))
& ( ( ((Fin 1) => (Fin 1) => (Fin 3) => (Fin 3) => Float32)
& ((Fin 1) => Float32))
& ( ( ((Fin 1 & (Fin 7 & Fin 7)) => (Fin 20) => Float32)
& ((Fin 20) => Float32))
& (((Fin 20) => (Fin 10) => Float32) & ((Fin 10) => Float32)))))
(all_params', losses') = trainClass model small_ims small_labels Epochs Minibatch Minibatches
:p losses'
[18824.857, 100.18297, 84.9009, 81.476204, 79.5075]