Dex: NN Layers
Neural Networks
import plot
NN Prelude
def relu (input : Float) : Float =
  select (input > 0.0) input 0.0
instance [Add a, Add b] Add (a & b)
  add = \(a, b) (c, d). ( (a + c), (b + d))
  sub = \(a, b) (c, d). ( (a - c), (b - d))
  zero = (zero, zero)
instance [VSpace a, VSpace b] VSpace (a & b)
  scaleVec = \ s (a, b) . (scaleVec s a, scaleVec s b)
data Layer inp:Type out:Type params:Type =
  AsLayer {forward:(params -> inp -> out) & init:(Key -> params)}
def forward (l:Layer i o p) (p : p) (x : i): o =
  (AsLayer l' ) = l
  (getAt #forward l') p x
def init (l:Layer i o p) (k:Key)  : p  =
  (AsLayer l') = l
  (getAt #init l') k
Layers
Dense layer
def DenseParams (a:Type) (b:Type) : Type =
   ((a=>b=>Float) & (b=>Float))
def dense (a:Type) (b:Type) : Layer (a=>Float) (b=>Float) (DenseParams a b) =
  AsLayer {
    forward = (\ ((weight, bias)) x .
               for j. (bias.j + sum for i. weight.i.j * x.i)),
    init = arb
   }
CNN layer
def CNNParams (inc:Type) (outc:Type) (kw:Int) (kh:Int) : Type =
  ((outc=>inc=>Fin kh=>Fin kw=>Float) &
   (outc=>Float))
def conv2d (x:inc=>(Fin h)=>(Fin w)=>Float)
           (kernel:outc=>inc=>(Fin kh)=>(Fin kw)=>Float) :
     outc=>(Fin h)=>(Fin w)=>Float =
     for o i j.
         (i', j') = (ordinal i, ordinal j)
         case (i' + kh) <= h && (j' + kw) <= w of
          True ->
              sum for (ki, kj, inp).
                  (di, dj) = (fromOrdinal (Fin h) (i' + (ordinal ki)),
                              fromOrdinal (Fin w) (j' + (ordinal kj)))
                  x.inp.di.dj * kernel.o.inp.ki.kj
          False -> zero
def cnn (h:Int) ?-> (w:Int) ?-> (inc:Type) (outc:Type) (kw:Int) (kh:Int) :
    Layer (inc=>(Fin h)=>(Fin w)=>Float)
          (outc=>(Fin h)=>(Fin w)=>Float)
          (CNNParams inc outc kw kh) =
  AsLayer {
    forward = (\ (weight, bias) x. for o i j . (conv2d x weight).o.i.j + bias.o),
    init = arb
  }
Pooling
def split (x: m=>v) : n=>o=>v =
    for i j. x.((ordinal (i,j))@m)
def imtile (x: a=>b=>v) : n=>o=>p=>q=>v =
    for kw kh w h. (split (split x).w.kw).h.kh
def meanpool (kh: Type) (kw: Type) (x : m=>n=> Float) : ( h=>w=> Float) =
    out : (kh => kw => h => w => Float) = imtile x
    mean for (i,j). out.i.j
Simple point classifier
[k1, k2] = splitKey $ newKey 1
x1 : Fin 100 => Float = arb k1
x2 : Fin 100 => Float = arb k2
y = for i. case ((x1.i > 0.0) && (x2.i > 0.0)) || ((x1.i < 0.0) && (x2.i < 0.0)) of
  True -> 1
  False -> 0
xs = for i. [x1.i, x2.i]
:html showPlot $ xycPlot x1 x2 $ for i. IToF y.i
simple = \h1.
  ndense1 = dense (Fin 2) h1
  ndense2 = dense h1 (Fin 2)
  AsLayer {
    forward = (\ (dense1, dense2) x.
         x1' = forward ndense1 dense1 x
         x1 = for i. relu x1'.i
         logsoftmax $ forward ndense2 dense2 x1),
    init = (\key.
         [k1, k2] = splitKey key
         (init ndense1 k1, init ndense2 k2))
  }
:t simple
((h1:Type)
 -> Layer
      ((Fin 2) => Float32)
      ((Fin 2) => Float32)
      ( (((Fin 2) => h1 => Float32) & (h1 => Float32))
      & ((h1 => (Fin 2) => Float32) & ((Fin 2) => Float32))))
Train a multiclass classifier with minibatch SGD
' minibatch * minibatches = batch
def trainClass [VSpace p] (model: Layer a (b=>Float) p)
                           (x: batch=>a)
                           (y: batch=>b)
                           (epochs : Type)
                           (minibatch : Type)
                           (minibatches : Type) :
    (epochs => p & epochs => Float ) =
  xs : minibatches => minibatch => a = split x
  ys : minibatches => minibatch => b = split y
  unzip $ withState (init model $ newKey 0) $ \params .
     for _ : epochs.
       loss = sum $ for b : minibatches. 
              (loss, gradfn) =  vjp (\ params.
                            -sum for j.
                                       result = forward model params xs.b.j
                                       result.(ys.b.j)) (get params)
              gparams = gradfn 1.0
              params := (get params) - scaleVec (0.05 / (IToF 100)) gparams
              loss
       (get params, loss)
-- todo : Do I have to give minibatches as a param?
simple_model = simple (Fin 10)
(all_params,losses) = trainClass simple_model xs (for i. (y.i @ (Fin 2))) (Fin 500) (Fin 100) (Fin 1)
span = linspace (Fin 10) (-1.0) (1.0)
tests = for h : (Fin 50). for i . for j.
        r = forward simple_model all_params.((ordinal h * 10)@_) [span.i, span.j]
        [exp r.(1@_), exp r.(0@_), 0.0]
:html imseqshow tests
LeNet for image classification
H = 28
W = 28
Image = Fin 1 => Fin H => Fin W => Float 
Class = Fin 10
lenet = \h1 h2 h3 .
  ncnn1 = cnn (Fin 1) h1 3 3
  ncnn2 = cnn h1 h2 3 3
  Pooled = (h2 & Fin 7 & Fin 7)
  ndense1 = dense Pooled h3
  ndense2 = dense h3 Class
  AsLayer {
    forward = (\ (cnn1, cnn2, dense1, dense2) inp.
         x:Image = inp
         x1' = forward ncnn1 cnn1 x
         x1 = for i j k. relu x1'.i.j.k
         x2' = forward ncnn2 cnn2 x1
         x2 = for i j k. relu x2'.i.j.k
         x3 : (h2 => Fin 7 => Fin 7 => Float) = for c. meanpool (Fin 4) (Fin 4) x2.c
         x4' = forward ndense1 dense1 for (i,j,k). x3.i.j.k     
         x4 = for i. relu x4'.i
         logsoftmax $ forward ndense2 dense2 x4),
    init = (\key.
         [k1, k2, k3, k4] = splitKey key
         (init ncnn1 k1, init ncnn2 k2,
         init ndense1 k3, init ndense2 k4))
  }
:t lenet
((h1:Type)
 -> (h2:Type)
 -> (h3:Type)
 -> Layer
      ((Fin 1) => (Fin 28) => (Fin 28) => Float32)
      ((Fin 10) => Float32)
      ( ((h1 => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & (h1 => Float32))
      & ( ((h2 => h1 => (Fin 3) => (Fin 3) => Float32) & (h2 => Float32))
        & ( (((h2 & (Fin 7 & Fin 7)) => h3 => Float32) & (h3 => Float32))
          & ((h3 => (Fin 10) => Float32) & ((Fin 10) => Float32))))))
Data Loading
Batch = Fin 5000
Full = Fin ((size Batch) * H * W)
def pixel (x:Char) : Float32 =
     r = W8ToI x
     IToF case r < 0 of
             True -> (abs r) + 128
             False -> r
def getIm : Batch => Image = 
    (AsList _ im) = unsafeIO do readFile "examples/mnist.bin"
    raw = unsafeCastTable Full im
    for b: Batch  c: (Fin 1) i:(Fin W) j:(Fin H).
        pixel raw.((ordinal (b, i, j)) @ Full)
def getLabel : Batch => Class =
    (AsList _ im2) = unsafeIO do readFile "examples/labels.bin"
    r = unsafeCastTable Batch im2
    for i. (W8ToI r.i @ Class)
Training loop
Get binary files from:
wget https://github.com/srush/learns-dex/raw/main/mnist.bin
wget https://github.com/srush/learns-dex/raw/main/labels.bin
Comment out these lines
ims = getIm
labels = getLabel
small_ims = for i: (Fin 10). ims.((ordinal i)@_)
small_labels = for i: (Fin 10). labels.((ordinal i)@_)
:p small_labels
[ (5@Fin 10)
, (0@Fin 10)
, (4@Fin 10)
, (1@Fin 10)
, (9@Fin 10)
, (2@Fin 10)
, (1@Fin 10)
, (3@Fin 10)
, (1@Fin 10)
, (4@Fin 10) ]
Epochs = (Fin 5)
Minibatches = (Fin 1)
Minibatch = (Fin 10)
:t ims.(2@_)
((Fin 1) => (Fin 28) => (Fin 28) => Float32)
model = lenet (Fin 1) (Fin 1) (Fin 20) 
init_param = (init model  $ newKey 0)
:p forward model init_param (ims.(2@Batch))
[ -6193.269
, -2458.7317
, -1936.6416
, -3916.4082
, -4915.504
, 0.
, -2748.0789
, -3704.304
, -3521.0889
, -1654.1918 ]
Sanity check
:t (grad ((\x param. sum (forward model param x)) (ims.(2@_)))) init_param
( (((Fin 1) => (Fin 1) => (Fin 3) => (Fin 3) => Float32) & ((Fin 1) => Float32))
& ( ( ((Fin 1) => (Fin 1) => (Fin 3) => (Fin 3) => Float32)
    & ((Fin 1) => Float32))
  & ( ( ((Fin 1 & (Fin 7 & Fin 7)) => (Fin 20) => Float32)
      & ((Fin 20) => Float32))
    & (((Fin 20) => (Fin 10) => Float32) & ((Fin 10) => Float32)))))
(all_params', losses') = trainClass model small_ims small_labels Epochs Minibatch Minibatches
:p losses'
[18824.857, 100.18297, 84.9009, 81.476204, 79.5075]
