branch: master
bert.py
14726 bytesRaw
import re, os
from pathlib import Path
from tinygrad.tensor import Tensor, cast
from tinygrad import nn, dtypes
from tinygrad.helpers import fetch, get_child
from tinygrad.nn.state import get_parameters

# allow for monkeypatching
Embedding = nn.Embedding
Linear = nn.Linear
LayerNorm = nn.LayerNorm

class BertForQuestionAnswering:
  def __init__(self, hidden_size=1024, intermediate_size=4096, max_position_embeddings=512, num_attention_heads=16, num_hidden_layers=24, type_vocab_size=2, vocab_size=30522, attention_probs_dropout_prob=0.1, hidden_dropout_prob=0.1):
    self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
    self.qa_outputs = Linear(hidden_size, 2)

  def load_from_pretrained(self):
    fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt"
    fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn)
    fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt"
    fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab)

    import torch
    with open(fn, "rb") as f:
      state_dict = torch.load(f, map_location="cpu")

    for k, v in state_dict.items():
      if "dropout" in k: continue # skip dropout
      if "pooler" in k: continue # skip pooler
      get_child(self, k).assign(v.numpy()).realize()

  def __call__(self, input_ids:Tensor, attention_mask:Tensor, token_type_ids:Tensor):
    sequence_output = self.bert(input_ids, attention_mask, token_type_ids)
    logits = self.qa_outputs(sequence_output)
    start_logits, end_logits = logits.chunk(2, dim=-1)
    start_logits = start_logits.reshape(-1, 1)
    end_logits = end_logits.reshape(-1, 1)

    return Tensor.stack(start_logits, end_logits)

class BertForPretraining:
  def __init__(self, hidden_size:int=1024, intermediate_size:int=4096, max_position_embeddings:int=512, num_attention_heads:int=16, num_hidden_layers:int=24, type_vocab_size:int=2, vocab_size:int=30522, attention_probs_dropout_prob:float=0.1, hidden_dropout_prob:float=0.1):
    """Default is BERT-large"""
    self.bert = Bert(hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob)
    self.cls = BertPreTrainingHeads(hidden_size, vocab_size, self.bert.embeddings.word_embeddings.weight)

  def __call__(self, input_ids:Tensor, attention_mask:Tensor, masked_lm_positions:Tensor, token_type_ids:Tensor):
    output = self.bert(input_ids, attention_mask, token_type_ids)
    return self.cls(output, masked_lm_positions)

  # Reference has residual on denominator: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L315
  def sparse_categorical_crossentropy(self, predictions:Tensor, labels:Tensor, ignore_index=-1):
    log_probs, loss_mask = predictions.log_softmax(dtype=dtypes.float), (labels != ignore_index)
    y_counter = Tensor.arange(predictions.shape[-1], requires_grad=False, device=predictions.device).unsqueeze(0).expand(labels.numel(), predictions.shape[-1])
    y = ((y_counter == labels.flatten().reshape(-1, 1)) * loss_mask.reshape(-1, 1)).reshape(*labels.shape, predictions.shape[-1])
    return -((log_probs * y).sum()) / (loss_mask.sum() + 1e-5) # Small constant to avoid division by zero

  def loss(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
    masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)
    next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)
    return masked_lm_loss + next_sentence_loss

  def accuracy(self, prediction_logits:Tensor, seq_relationship_logits:Tensor, masked_lm_ids:Tensor, masked_lm_weights:Tensor, next_sentence_labels:Tensor):
    valid = masked_lm_ids != 0
    masked_lm_predictions = prediction_logits.argmax(-1)
    masked_lm_correct = (masked_lm_predictions == masked_lm_ids) * valid
    masked_lm_loss = self.sparse_categorical_crossentropy(prediction_logits, masked_lm_ids, ignore_index=masked_lm_weights)

    seq_relationship_predictions = seq_relationship_logits.argmax(-1)
    seq_relationship_correct = (seq_relationship_predictions == next_sentence_labels)
    next_sentence_loss = seq_relationship_logits.binary_crossentropy_logits(next_sentence_labels)

    # TODO: is it okay that next_sentence_loss is half here?
    return masked_lm_correct.sum().float() / valid.sum(), seq_relationship_correct.mean(), masked_lm_loss, next_sentence_loss.float()

  def load_from_pretrained(self, tf_weight_path:str=Path(__file__).parent.parent / "datasets" / "wiki"):
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Mute tf flag info
    # load from tensorflow
    import tensorflow as tf
    import numpy as np

    state_dict = {}
    for name, _ in tf.train.list_variables(str(tf_weight_path)):
      state_dict[name] = tf.train.load_variable(str(tf_weight_path), name)

    for k, v in state_dict.items():
      m = k.split("/")
      if any(n in ["adam_v", "adam_m", "global_step", "LAMB", "LAMB_1", "beta1_power", "beta2_power"] for n in m):
        continue

      pointer = self
      n = m[-1] # this is just to stop python from complaining about possibly unbound local variable
      for i, n in enumerate(m):
        if re.fullmatch(r'[A-Za-z]+_\d+', n):
          l = re.split(r'_(\d+)', n)[:-1]
        else:
          l = [n]
        if l[0] in ["kernel", "gamma", "output_weights"]:
          pointer = getattr(pointer, "weight")
        elif l[0] in ["output_bias", "beta"]:
          pointer = getattr(pointer, "bias")
        elif l[0] == "pooler":
          pointer = getattr(getattr(self, "cls"), "pooler")
        else:
          pointer = getattr(pointer, l[0])
        if len(l) == 2: # layers
          pointer = pointer[int(l[1])]
      if n[-11:] == "_embeddings":
        pointer = getattr(pointer, "weight")
      elif n == "kernel":
        v = np.transpose(v)
      cast(Tensor, pointer).assign(v).realize()

    params = get_parameters(self)
    count = 0
    for p in params:
      param_count = 1
      for s in p.shape:
        param_count *= s
      count += param_count
    print(f"Total parameters: {count / 1000 / 1000}M")
    return self

class BertPreTrainingHeads:
  def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
    self.predictions = BertLMPredictionHead(hidden_size, vocab_size, embeddings_weight)
    self.pooler = BertPooler(hidden_size)
    self.seq_relationship = Linear(hidden_size, 2)

  def __call__(self, sequence_output:Tensor, masked_lm_positions:Tensor):
    prediction_logits = self.predictions(gather(sequence_output, masked_lm_positions))
    seq_relationship_logits = self.seq_relationship(self.pooler(sequence_output))
    return prediction_logits, seq_relationship_logits

class BertLMPredictionHead:
  def __init__(self, hidden_size:int, vocab_size:int, embeddings_weight:Tensor):
    self.transform = BertPredictionHeadTransform(hidden_size)
    self.embedding_weight = embeddings_weight
    self.bias = Tensor.zeros(vocab_size, dtype=dtypes.float32)

  def __call__(self, hidden_states:Tensor):
    return self.transform(hidden_states) @ self.embedding_weight.T + self.bias

class BertPredictionHeadTransform:
  def __init__(self, hidden_size:int):
    self.dense = Linear(hidden_size, hidden_size)
    self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)

  def __call__(self, hidden_states:Tensor):
    return self.LayerNorm(gelu(self.dense(hidden_states)))

class BertPooler:
  def __init__(self, hidden_size:int):
    self.dense = Linear(hidden_size, hidden_size)

  def __call__(self, hidden_states:Tensor):
    return self.dense(hidden_states[:, 0]).tanh()

def gather(prediction_logits:Tensor, masked_lm_positions:Tensor):
  counter = Tensor.arange(prediction_logits.shape[1], device=prediction_logits.device, requires_grad=False).reshape(1, 1, prediction_logits.shape[1]).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
  onehot = counter == masked_lm_positions.unsqueeze(2).expand(*masked_lm_positions.shape, prediction_logits.shape[1])
  return onehot @ prediction_logits

class Bert:
  def __init__(self, hidden_size, intermediate_size, max_position_embeddings, num_attention_heads, num_hidden_layers, type_vocab_size, vocab_size, attention_probs_dropout_prob, hidden_dropout_prob):
    self.embeddings = BertEmbeddings(hidden_size, max_position_embeddings, type_vocab_size, vocab_size, hidden_dropout_prob)
    self.encoder = BertEncoder(hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob)

  def __call__(self, input_ids, attention_mask, token_type_ids):
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    embedding_output = self.embeddings(input_ids, token_type_ids)
    encoder_outputs = self.encoder(embedding_output, extended_attention_mask)

    return encoder_outputs

class BertEmbeddings:
  def __init__(self, hidden_size, max_position_embeddings, type_vocab_size, vocab_size,  hidden_dropout_prob):
    self.word_embeddings = Embedding(vocab_size, hidden_size)
    self.position_embeddings = Embedding(max_position_embeddings, hidden_size)
    self.token_type_embeddings = Embedding(type_vocab_size, hidden_size)
    self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
    self.dropout = hidden_dropout_prob

  def __call__(self, input_ids, token_type_ids):
    input_shape = input_ids.shape
    seq_length = input_shape[1]

    position_ids = Tensor.arange(seq_length, requires_grad=False, device=input_ids.device).unsqueeze(0).expand(*input_shape)
    words_embeddings = self.word_embeddings(input_ids)
    position_embeddings = self.position_embeddings(position_ids)
    token_type_embeddings = self.token_type_embeddings(token_type_ids)

    embeddings = words_embeddings + position_embeddings + token_type_embeddings
    embeddings = self.LayerNorm(embeddings)
    embeddings = embeddings.dropout(self.dropout)
    return embeddings

class BertEncoder:
  def __init__(self, hidden_size, intermediate_size, num_attention_heads, num_hidden_layers, attention_probs_dropout_prob, hidden_dropout_prob):
    self.layer = [BertLayer(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) for _ in range(num_hidden_layers)]

  def __call__(self, hidden_states, attention_mask):
    for layer in self.layer:
      hidden_states = layer(hidden_states, attention_mask)
    return hidden_states

class BertLayer:
  def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
    self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
    self.intermediate = BertIntermediate(hidden_size, intermediate_size)
    self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)

  def __call__(self, hidden_states, attention_mask):
    attention_output = self.attention(hidden_states, attention_mask)
    intermediate_output = self.intermediate(attention_output)
    layer_output = self.output(intermediate_output, attention_output)
    return layer_output

class BertOutput:
  def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
    self.dense = Linear(intermediate_size, hidden_size)
    self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
    self.dropout = hidden_dropout_prob

  def __call__(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = hidden_states.dropout(self.dropout)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states

def gelu(x):
  return x * 0.5 * (1.0 + (x / 1.41421).erf())

class BertIntermediate:
  def __init__(self, hidden_size, intermediate_size):
    self.dense = Linear(hidden_size, intermediate_size)

  def __call__(self, hidden_states):
    x = self.dense(hidden_states)
    # tinygrad gelu is openai gelu but we need the original bert gelu
    return gelu(x)

class BertAttention:
  def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
    self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
    self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)

  def __call__(self, hidden_states, attention_mask):
    self_output = self.self(hidden_states, attention_mask)
    attention_output = self.output(self_output, hidden_states)
    return attention_output

class BertSelfAttention:
  def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
    self.num_attention_heads = num_attention_heads
    self.attention_head_size = int(hidden_size / num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size

    self.query = Linear(hidden_size, self.all_head_size)
    self.key = Linear(hidden_size, self.all_head_size)
    self.value = Linear(hidden_size, self.all_head_size)

    self.dropout = attention_probs_dropout_prob

  def __call__(self, hidden_states, attention_mask):
    mixed_query_layer = self.query(hidden_states)
    mixed_key_layer = self.key(hidden_states)
    mixed_value_layer = self.value(hidden_states)

    query_layer = self.transpose_for_scores(mixed_query_layer)
    key_layer = self.transpose_for_scores(mixed_key_layer)
    value_layer = self.transpose_for_scores(mixed_value_layer)

    context_layer = Tensor.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, self.dropout)

    context_layer = context_layer.transpose(1, 2)
    context_layer = context_layer.reshape(context_layer.shape[0], context_layer.shape[1], self.all_head_size)

    return context_layer

  def transpose_for_scores(self, x):
    x = x.reshape(x.shape[0], x.shape[1], self.num_attention_heads, self.attention_head_size)
    return x.transpose(1, 2)

class BertSelfOutput:
  def __init__(self, hidden_size, hidden_dropout_prob):
    self.dense = Linear(hidden_size, hidden_size)
    self.LayerNorm = LayerNorm(hidden_size, eps=1e-12)
    self.dropout = hidden_dropout_prob

  def __call__(self, hidden_states, input_tensor):
    hidden_states = self.dense(hidden_states)
    hidden_states = hidden_states.dropout(self.dropout)
    hidden_states = self.LayerNorm(hidden_states + input_tensor)
    return hidden_states