branch: master
beautiful_cifar.py
8164 bytesRaw
import time
start_tm = time.perf_counter()
import math
from typing import Tuple, cast
import numpy as np
from tinygrad import Tensor, nn, GlobalCounters, TinyJit, dtypes
from tinygrad.helpers import partition, trange, getenv, Context
from extra.lr_scheduler import OneCycleLR

dtypes.default_float = dtypes.half

# from https://github.com/tysam-code/hlb-CIFAR10/blob/main/main.py
batchsize = getenv("BS", 1024)
bias_scaler = 64
hyp = {
  'opt': {
    'bias_lr':        1.525 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
    'non_bias_lr':    1.525 / 512,
    'bias_decay':     6.687e-4 * batchsize/bias_scaler,
    'non_bias_decay': 6.687e-4 * batchsize,
    'scaling_factor': 1./9,
    'percent_start': .23,
    'loss_scale_scaler': 1./32, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
  },
  'net': {
    'whitening': {
      'kernel_size': 2,
      'num_examples': 50000,
    },
    'batch_norm_momentum': .4, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
    'cutmix_size': 3,
    'cutmix_epochs': 6,
    'pad_amount': 2,
    'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
  },
  'misc': {
    'ema': {
      'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
      'decay_base': .95,
      'decay_pow': 3.,
      'every_n_steps': 5,
    },
    'train_epochs': 12,
    #'train_epochs': 12.1,
    'device': 'cuda',
    'data_location': 'data.pt',
  }
}

scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
depths = {
  'init':   round(scaler**-1*hyp['net']['base_depth']), # 32  w/ scaler at base value
  'block1': round(scaler** 0*hyp['net']['base_depth']), # 64  w/ scaler at base value
  'block2': round(scaler** 2*hyp['net']['base_depth']), # 256 w/ scaler at base value
  'block3': round(scaler** 3*hyp['net']['base_depth']), # 512 w/ scaler at base value
  'num_classes': 10
}
whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2

class ConvGroup:
  def __init__(self, channels_in, channels_out):
    self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1, bias=False)
    self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1, bias=False)
    self.norm1 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
    self.norm2 = nn.BatchNorm(channels_out, track_running_stats=False, eps=1e-12, momentum=hyp['net']['batch_norm_momentum'])
    cast(Tensor, self.norm1.weight).requires_grad = False
    cast(Tensor, self.norm2.weight).requires_grad = False
  def __call__(self, x:Tensor) -> Tensor:
    x =    self.norm1(self.conv1(x).max_pool2d().float()).cast(dtypes.default_float).quick_gelu()
    return self.norm2(self.conv2(x).float()).cast(dtypes.default_float).quick_gelu()

class SpeedyConvNet:
  def __init__(self):
    self.whiten = nn.Conv2d(3, 2*whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0, bias=False)
    self.conv_group_1 = ConvGroup(2*whiten_conv_depth, depths['block1'])
    self.conv_group_2 = ConvGroup(depths['block1'], depths['block2'])
    self.conv_group_3 = ConvGroup(depths['block2'], depths['block3'])
    self.linear = nn.Linear(depths['block3'], depths['num_classes'], bias=False)
  def __call__(self, x:Tensor) -> Tensor:
    x = self.whiten(x).quick_gelu()
    x = x.sequential([self.conv_group_1, self.conv_group_2, self.conv_group_3])
    return self.linear(x.max(axis=(2,3))) * hyp['opt']['scaling_factor']

if __name__ == "__main__":
  # *** dataset ***
  X_train, Y_train, X_test, Y_test = nn.datasets.cifar()
  # TODO: without this line indexing doesn't fuse!
  X_train, Y_train, X_test, Y_test = [x.contiguous() for x in [X_train, Y_train, X_test, Y_test]]
  cifar10_std, cifar10_mean = X_train.float().std_mean(axis=(0, 2, 3))
  def preprocess(X:Tensor, Y:Tensor) -> Tuple[Tensor, Tensor]:
    return ((X - cifar10_mean.view(1, -1, 1, 1)) / cifar10_std.view(1, -1, 1, 1)).cast(dtypes.default_float), Y.one_hot(depths['num_classes'])

  # *** model ***
  model = SpeedyConvNet()
  state_dict = nn.state.get_state_dict(model)

  #for k,v in nn.state.torch_load("/tmp/cifar_net.pt").items(): print(k)

  params_bias, params_non_bias = partition(state_dict.items(), lambda x: 'bias' in x[0])
  opt_bias     = nn.optim.SGD([x[1] for x in params_bias],     lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['bias_decay'])
  opt_non_bias = nn.optim.SGD([x[1] for x in params_non_bias], lr=0.01, momentum=.85, nesterov=True, weight_decay=hyp['opt']['non_bias_decay'])
  opt = nn.optim.OptimizerGroup(opt_bias, opt_non_bias)

  num_steps_per_epoch      = X_train.size(0) // batchsize
  total_train_steps        = math.ceil(num_steps_per_epoch * hyp['misc']['train_epochs'])
  loss_batchsize_scaler = 512/batchsize

  pct_start = hyp['opt']['percent_start']
  initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
  final_lr_ratio = .07 # Actually pretty important, apparently!
  lr_sched_bias     = OneCycleLR(opt_bias,     max_lr=hyp['opt']['bias_lr'],     pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)
  lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps)

  def loss_fn(out, Y):
    return out.cross_entropy(Y, reduction='none', label_smoothing=0.2).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])

  @TinyJit
  @Tensor.train()
  def train_step(idxs:Tensor) -> Tensor:
    with Context(SPLIT_REDUCEOP=0, FUSE_ARANGE=1):
      X = X_train[idxs]
      Y = Y_train[idxs].realize(X)
    X, Y = preprocess(X, Y)
    out = model(X)
    loss = loss_fn(out, Y)
    opt.zero_grad()
    loss.backward()
    opt.step()
    lr_sched_bias.step()
    lr_sched_non_bias.step()
    return loss / (batchsize*loss_batchsize_scaler)

  eval_batchsize = 2500
  @TinyJit
  @Tensor.test()
  def val_step() -> Tuple[Tensor, Tensor]:
    # TODO with Tensor.no_grad()
    Tensor.no_grad = True
    loss, acc = [], []
    for i in range(0, X_test.size(0), eval_batchsize):
      X, Y = preprocess(X_test[i:i+eval_batchsize], Y_test[i:i+eval_batchsize])
      out = model(X)
      loss.append(loss_fn(out, Y))
      acc.append((out.argmax(-1).one_hot(depths['num_classes']) * Y).sum() / eval_batchsize)
    ret = Tensor.stack(*loss).mean() / (batchsize*loss_batchsize_scaler), Tensor.stack(*acc).mean()
    Tensor.no_grad = False
    return ret

  np.random.seed(1337)
  for epoch in range(math.ceil(hyp['misc']['train_epochs'])):
    # TODO: move to tinygrad
    gst = time.perf_counter()
    idxs = np.arange(X_train.shape[0])
    np.random.shuffle(idxs)
    tidxs = Tensor(idxs, dtype='int')[:num_steps_per_epoch*batchsize].reshape(num_steps_per_epoch, batchsize)  # NOTE: long doesn't fold
    train_loss:float = 0
    for epoch_step in (t:=trange(num_steps_per_epoch)):
      st = time.perf_counter()
      GlobalCounters.reset()
      loss = train_step(tidxs[epoch_step].contiguous()).float().item()
      t.set_description(f"*** loss: {loss:5.3f}   lr: {opt_non_bias.lr.item():.6f}"
                         f"   tm: {(et:=(time.perf_counter()-st))*1000:6.2f} ms {GlobalCounters.global_ops/(1e9*et):7.0f} GFLOPS")
      train_loss += loss
    gmt = time.perf_counter()
    GlobalCounters.reset()
    val_loss, acc = [x.float().item() for x in val_step()]
    get = time.perf_counter()
    print(f"\033[F*** epoch {epoch:3d}       tm: {(gmt-gst):5.2f} s    val_tm: {(get-gmt):5.2f} s   train_loss: {train_loss/num_steps_per_epoch:5.3f}   val_loss: {val_loss:5.3f}   eval acc: {acc*100:5.2f}%  @ {get-start_tm:6.2f} s  ")