Dex: NN Layers


Neural Networks

import plot

NN Prelude

def relu (input : Float) : Float = select (input > 0.0) input 0.0
instance [Add a, Add b] Add (a & b) add = \(a, b) (c, d). ( (a + c), (b + d)) sub = \(a, b) (c, d). ( (a - c), (b - d)) zero = (zero, zero)
instance [VSpace a, VSpace b] VSpace (a & b) scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b)
data Layer inp:Type out:Type params:Type = AsLayer {forward:(params -> inp -> out) & init:(Key -> params)}
def forward (l:Layer i o p) (p : p) (x : i): o = (AsLayer l' ) = l (getAt #forward l') p x
def init (l:Layer i o p) (k:Key) : p = (AsLayer l') = l (getAt #init l') k

Layers

Dense layer

def DenseParams (a:Type) (b:Type) : Type = ((a=>b=>Float) & (b=>Float))
def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) = AsLayer { forward = (\ ((weight, bias)) x . for j. (bias.j + sum for i. weight.i.j * x.i)), init = arb }

CNN layer

def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type = ((outc=>inc=>Fin kh=>Fin kw=>Float) & (outc=>Float))
def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float) (kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) : outc=>(Fin h)=>(Fin w)=>Float = for o i j. (i', j') = (ordinal i, ordinal j) case (i' + kh) <= h && (j' + kw) <= w of True -> sum for (ki, kj, inp). (di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)), fromOrdinal (Fin w) (j' + (ordinal kj))) x.inp.di.dj * kernel.o.inp.ki.kj False -> zero
def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Layer (inc=>(Fin h)=>(Fin w)=>Float) (outc=>(Fin h)=>(Fin w)=>Float) (CNNParams inc outc kw kh) = AsLayer { forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o), init = arb }

Pooling

def split (x: m=>v) : n=>o=>v = for i j. x.((ordinal (i,j))@m)
def imtile (x: a=>b=>v) : n=>o=>p=>q=>v = for kw kh w h. (split (split x).w.kw).h.kh
def meanpool (kh: Type) (kw: Type) (x : m=>n=> Float) : ( h=>w=> Float) = out : (kh => kw => h => w => Float) = imtile x mean for (i,j). out.i.j

Simple point classifier

[k1, k2] = splitKey $ newKey 1
x1 : Fin 100 => Float = arb k1
x2 : Fin 100 => Float = arb k2
y = for i. case ((x1.i > 0.0) && (x2.i > 0.0)) || ((x1.i < 0.0) && (x2.i < 0.0)) of True -> 1 False -> 0
xs = for i. [x1.i, x2.i]
:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i
simple = \h1. ndense1 = dense (Fin 2) h1 ndense2 = dense h1 (Fin 2) AsLayer { forward = (\ (dense1, dense2) x. x1' = forward ndense1 dense1 x x1 = for i. relu x1'.i logsoftmax $ forward ndense2 dense2 x1), init = (\key. [k1, k2] = splitKey key (init ndense1 k1, init ndense2 k2)) }
:t simple
((h1:Type) -> Layer ((Fin 2) => Float32) ((Fin 2) => Float32) ( (((Fin 2) => h1 => Float32) & (h1 => Float32)) & ((h1 => (Fin 2) => Float32) & ((Fin 2) => Float32))))

Train a multiclass classifier with minibatch SGD ' minibatch * minibatches = batch

def trainClass [VSpace p] (model: Layer a (b=>Float) p) (x: batch=>a) (y: batch=>b) (epochs : Type) (minibatch : Type) (minibatches : Type) : (epochs => p & epochs => Float ) = xs : minibatches => minibatch => a = split x ys : minibatches => minibatch => b = split y unzip $ withState (init model $ newKey 0) $ \params . for _ : epochs. loss = sum $ for b : minibatches. (loss, gradfn) = vjp (\ params. -sum for j. result = forward model params xs.b.j result.(ys.b.j)) (get params) gparams = gradfn 1.0 params := (get params) - scaleVec (0.05 / (IToF 100)) gparams loss (get params, loss)
-- todo : Do I have to give minibatches as a param?
simple_model = simple (Fin 10)
(all_params,losses) = trainClass simple_model xs (for i. (y.i @ (Fin 2))) (Fin 500) (Fin 100) (Fin 1)
span = linspace (Fin 10) (-1.0) (1.0)
tests = for h : (Fin 50). for i . for j. r = forward simple_model all_params.((ordinal h * 10)@_) [span.i, span.j] [exp r.(1@_), exp r.(0@_), 0.0]
:html imseqshow tests

LeNet for image classification

H = 28
W = 28
Image = Fin 1 => Fin H => Fin W => Float
Class = Fin 10
lenet = \h1 h2 h3 . ncnn1 = cnn (Fin 1) h1 3 3 ncnn2 = cnn h1 h2 3 3 Pooled = (h2 & Fin 7 & Fin 7) ndense1 = dense Pooled h3 ndense2 = dense h3 Class AsLayer { forward = (\ (cnn1, cnn2, dense1, dense2) inp. x:Image = inp x1' = forward ncnn1 cnn1 x x1 = for i j k. relu x1'.i.j.k x2' = forward ncnn2 cnn2 x1 x2 = for i j k. relu x2'.i.j.k x3 : (h2 => Fin 7 => Fin 7 => Float) = for c. meanpool (Fin 4) (Fin 4) x2.c x4' = forward ndense1 dense1 for (i,j,k). x3.i.j.k x4 = for i. relu x4'.i logsoftmax $ forward ndense2 dense2 x4), init = (\key. [k1, k2, k3, k4] = splitKey key (init ncnn1 k1, init ncnn2 k2, init ndense1 k3, init ndense2 k4)) }
:t lenet
((h1:Type) -> (h2:Type) -> (h3:Type) -> Layer ((Fin 1) => (Fin 28) => (Fin 28) => Float32) ((Fin 10) => Float32) ( ((h1 => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & (h1 => Float32)) & ( ((h2 => h1 => (Fin 3) => (Fin 3) => Float32) & (h2 => Float32)) & ( (((h2 & (Fin 7 & Fin 7)) => h3 => Float32) & (h3 => Float32)) & ((h3 => (Fin 10) => Float32) & ((Fin 10) => Float32))))))

Data Loading

Batch = Fin 5000
Full = Fin ((size Batch) * H * W)
def pixel (x:Char) : Float32 = r = W8ToI x IToF case r < 0 of True -> (abs r) + 128 False -> r
def getIm : Batch => Image = (AsList _ im) = unsafeIO do readFile "examples/mnist.bin" raw = unsafeCastTable Full im for b: Batch c: (Fin 1) i:(Fin W) j:(Fin H). pixel raw.((ordinal (b, i, j)) @ Full)
def getLabel : Batch => Class = (AsList _ im2) = unsafeIO do readFile "examples/labels.bin" r = unsafeCastTable Batch im2 for i. (W8ToI r.i @ Class)

Training loop

Get binary files from:

wget https://github.com/srush/learns-dex/raw/main/mnist.bin

wget https://github.com/srush/learns-dex/raw/main/labels.bin

Comment out these lines

ims = getIm
labels = getLabel
small_ims = for i: (Fin 10). ims.((ordinal i)@_)
small_labels = for i: (Fin 10). labels.((ordinal i)@_)
:p small_labels
[ (5@Fin 10) , (0@Fin 10) , (4@Fin 10) , (1@Fin 10) , (9@Fin 10) , (2@Fin 10) , (1@Fin 10) , (3@Fin 10) , (1@Fin 10) , (4@Fin 10) ]
Epochs = (Fin 5)
Minibatches = (Fin 1)
Minibatch = (Fin 10)
:t ims.(2@_)
((Fin 1) => (Fin 28) => (Fin 28) => Float32)
model = lenet (Fin 1) (Fin 1) (Fin 20)
init_param = (init model $ newKey 0)
:p forward model init_param (ims.(2@Batch))
[ -6193.269 , -2458.7317 , -1936.6416 , -3916.4082 , -4915.504 , 0. , -2748.0789 , -3704.304 , -3521.0889 , -1654.1918 ]

Sanity check

:t (grad ((\x param. sum (forward model param x)) (ims.(2@_)))) init_param
( (((Fin 1) => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & ((Fin 1) => Float32)) & ( ( ((Fin 1) => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & ((Fin 1) => Float32)) & ( ( ((Fin 1 & (Fin 7 & Fin 7)) => (Fin 20) => Float32) & ((Fin 20) => Float32)) & (((Fin 20) => (Fin 10) => Float32) & ((Fin 10) => Float32)))))
(all_params', losses') = trainClass model small_ims small_labels Epochs Minibatch Minibatches
:p losses'
[18824.857, 100.18297, 84.9009, 81.476204, 79.5075]