branch: master
testconv.py
2034 bytesRaw
#!/usr/bin/env python3
import time
from ane import ANE, ANETensor

def benchmark(ane):
  tin = ANETensor(512*0x20)
  tout = ANETensor(512*0x20)
  dat = open("../ops/gemm.hwx", "rb").read()
  for k,v in ane.debug(dat[0x4000:0x4300], 16).items():
    print(k,v)
  comp = ane.compile(dat)

  st = time.time()
  for i in range(1000):
    ret = ane.run(comp, tin, tout)
  et = time.time()
  ts = (et-st)
  ops = 1000*512*512*2

  print("%.2f ms, %.2f gigaops/sec" % (ts*1000, ops*1e-9/ts))


if __name__ == "__main__":
  ane = ANE()

  # 0x20 per row
  tin = ANETensor(0x60)
  tout = ANETensor(0x60)
  tw = ANETensor(0x60)

  tind = tin.data()
  toutd = tout.data()
  twd = tw.data()

  #tind[0:4] = [-1,1,-2,2]
  tind[0] =  1
  tind[0x20] = -2
  tind[0x40] = 3

  # toutd[0] = \
  #   tind[0] * twd[0] + \
  #   tind[0x20] + twd[1] + \
  #   tind[0x40] + twd[2]

  twd[0] = 4
  twd[1] = 0x100

  twd[0x20] = 5
  twd[0x21] = 5
  twd[0x22] = 5

  twd[0x40] = 12

  print("** before **")
  print(tind)
  print(toutd)

  #benchmark(ane)
  #exit(0)

  """
  dat = list(open("../ops/sum.hwx", "rb").read())
  dat = bytes(dat)
  for k,v in ane.debug(dat[0x4000:0x4300], 16).items():
    print(k,v)
  comp = ane.compile(dat)
  ret = ane.run(comp, tin, tout, tw)
  """

  datb = open("../ops/sum.hwx", "rb").read()
  dat = open("../ops/conv.hwx", "rb").read()
  dd = ane.unpack(dat[0x4000:0x4300])
  # use the 3rd arg as the weights
  dd["aneTD.Header[9].KBase0"] = 6
  dd["aneRegs.NE.PostScale.PostScale"] = 0x3c00
  #dd["aneRegs.L2.L2Cfg.InputReLU"] = 1
  #dd["aneRegs.NE.MACCfg.NonlinearMode"] = 1
  #dd["aneRegs.TileDMADst.Fmt.MemFmt"] = 0
  #dd["aneRegs.L2.ResultBase.Addr"] = 0
  #dd["aneRegs.Common.ChCfg.InFmt"] = 1
  #dd["aneRegs.TileDMADst.Fmt.ZeroPadFirst"] = 0
  #dd["aneRegs.TileDMADst.DMAConfig.En"] = 0
  for k,v in dd.items():
    print(k,v)
  dat = datb[:0x4000] + ane.pack(dd, dat[0x4000:0x4300]) + datb[0x4300:]
  comp = ane.compile(dat)
  ret = ane.run(comp, tin, tout, tw)

  print("** after **")
  print(tind)
  print(toutd)