Welcome to Tez’s documentation!

Tez is a simple pytorch trainer to make your life easy. It comes with some useful dataset classes and callbacks.

Instead of inheriting from nn.Module, we inherit from tez.Model.

import tez

class MyModel(tez.Model):
   def __init__(self):
      super().__init__()
      # do something here

   def forward(self, arg1, arg2):
      # do something here
      return outputs, loss, metrics

In Tez, the dataset class and model’s forward function are closely related. The output names from dataset class must be same as the input arguments in forward function of the model.

See an example below:

import tez

class MyModel(tez.Model):
   def __init__(self):
      super().__init__()
      .
      .
      # tell when to step the scheduler
      self.step_scheduler_after="batch"

   def monitor_metrics(self, outputs, targets):
      outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
      targets = targets.cpu().detach().numpy()
      accuracy = metrics.accuracy_score(targets, outputs)
      return {"accuracy": accuracy}

   def fetch_scheduler(self):
      # create your own scheduler

   def fetch_optimizer(self):
      # create your own optimizer

   def forward(self, ids, mask, token_type_ids, targets):
      _, o_2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
      b_o = self.bert_drop(o_2)
      output = self.out(b_o)

      # calculate loss here
      loss = nn.BCEWithLogitsLoss()(output, targets)

      # calculate the metric dictionary here
      metric_dict = self.monitor_metrics(output, targets)
      return output, loss, metric_dict